diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..00f7280af --- /dev/null +++ b/.dockerignore @@ -0,0 +1,106 @@ +# Git files +.git +.gitignore +.gitattributes +.github + +# Documentation (we only need README.md) +docs/ +INSTALL_GUIDE.md +CONTRIBUTING.md +LICENSE +*.md +!README.md + +# Development files +.vscode +.idea +*.swp +*.swo +*~ +.editorconfig + +# Python cache and build artifacts +__pycache__ +*.py[cod] +*$py.class +*.so +.Python +*.egg-info +dist +build +.pytest_cache +.coverage +htmlcov +*.egg +MANIFEST +.mypy_cache +.ruff_cache +.tox +.nox +*.cover +.hypothesis + +# Node (we'll install and build in Docker) +chainforge/react-server/node_modules +chainforge/react-server/.pnp +chainforge/react-server/.pnp.js +chainforge/react-server/coverage +chainforge/react-server/build + +# Testing and development +tests/ +test/ +*.test.js +*.test.ts +*.test.tsx +*.spec.js +*.spec.ts +*.spec.tsx +coverage/ +.coverage +*.log +*.logs + +# Environment files +.env +.env.* +!.env.example +*.local + +# ChainForge specific +chainforge/cache +chainforge/examples/oaievals/ +chainforge_assets/ +packages/ +jobs/ +data/ + +# Docker files (avoid recursive copying) +Dockerfile* +docker-compose*.yml +.dockerignore + +# CI/CD +.circleci +.travis.yml +.gitlab-ci.yml +azure-pipelines.yml + +# IDE and editor files +*.sublime-* +.vscode +.idea +*.iml + +# OS files +.DS_Store +Thumbs.db +Desktop.ini + +# Temporary files +*.tmp +*.temp +*.bak +*.swp +*~ diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml new file mode 100644 index 000000000..aa5077788 --- /dev/null +++ b/.github/workflows/docker-build.yml @@ -0,0 +1,303 @@ +name: Build Docker Images + +on: + pull_request: + branches: + - main + - master + - ragforge + push: + branches: + - main + - master + - ragforge + tags: + - 'v*' + +# Concurrency: cancel in-progress runs when new one starts +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + REGISTRY: docker.io + IMAGE_NAME: ${{ secrets.DOCKER_USERNAME }}/chainforge + +jobs: + build-cpu-amd64: + name: Build CPU AMD64 Image + runs-on: ubuntu-latest + # Skip if commit message contains [skip ci] or [ci skip] + if: ${{ !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[ci skip]') }} + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Build CPU AMD64 image + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile + push: false + platforms: linux/amd64 + provenance: false + sbom: false + outputs: type=image,name=${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=${{ github.event_name != 'pull_request' }} + id: build-amd64 + + - name: Export digest + if: github.event_name != 'pull_request' + run: | + mkdir -p /tmp/digests + digest="${{ steps.build-amd64.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + if: github.event_name != 'pull_request' + uses: actions/upload-artifact@v4 + with: + name: digests-cpu-amd64 + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 + + build-cpu-arm64: + name: Build CPU ARM64 Image + runs-on: ubuntu-latest + # Skip if commit message contains [skip ci] or [ci skip] + if: ${{ !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[ci skip]') }} + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Build CPU ARM64 image + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile + push: false + platforms: linux/arm64 + provenance: false + sbom: false + outputs: type=image,name=${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=${{ github.event_name != 'pull_request' }} + id: build-arm64 + + - name: Export digest + if: github.event_name != 'pull_request' + run: | + mkdir -p /tmp/digests + digest="${{ steps.build-arm64.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + if: github.event_name != 'pull_request' + uses: actions/upload-artifact@v4 + with: + name: digests-cpu-arm64 + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 + + build-gpu-amd64: + name: Build GPU AMD64 Image + runs-on: ubuntu-latest + # Skip if commit message contains [skip ci] or [ci skip] + if: ${{ !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[ci skip]') }} + permissions: + contents: read + packages: write + + steps: + - name: Free up disk space + run: | + echo "Before cleanup:" + df -h + + # Remove unnecessary software to free up space + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + sudo rm -rf /usr/local/share/boost + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + + # Remove large packages + sudo apt-get remove -y '^aspnetcore-.*' || true + sudo apt-get remove -y '^dotnet-.*' --fix-missing || true + sudo apt-get remove -y '^llvm-.*' --fix-missing || true + sudo apt-get remove -y 'php.*' --fix-missing || true + sudo apt-get remove -y '^mongodb-.*' --fix-missing || true + sudo apt-get remove -y '^mysql-.*' --fix-missing || true + sudo apt-get remove -y azure-cli google-chrome-stable firefox powershell mono-devel libgl1-mesa-dri --fix-missing || true + sudo apt-get remove -y google-cloud-sdk --fix-missing || true + sudo apt-get remove -y google-cloud-cli --fix-missing || true + + # Clean up + sudo apt-get autoremove -y + sudo apt-get clean + sudo docker image prune --all --force + sudo rm -rf /var/lib/apt/lists/* + + echo "After cleanup:" + df -h + + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Build GPU AMD64 image + uses: docker/build-push-action@v5 + with: + context: . + file: ./Dockerfile.gpu + push: false + platforms: linux/amd64 + provenance: false + sbom: false + outputs: type=image,name=${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=${{ github.event_name != 'pull_request' }} + id: build-gpu + + - name: Export digest + if: github.event_name != 'pull_request' + run: | + mkdir -p /tmp/digests + digest="${{ steps.build-gpu.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + if: github.event_name != 'pull_request' + uses: actions/upload-artifact@v4 + with: + name: digests-gpu-amd64 + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 + + + create-cpu-manifest: + name: Create CPU Multi-Arch Manifest + runs-on: ubuntu-latest + needs: [build-cpu-amd64, build-cpu-arm64] + if: ${{ github.event_name != 'pull_request' && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[ci skip]') }} + permissions: + contents: read + packages: write + + steps: + - name: Download digests + uses: actions/download-artifact@v4 + with: + path: /tmp/digests + pattern: digests-cpu-* + merge-multiple: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,value=cpu + type=raw,value=latest + + - name: Create and push multi-arch manifest + working-directory: /tmp/digests + run: | + docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ + $(printf '${{ env.IMAGE_NAME }}@sha256:%s ' *) + env: + DOCKER_METADATA_OUTPUT_JSON: ${{ steps.meta.outputs.json }} + + create-gpu-manifest: + name: Create GPU Manifest + runs-on: ubuntu-latest + needs: [build-gpu-amd64] + if: ${{ github.event_name != 'pull_request' && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[ci skip]') }} + permissions: + contents: read + packages: write + + steps: + - name: Download digests + uses: actions/download-artifact@v4 + with: + path: /tmp/digests + pattern: digests-gpu-* + merge-multiple: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_PASSWORD }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch,suffix=-gpu + type=semver,pattern={{version}},suffix=-gpu + type=semver,pattern={{major}}.{{minor}},suffix=-gpu + type=raw,value=gpu + + - name: Create and push GPU manifest + working-directory: /tmp/digests + run: | + docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ + $(printf '${{ env.IMAGE_NAME }}@sha256:%s ' *) + env: + DOCKER_METADATA_OUTPUT_JSON: ${{ steps.meta.outputs.json }} diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 806719cda..da32a8ccb 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -12,8 +12,8 @@ jobs: strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest, macos-13, macos-14] - python-version: [3.11, 3.12, 3.13] + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.10.11", 3.11, 3.12] fail-fast: false steps: @@ -29,17 +29,24 @@ jobs: uses: actions/cache@v4 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('chainforge/requirements.txt') }} + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('chainforge/requirements.txt', 'chainforge/constraints.txt', 'setup.py') }} restore-keys: | ${{ runner.os }}-pip-${{ matrix.python-version }}- + ${{ runner.os }}-pip- - name: Upgrade pip run: python -m pip install --upgrade pip - - name: Install dependencies + - name: Install package + run: | + pip install -e . + + - name: Install with RAG extras + run: | + pip install -e .[rag] + + - name: Install pytest run: | - pip install -r chainforge/requirements.txt - pip install . pip install pytest # TODO: Check that the server runs diff --git a/.gitignore b/.gitignore index 907b83455..631eca426 100644 --- a/.gitignore +++ b/.gitignore @@ -185,3 +185,4 @@ pyrightconfig.json # End of https://www.toptal.com/developers/gitignore/api/python chainforge_assets/ .vscode/ +**/bun.lock \ No newline at end of file diff --git a/DOCKER.md b/DOCKER.md new file mode 100644 index 000000000..c031eb7c4 --- /dev/null +++ b/DOCKER.md @@ -0,0 +1,371 @@ +# ChainForge Docker Setup + +This document describes how to run ChainForge using Docker. ChainForge provides two Docker image variants: + +- **CPU (latest)**: Multi-architecture (AMD64 + ARM64) optimized for CPU-only environments +- **GPU**: AMD64-only with CUDA support for GPU acceleration + +## Prerequisites + +- Docker (version 20.10 or later) +- Docker Compose (version 2.0 or later) +- For GPU: NVIDIA Docker runtime (nvidia-docker2) + +## Architecture Support + +**CPU Images**: Available for both AMD64 and ARM64 architectures. Docker automatically pulls the correct architecture for your system. + +**GPU Images**: AMD64 only (ARM64 GPU support is limited and CUDA packages are very large) + +## Quick Start + +### CPU Version (Default) + +Using Docker Compose (Recommended): + +```bash +docker-compose up -d +``` + +Or using Docker CLI: + +```bash +# Pull pre-built image +docker pull gauransh/chainforge:latest + +# Or build locally +docker build -t chainforge:cpu . + +# Run +docker run -d \ + -p 8000:8000 \ + -v chainforge-data:/home/chainforge/.local/share/chainforge \ + --name chainforge \ + --restart unless-stopped \ + gauransh/chainforge:latest +``` + +### GPU Version + +Using Docker Compose (Recommended): + +```bash +docker-compose -f docker-compose.gpu.yml up -d +``` + +Or using Docker CLI: + +```bash +# Pull pre-built image +docker pull gauransh/chainforge:gpu + +# Or build locally +docker build -f Dockerfile.gpu -t chainforge:gpu . + +# Run with GPU support +docker run -d \ + -p 8000:8000 \ + -v chainforge-data:/home/chainforge/.local/share/chainforge \ + --name chainforge-gpu \ + --gpus all \ + --restart unless-stopped \ + gauransh/chainforge:gpu +``` + +Access ChainForge at: http://localhost:8000 + +## Image Variants + +### CPU (latest) +- **Architectures**: AMD64, ARM64 (multi-arch manifest) +- **Size**: Optimized and minimal (~800MB compressed) +- **Dependencies**: Uses `constraints.txt` for version pinning +- **PyTorch**: CPU-only build from https://download.pytorch.org/whl/cpu +- **Tags**: `latest`, `cpu`, branch names (e.g., `main`, `ragforge`), version tags (e.g., `v1.0`, `1.0`) + +### GPU +- **Architecture**: AMD64 only +- **Size**: Larger due to CUDA support (~2-3GB compressed) +- **Dependencies**: Uses `constraints.txt` for version pinning +- **PyTorch**: CUDA 12.1 build from https://download.pytorch.org/whl/cu121 +- **Tags**: `gpu`, branch names with `-gpu` suffix (e.g., `main-gpu`, `ragforge-gpu`), version tags with `-gpu` suffix (e.g., `v1.0-gpu`, `1.0-gpu`) +- **Requirements**: NVIDIA GPU with CUDA support, nvidia-docker runtime + +**When to use GPU variant:** +- You have NVIDIA GPUs available (AMD64 architecture) +- You need GPU acceleration for ML/AI workloads +- You're running compute-intensive models locally + +**When to use CPU variant:** +- Running on CPU-only machines +- Using ARM64 devices (Apple Silicon, Raspberry Pi, etc.) +- Deploying to cloud platforms without GPU +- Smaller image size is preferred + +## Managing Containers + +### View logs + +```bash +# CPU version +docker-compose logs -f + +# GPU version +docker-compose -f docker-compose.gpu.yml logs -f +``` + +### Stop containers + +```bash +# CPU version +docker-compose down + +# GPU version +docker-compose -f docker-compose.gpu.yml down +``` + +### Rebuild images + +```bash +# CPU version +docker-compose build --no-cache + +# GPU version +docker-compose -f docker-compose.gpu.yml build --no-cache +``` + +### Remove volumes (delete all data) + +```bash +docker-compose down -v +``` + +## Environment Variables + +You can set API keys and other environment variables by: + +1. Creating a `.env` file in the project root +2. Adding your variables (they're already templated in docker-compose.yml): + +```env +OPENAI_API_KEY=your_key_here +ANTHROPIC_API_KEY=your_key_here +COHERE_API_KEY=your_key_here +GOOGLE_API_KEY=your_key_here +DEEPSEEK_API_KEY=your_key_here +HUGGINGFACE_API_KEY=your_key_here +``` + +3. Uncommenting the relevant lines in `docker-compose.yml` + +## Docker Image Structure + +Multi-stage build optimized for minimal size and fast builds: + +### Stage 1: Frontend Builder +- Starts from `node:20-slim` +- Installs npm dependencies and builds React frontend +- Cleans up node_modules and npm cache after build +- Only the `/build` directory is copied to final image + +### Stage 2: Python Builder +- Starts from `python:3.12-slim` +- Installs build dependencies (build-essential, git) +- Installs PyTorch (CPU or CUDA variant) +- Installs Python dependencies from `requirements.txt` with `constraints.txt` +- Installs ChainForge package +- Purges build tools and cleans caches + +### Stage 3: Runtime Image +- Minimal `python:3.12-slim` base +- Only runtime dependencies: git, libgomp1 +- Copies Python packages from builder +- Copies React build from frontend builder +- Aggressive cleanup of unnecessary files (tests, docs, cache, static libs) +- Runs as non-root user (chainforge, uid 1000) + +**Optimizations:** +- Multi-stage build keeps final image small +- No build tools in runtime image +- All RUN commands combined into single layers to minimize layer count +- Aggressive file cleanup reduces image size by ~30% + +## Automated Builds (CI/CD) + +Docker images are automatically built and pushed to Docker Hub via GitHub Actions. + +### Build Triggers + +Builds run on: +- **Pull Requests**: Builds are tested but NOT pushed to Docker Hub +- **Branch Pushes**: `main`, `master`, `ragforge` - Built and pushed with branch-specific tags +- **Version Tags**: `v*` (e.g., `v1.0.0`) - Built and pushed with version tags + +### Build Architecture + +The workflow uses a **parallel multi-architecture build strategy**: + +1. **3 Parallel Build Jobs**: + - `build-cpu-amd64`: Builds CPU variant for AMD64 + - `build-cpu-arm64`: Builds CPU variant for ARM64 + - `build-gpu-amd64`: Builds GPU variant for AMD64 (with aggressive disk cleanup) + +2. **Manifest Creation Jobs**: + - `create-cpu-manifest`: Combines AMD64 + ARM64 into multi-arch CPU manifest + - `create-gpu-manifest`: Creates GPU manifest (AMD64 only) + +**Why separate builds?** +- **No QEMU emulation** - Native builds are 3-5x faster than cross-compilation +- **Parallel execution** - All architectures build simultaneously +- **Better caching** - Each architecture has independent build cache + +### Available Tags + +**CPU Images** (multi-arch: amd64 + arm64): +``` +gauransh/chainforge:latest +gauransh/chainforge:cpu +gauransh/chainforge:main +gauransh/chainforge:ragforge +gauransh/chainforge:v1.0.0 +gauransh/chainforge:1.0 +``` + +**GPU Images** (amd64 only): +``` +gauransh/chainforge:gpu +gauransh/chainforge:main-gpu +gauransh/chainforge:ragforge-gpu +gauransh/chainforge:v1.0.0-gpu +gauransh/chainforge:1.0-gpu +``` + +### Storage Optimizations + +To prevent hitting GitHub Actions storage limits: + +1. **No GitHub Actions cache** - Eliminated cache storage overhead +2. **No ARM64 GPU builds** - GPU only builds for AMD64 (50% storage reduction) +3. **Disabled provenance and SBOM** - Reduces metadata size +4. **Aggressive runner cleanup** - GPU builds free ~30-40GB before starting by removing: + - .NET SDK, Android tools, GHC, CodeQL + - LLVM, PHP, MongoDB, MySQL + - Chrome, Firefox, Azure CLI, Google Cloud SDK +5. **Minimal Dockerfile** - All operations combined into single layers + +### Concurrency Control + +The workflow uses concurrency groups to automatically cancel in-progress builds when a new commit is pushed to the same branch: + +```yaml +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true +``` + +### Skip CI + +To skip Docker builds entirely, add `[skip ci]` or `[ci skip]` to your commit message: + +```bash +git commit -m "Update documentation [skip ci]" +git commit -m "Fix typo [ci skip]" +``` + +This prevents all build jobs from running, saving time and resources. + +### GitHub Secrets Required + +To enable automated image publishing, configure these secrets: + +- `DOCKER_USERNAME` - Docker Hub username +- `DOCKER_PASSWORD` - Docker Hub password or personal access token + +**Setup**: Repository Settings → Secrets and variables → Actions → New repository secret + +### Build Workflow Summary + +``` +┌─────────────────────────────────────────────────────────┐ +│ PR or Push Event │ +└─────────────────────────────────────────────────────────┘ + │ + ┌───────────┴───────────┐ + │ Concurrency Check │ + │ (Cancel old builds?) │ + └───────────┬───────────┘ + │ + ┌───────────────────┼───────────────────┐ + │ │ │ + ▼ ▼ ▼ +┌──────────────┐ ┌──────────────┐ ┌──────────────┐ +│ CPU AMD64 │ │ CPU ARM64 │ │ GPU AMD64 │ +│ Build │ │ Build │ │ Build │ +│ │ │ │ │ (w/ cleanup) │ +└──────┬───────┘ └──────┬───────┘ └──────┬───────┘ + │ │ │ + │ upload digest │ upload digest │ upload digest + │ │ │ + └────────┬─────────┴──────────────────┘ + │ + ┌─────────┴──────────┐ + │ │ + ▼ ▼ +┌─────────────┐ ┌─────────────┐ +│ CPU │ │ GPU │ +│ Manifest │ │ Manifest │ +│ (amd64+arm64)│ │ (amd64 only)│ +└─────────────┘ └─────────────┘ + │ │ + └─────────┬──────────┘ + │ + ▼ + Push to Docker Hub +``` + +## Troubleshooting + +### ESLint Config Error + +If you see "ESLint couldn't find the config 'semistandard'": +- This is fixed in the current Dockerfile by installing devDependencies +- Rebuild the image: `docker-compose build --no-cache` + +### Port Already in Use + +If port 8000 is already in use: +- Change the port mapping in `docker-compose.yml` +- Example: `"8080:8000"` to use port 8080 on your host + +### Permission Issues + +If you encounter permission errors: +- The image runs as a non-root user (uid 1000) +- Ensure your volume permissions match this user +- You can adjust the UID in the Dockerfile if needed + +## Performance Tips + +- The multi-stage build creates an optimized image +- The image is optimized to be under 1GB +- Unnecessary files and caches are cleaned during build +- Consider allocating more memory to Docker if builds are slow + +## Pushing to Registry + +To push the image to a registry: + +```bash +# Tag the image +docker tag chainforge:latest your-registry/chainforge:tag + +# Push to registry +docker push your-registry/chainforge:tag +``` + +Or let docker-compose handle it: + +```bash +docker-compose build +docker-compose push +``` diff --git a/Dockerfile b/Dockerfile index 89e993283..2a69f30c7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,84 @@ -FROM python:3.12-slim AS builder +# Multi-stage build: Stage 1 - Build React frontend +FROM node:20-slim AS frontend-builder -RUN pip install --upgrade pip -RUN pip install chainforge --no-cache-dir +WORKDIR /app + +# Copy package files and install dependencies (including dev for build) +COPY chainforge/react-server/package*.json ./ +RUN npm ci --legacy-peer-deps --prefer-offline + +# Copy source files and build +COPY chainforge/react-server/ ./ +RUN npm run build + +# Stage 2 - Build Python dependencies (CPU version with constraints) +FROM python:3.12-slim AS python-builder + +# Install only the build dependencies we need +RUN apt-get --allow-releaseinfo-change update && \ + apt-get install -y --no-install-recommends \ + build-essential \ + git \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /build + +# Upgrade pip tools first (this layer is highly cacheable) +RUN pip install --no-cache-dir --upgrade pip setuptools wheel + +# Copy requirements first for better layer caching +COPY chainforge/requirements.txt chainforge/constraints.txt ./ + +# Install PyTorch CPU-only FIRST as it's the largest dependency +# This separates the longest-running install into its own layer +RUN pip install --no-cache-dir --prefix=/install \ + --extra-index-url https://download.pytorch.org/whl/cpu \ + torch torchvision torchaudio + +# Install remaining requirements with constraints +# Using --find-links to help pip resolve faster +RUN pip install --no-cache-dir --prefix=/install \ + -r requirements.txt \ + -c constraints.txt + +# Copy project files and build the package (smallest layer last) +COPY setup.py README.md ./ +COPY chainforge/ ./chainforge/ +RUN pip install --no-cache-dir --prefix=/install . + +# Stage 3 - Final minimal runtime image +FROM python:3.12-slim + +# Install only runtime dependencies (no build tools) +RUN apt-get --allow-releaseinfo-change update && \ + apt-get install -y --no-install-recommends \ + git \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean WORKDIR /chainforge +# Copy Python packages from builder +COPY --from=python-builder /install /usr/local + +# Copy the built React app from the frontend-builder stage to the installed package location +COPY --from=frontend-builder /app/build /usr/local/lib/python3.12/site-packages/chainforge/react-server/build + +# Clean up any unnecessary files to reduce image size +RUN find /usr/local -type d -name "tests" -exec rm -rf {} + 2>/dev/null || true && \ + find /usr/local -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true && \ + find /usr/local -name "*.pyc" -delete && \ + find /usr/local -name "*.pyo" -delete && \ + find /usr/local -name "*.md" -delete 2>/dev/null || true + +# Run as non-root user for security +RUN useradd -m -u 1000 chainforge && \ + mkdir -p /home/chainforge/.local/share/chainforge && \ + chown -R chainforge:chainforge /chainforge /home/chainforge + +USER chainforge + EXPOSE 8000 + ENTRYPOINT [ "chainforge", "serve", "--host", "0.0.0.0" ] diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 000000000..26a1687c6 --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,71 @@ +# Multi-stage build: Stage 1 - Build React frontend +FROM node:20-slim AS frontend-builder + +WORKDIR /app + +# Copy everything and build in one layer to minimize size +COPY chainforge/react-server/ ./ +RUN npm ci --legacy-peer-deps --prefer-offline --production=false && \ + npm run build && \ + rm -rf node_modules && \ + npm cache clean --force + +# Stage 2 - Build Python dependencies (GPU version with CUDA support) +FROM python:3.12-slim AS python-builder + +WORKDIR /build + +# Copy all requirements and source files +COPY chainforge/requirements.txt chainforge/constraints.txt setup.py README.md ./ +COPY chainforge/ ./chainforge/ + +# Do everything in one layer to minimize build context and layers +RUN apt-get --allow-releaseinfo-change update && \ + apt-get install -y --no-install-recommends build-essential git && \ + pip install --no-cache-dir --upgrade pip setuptools wheel && \ + pip install --no-cache-dir --prefix=/install \ + --extra-index-url https://download.pytorch.org/whl/cu121 \ + torch torchvision torchaudio && \ + pip install --no-cache-dir --prefix=/install \ + -r requirements.txt \ + -c constraints.txt && \ + pip install --no-cache-dir --prefix=/install . && \ + apt-get purge -y --auto-remove build-essential && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* /root/.cache + +# Stage 3 - Final minimal runtime image +FROM python:3.12-slim + +WORKDIR /chainforge + +# Copy Python packages and React build from builder stages +COPY --from=python-builder /install /usr/local +COPY --from=frontend-builder /app/build /usr/local/lib/python3.12/site-packages/chainforge/react-server/build + +# Install runtime deps, clean up, and create user in one layer +RUN apt-get --allow-releaseinfo-change update && \ + apt-get install -y --no-install-recommends git libgomp1 && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean && \ + find /usr/local -type d -name "tests" -exec rm -rf {} + 2>/dev/null || true && \ + find /usr/local -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true && \ + find /usr/local -type d -name "*.dist-info" -exec rm -rf {}/RECORD {} + 2>/dev/null || true && \ + find /usr/local -name "*.pyc" -delete && \ + find /usr/local -name "*.pyo" -delete && \ + find /usr/local -name "*.a" -delete && \ + find /usr/local -name "*.md" -delete 2>/dev/null || true && \ + find /usr/local -name "*.txt" -delete 2>/dev/null || true && \ + rm -rf /usr/local/lib/python3.12/site-packages/*/tests && \ + rm -rf /usr/local/lib/python3.12/site-packages/*/test && \ + rm -rf /tmp/* /var/tmp/* && \ + useradd -m -u 1000 chainforge && \ + mkdir -p /home/chainforge/.local/share/chainforge && \ + chown -R chainforge:chainforge /chainforge /home/chainforge + +USER chainforge + +EXPOSE 8000 + +ENTRYPOINT [ "chainforge", "serve", "--host", "0.0.0.0" ] + diff --git a/README.md b/README.md index c78c5a42f..51237252c 100644 --- a/README.md +++ b/README.md @@ -49,20 +49,14 @@ Open [localhost:8000](http://localhost:8000/) in a Google Chrome, Firefox, Micro You can set your API keys by clicking the Settings icon in the top-right corner. If you prefer to not worry about this everytime you open ChainForge, we **highly recommend** that save your OpenAI, Anthropic, Google, etc API keys and/or Amazon AWS credentials to your local environment. For more details, see the [How to Install](https://chainforge.ai/docs/getting_started/). ## Run using Docker +**Quick start with Docker Compose (recommended):** -You can use our [Dockerfile](/Dockerfile) to run `ChainForge` locally using `Docker Desktop`: - -- Build the `Dockerfile`: - ```shell - docker build -t chainforge . - ``` - -- Run the image: - ```shell - docker run -p 8000:8000 chainforge - ``` +```bash +docker compose up -d +``` +Access ChainForge at http://localhost:8000 -Now you can open the browser of your choice and open `http://127.0.0.1:8000`. +For detailed Docker documentation including architecture support, environment variables, and CI/CD setup, see [DOCKER.md](DOCKER.md). # Supported providers diff --git a/chainforge/constraints.txt b/chainforge/constraints.txt new file mode 100644 index 000000000..379510e80 --- /dev/null +++ b/chainforge/constraints.txt @@ -0,0 +1,4 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch +torchvision +torchaudio \ No newline at end of file diff --git a/chainforge/examples/chainforge-docs.cfzip b/chainforge/examples/chainforge-docs.cfzip new file mode 100644 index 000000000..1fcac88ca Binary files /dev/null and b/chainforge/examples/chainforge-docs.cfzip differ diff --git a/chainforge/examples/custom_provider_cohere.py b/chainforge/examples/custom_provider_cohere.py index 26b03d005..de1bc3451 100644 --- a/chainforge/examples/custom_provider_cohere.py +++ b/chainforge/examples/custom_provider_cohere.py @@ -45,7 +45,8 @@ # Our custom model provider for Cohere's text generation API. @provider(name="Cohere", - emoji="🖇", + emoji="🖇", + category="model", models=['command', 'command-nightly', 'command-light', 'command-light-nightly'], rate_limit="sequential", # enter "sequential" for blocking; an integer N > 0 means N is the max mumber of requests per minute. settings_schema=COHERE_SETTINGS_SCHEMA) diff --git a/chainforge/examples/evaluate-rag-pipeline.cfzip b/chainforge/examples/evaluate-rag-pipeline.cfzip new file mode 100644 index 000000000..4db9857a2 Binary files /dev/null and b/chainforge/examples/evaluate-rag-pipeline.cfzip differ diff --git a/chainforge/examples/levenshtein_retriever.py b/chainforge/examples/levenshtein_retriever.py new file mode 100644 index 000000000..8472a678a --- /dev/null +++ b/chainforge/examples/levenshtein_retriever.py @@ -0,0 +1,91 @@ +from typing import List, Dict, Any, Union +from chainforge.providers import provider + +@provider( + name="Levenshtein Retriever", + emoji="🔢", + models=[], + rate_limit="sequential", + settings_schema={ + "settings": { + "top_k": {"type": "integer", "title": "Top K", "default": 5, "minimum": 1} + }, + "ui": { + "top_k": {"ui:widget": "range"}, + } + }, + category="retriever" +) +def LevenshteinRetriever( + chunks: List[Dict[str, Any]], + queries: List[Union[str, Dict[str, Any]]], + settings: Dict[str, Any] +) -> List[Dict[str, Any]]: + """Return top-K chunks per query using plain Levenshtein distance.""" + + def lev(a: str, b: str) -> int: + m, n = len(a), len(b) + dp = [[0]*(n+1) for _ in range(m+1)] + for i in range(m+1): dp[i][0] = i + for j in range(n+1): dp[0][j] = j + for i in range(1, m+1): + ai = a[i-1] + for j in range(1, n+1): + cost = 0 if ai == b[j-1] else 1 + dp[i][j] = min( + dp[i-1][j] + 1, # deletion + dp[i][j-1] + 1, # insertion + dp[i-1][j-1] + cost # substitution + ) + return dp[m][n] + + # 1) Coerce & clamp Top-K + try: + top_k = int(settings.get("top_k", 5)) + except (TypeError, ValueError): + top_k = 5 + if top_k < 1: + top_k = 1 + + results: List[Dict[str, Any]] = [] + + for q in queries: + # normalize to dict with a "text" field, preserving extra keys + if isinstance(q, dict): + prompt = q.get("text") or q.get("query") or "" + query_obj = {**q, "text": prompt} + else: + prompt = str(q) + query_obj = {"text": prompt} + + q_low = prompt.lower() + + # score each chunk + scored: List[tuple] = [] + for chunk in chunks: + text = chunk.get("text", "") + dist = lev(q_low, text.lower()) + scored.append((chunk, dist)) + + # 2) Stable sort: primary by distance, secondary by chunkId (optional) + scored.sort(key=lambda x: (x[1], str(x[0].get("chunkId", "")))) + top = scored[:top_k] + + retrieved = [] + for _, (chunk, dist) in enumerate(top, start=1): + max_len = max(len(prompt), len(chunk.get("text", "")), 1) + sim = 1 - dist / max_len + retrieved.append({ + "text": chunk.get("text", ""), + "similarity": sim, + "docTitle": chunk.get("docTitle", ""), + "chunkId": chunk.get("chunkId", ""), + "chunkLibrary": chunk.get("chunkLibrary", "") + }) + + results.append({ + "query_object": query_obj, + "retrieved_chunks": retrieved + }) + + return results diff --git a/chainforge/examples/paragraph_chunker.py b/chainforge/examples/paragraph_chunker.py new file mode 100644 index 000000000..5634750a7 --- /dev/null +++ b/chainforge/examples/paragraph_chunker.py @@ -0,0 +1,29 @@ +import re +from chainforge.providers import provider +from typing import List + +@provider( + name="Paragraph Chunker", + emoji="¶", + models=[], # no LLM models needed + rate_limit="sequential", # chunkers run sequentially + settings_schema={}, # no extra settings + category="chunker" +) +def ParagraphChunker( + text: str +) -> List[str]: + """ + Splits the input text into paragraphs at blank lines. + + Args: + text: the full text to chunk. + + Returns: + A list of non‑empty paragraph strings. + """ + # normalize line endings + text = text.replace("\r\n", "\n") + # split on one-or-more blank lines + paras = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()] + return paras diff --git a/chainforge/examples/rag-example-flow3.cfzip b/chainforge/examples/rag-example-flow3.cfzip new file mode 100644 index 000000000..1d81c2ba2 Binary files /dev/null and b/chainforge/examples/rag-example-flow3.cfzip differ diff --git a/chainforge/flask_app.py b/chainforge/flask_app.py index dd58d11b3..2c0f2a4e3 100644 --- a/chainforge/flask_app.py +++ b/chainforge/flask_app.py @@ -5,15 +5,40 @@ from statistics import mean, median, stdev from datetime import datetime from flask import Flask, request, jsonify, render_template, send_from_directory, send_file, after_this_request -from flask_cors import CORS +from flask_cors import CORS, cross_origin from chainforge.providers import ProviderRegistry from chainforge.security.password_utils import ensure_password from chainforge.security.secure_save import load_json_file, save_json_file import requests as py_requests from platformdirs import user_data_dir +import copy +from collections import defaultdict + +""" ======================================================== + DETECT RAGFORGE AVAILABILITY AND IMPORT RAGFORGE MODULES + ======================================================== +""" +def IS_RAG_AVAILABLE(): + from importlib.util import find_spec + try: + packages = ["pyarrow", "lancedb", "sentence_transformers", "chonkie", "rank_bm25", "numpy", "nltk"] + for package in packages: + if find_spec(package) is None: + return False + print("RAGForge dependencies detected. Enabling RAGForge features...") + return True + except ImportError: + print("You are running ChainForge core. RAGForge dependencies were not detected; hence, RAG features will be disabled.") + return False +RAG_AVAILABLE = IS_RAG_AVAILABLE() # RAG-specific imports -from markitdown import MarkItDown +if RAG_AVAILABLE: + from chainforge.rag.chunkers import ChunkingMethodRegistry + from chainforge.rag.retrievers import RetrievalMethodRegistry + from chainforge.rag.rerankers import RerankingMethodRegistry, rrf_fuse, weighted_avg_fuse + from chainforge.rag.embeddings import EmbeddingMethodRegistry + from markitdown import MarkItDown """ ================= @@ -43,6 +68,9 @@ SECURE_MODE: Literal['off', 'settings', 'all'] = 'off' # The mode of encryption to use for files FLOWS_DIR_PWD = None # The password to use for encryption/decryption +# GLOBAL STATE: Stores progress for the current retrieval operation +RETRIEVAL_PROGRESS = {} + class MetricType(Enum): KeyValue = 0 KeyValue_Numeric = 1 @@ -271,10 +299,12 @@ def exclude_key(d, key_to_exclude): def index(): # Get the index.html HTML code html_str = render_template("index.html") + + # RAG available flag + rag_av = "true" if RAG_AVAILABLE else "false" - # Inject global JS variables __CF_HOSTNAME and __CF_PORT at the top so that the application knows - # that it's running from a Flask server, and what the hostname and port of that server is: - html_str = html_str[:60] + f'' + html_str[60:] + # Inject global JS variables like __CF_HOSTNAME and __CF_PORT at the top so that the application knows that it's running from a Flask server, and what the hostname and port of that server is: + html_str = html_str[:60] + f'' + html_str[60:] return html_str @@ -401,6 +431,11 @@ def fetchExampleFlow(): ret.headers.add('Access-Control-Allow-Origin', '*') return ret +@app.get("/examples/") +@cross_origin() +def serve_cfzip(filename: str): + return send_from_directory(EXAMPLES_DIR, filename, mimetype="application/zip") + @app.route('/app/fetchOpenAIEval', methods=['POST']) def fetchOpenAIEval(): @@ -494,6 +529,23 @@ def fetchEnvironAPIKeys(): return ret +@app.route('/app/checkRagAvailable', methods=['POST']) +def checkRagAvailable(): + """ + Check if RAG dependencies are available. + Returns True if all required RAG packages are installed, False otherwise. + """ + try: + rag_available = IS_RAG_AVAILABLE() + except ImportError: + # One or more RAG dependencies are missing + rag_available = False + + ret = jsonify({"rag_available": rag_available}) + ret.headers.add('Access-Control-Allow-Origin', '*') + return ret + + @app.route('/app/makeFetchCall', methods=['POST']) def makeFetchCall(): """ @@ -1039,7 +1091,7 @@ def media_to_text(uid): try: ext = os.path.splitext(file_path)[1].lower() - allowed_extensions = {".pdf", ".txt", ".docx", ".xlsx", ".xls", ".pptx"} + allowed_extensions = {".pdf", ".txt", ".docx", ".xlsx", ".xls", ".pptx", ".md"} if ext == '.txt': # Read text files directly with open(file_path, 'rb') as f: @@ -1311,6 +1363,601 @@ def verify_media_file_integrity(uid): raise ValueError(f"Hash mismatch: expected {expected_hash}, got {actual_hash}") +""" + RAGForge Endpoints and Functions +""" +# Chunking Endpoint +@app.route("/chunk", methods=["POST"]) +def chunk(): + """ + Handles text processing requests, specifically chunking. + Uses a registry to dispatch to the correct chunking function. + Expects multipart/form-data with: + - 'baseMethod' in request.form + - 'document' in request.files (UTF-8 text as a file/blob) + - optional additional settings as small form fields + """ + if not request.form and not request.files: + return jsonify({"error": "Request must be form data"}), 400 + + base_method = request.form.get("baseMethod") + if not base_method: + return jsonify({"error": "Missing 'baseMethod' in form data"}), 400 + + # We now require the text as an uploaded file named "document" + file = request.files.get("document") + if file is None: + return jsonify({"error": "Missing 'document' in form data"}), 400 + + # Read and decode the uploaded text file + raw_bytes = file.read() # type: bytes + text = raw_bytes.decode("utf-8", errors="ignore") # bytes -> str + + # Look up the chunking handler + handler = ChunkingMethodRegistry.get_handler(base_method) + + # if it wasn't a built‑in chunker, see if it's a custom provider + if not handler and base_method.startswith("__custom/"): + provider_name = base_method[len("__custom/"):] + entry = ProviderRegistry.get(provider_name) + if entry and entry.get("func"): + handler = entry["func"] + + if not handler: + return jsonify({"error": f"Unsupported chunking method: {base_method}"}), 400 + + # Extract additional settings from form data, converting types carefully + settings = {} + known_int_params = {"chunk_size", "chunk_overlap", "n_topics", "min_topic_size", "top_k", "max_features"} + known_float_params = {"bm25_k1", "bm25_b"} + known_bool_params = {"keep_separator"} + + for key, value in request.form.items(): + if key == "baseMethod": + continue + try: + if key in known_int_params: + settings[key] = int(value) + elif key in known_float_params: + settings[key] = float(value) + elif key in known_bool_params: + # Handle boolean conversion robustly + settings[key] = value.lower() in ['true', 'yes'] + else: + settings[key] = value # Keep as string if type unknown + except (ValueError, TypeError): + print(f"Warning: Could not convert setting '{key}' with value '{value}' to expected type. Using raw value.", file=sys.stderr) + settings[key] = value # Fallback to string if conversion fails + + try: + # Call the registered handler function + chunks = handler(text, **settings) + return jsonify({"chunks": chunks}), 200 + except ValueError as ve: # Catch specific config/setup errors + print(f"Configuration or setup error during chunking ({base_method}): {ve}", file=sys.stderr) + return jsonify({"error": f"Setup error: {ve}"}), 400 # Bad Request + except ImportError as ie: + print(f"Import error during chunking ({base_method}): {ie}", file=sys.stderr) + return jsonify({"error": f"Missing library dependency: {ie.name}"}), 500 # Internal Server Error + except Exception as e: + # Log the full error for server-side debugging + print(f"Unexpected error during chunking ({base_method}): {e}", file=sys.stderr) + # Return a generic error to the client + return jsonify({"error": "An internal error occurred during text processing."}), 500 + +# === Retrieval Endpoint=== +@app.route("/retrieve", methods=["POST"]) +def retrieve(): + """ + Process multiple retrieval methods against provided chunks and queries. + + Expected request format: + { + "methods": [ + { + "id": "unique_method_id", + "baseMethod": "bm25", + "methodName": "BM25", + "library": "BM25", + "settings": { "top_k": 5, ... } + }, + ... + ], + "chunks": [ + { + "text": "chunk text", + "prompt": "original query", + "fill_history": {"chunkMethod": "method name", "docTitle": "doc1", ...}, + "metavars": {"docTitle": "doc1", "chunkLibrary": "library", ...}, + ... + }, + ... + ], + "queries": [ + { + "query": "query text", + "fill_history": {"chunkMethod": "method name", "docTitle": "doc1", ...}, + "metavars": {"docTitle": "doc1", "chunkLibrary": "library", ...}, + ... + }, + ... + ] + } + + Expected output format: + A flat array of objects in the ChainForge TemplateVarInfo format: + [ + { + "text": "chunk text", + "prompt": "query text", + "vars": { + "query": "query text", + "chunkMethod": "chunking method used" + }, + "metavars": { + "method": "retrieval method name", + "baseMethod": "retrieval base method", + "chunkMethod": "chunking method used", + "similarity": 0.85, + "docTitle": "document title", + "chunkId": "unique id", + "rank": 1 + }, + "fill_history": { + "retrievalMethod": "method name", + "baseMethod": "base method type", + "methodId": "method id", + "embeddingModel": "model name if applicable", + "chunkMethod": "chunking method used", + "similarity": 0.85, + "docTitle": "document title", + "chunkId": "unique id" + }, + "llm": "retrieval method name" + }, + ... + ] + """ + global RETRIEVAL_PROGRESS + data = request.json + methods = data.get("methods", []) + chunks = data.get("chunks", []) + queries = data.get("queries", []) + api_keys = data.get("api_keys", []) + + fusion_enabled = bool(data.get("fusion_enabled", False)) + linked_groups = data.get("linked_groups", []) if fusion_enabled else [] + method_name_by_id = {m["id"]: m["methodName"] for m in methods} + + queries = [{'text': q} if isinstance(q, str) else q for q in queries] + + print("[DEBUG] ", methods) + + try: + + # Validate inputs + if not methods: + return jsonify({"error": "No retrieval methods provided"}), 400 + if not chunks: + return jsonify({"error": "No chunks provided"}), 400 + if not queries: + return jsonify({"error": "No queries provided"}), 400 + + RETRIEVAL_PROGRESS = {m["methodName"]: 0 for m in methods} + + method_id_to_group = {} + group_cfg = {} + if fusion_enabled: + for g in linked_groups: + gid = g.get("id") + if not gid: + continue + group_cfg[gid] = g + for mid in g.get("methodKeys", []): + method_id_to_group[mid] = gid + + # (query_text, chunkMethod) -> { methodId -> [ {doc_id, rank, score, obj} ] } + staging = defaultdict(lambda: defaultdict(list)) + + + resolved_handlers = {} + for method in methods: + base_method = method.get("baseMethod") + + # 1) Try built‑in lookup + handler = RetrievalMethodRegistry.get_handler(base_method) + + # 2) Fallback to any custom provider + if not handler and base_method.startswith("__custom/"): + provider_name = base_method[len("__custom/"):] + entry = ProviderRegistry.get(provider_name) + if entry and entry.get("func"): + handler = entry["func"] + + if not handler: + return jsonify({"error": f"Unknown retrieval method: {base_method}"}), 400 + + # cache it for later use + resolved_handlers[base_method] = handler + def find_query_metadata(query_text, queries): + for q in queries: + if isinstance(q, dict) and q.get("text") == query_text: + return q + return {} + + # Group chunks by chunking method + chunks_by_method = {} + for chunk in chunks: + # Extract chunking method from the chunk + chunk_method = chunk.get("fill_history", {}).get("chunkMethod", "unknown") + + if chunk_method not in chunks_by_method: + chunks_by_method[chunk_method] = [] + + # Store the full chunk with all its metadata + chunks_by_method[chunk_method].append({ + "text": chunk.get("text", ""), + "docTitle": chunk.get("metavars", {}).get("docTitle", ""), + "chunkId": chunk.get("metavars", {}).get("chunkId", ""), + "chunkMethod": chunk_method, + "chunkLibrary": chunk.get("metavars", {}).get("chunkLibrary", "") + }) + + print(len(chunks_by_method)) + + # Group retrieval methods by embedding model to avoid redundant computation + embedding_methods = {} # model -> list of methods requiring this model + keyword_methods = [] # methods not requiring embeddings + + for method in methods: + embedding_provider = method.get("embeddingProvider", None) + if embedding_provider: + # This is an embedding-based method + embedding_model = method.get("settings", {}).get("embeddingModel", "default") + full_embedder = f"{embedding_provider}#{embedding_model}" + if full_embedder not in embedding_methods: + embedding_methods[full_embedder] = [] + embedding_methods[full_embedder].append(method) + else: + # Non-embedding method + keyword_methods.append(method) + + print("[DEBUG] ", embedding_methods) + # Prepare the final flat results array + flat_results = [] + + # Process each chunking method separately + for chunk_method, chunk_group in chunks_by_method.items(): + # Skip empty chunk groups + if not chunk_group: + continue + + # Process keyword methods for this chunk group + for method in keyword_methods: + method_id = method.get("id") + base_method = method.get("baseMethod") + method_name = method.get("methodName") + + try: + handler = resolved_handlers.get(base_method) + if not handler: + raise ValueError(f"Unknown method: {base_method}") + RETRIEVAL_PROGRESS[method_name] = 10 + start_time = time.perf_counter() + + # Get retrieved chunks for this method and chunk group + retrieved = handler(chunk_group, queries, method.get("settings", {})) + RETRIEVAL_PROGRESS[method_name] = 70 + end_time = time.perf_counter() + latency_ms = (end_time - start_time) * 1000 + # Process retrieved chunks for each query + for resp in retrieved: + query_object = resp.get("query_object", "") + retrieved_chunks = resp.get("retrieved_chunks", []) + for i, chunk in enumerate(retrieved_chunks): + # Create response object + response_obj = { + "text": chunk["text"], + "prompt": query_object['text'], + "eval_res": { + "items": [{ + "similarity": chunk["similarity"], + "rank": i + 1, + }], + "dtype": "KeyValue_Mixed", + }, + "vars": { + **query_object.get("vars", {}), # Include original query vars + **query_object.get("fill_history", {}), # Include original query fill_history, if any + "query": query_object['text'], + "retrievalMethod": method_name, + "chunkMethod": chunk_method, # Include chunking method in vars + }, + "metavars": { + **query_object.get("metavars", {}), # Include original query metavars + "methodId": method_id, + "retrievalMethodSignature": base_method, + "signature": chunk_method + "-" + method_name, + "docTitle": chunk.get("docTitle", ""), + "chunkId": chunk.get("chunkId", ""), + "chunkLibrary": chunk.get("chunkLibrary", ""), + "latency_ms": f"{latency_ms:.2f}ms" + }, + "llm": chunk.get("llm", "(none)"), # Use chunk's LLM if available + } + + if fusion_enabled: + doc_id = chunk.get("chunkId"); + score = float(chunk.get("similarity", 0.0)) + rank = i + 1 + query_txt = query_object['text'] + staging[(query_txt, chunk_method)][method_id].append({ + "doc_id": doc_id, "rank": rank, "score": score, "obj": response_obj + }) + + flat_results.append(response_obj) + RETRIEVAL_PROGRESS[method_name] = 100 + except Exception as e: + # Skip errors - we'll just not include results from this method + print(f"Error with {method_name} on {chunk_method}: {str(e)}") + continue + + # Process embedding-based methods for this chunk group + for embedder, methods in embedding_methods.items(): + try: + provider, model_name = embedder.split("#", 1) + embedder_func = EmbeddingMethodRegistry.get_embedder(provider) + model_path = next((m['settings'].get('embeddingLocalPath') for m in methods if + m['settings'].get('embeddingLocalPath')), None) + + if not embedder_func: + raise ValueError(f"Unknown embedding model: {model_name}") + for m in methods: + RETRIEVAL_PROGRESS[m["methodName"]] = 30 + + # Compute embeddings once for all methods using this model + chunk_texts = [c["text"] for c in chunk_group] + chunk_embeddings = embedder_func(chunk_texts, model_name, model_path, api_keys) + query_embeddings = embedder_func([query.get("text", "") for query in queries], model_name, model_path, api_keys) + + except Exception as e: + raise RuntimeError( + f"Embedding error with {embedder} on {chunk_method}: {e}" + ) + + # Process each method with the same embeddings + for method in methods: + method_id = method.get("id") + base_method = method.get("baseMethod") + method_name = method.get("methodName") + + # A safe database path to use to store on local disk, if necessary + # :: For instance, vector databases like LanceDB, FAISS or Chroma. + db_path = os.path.join(MEDIA_DIR, method_id + ".db") + + try: + handler = resolved_handlers.get(base_method) + if not handler: + raise ValueError(f"Unknown method: {base_method}") + RETRIEVAL_PROGRESS[method_name] = 50 + start_time = time.perf_counter() + # Get retrieved chunks for this method and chunk group + retrieved = handler(chunk_group, chunk_embeddings, queries, query_embeddings, method.get("settings", {}), db_path) + RETRIEVAL_PROGRESS[method_name] = 80 + end_time = time.perf_counter() + latency_ms = (end_time - start_time) * 1000 + # Process retrieved chunks for each query + for resp in retrieved: + query_object = resp.get("query_object", "") + retrieved_chunks = resp.get("retrieved_chunks", []) + for i, chunk in enumerate(retrieved_chunks): + # Create response object + response_obj = { + "text": chunk["text"], + "prompt": query_object['text'], + "eval_res": { + "items": [{ + "similarity": chunk["similarity"], + "rank": i + 1, + }], + "dtype": "KeyValue_Mixed", + }, + "vars": { + **query_object.get("vars", {}), # Include original query vars + **query_object.get("fill_history", {}), # Include original query fill_history, if any + "query": query_object['text'], + "retrievalMethod": method_name, + "chunkMethod": chunk_method, # Include chunking method in vars + }, + "metavars": { + **query_object.get("metavars", {}), # Include original query metavars + "methodId": method_id, + "retrievalMethodSignature": base_method, + "signature": chunk_method + "-" + method_name, + "docTitle": chunk.get("docTitle", ""), + "chunkId": chunk.get("chunkId", ""), + "chunkLibrary": chunk.get("chunkLibrary", ""), + "embeddingModel": model_name, + "latency_ms": f"{latency_ms:.2f}ms" + }, + "llm": chunk.get("llm", "(none)"), # Use chunk's LLM if available + } + + if fusion_enabled: + doc_id = chunk.get("chunkId"); + score = float(chunk.get("similarity", 0.0)) + rank = i + 1 + query_txt = query_object['text'] + staging[(query_txt, chunk_method)][method_id].append({ + "doc_id": doc_id, "rank": rank, "score": score, "obj": response_obj + }) + flat_results.append(response_obj) + RETRIEVAL_PROGRESS[method_name] = 100 + except Exception as e: + print(f"Error with {method_name} on {chunk_method}: {str(e)}") + continue + # === Retrieval Fusion (imported helpers) === + if fusion_enabled and linked_groups: + for (query_txt, chunk_method), per_method in staging.items(): + groups = defaultdict(dict) + for mid, items in per_method.items(): + gid = method_id_to_group.get(mid) + if gid: groups[gid][mid] = items + for gid, method_lists in groups.items(): + cfg = group_cfg.get(gid, {}) + fmethod = cfg.get("fusionMethod") + settings = cfg.get("fusionSettings") or {} + if fmethod in ("reciprocal_rank_fusion"): + method_keys = (cfg.get("methodKeys") or []) + weights_arr = settings.get("weights") or [] + weights_map = {mid: float(w) for i, mid in enumerate(method_keys) + for w in [weights_arr[i] if i < len(weights_arr) else None] if isinstance(w, (int, float))} + k_val = int(settings.get("k", settings.get("K", 60))) + fused = rrf_fuse(method_lists, k=k_val, weights_by_method=weights_map) + fusion_sig, fusion_name = "fusion:rrf", "rrf" + else: # Weighted Average + method_keys = (cfg.get("methodKeys") or []) + weights_arr = settings.get("weights") or [] + weights_map = { + mid: float(w) + for i, mid in enumerate(method_keys) + for w in [weights_arr[i] if i < len(weights_arr) else None] + if isinstance(w, (int, float)) + } + fused = weighted_avg_fuse( + method_lists, + weights_by_method=weights_map, + ) + fusion_sig, fusion_name = "fusion:weighted_average", "weighted_average" + group_method_ids = [mid for mid in (cfg.get("methodKeys") or []) if mid in method_lists] + pretty_names = [method_name_by_id[mid] for mid in group_method_ids] + fused_label = f"Fused ({' + '.join(pretty_names)})" + for rank_idx, (doc_id, fused_score, base_obj) in enumerate(fused, start=1): + obj = copy.deepcopy(base_obj) + obj["eval_res"]["items"] = [{"similarity": fused_score, "rank": rank_idx}] + obj["vars"]["retrievalMethod"] = fused_label + obj["metavars"].update({ + "methodId": f"group:{gid}", + "retrievalMethodSignature": fusion_sig, + "signature": f"{chunk_method}-FUSED-{gid}", + }) + flat_results.append(obj) + + return jsonify(flat_results), 200 + except RuntimeError as e: + return jsonify({"error": str(e)}), 400 + +@app.route('/getRetrieveProgress', methods=['GET']) +def get_retrieve_progress(): + """ + Returns the current progress of all active retrieval methods. + Used by the frontend polling loop. + """ + global RETRIEVAL_PROGRESS + return jsonify(RETRIEVAL_PROGRESS) + + +# === Reranking Endpoint === +@app.route("/rerank", methods=["POST"]) +def rerank(): + """ + Rerank documents using the specified reranking method. + + Expected form data: + - baseMethod: The reranking method identifier (e.g., "cross_encoder", "cohere_rerank") + - documents: JSON array of document texts to rerank + - query: Query text for relevance scoring (optional) + - api_keys: JSON object containing API keys (optional) + - Additional method-specific settings as form fields + + Returns: + - JSON response with reranked documents containing: + - reranked_documents: List of documents with scores and indices + """ + if not request.form: + return jsonify({"error": "Request must be form data"}), 400 + + base_method = request.form.get("baseMethod") + documents_json = request.form.get("documents") + query = request.form.get("query", "") + api_keys_json = request.form.get("api_keys", "{}") + + if not base_method: + return jsonify({"error": "Missing 'baseMethod' in form data"}), 400 + if not documents_json: + return jsonify({"error": "Missing 'documents' in form data"}), 400 + + try: + # Parse documents JSON + import json + documents = json.loads(documents_json) + if not isinstance(documents, list): + return jsonify({"error": "Documents must be a JSON array"}), 400 + except (json.JSONDecodeError, ValueError) as e: + return jsonify({"error": f"Invalid JSON in documents: {e}"}), 400 + + # Parse api_keys JSON + try: + api_keys = json.loads(api_keys_json) if api_keys_json else {} + if not isinstance(api_keys, dict): + return jsonify({"error": "api_keys must be a JSON object"}), 400 + except (json.JSONDecodeError, ValueError) as e: + return jsonify({"error": f"Invalid JSON in api_keys: {e}"}), 400 + + # Get the reranking handler + handler = RerankingMethodRegistry.get_handler(base_method) + + # Check for custom provider if not found in built-in methods + if not handler and base_method.startswith("__custom/"): + provider_name = base_method[len("__custom/"):] + entry = ProviderRegistry.get(provider_name) + if entry and entry.get("func"): + handler = entry["func"] + + if not handler: + return jsonify({"error": f"Unsupported reranking method: {base_method}"}), 400 + + # Extract additional settings from form data + settings = {} + known_int_params = {"top_k", "batch_size", "max_chunks_per_doc", "preserve_top_k", "k_param"} + known_float_params = {"lambda_param", "diversity_threshold"} + known_bool_params = {"normalize_scores"} + + for key, value in request.form.items(): + if key not in ["baseMethod", "documents", "query", "api_keys"]: + try: + if key in known_int_params: + settings[key] = int(value) + elif key in known_float_params: + settings[key] = float(value) + elif key in known_bool_params: + settings[key] = value.lower() in ['true', 'yes'] + else: + settings[key] = value # Keep as string if type unknown + except (ValueError, TypeError): + print(f"Warning: Could not convert setting '{key}' with value '{value}' to expected type. Using raw value.", file=sys.stderr) + settings[key] = value # Fallback to string if conversion fails + + # Add api_keys to settings + if api_keys: + settings['api_keys'] = api_keys + + try: + # Call the reranking handler + reranked_results = handler(documents, query, **settings) + return jsonify({"reranked_documents": reranked_results}), 200 + + except ValueError as ve: + print(f"Configuration or setup error during reranking ({base_method}): {ve}", file=sys.stderr) + return jsonify({"error": f"Setup error: {ve}"}), 400 + except ImportError as ie: + print(f"Import error during reranking ({base_method}): {ie}", file=sys.stderr) + return jsonify({"error": f"Missing library dependency: {ie.name}"}), 500 + except Exception as e: + print(f"Unexpected error during reranking ({base_method}): {e}", file=sys.stderr) + return jsonify({"error": "An internal error occurred during reranking."}), 500 + + @app.route('/api/proxyImage', methods=['GET']) def proxy_image(): """Proxy for fetching images to avoid CORS restrictions""" @@ -1342,7 +1989,7 @@ def proxy_image(): except Exception as e: return jsonify({"error": f"Error fetching image: {str(e)}"}), 500 - + """ SPIN UP SERVER """ diff --git a/chainforge/providers/protocol.py b/chainforge/providers/protocol.py index 588cb963a..0f4887d07 100644 --- a/chainforge/providers/protocol.py +++ b/chainforge/providers/protocol.py @@ -1,4 +1,5 @@ -from typing import Protocol, Optional, Dict, List, Literal, Union, Any +from typing import Protocol, Optional, Dict, List, Literal, Union, Any, TypedDict +import inspect """ OpenAI chat message format typing @@ -12,6 +13,14 @@ class ChatMessage(Dict): ChatHistory = List[ChatMessage] +class SettingsSchema(TypedDict, total=False): + # react-jsonschema-form contract: + # "settings" = JSON-Schema *properties* object, "ui" = uiSchema + settings: Dict[str, Any] + ui: Dict[str, Any] + +Category = Literal["model", "retriever", "chunker"] + class CustomProviderProtocol(Protocol): """ @@ -36,53 +45,118 @@ def __call__(self, """ pass +class CustomChunkerProtocol(Protocol): + +# A Callable protocol to implement for custom chunker provider completions.. + + def __call__(self, text: str) -> List[str]: + """ + Define a call to your custom chunker. + + Parameters: + - `text`: A string of source text (e.g., a document or article) to be split into smaller segments. + + Returns: + - A list of string chunks (typically paragraphs or sections) derived from the input text. + """ + pass + +class CustomRetrieverProtocol(Protocol): + +# A Callable protocol to implement for custom retriever provider completions. + + def __call__(self, + chunks: List[Dict[str, Any]], + queries: List[Union[str, Dict[str, Any]]], + settings: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Define a call to your custom retriever. + + Parameters: + - `chunks`: A list of document fragments (e.g., from a chunker) with metadata for context retrieval. + - `queries`: A list of user queries or prompts, optionally with additional metadata. + - `settings`: Dictionary of retrieval settings, such as similarity threshold or scoring method. + + Returns: + - A list of retrieved chunks with associated metadata, typically ranked by relevance to each query. + """ + pass + +class ProviderEntry(TypedDict, total=False): + name: str + func: Any + script_id: str + emoji: str + models: Optional[List[str]] + rate_limit: Union[int, Literal["sequential"]] + settings_schema: Optional[SettingsSchema] + category: Category """ A registry for custom providers """ class _ProviderRegistry: def __init__(self): - self._registry = {} + self._registry: Dict[str, ProviderEntry] = {} # TYPED self._curr_script_id = '0' - self._last_updated = {} + self._last_updated: Dict[str, Optional[str]] = {} def set_curr_script_id(self, id: str): self._curr_script_id = id - def register(self, cls: CustomProviderProtocol, name: str, **kwargs): - if name is None or isinstance(name, str) is False or len(name) == 0: - raise Exception("Cannot register custom model provider: No name given. Name must be a string and unique.") + def register(self, cls_or_fn: Any, name: str, **kwargs): + if not name or not isinstance(name, str): + raise Exception("Cannot register custom provider: Name must be a non-empty string.") self._last_updated[name] = self._registry[name]["script_id"] if name in self._registry else None - self._registry[name] = { "name": name, "func": cls, "script_id": self._curr_script_id, **kwargs } - - def get(self, name): + self._registry[name] = { + "name": name, + "func": cls_or_fn, + "script_id": self._curr_script_id, + **kwargs + } + + def get(self, name: str) -> Optional[ProviderEntry]: return self._registry.get(name) - def get_all(self): + def get_all(self) -> List[ProviderEntry]: return list(self._registry.values()) - def has(self, name): + def get_all_by_category(self, category: Category) -> List[ProviderEntry]: + return [e for e in self._registry.values() if e.get("category") == category] + + def has(self, name: str) -> bool: return name in self._registry - def remove(self, name): + def remove(self, name: str): if self.has(name): del self._registry[name] def watch_next_registered(self): self._last_updated = {} - def last_registered(self): + def last_registered(self) -> Dict[str, Optional[str]]: return {k: v for k, v in self._last_updated.items()} # Global instance of the registry. ProviderRegistry = _ProviderRegistry() +def _ensure_params(fn: Any, required: List[str]) -> None: + try: + params = list(inspect.signature(fn).parameters) + except (TypeError, ValueError): + # skip strict check + return + missing = [p for p in required if p not in params] + if missing: + raise TypeError(f"{getattr(fn, '__name__', fn)} must define params: {', '.join(required)}") + def provider(name: str = 'Custom Provider', emoji: str = '✨', models: Optional[List[str]] = None, rate_limit: Union[int, Literal["sequential"]] = "sequential", - settings_schema: Optional[Dict] = None): + settings_schema: Optional[SettingsSchema] = None, + category: Category = "model"): """ A decorator for registering custom LLM provider methods or classes (Callables) that conform to `CustomProviderProtocol`. @@ -118,9 +192,31 @@ def provider(name: str = 'Custom Provider', NOTE: Only `textarea`, `range`, and enum, and text input UI widgets are properly supported from `react-jsonschema-form`; you can try other widget types, but the CSS may not display property. + - category: "model" | "retriever" | "chunker" + Callable shapes: + - category == "model": + (prompt, model, chat_history, **kwargs) -> str + - category == "retriever": + (chunks, queries, settings) -> List[...] + - category == "chunker": + (text) -> List[str] """ - def dec(cls: CustomProviderProtocol): - ProviderRegistry.register(cls, name=name, emoji=emoji, models=models, rate_limit=rate_limit, settings_schema=settings_schema) + def dec(cls: Union[CustomProviderProtocol, CustomChunkerProtocol, CustomRetrieverProtocol]): + # Allow functions OR classes-with-__call__ + fn = cls() if inspect.isclass(cls) else cls + + # Friendly signature check + if category == "retriever": + _ensure_params(fn, ["chunks", "queries", "settings"]) + elif category == "chunker": + _ensure_params(fn, ["text"]) + # (we skip strict check for "model" to allow flexible kwargs) + + ProviderRegistry.register( + fn, name=name, emoji=emoji, models=models, + rate_limit=rate_limit, settings_schema=settings_schema, + category=category, + ) return cls return dec \ No newline at end of file diff --git a/chainforge/rag/__init__.py b/chainforge/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/chainforge/rag/chunkers.py b/chainforge/rag/chunkers.py new file mode 100644 index 000000000..f66fa2891 --- /dev/null +++ b/chainforge/rag/chunkers.py @@ -0,0 +1,412 @@ +# --- Chunk Endpoint --- +import sys +from typing import List, Dict, Any, Callable, Union + +# === Define the Chunking Registry (Place after imports) === +class ChunkingMethodRegistry: + """Registry for text chunking methods.""" + _methods: Dict[str, Callable] = {} + + @classmethod + def register(cls, identifier: str): + """Decorator to register a chunking function.""" + if not isinstance(identifier, str) or not identifier: + raise ValueError("Method identifier must be a non-empty string.") + + def decorator(handler_func: Callable): + if not callable(handler_func): + raise TypeError("Registered handler must be a callable function.") + if identifier in cls._methods: + print(f"Warning: Overwriting existing chunking method '{identifier}'.", file=sys.stderr) + cls._methods[identifier] = handler_func + # print(f"Registered chunking method: {identifier}") # Optional: for debugging + return handler_func + return decorator + + @classmethod + def get_handler(cls, identifier: str) -> Union[Callable, None]: + """Get the handler function for a given method identifier.""" + return cls._methods.get(identifier) + +# === Chunking Helper Functions === +@ChunkingMethodRegistry.register("overlapping_openai_tiktoken") +def overlapping_openai_tiktoken(text: str, **kwargs: Any) -> List[str]: + # OpenAI's Tiktoken for token-based chunking + import tiktoken + + model = kwargs.get("model", "gpt-3.5-turbo") + chunk_size = int(kwargs.get("chunk_size", 200)) + chunk_overlap = int(kwargs.get("chunk_overlap", 50)) + + # Consider making model name configurable if needed + enc = None + model_error = None + try: + enc = tiktoken.encoding_for_model(model) + except Exception as e: + model_error = e + try: + enc = tiktoken.get_encoding(model) + except Exception as e2: + print(f"Warning: Could not resolve tokenizer/model '{model}' via encoding_for_model ({model_error}) or get_encoding ({e2}); falling back to cl100k_base.", file=sys.stderr) + enc = tiktoken.get_encoding("cl100k_base") + + tokens = enc.encode(text) + result = [] + start = 0 + while start < len(tokens): + end = min(start + chunk_size, len(tokens)) # Prevent overshoot + chunk_tokens = tokens[start:end] + # Filter out potential empty strings from decoding edge cases + decoded_chunk = enc.decode(chunk_tokens).strip() + if decoded_chunk: + result.append(decoded_chunk) + + # Ensure overlap doesn't push start before 0 + start = max(0, end - chunk_overlap) + + # Break if we've processed the last chunk or start isn't advancing + if end == len(tokens) or start >= end: + break + + # Safety break for potential infinite loops if overlap >= size + if chunk_overlap >= chunk_size and start > 0: + print(f"Warning: chunk_overlap ({chunk_overlap}) >= chunk_size ({chunk_size}). Breaking loop early.", file=sys.stderr) + break + + return result if result else [text] + +@ChunkingMethodRegistry.register("overlapping_huggingface_tokenizers") +def overlapping_huggingface_tokenizers(text: str, **kwargs: Any) -> List[str]: + # HuggingFace Tokenizers for token-based chunking + from transformers import AutoTokenizer + + tokenizer = kwargs.get("tokenizer", "bert-base-uncased") + + # Consider making model name configurable + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + except Exception as e: + print(f"Error loading HuggingFace tokenizer model '{tokenizer}': {e}", file=sys.stderr) + raise ValueError(f"Failed to load HuggingFace tokenizer model {tokenizer}.") from e + + chunk_size = int(kwargs.get("chunk_size", 200)) + chunk_overlap = int(kwargs.get("chunk_overlap", 50)) + + tokens = tokenizer.encode(text, add_special_tokens=False) # Avoid splitting on special tokens + result = [] + start = 0 + while start < len(tokens): + end = min(start + chunk_size, len(tokens)) # Prevent overshoot + chunk_tokens = tokens[start:end] + # skip_special_tokens=True ensures things like [CLS] aren't in the output text + decoded_chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True).strip() + if decoded_chunk: + result.append(decoded_chunk) + + start = max(0, end - chunk_overlap) # Ensure overlap doesn't push start before 0 + + # Break if we've processed the last chunk or start isn't advancing + if end == len(tokens) or start >= end: + break + + # Safety break for potential infinite loops if overlap >= size + if chunk_overlap >= chunk_size and start > 0: + print(f"Warning: chunk_overlap ({chunk_overlap}) >= chunk_size ({chunk_size}). Breaking loop early.", file=sys.stderr) + break + + return result if result else [text] + +# --- Markdown headings chunker --- +@ChunkingMethodRegistry.register("markdown_header") +def markdown_header(text: str, **kwargs) -> list[str]: + """ + Splits markdown into sections at each ATX heading (levels 1–6), + keeping the heading with its section. + """ + import re + if text is None: + return [""] + + # normalize CRLF + text = text.replace("\r\n", "\n") + + # split at lines that start with 1–6 '#' followed by a space; keep the heading (lookahead) + sections = re.split(r"(?m)(?=^#{1,6}\s+)", text) + chunks = [sec.strip() for sec in sections if sec and sec.strip()] + return chunks if chunks else [text] + + +@ChunkingMethodRegistry.register("syntax_nltk") +def syntax_nltk(text: str, **kwargs: Any) -> List[str]: + import nltk + from nltk.tokenize import sent_tokenize + + # Ensure both punkt and punkt_tab are available (NLTK >= 3.8) + for resource in ["punkt", "punkt_tab"]: + try: + nltk.data.find(f"tokenizers/{resource}") + except LookupError: + try: + nltk.download(resource, quiet=True) + except Exception as e: + raise ValueError(f"Error downloading NLTK {resource}: {e}") + + try: + sents = sent_tokenize(text) + sents = [s.strip() for s in sents if s.strip()] + return sents if sents else [text] + except Exception as e: + raise ValueError(f"NLTK sent_tokenize error: {e}") + + +# TextTiling method +_SIMPLE_EN_STOPWORDS = { + "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", "yourselves", + "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself", "they", "them", "their", + "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these", "those", "am", "is", "are", "was", + "were", "be", "been", "being", "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an", "the", "and", + "but", "if", "or", "because", "as", "until", "while", "of", "at", "by", "for", "with", "about", "against", "between", + "into", "through", "during", "before", "after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", + "over", "under", "again", "further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", + "each", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", + "s", "t", "can", "will", "just", "don", "should", "now" +} +@ChunkingMethodRegistry.register("syntax_texttiling") +def syntax_texttiling(text: str) -> List[str]: + from nltk.tokenize import TextTilingTokenizer + + # Pass our own stopwords to avoid touching nltk.corpus.stopwords + ttt = TextTilingTokenizer(stopwords=_SIMPLE_EN_STOPWORDS) + chunks = ttt.tokenize(text) + return chunks if chunks else [text] + +""" + Chonkie Methods +""" +@ChunkingMethodRegistry.register("chonkie_token") +def chonkie_token(text: str, **kwargs: Any) -> List[str]: + from chonkie import TokenChunker + + tokenizer = kwargs.get("tokenizer", "gpt2") + chunk_size = int(kwargs.get("chunk_size", 512)) + chunk_overlap = int(kwargs.get("chunk_overlap", 0)) + + chunker = TokenChunker( + tokenizer=tokenizer, # Supports string identifiers + chunk_size=chunk_size, # Maximum tokens per chunk + chunk_overlap=chunk_overlap, # Overlap between chunks + ) + + texts = [t.text for t in chunker.chunk(text)] + return texts if texts else [text] + +@ChunkingMethodRegistry.register("chonkie_sentence") +def chonkie_sentence(text: str, **kwargs: Any) -> List[str]: + from chonkie import SentenceChunker + import json + + tokenizer_or_token_counter = kwargs.get("tokenizer_or_token_counter", "gpt2") + chunk_size = int(kwargs.get("chunk_size", 1)) + chunk_overlap = int(kwargs.get("chunk_overlap", 0)) + min_sentences_per_chunk = int(kwargs.get("min_sentences_per_chunk", 1)) + min_characters_per_sentence = int(kwargs.get("min_characters_per_sentence", 12)) + delim = kwargs.get("delim", '[".", "!", "?", "\\n\\n"]') + include_delim = kwargs.get("include_delim", "prev") + if len(include_delim.strip()) == 0: + include_delim = None + + try: + delim = json.loads(delim) # Validate JSON format + if not isinstance(delim, list) or not all(isinstance(d, str) for d in delim): + raise ValueError("Delim must be a JSON parseable string representing an array of characters.") + except Exception as e: + print(f"Invalid JSON format for delim: {delim}. Delimeter must be a JSON parseable string representing an array of characters. Skipping custom delimeter. Error: {e}", file=sys.stderr) + delim = ['.', '!', '?', '\n'] + + chunker = SentenceChunker( + tokenizer_or_token_counter=tokenizer_or_token_counter, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + min_sentences_per_chunk=min_sentences_per_chunk, + min_characters_per_sentence=min_characters_per_sentence, + delim=delim, # Custom delimiters + include_delim=include_delim, # Include delimiters in the chunk + ) + + texts = [t.text for t in chunker.chunk(text)] + return texts if texts else [text] + +@ChunkingMethodRegistry.register("chonkie_recursive") +def chonkie_recursive(text: str, **kwargs: Any) -> List[str]: + from chonkie import RecursiveChunker, RecursiveRules + import json + + tokenizer_or_token_counter = kwargs.get("tokenizer_or_token_counter", "gpt2") + chunk_size = int(kwargs.get("chunk_size", 512)) + min_characters_per_chunk = int(kwargs.get("min_characters_per_chunk", 12)) + + # If provided, will override the default rules and provided delimeters + # Format is "-", e.g., "markdown-en" + # https://huggingface.co/datasets/chonkie-ai/recipes/viewer/recipes/train?row=5&views%5B%5D=recipes + use_premade_recipe = kwargs.get("use_premade_recipe", None) + + # Custom recipe is assumed to be a JSON parseable string, + # representing a list of dictionaries for RecursiveLevels. + # For format, see https://docs.chonkie.ai/chunkers/recursive-chunker and the above dataset, + # far right column, "recipe" under the "recursive_rules"->"levels" key. + # If provided, will override the default rules and provided delimeters + custom_recipe = kwargs.get("custom_recipe", None) + + rules = RecursiveRules() + if custom_recipe: + try: + rules = RecursiveRules.from_dict({ "levels": json.loads(custom_recipe) }) + except Exception as e: + print(f"Invalid JSON format for custom recipe: {custom_recipe}. Error: {e}", file=sys.stderr) + custom_recipe = None + elif use_premade_recipe: + try: + if "-" in use_premade_recipe: + # Initialize using recipe (e.g., "markdown-en") + name, lang = use_premade_recipe.split("-") + rules = RecursiveRules.from_recipe(name=name, lang=lang) + else: + # Handle language-only case + # Initialize using recipe (e.g., "en") + rules = RecursiveRules.from_recipe(lang=use_premade_recipe) + except Exception as e: + print(f"Invalid recipe name for use_premade_recipe: {use_premade_recipe}. Error: {e}", file=sys.stderr) + use_premade_recipe = None + + chunker = RecursiveChunker( + tokenizer_or_token_counter=tokenizer_or_token_counter, + chunk_size=chunk_size, + rules=rules, + min_characters_per_chunk=min_characters_per_chunk, + ) + + texts = [t.text for t in chunker.chunk(text)] + return texts if texts else [text] + +@ChunkingMethodRegistry.register("chonkie_semantic") +def chonkie_semantic(text: str, **kwargs: Any) -> List[str]: + from chonkie import SemanticChunker + import json + import sys + + # --- CONFIGURATION --- + # 1. Setup Model + embedding_model = kwargs.get("embedding_model", "minishlab/potion-base-8M") + local_path = kwargs.get("embedding_local_path", '') + if local_path: + embedding_model = local_path + + chunk_size = int(kwargs.get("chunk_size", 512)) + threshold = kwargs.get("threshold", 0.8) + + # 2. Setup Advanced Params (Supported in 1.3.1) + similarity_window = int(kwargs.get("similarity_window", 1)) + min_sentences = int(kwargs.get("min_sentences", 1)) + min_characters_per_sentence = int(kwargs.get("min_characters_per_sentence", 12)) + + # 3. THE MAGIC SWITCH (SDPM Support) + # If the user sets skip_window > 0, this acts exactly like the old SDPMChunker. + # If skip_window is 0 (default), it acts like standard SemanticChunker. + skip_window = int(kwargs.get("skip_window", 0)) + + # 4. Clean Threshold + if isinstance(threshold, str) and threshold != 0.8: + try: + threshold = float(threshold) + except ValueError: + threshold = 0.8 + + # --- INITIALIZATION --- + # Note: We removed 'mode', 'threshold_step', 'delim', and 'min_chunk_size' + # because Chonkie 1.3.1 no longer supports them. + chunker = SemanticChunker( + embedding_model=embedding_model, + threshold=threshold, + chunk_size=chunk_size, + similarity_window=similarity_window, + min_sentences=min_sentences, + min_characters_per_sentence=min_characters_per_sentence, + skip_window=skip_window + ) + + # --- EXECUTION --- + texts = [t.text for t in chunker.chunk(text)] + return texts if texts else [text] + +@ChunkingMethodRegistry.register("chonkie_late") +def chonkie_late(text: str, **kwargs: Any) -> List[str]: + from sentence_transformers import SentenceTransformer + from chonkie import LateChunker, RecursiveRules + import json + + # Ensure sentence-transformers doesn't choke on Chonkie's extra kwarg. + original_encode = SentenceTransformer.encode + # Only wrap once per process to avoid stacking wrappers. + if not getattr(original_encode, "_chainforge_patch", False): + def encode_without_add_special_tokens(self, sentences, **encode_kwargs): + # SentenceTransformer >=3 raises if this kwarg is unsupported, + # but Chonkie always sets it, so just remove it and forward. + if "add_special_tokens" in encode_kwargs: + encode_kwargs = dict(encode_kwargs) + encode_kwargs.pop("add_special_tokens", None) + return original_encode(self, sentences, **encode_kwargs) + + encode_without_add_special_tokens._chainforge_patch = True + SentenceTransformer.encode = encode_without_add_special_tokens + + # Basic parameters + embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2") + embedding_path = kwargs.get("embedding_local_path", '') + if embedding_path != '': + embedding_model = embedding_path + chunk_size = int(kwargs.get("chunk_size", 512)) + min_characters_per_chunk = int(kwargs.get("min_characters_per_chunk", 24)) + + # If provided, will override the default rules + # Format is "-", e.g., "markdown-en" + use_premade_recipe = kwargs.get("use_premade_recipe", None) + + # Custom recipe is assumed to be a JSON parseable string, + # representing a list of dictionaries for RecursiveLevels + custom_recipe = kwargs.get("custom_recipe", None) + + # Handle rules setup (similar to recursive chunker) + rules = RecursiveRules() + if custom_recipe: + try: + rules = RecursiveRules.from_dict({ "levels": json.loads(custom_recipe) }) + except Exception as e: + print(f"Invalid JSON format for custom recipe: {custom_recipe}. Error: {e}", file=sys.stderr) + custom_recipe = None + elif use_premade_recipe: + try: + # Initialize using recipe (e.g., "markdown-en") + if "-" in use_premade_recipe: + # Initialize using recipe (e.g., "markdown-en") + name, lang = use_premade_recipe.split("-") + rules = RecursiveRules.from_recipe(name=name, lang=lang) + else: + # Handle language-only case + # Initialize using recipe (e.g., "en") + rules = RecursiveRules.from_recipe(lang=use_premade_recipe) + except Exception as e: + print(f"Invalid recipe name for use_premade_recipe: {use_premade_recipe}. Error: {e}", file=sys.stderr) + use_premade_recipe = None + + # Initialize standard chunker with provided parameters + chunker = LateChunker( + embedding_model=embedding_model, + chunk_size=chunk_size, + rules=rules, + min_characters_per_chunk=min_characters_per_chunk, + ) + + chunks = chunker.chunk(text) + return [chunk.text for chunk in chunks] if chunks else [text] + diff --git a/chainforge/rag/embeddings.py b/chainforge/rag/embeddings.py new file mode 100644 index 000000000..39db008f1 --- /dev/null +++ b/chainforge/rag/embeddings.py @@ -0,0 +1,271 @@ +import os + +""" +NOTE: The following API key names are passed in from the ChainForge settings: + +OpenAI: "", +OpenAI_BaseURL: "", +Anthropic: "", +Google: "", +Azure_OpenAI: "", +Azure_OpenAI_Endpoint: "", +HuggingFace: "", +AlephAlpha: "", +AlephAlpha_BaseURL: "", +Ollama_BaseURL: "", +AWS_Access_Key_ID: "", +AWS_Secret_Access_Key: "", +AWS_Session_Token: "", +AWS_Region: "us-east-1", +AmazonBedrock: JSON.stringify({ credentials: {}, region: "us-east-1" }), +Together: "", + +""" + +class EmbeddingMethodRegistry: + _models = {} + + @classmethod + def register(cls, model_name): + def decorator(embedding_func): + cls._models[model_name] = embedding_func + return embedding_func + + return decorator + + @classmethod + def get_embedder(cls, model_name): + return cls._models.get(model_name) + + @classmethod + def list_models(cls): + return list(cls._models.keys()) + + +@EmbeddingMethodRegistry.register("huggingface") +def huggingface_embedder(texts, model_name="sentence-transformers/all-mpnet-base-v2", path=None, + api_keys=None): + """ + Generate embeddings using HuggingFace Transformers. + + Args: + texts: List of text strings to embed + model_name: HuggingFace model name/path (default: sentence-transformers/all-mpnet-base-v2) + path: in case you need to you local path + + Returns: + List of embeddings for each text + """ + try: + from transformers import AutoTokenizer, AutoModel + import torch + + print(f"Using HuggingFace model: {model_name} for {len(texts)} texts") + + if path: + model_name = path + + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name) + + embeddings = [] + batch_size = 32 + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i + batch_size] + batch_embeddings = [] + + for t in batch_texts: + inputs = tokenizer(t, return_tensors="pt", truncation=True, padding=True, + max_length=512) # Add max_length for safety + with torch.no_grad(): + outputs = model(**inputs) + # Use mean pooling by default + emb = outputs.last_hidden_state.mean(dim=1).squeeze().tolist() + batch_embeddings.append(emb) + + embeddings.extend(batch_embeddings) + + return embeddings + except Exception as e: + print(f"HuggingFace embedder failed: {str(e)}") + raise ValueError(f"Failed to generate HuggingFace embeddings: {str(e)}") + + +@EmbeddingMethodRegistry.register("openai") +def openai_embedder(texts, model_name="text-embedding-ada-002", path=None, api_keys=None): + """ + Generate embeddings using OpenAI Embeddings. + + Args: + texts: List of text strings to embed + model_name: OpenAI embedding model to use (default: text-embedding-ada-002) + path: not used + + Returns: + List of embeddings for each text + """ + try: + from openai import OpenAI + import os + + # Get the OpenAI API key from environment or settings + openai_api_key = api_keys and api_keys.get("OpenAI") or os.environ.get("OPENAI_API_KEY") + if not openai_api_key: + raise ValueError("Missing OpenAI key.") + + # construct client with the key + client = OpenAI(api_key=openai_api_key) + print(f"Using OpenAI model: {model_name} for {len(texts)} texts") + + embeddings = [] + # Process in batches of 16 to stay within rate limits + batch_size = 16 + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i + batch_size] + batch_embeddings = [] + + for t in batch_texts: + resp = client.embeddings.create(input=t, model=model_name) + emb = resp.data[0].embedding + batch_embeddings.append(emb) + + embeddings.extend(batch_embeddings) + + return embeddings + except Exception as e: + print(f"OpenAI embedder failed: {str(e)}") + raise ValueError(f"Failed to generate OpenAI embeddings: {str(e)}") + + +@EmbeddingMethodRegistry.register("cohere") +def cohere_embedder(texts, model_name="embed-english-v2.0", path=None, api_keys=None): + """ + Generate embeddings using Cohere Embeddings. + + Args: + texts: List of text strings to embed + model_name: Cohere embedding model to use (default: embed-english-v2.0) + path: non utilisé + + Returns: + List of embeddings for each text + """ + try: + import cohere + print(f"Using Cohere model: {model_name} for {len(texts)} texts") + + # Get API key from environment or settings + api_key = api_keys and api_keys.get("Cohere") or os.environ.get("COHERE_API_KEY") + if not api_key: + raise ValueError("Cohere API key not found in environment or app config") + + co = cohere.Client(api_key) + + batch_size = 32 + embeddings = [] + + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i + batch_size] + response = co.embed(texts=batch_texts, model=model_name) + embeddings.extend(response.embeddings) + + return embeddings + except Exception as e: + print(f"Cohere embedder failed: {str(e)}") + raise ValueError(f"Failed to generate Cohere embeddings: {str(e)}") + + +@EmbeddingMethodRegistry.register("sentence-transformers") +def sentence_transformers_embedder(texts, model_name="all-MiniLM-L6-v2", path=None, api_keys=None): + """ + Generate embeddings using Sentence Transformers. + + Args: + texts: List of text strings to embed + model_name: Sentence Transformers model name (default: all-MiniLM-L6-v2) + + Returns: + List of embeddings for each text + """ + try: + from sentence_transformers import SentenceTransformer + print(f"Using SentenceTransformer model: {model_name} for {len(texts)} texts") + + if path: + model_name = path + + model = SentenceTransformer(model_name) + + # Process in reasonable batch sizes + batch_size = 32 + embeddings = [] + + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i + batch_size] + _embs = model.encode(batch_texts).tolist() + embeddings.extend(_embs) + + return embeddings + except Exception as e: + print(f"SentenceTransformer embedder failed: {str(e)}") + raise ValueError(f"Failed to generate SentenceTransformer embeddings: {str(e)}") + + +@EmbeddingMethodRegistry.register("azure-openai") +def azure_openai_embedder(texts, model_name="text-embedding-ada-002", deployment_name=None, api_keys=None): + """ + Generate embeddings using Azure OpenAI Embeddings. + + Args: + texts: List of text strings to embed + model_name: OpenAI embedding model to use (default: text-embedding-ada-002) + deployment_name: used for name of deployment + + Returns: + List of embeddings for each text + """ + try: + from openai import AzureOpenAI + import concurrent.futures + from tqdm import tqdm + + print(f"Using Azure OpenAI model: {model_name} for {len(texts)} texts") + + azure_api_key = api_keys and api_keys.get("Azure_OpenAI") or os.environ.get("AZURE_OPENAI_API_KEY") + azure_endpoint = api_keys and api_keys.get("Azure_OpenAI_Endpoint") or os.environ.get("AZURE_OPENAI_ENDPOINT") + + if not azure_api_key: + raise ValueError("API key for Azure OpenAI is missing.") + if not azure_endpoint: + raise ValueError("Endpoint for Azure OpenAI is missing.") + + client = AzureOpenAI( + api_key=azure_api_key, + api_version="2023-05-15", + azure_endpoint=azure_endpoint + ) + + embeddings = [] + batch_size = 16 + + def get_embedding(t): + resp = client.embeddings.create( + input=t, + model=deployment_name + ) + return resp.data[0].embedding + + # Initialisation de tqdm pour le nombre total de textes + with tqdm(total=len(texts), desc="Generation of embeddings using Azure OpenAI") as pbar: + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i + batch_size] + with concurrent.futures.ThreadPoolExecutor() as executor: + batch_embeddings = list(executor.map(get_embedding, batch_texts)) + embeddings.extend(batch_embeddings) + pbar.update(len(batch_texts)) + + return embeddings + except Exception as e: + print(f"Azure OpenAI embedder failed: {str(e)}") + raise ValueError(f"Failed to generate OpenAI embeddings: {str(e)}") diff --git a/chainforge/rag/rerankers.py b/chainforge/rag/rerankers.py new file mode 100644 index 000000000..edfce4774 --- /dev/null +++ b/chainforge/rag/rerankers.py @@ -0,0 +1,250 @@ +import sys +from typing import List, Dict, Any, Callable, Union +from collections import defaultdict +import copy + +# === Reranking Registry === +class RerankingMethodRegistry: + """Registry for document reranker methods.""" + _methods: Dict[str, Callable] = {} + + @classmethod + def register(cls, identifier: str): + """Decorator to register a reranking function.""" + if not isinstance(identifier, str) or not identifier: + raise ValueError("Method identifier must be a non-empty string.") + + def decorator(handler_func: Callable): + if not callable(handler_func): + raise TypeError("Registered handler must be a callable function.") + if identifier in cls._methods: + print(f"Warning: Overwriting existing reranking method '{identifier}'.", file=sys.stderr) + cls._methods[identifier] = handler_func + return handler_func + return decorator + + @classmethod + def get_handler(cls, identifier: str) -> Union[Callable, None]: + """Get the handler function for a given method identifier.""" + return cls._methods.get(identifier) + +# === Reranker Methods === +@RerankingMethodRegistry.register("cross_encoder") +def cross_encoder_rerank(documents: List[str], query: str = "", **kwargs: Any) -> List[Dict[str, Any]]: + """ + Rerank documents using a cross-encoder model, using the + `sentence-transformers` library. + + See https://sbert.net/docs/cross_encoder/pretrained_models.html for available models. + + Args: + documents: List of document texts to rerank + query: Query text for relevance scoring + **kwargs: Additional settings including: + - model: Cross-encoder model name + - top_k: Number of top documents to return + - batch_size: Batch size for processing + + Returns: + List of dictionaries with 'document', 'score', and 'index' keys + """ + try: + from sentence_transformers import CrossEncoder + except ImportError: + raise ImportError("sentence-transformers library is required for cross-encoder reranking") + + if not documents: + return [] + + if not query: + # If no query provided, return documents in original order with synthetic scores + return [ + { + "document": doc, + "score": 1.0 - (i / len(documents)), + "index": i + } + for i, doc in enumerate(documents) + ] + + model_name = kwargs.get("model", "cross-encoder/ms-marco-MiniLM-L-6-v2") + top_k = int(kwargs.get("top_k", min(5, len(documents)))) + batch_size = int(kwargs.get("batch_size", 32)) + + try: + # Load the cross-encoder model + model = CrossEncoder(model_name) + + # Create query-document pairs + pairs = [(query, doc) for doc in documents] + + # Get relevance scores + scores = model.predict(pairs, batch_size=batch_size) + + # Create results with scores and original indices + results = [ + { + "document": documents[i], + "score": float(scores[i]), + "index": i + } + for i in range(len(documents)) + ] + + # Sort by score (descending) and return top_k + results.sort(key=lambda x: x["score"], reverse=True) + return results[:top_k] + + except Exception as e: + print(f"Error in cross-encoder reranking: {e}", file=sys.stderr) + raise + +@RerankingMethodRegistry.register("cohere_rerank") +def cohere_rerank(documents: List[str], query: str = "", **kwargs: Any) -> List[Dict[str, Any]]: + """ + Rerank documents using Cohere's reranking API. + + Args: + documents: List of document texts to rerank + query: Query text for relevance scoring + **kwargs: Additional settings including: + - model: Cohere model name (e.g., 'rerank-v3.5') + - top_k: Number of top documents to return + - max_chunks_per_doc: Maximum chunks per document + - api_keys: Dictionary containing API keys (optional) + + Returns: + List of dictionaries with 'document', 'score', and 'index' keys + """ + try: + import cohere + except ImportError: + raise ImportError("cohere library is required for Cohere reranking") + + if not documents: + return [] + + if not query: + # If no query provided, return documents in original order with synthetic scores + return [ + { + "document": doc, + "score": 1.0 - (i / len(documents)), + "index": i + } + for i, doc in enumerate(documents) + ] + + model_name = kwargs.get("model", "rerank-v3.5") + top_k = int(kwargs.get("top_k", min(5, len(documents)))) + max_chunks_per_doc = int(kwargs.get("max_chunks_per_doc", 10)) + api_keys = kwargs.get("api_keys") + + # Get API key from api_keys parameter or environment + import os + api_key = api_keys and api_keys.get("Cohere") or os.getenv("COHERE_API_KEY") + if not api_key: + raise ValueError("Cohere API key not found in api_keys parameter or COHERE_API_KEY environment variable") + + try: + # Initialize Cohere client + co = cohere.ClientV2(api_key) + + # Limit documents if too many + docs_to_rerank = documents[:max_chunks_per_doc * top_k] if len(documents) > max_chunks_per_doc * top_k else documents + + # Call Cohere rerank API + response = co.rerank( + model=model_name, + query=query, + documents=docs_to_rerank, + top_n=top_k + ) + + # Format results + results = [] + for result in response.results: + original_index = result.index + results.append({ + "document": docs_to_rerank[original_index], + "score": float(result.relevance_score), + "index": original_index + }) + + return results + + except Exception as e: + print(f"Error in Cohere reranking: {e}", file=sys.stderr) + raise + +# === Retrieval Fusion Methods === + +def _best_obj_for_doc(method_lists, doc_id): + best_mid, best_rank = None, 10**9 + for mid, items in method_lists.items(): + for it in items: + if it["doc_id"] == doc_id and it["rank"] < best_rank: + best_rank, best_mid = it["rank"], mid + for it in method_lists[best_mid]: + if it["doc_id"] == doc_id: + return it["obj"] + return None + +def weighted_avg_fuse(method_lists, weights_by_method=None): + """Simple weighted average of raw sccores""" + weights_by_method = weights_by_method or {} + + # gather all doc ids present in any method list + all_doc_ids = set() + for items in method_lists.values(): + for it in items: + all_doc_ids.add(it["doc_id"]) + + # index raw scores by method -> doc_id -> score + raw_score = { + mid: {it["doc_id"]: float(it["score"]) for it in items} + for mid, items in method_lists.items() + } + + fused_scores = {} + for d in all_doc_ids: + s = 0.0 + for mid, scores in raw_score.items(): + w = float(weights_by_method.get(mid, 1.0)) + s += w * scores.get(d, 0.0) + fused_scores[d] = s + + fused = [] + for d, s in fused_scores.items(): + base_obj = _best_obj_for_doc(method_lists, d) + fused.append((d, s, base_obj)) + fused.sort(key=lambda x: (-x[1], x[0])) + return fused + +def rrf_fuse(method_lists, k=60, weights_by_method=None): + """RRF uses ranks with the 1/(k + rank) formula; weights apply per method.""" + weights_by_method = weights_by_method or {} + rank_maps = { + mid: {it["doc_id"]: int(it["rank"]) for it in items} + for mid, items in method_lists.items() + } + all_docs = set() + for items in method_lists.values(): + for it in items: + all_docs.add(it["doc_id"]) + + fused = [] + for d in all_docs: + score, contributors = 0.0, [] + for mid, rmap in rank_maps.items(): + r = rmap.get(d) + if r is not None: + w = float(weights_by_method.get(mid, 1.0)) + score += w * (1.0 / (k + r)) + contributors.append(mid) + best_mid = min(contributors, key=lambda m: rank_maps[m][d]) + best_obj = next(it["obj"] for it in method_lists[best_mid] if it["doc_id"] == d) + fused.append((d, score, best_obj)) + fused.sort(key=lambda x: (-x[1], x[0])) + return fused + diff --git a/chainforge/rag/retrievers.py b/chainforge/rag/retrievers.py new file mode 100644 index 000000000..ca072fa5e --- /dev/null +++ b/chainforge/rag/retrievers.py @@ -0,0 +1,426 @@ +import math, heapq +from typing import List, Any, Tuple, Dict +import numpy as np +from chainforge.rag.simple_preprocess import simple_preprocess +from chainforge.rag.vector_stores import LancedbVectorStore, FaissVectorStore + + +# Define a registry for retrieval methods +class RetrievalMethodRegistry: + _methods = {} + + @classmethod + def register(cls, method_name): + def decorator(handler_func): + cls._methods[method_name] = handler_func + return handler_func + return decorator + + @classmethod + def get_handler(cls, method_name): + return cls._methods.get(method_name) + +def normalize_query(raw_q: Any) -> Tuple[Dict[str, Any], str]: + """ + Turn any raw_q (dict or other) into: + 1) a normalized query-object dict + 2) the canonical text string to use + """ + if isinstance(raw_q, dict): + q_obj = raw_q + else: + q_obj = {"text": str(raw_q)} + + text = str( + q_obj.get("text") + or q_obj.get("query") + or q_obj.get("prompt", "") + ) + return q_obj, text + +@RetrievalMethodRegistry.register("embedding") +def handle_embedding(chunk_objs, chunk_embeddings, query_objs, query_embeddings, settings, db_path): + """ + Unified embedding-based retrieval handler that delegates to the appropriate + vector store backend (LanceDB or FAISS) based on storage_backend settings. + + The similarity metric is passed through to the vector store, which handles + the actual similarity computation. + """ + storage_backend = settings.get("storage_backend", "lancedb") + similarity_metric = settings.get("similarity_metric", "cosine") + + # Map similarity metric names to backend-specific metric names + # This will be used by the vector store handlers + metric_map = { + "cosine": "cosine", + "euclidean": "l2", + "dot_product": "dot", + } + settings["metric"] = metric_map.get(similarity_metric, "cosine") + + # Route to the appropriate vector store handler + if storage_backend == "lancedb": + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + if handler: + return handler(chunk_objs, chunk_embeddings, query_objs, query_embeddings, settings, db_path) + raise ValueError("LanceDB handler not found") + elif storage_backend == "faiss": + handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + if handler: + return handler(chunk_objs, chunk_embeddings, query_objs, query_embeddings, settings, db_path) + raise ValueError("FAISS handler not found") + else: + raise ValueError(f"Unsupported storage backend: {storage_backend}. Use 'lancedb' or 'faiss'.") + +@RetrievalMethodRegistry.register("bm25") +def handle_bm25(chunk_objs: List[Dict], query_objs: List[Any], settings: Dict[str, Any]) -> List[Dict]: + from rank_bm25 import BM25Okapi + """Retrieve top-k chunks for each query using BM25.""" + # Build BM25 index + docs = [str(c.get("text", "")) for c in chunk_objs] + tokenized_corpus = [simple_preprocess(doc) for doc in docs] + k1 = float(settings.get("bm25_k1", 1.5)) + b = float(settings.get("bm25_b", 0.75)) + bm25 = BM25Okapi(tokenized_corpus, k1=k1, b=b) + + top_k = int(settings.get("top_k", 5)) + results: List[Dict] = [] + + for raw_q in query_objs: + # Normalize and extract text + q_obj, query_text = normalize_query(raw_q) + # Tokenize for scoring + tokens = simple_preprocess(query_text) + + # Score & normalize + scores = bm25.get_scores(tokens) + if scores.size == 0: + results.append({"query_object": q_obj, "retrieved_chunks": []}) + continue + + max_score = float(scores.max()) or 1.0 + normalized = (scores / max_score).tolist() + + # Pick top-k & build hits + top_idxs = sorted( + range(len(normalized)), + key=lambda i: normalized[i], + reverse=True + )[:top_k] + + hits = [] + for idx in top_idxs: + c = chunk_objs[idx] + hits.append({ + "text": c.get("text", ""), + "similarity": normalized[idx], + "docTitle": c.get("docTitle", ""), + "chunkId": c.get("chunkId", ""), + }) + + results.append({ + "query_object": q_obj, + "retrieved_chunks": hits + }) + + return results + + +@RetrievalMethodRegistry.register("tfidf") +def handle_tfidf(chunk_objs: List[Dict], query_objs: List[Any], settings: Dict[str, Any]) -> List[Dict]: + from sklearn.feature_extraction.text import TfidfVectorizer + """Retrieve top-k chunks for each query using TF-IDF cosine similarity.""" + # Safely cast settings + top_k = int(settings.get("top_k", 5)) + max_features = int(settings.get("max_features", 500)) + + # Prepare the corpus texts + docs = [str(c.get("text", "")) for c in chunk_objs] + + # Fit the TF-IDF vectorizer + vectorizer = TfidfVectorizer(stop_words="english", max_features=max_features) + tfidf_matrix = vectorizer.fit_transform(docs) + + results: List[Dict] = [] + for raw_q in query_objs: + # Normalize and extract text + q_obj, query_text = normalize_query(raw_q) + + # Transform query into vector + query_vec = vectorizer.transform([query_text]) + + # Compute raw similarities + sims = (tfidf_matrix * query_vec.T).toarray().flatten() + max_sim = float(sims.max()) if sims.size and sims.max() > 0 else 1.0 + normalized = sims / max_sim + + # Pick top-k indices + top_idxs = sorted( + range(len(normalized)), + key=lambda i: normalized[i], + reverse=True + )[:top_k] + + # Build hits + hits = [] + for idx in top_idxs: + c = chunk_objs[idx] + hits.append({ + "text": c.get("text", ""), + "similarity": float(normalized[idx]), + "docTitle": c.get("docTitle", ""), + "chunkId": c.get("chunkId", ""), + }) + + results.append({ + "query_object": q_obj, + "retrieved_chunks": hits + }) + + return results + +@RetrievalMethodRegistry.register("boolean") +def handle_boolean(chunk_objs: List[Dict], query_objs: List[Any], settings: Dict[str, Any]) -> List[Dict]: + """Retrieve chunks by boolean overlap (minimum token matches).""" + # Cast settings + top_k = int(settings.get("top_k", 5)) + required_match_count = int(settings.get("required_match_count", 1)) + + # Pre-tokenize chunks + chunk_texts = [str(c.get("text", "")) for c in chunk_objs] + tokenized_chunks = [set(simple_preprocess(text)) for text in chunk_texts] + + results: List[Dict] = [] + for raw_q in query_objs: + # Normalize and extract text + q_obj, query_text = normalize_query(raw_q) + + # Tokenize the query + q_tokens = set(simple_preprocess(query_text)) + + # If not enough tokens, no hits + if len(q_tokens) < required_match_count: + results.append({"query_object": q_obj, "retrieved_chunks": []}) + continue + + scored: List[Tuple[int, float]] = [] + for idx, c_tokens in enumerate(tokenized_chunks): + matches = len(q_tokens & c_tokens) + if matches >= required_match_count: + score = matches / (len(c_tokens) or 1) + scored.append((idx, score)) + + # Sort & take top_k + scored.sort(key=lambda x: x[1], reverse=True) + + # Build retrieved_chunks + retrieved: List[Dict] = [] + if scored: + top_score = scored[0][1] or 1.0 + for idx, raw_score in scored[:top_k]: + c = chunk_objs[idx] + norm_score = raw_score / top_score + retrieved.append({ + "text": c.get("text", ""), + "similarity": float(norm_score), + "docTitle": c.get("docTitle", ""), + "chunkId": c.get("chunkId", ""), + }) + + results.append({ + "query_object": q_obj, + "retrieved_chunks": retrieved + }) + + return results + + +@RetrievalMethodRegistry.register("overlap") +def handle_keyword_overlap(chunk_objs: List[Dict], query_objs: List[Any], settings: Dict[str, Any]) -> List[Dict]: + """Retrieve chunks by keyword overlap (raw token count).""" + # Settings + top_k = int(settings.get("top_k", 5)) + + # Pre-tokenize chunks + docs = [str(c.get("text", "")) for c in chunk_objs] + tokenized_chunks = [set(simple_preprocess(doc)) for doc in docs] + + results: List[Dict] = [] + for raw_q in query_objs: + # Normalize and extract text + q_obj, query_text = normalize_query(raw_q) + + # Tokenize the query + q_tokens = set(simple_preprocess(query_text)) + + # Score by overlap count + scored: List[Tuple[int, int]] = [] + for idx, c_tokens in enumerate(tokenized_chunks): + overlap = len(q_tokens & c_tokens) + scored.append((idx, overlap)) + + # Sort descending + scored.sort(key=lambda x: x[1], reverse=True) + + # Build retrieved list + retrieved: List[Dict] = [] + if scored and scored[0][1] > 0: + max_overlap = scored[0][1] + for idx, raw_score in scored[:top_k]: + c = chunk_objs[idx] + norm_score = raw_score / max_overlap + retrieved.append({ + "text": c.get("text", ""), + "similarity": float(norm_score), + "docTitle": c.get("docTitle", ""), + "chunkId": c.get("chunkId", ""), + }) + + results.append({ + "query_object": q_obj, + "retrieved_chunks": retrieved + }) + + return results + +@RetrievalMethodRegistry.register("clustered") +def handle_clustered(chunk_objs, chunk_embeddings, query_objs, query_embeddings, settings, db_path): + """ + Retrieve chunks using a combination of query similarity and cluster similarity. + """ + from sklearn.metrics.pairwise import cosine_similarity as sklearn_cosine + from sklearn.cluster import KMeans + + top_k = settings.get("top_k", 5) + n_clusters = settings.get("n_clusters", 3) + query_coeff = settings.get("query_coeff", 0.6) + center_coeff = settings.get("center_coeff", 0.4) + results = [] + + # Convert embeddings to numpy array for clustering + X = np.array(chunk_embeddings) + + # Only perform clustering if we have enough samples + if len(X) >= 2: + n_clusters = min(n_clusters, len(X)) + kmeans = KMeans(n_clusters=n_clusters, random_state=42) + labels = kmeans.fit_predict(X) + cluster_centers = kmeans.cluster_centers_ + + for query_obj, query_emb in zip(query_objs, query_embeddings): + min_heap = [] + query_emb_np = np.array(query_emb).reshape(1, -1) + + for i, (chunk, chunk_emb) in enumerate(zip(chunk_objs, chunk_embeddings)): + # Calculate similarity to query + chunk_emb_np = np.array(chunk_emb).reshape(1, -1) + query_sim = float(sklearn_cosine(chunk_emb_np, query_emb_np)[0][0]) + + # Calculate similarity to cluster center + center_sim = float(sklearn_cosine( + chunk_emb_np, + cluster_centers[labels[i]].reshape(1, -1) + )[0][0]) + + # Combined similarity score (weighted) + combined_sim = query_coeff * query_sim + center_coeff * center_sim + + if len(min_heap) < top_k: + heapq.heappush(min_heap, (combined_sim, i)) + elif combined_sim > min_heap[0][0]: + heapq.heappushpop(min_heap, (combined_sim, i)) + + # Convert heap to sorted results + retrieved = [] + for sim, i in sorted(min_heap, reverse=True): + chunk = chunk_objs[i] + retrieved.append({ + "text": chunk.get("text", ""), + "similarity": float(sim), + "docTitle": chunk.get("docTitle", ""), + "chunkId": chunk.get("chunkId", ""), + }) + + results.append({'query_object': query_obj, 'retrieved_chunks': retrieved}) + return results + + +@RetrievalMethodRegistry.register("lancedb_vector_store") +def handle_lancedb_vector_store(chunk_objs, chunk_embeddings, query_objs, query_embeddings, settings, db_path): + """ + Retrieve chunks using a local vector store with LanceDB. + """ + + top_k = settings.get("top_k", 5) + user_requested_metric = settings.get("metric", "l2").lower() + if user_requested_metric not in ["l2", "cosine", "dot"]: + print(f"Warning: Invalid FAISS metric '{user_requested_metric}' specified. Defaulting to 'l2'.") + user_requested_metric = "l2" + + # Basic Input Validation + if not chunk_objs or not chunk_embeddings: + raise Exception("Error: chunk_objs or chunk_embeddings are empty.") + if not query_objs or not query_embeddings: + raise Exception("Error: query_objs or query_embeddings are empty.") + + # Create a local vector store (loading an existing one from disk if it exists) + vector_store = LancedbVectorStore( + db_path=db_path, + embedding_func=None, + ) + + # Add chunks to the vector store + # NOTE: This will automatically skip chunks that are already in the store, + # since the store used the hash of the chunk text as the ID. + vector_store.add( + texts=[chunk.get("text", "") for chunk in chunk_objs], + embeddings=chunk_embeddings, + metadata=[{ + "fill_history": chunk.get("fill_history", {}), + "metadata": chunk.get("metadata", {}), + } for chunk in chunk_objs], + ) + + # Perform a similarity search for each query + results = [] + for query_obj, query_emb in zip(query_objs, query_embeddings): + # Perform the search + res = vector_store.search( + query=query_emb if query_emb is not None else query_obj.get("text", ""), + k=top_k, + metric=user_requested_metric, + ) + results.append({'query_object': query_obj, 'retrieved_chunks': res}) + + return results + + +@RetrievalMethodRegistry.register("faiss_vector_store") +def handle_faiss_vector_store(chunk_objs, chunk_embeddings, query_objs, query_embeddings, settings, db_path): + top_k = settings.get("top_k", 5) + metric = settings.get("metric", "l2") + vector_store = FaissVectorStore( + db_path=db_path, + embedding_func=None, + index_name="index", + metric=metric + ) + # Ajout des documents (si besoin) + vector_store.add( + texts=[chunk.get("text", "") for chunk in chunk_objs], + embeddings=chunk_embeddings, + metadata=[{ + "fill_history": chunk.get("fill_history", {}), + "metadata": chunk.get("metadata", {}), + } for chunk in chunk_objs], + ) + # Recherche pour chaque requête + results = [] + for query_obj, query_emb in zip(query_objs, query_embeddings): + res = vector_store.search( + query=query_emb if query_emb is not None else query_obj.get("text", ""), + k=top_k, + ) + results.append({'query_object': query_obj, 'retrieved_chunks': res}) + return results diff --git a/chainforge/rag/simple_preprocess.py b/chainforge/rag/simple_preprocess.py new file mode 100644 index 000000000..babaca955 --- /dev/null +++ b/chainforge/rag/simple_preprocess.py @@ -0,0 +1,95 @@ +""" +This file provides the `simple_preprocess` function from the Gensim library. +The `simple_preprocess` function is a utility for converting a document into a list of tokens. +It lowercases, tokenizes, and de-accents the text (if specified). +It also allows for filtering tokens based on their length. + +The issue is that gensim is no longer actively maintained and +requires a specific version of numpy to work properly (<2.0). +and this version is not compatible with the latest versions of other libraries. +This is a workaround to use the simple_preprocess +function without installing gensim as a dependency. + +The code in this file is copied from the Gensim library. +Hence, the original code, and this specific file, is licensed under the +`GNU LGPLv2.1 license `. +The original code can be found at: https://github.com/piskvorky/gensim + +Copyright (C) 2010 Radim Rehurek +Licensed under the GNU LGPL v2.1 - https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html + +This file is free software; you can redistribute it and/or +modify it under the terms of the GNU Lesser General Public +License as published by the Free Software Foundation; either +version 2.1 of the License, or (at your option) any later version. + +This file is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +Lesser General Public License for more details. +""" +import re +import unicodedata +from six import u + +def to_unicode(text, encoding='utf8', errors='strict'): + """Convert a string (bytestring in `encoding` or unicode), to unicode.""" + if isinstance(text, str): + return text + return str(text, encoding, errors=errors) + +def deaccent(text): + """ + Remove accentuation from the given string. Input text is either a unicode string or utf8 encoded bytestring. + + Return input string with accents removed, as unicode. + + >>> deaccent("Šéf chomutovských komunistů dostal poštou bílý prášek") + u'Sef chomutovskych komunistu dostal postou bily prasek' + + """ + if not isinstance(text, str): + # assume utf8 for byte strings, use default (strict) error handling + text = text.decode('utf8') + norm = unicodedata.normalize("NFD", text) + result = u('').join(ch for ch in norm if unicodedata.category(ch) != 'Mn') + return unicodedata.normalize("NFC", result) + +def tokenize(text, lowercase=False, deacc=False, errors="strict", to_lower=False, lower=False): + """ + Iteratively yield tokens as unicode strings, removing accent marks + and optionally lowercasing the unidoce string by assigning True + to one of the parameters, lowercase, to_lower, or lower. + + Input text may be either unicode or utf8-encoded byte string. + + The tokens on output are maximal contiguous sequences of alphabetic + characters (no digits!). + + >>> list(tokenize('Nic nemůže letět rychlostí vyšší, než 300 tisíc kilometrů za sekundu!', deacc = True)) + [u'Nic', u'nemuze', u'letet', u'rychlosti', u'vyssi', u'nez', u'tisic', u'kilometru', u'za', u'sekundu'] + + """ + PAT_ALPHABETIC = re.compile('(((?![\d])\w)+)', re.UNICODE) + lowercase = lowercase or to_lower or lower + text = to_unicode(text, errors=errors) + if lowercase: + text = text.lower() + if deacc: + text = deaccent(text) + for match in PAT_ALPHABETIC.finditer(text): + yield match.group() + +def simple_preprocess(doc, deacc=False, min_len=2, max_len=15): + """ + Convert a document into a list of tokens. + + This lowercases, tokenizes, de-accents (optional). -- the output are final + tokens = unicode strings, that won't be processed any further. + + """ + tokens = [ + token for token in tokenize(doc, lower=True, deacc=deacc, errors='ignore') + if min_len <= len(token) <= max_len and not token.startswith('_') + ] + return tokens \ No newline at end of file diff --git a/chainforge/rag/vector_stores.py b/chainforge/rag/vector_stores.py new file mode 100644 index 000000000..0f9a53429 --- /dev/null +++ b/chainforge/rag/vector_stores.py @@ -0,0 +1,902 @@ +import shutil +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Union +import os +import numpy as np +import pandas as pd +import lancedb +import hashlib, pickle + +# Faiss requires 'swig' to be installed, and the 'faiss-cpu' package (or 'faiss-gpu', built from source). +# Swig is only installable via homebrew on macOS, which makes this dependency difficult to support +# by default. To be safe we soft-fail if it's not installed. +try: + import faiss +except ImportError: + faiss = None + + +class VectorStore(ABC): + """ + Abstract base class for vector stores that store and retrieve embeddings. + + This class defines the common interface that all vector store implementations + should follow, allowing for easy swapping between different backends while + maintaining the same API. + """ + def __init__(self, embedding_func: Optional[callable], db_path, db_mode): + """ + Initialize the vector store. + + Args: + embedding_func: Optional function to generate embeddings + db_path: path for DB + db_mode: if want to create or load db + """ + self.embedding_func = embedding_func + + if db_mode == "create" and db_path is not None: + # Supprimer le contenu du dossier db_path + if os.path.exists(db_path): + # Supprime tout le contenu du dossier, mais garde le dossier lui-même + for filename in os.listdir(db_path): + file_path = os.path.join(db_path, filename) + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + else: + # Crée le dossier s'il n'existe pas + os.makedirs(db_path, exist_ok=True) + + @abstractmethod + def add(self, texts: List[str], embeddings: Optional[List[List[float]]] = None, + metadata: Optional[List[Dict[str, Any]]] = None) -> List[str]: + """ + Add documents and their embeddings to the store. + + Args: + texts: List of text documents + embeddings: List of embedding vectors for each document + metadata: Optional list of metadata dictionaries for each document + + Returns: + List of document IDs for the added documents + """ + pass + + @abstractmethod + def search(self, query: Union[str, List[float]], k: int = 5, + **kwargs) -> List[Dict[str, Any]]: + """ + Search for similar documents based on a query or query embedding. + + Args: + query: The query as a string (if embedding_func was passed on init), or the embedding vector + k: Number of results to return + **kwargs: Additional search parameters (method, filters, etc.) + + Returns: + List of document dictionaries with text, score, and metadata + """ + pass + + @abstractmethod + def get(self, doc_id: str) -> Optional[Dict[str, Any]]: + """ + Get a document by ID. + + Args: + doc_id: Document ID + + Returns: + Document dictionary or None if not found + """ + pass + + @abstractmethod + def delete(self, doc_ids: List[str]) -> bool: + """ + Delete documents by ID. + + Args: + doc_ids: List of document IDs to delete + + Returns: + Boolean indicating success + """ + pass + + @abstractmethod + def update(self, doc_id: str, text: Optional[str] = None, + embedding: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None) -> bool: + """ + Update a document by ID. + + Args: + doc_id: Document ID + text: New text (if None, text is not updated) + embedding: New embedding (if None, embedding is not updated) + metadata: New metadata (if None, metadata is not updated) + + Returns: + Boolean indicating success + """ + pass + + @abstractmethod + def get_all(self, limit: Optional[int] = None, + offset: int = 0) -> List[Dict[str, Any]]: + """ + Get all documents in the store. + + Args: + limit: Maximum number of documents to return (None for all) + offset: Number of documents to skip + + Returns: + List of document dictionaries + """ + pass + + @abstractmethod + def count(self) -> int: + """ + Get the number of documents in the store. + + Returns: + Number of documents + """ + pass + + @abstractmethod + def clear(self) -> bool: + """ + Clear all documents from the store. + + Returns: + Boolean indicating success + """ + pass + + +class LancedbVectorStore(VectorStore): + """ + Vector store implementation using LanceDB for local vector storage. + + LanceDB is an open-source, local-first vector database that's optimized for + efficient vector similarity search and is particularly suitable for + embeddings storage on a local machine. + """ + + def __init__(self, + db_path: str = "lancedb", + embedding_func: Optional[callable] = None, + table_name: str = "embeddings", + db_mode: str = "create"): + """ + Initialize LanceDB vector store. + + Args: + uri: Path where LanceDB will store the database on disk + table_name: Name of the table to use for vector storage + """ + super().__init__(embedding_func, db_path, db_mode) + # Create directory if it doesn't exist + + # Connect to the database + self.db = lancedb.connect(db_path) + self.table_name = table_name + self.table = None + + # Check if table exists + table_names = list(self.db.table_names()) + if table_name in table_names: + self.table = self.db.open_table(table_name) + + + def _generate_id(self, text: str) -> str: + """Generate SHA256 hash of text to use as document ID""" + return hashlib.sha256(text.encode('utf-8')).hexdigest() + + def add(self, texts: List[str], embeddings: Optional[List[List[float]]] = None, + metadata: Optional[List[Dict[str, Any]]] = None) -> List[str]: + """ + Add documents and their embeddings to the store. + + Args: + texts: List of text documents + embeddings: List of embedding vectors for each document + metadata: Optional list of metadata dictionaries for each document + + Returns: + List of document IDs for the added documents + """ + # Validate inputs + if not texts: + raise ValueError("No texts provided to add") + if metadata is None: + metadata = [{} for _ in range(len(texts))] + elif len(metadata) != len(texts): + raise ValueError("Number of metadata items must match number of texts") + + # Generate IDs for new documents using SHA256 hash + doc_ids = [self._generate_id(text) for text in texts] + + # Check if the hashed id already exists and if so, remove them from the list + # NOTE: We don't use upsert because we want to avoid running the embedding function + # on documents that already exist in the database, which could be expensive. + orig_doc_ids = doc_ids.copy() + if self.table is not None: + doc_ids_str = ",".join([f"'{doc_id}'" for doc_id in doc_ids]) + existing_ids = self.table.search().where(f"id IN ({doc_ids_str})").to_pandas() + existing_ids = set(existing_ids["id"].tolist()) + if existing_ids: + print(f"Found {len(existing_ids)} existing IDs in the database. Removing them from the list.") + # Get the indices of the existing IDs + existing_indices = [i for i, doc_id in enumerate(doc_ids) if doc_id in existing_ids] + # Remove the existing IDs from the lists, before proceeding + doc_ids = [doc_id for i, doc_id in enumerate(doc_ids) if i not in existing_indices] + texts = [text for i, text in enumerate(texts) if i not in existing_indices] + metadata = [meta for i, meta in enumerate(metadata) if i not in existing_indices] + if embeddings is not None: + embeddings = [embedding for i, embedding in enumerate(embeddings) if i not in existing_indices] + + # If no new documents to add, return existing IDs + if not texts: + print("No documents are new. Returning existing IDs.") + return orig_doc_ids + + # Sanity check that lengths match + if len(texts) != len(doc_ids) or len(doc_ids) != len(metadata) or (embeddings is not None and len(embeddings) != len(texts)): + raise ValueError("Mismatched lengths of texts, IDs, metadata, and/or embeddings") + + if embeddings is None: + if self.embedding_func is None: + raise ValueError("No embedding function provided and no embeddings given") + + # Generate embeddings using the embedding function + embeddings = self.embedding_func(texts) + if not isinstance(embeddings, list) or not all(isinstance(e, list) for e in embeddings): + raise ValueError("Embeddings must be a list of lists") + + if len(texts) != len(embeddings): + raise ValueError("Number of texts and embeddings must match") + + # Create the table if it doesn't exist + if self.table is None: + if not embeddings: + raise ValueError("Cannot create table with empty embeddings list") + + vector_dimension = len(embeddings[0]) + + # Create schema for the table using PyArrow + import pyarrow as pa + + # Get schema for metadata based on first item if available + # metadata_fields = [] + # if metadata and metadata[0]: + # for key, value in metadata[0].items(): + # if isinstance(value, str): + # metadata_fields.append(pa.field(key, pa.string())) + # elif isinstance(value, int): + # metadata_fields.append(pa.field(key, pa.int64())) + # elif isinstance(value, float): + # metadata_fields.append(pa.field(key, pa.float64())) + # elif isinstance(value, bool): + # metadata_fields.append(pa.field(key, pa.bool_())) + # else: + # # Convert other types to string + # metadata_fields.append(pa.field(key, pa.string())) + + # Import pickle for serialization + + # Create a schema with metadata as a binary field + schema = pa.schema([ + pa.field("id", pa.string()), + pa.field("text", pa.string()), + pa.field("vector", pa.list_(pa.float32(), vector_dimension)), + pa.field("metadata", pa.binary()) # Store as binary data + ]) + self.table = self.db.create_table(self.table_name, schema=schema) + + # Create data to add + data = [] + for i, (doc_id, text, embedding, meta) in enumerate(zip(doc_ids, texts, embeddings, metadata)): + doc = { + "id": doc_id, + "text": text, + "vector": embedding, + "metadata": pickle.dumps(meta) # Serialize metadata to binary + } + data.append(doc) + + # Add data to the table + self.table.add(data) + + return orig_doc_ids # Return original IDs, including those that were not added + + def search(self, query: Union[str, List[float]], k: int = 5, + **kwargs) -> List[Dict[str, Any]]: + """ + Search for similar documents based on a query embedding. + + Args: + query_embedding: The query embedding vector + k: Number of results to return + **kwargs: Additional search parameters: + - distance_metric: Distance metric for search ('cosine', 'euclidean', 'dot_product') + - method: Search method ('similarity', 'mmr', 'hybrid') + - lambda_param: Balance between relevance and diversity for MMR (0-1) + - keyword: Keyword for hybrid search + - blend: Balance between vector and keyword scores (0-1) + - filters: Query filters in LanceDB syntax + + Returns: + List of document dictionaries with text, score, and metadata + """ + if self.table is None: + return [] + + # Map schema metric names to LanceDB metric names + distance_metric = kwargs.get("distance_metric", "l2") + metric_mapping = { + "euclidean": "l2", + "dot_product": "dot", + "cosine": "cosine", + # Also support legacy names + "l2": "l2", + "dot": "dot" + } + distance_metric = metric_mapping.get(distance_metric, "l2") + + method = kwargs.get("method", "similarity") + filters = kwargs.get("filters", None) + + # Check if query is a string or embedding, and handle accordingly + if isinstance(query, str): + # If query is a string, generate embedding using the embedding function + if self.embedding_func is None: + raise ValueError("Embedding function not provided for string query") + query_embedding = self.embedding_func([query])[0] + elif isinstance(query, list): + # If query is a list, assume it's already an embedding + query_embedding = query + + # Search the table using the query embedding and distance metric + q = self.table.search(query_embedding).metric(distance_metric) + + if filters: + q = q.where(filters) + + if method == "similarity": + # Standard cosine similarity search + results = q.limit(k).to_pandas() + + elif method == "mmr": + # Maximum Marginal Relevance search + lambda_param = kwargs.get("lambda_param", 0.5) + results = q.limit(k * 3).to_pandas() # Get more results for diversity filtering + + # Apply MMR algorithm to rerank + vectors = np.array([r["vector"] for _, r in results.iterrows()]) + query_vec = np.array(query_embedding) + + # Normalize vectors + query_vec = query_vec / np.linalg.norm(query_vec) + vectors = vectors / np.linalg.norm(vectors, axis=1)[:, np.newaxis] + + # Calculate similarities + sims = np.dot(vectors, query_vec) + + # MMR reranking + selected = [] + remaining = list(range(len(vectors))) + + while len(selected) < k and remaining: + best_score = -1 + best_idx = -1 + + for i in remaining: + relevance = sims[i] + + # Calculate diversity component + if selected: + sel_vectors = vectors[selected] + diversity_sim = np.max(np.dot(vectors[i], sel_vectors.T)) + mmr_score = lambda_param * relevance - (1 - lambda_param) * diversity_sim + else: + mmr_score = relevance + + if mmr_score > best_score: + best_score = mmr_score + best_idx = i + + if best_idx != -1: + selected.append(best_idx) + remaining.remove(best_idx) + + results = results.iloc[selected] + + elif method == "hybrid": + # Hybrid search combining vector similarity with keyword matching + keyword = kwargs.get("keyword", "") + blend = kwargs.get("blend", 0.5) + + if not keyword: + return self.search(query_embedding, k, method="similarity") + + # Get vector search results + vector_results = q.limit(k * 2).to_pandas() + + # Get keyword search results + keyword_query = self.table.search().where(f"text LIKE '%{keyword}%'").limit(k * 2) + keyword_results = keyword_query.to_pandas() + + # Combine results with blended scoring + all_results = {} + + # Add vector results with blended score + for _, row in vector_results.iterrows(): + doc_id = row["id"] + vector_score = row["_distance"] # LanceDB distance score + all_results[doc_id] = {"row": row, "similarity": blend * vector_score} + + # Add or update with keyword results + for _, row in keyword_results.iterrows(): + doc_id = row["id"] + keyword_score = 1.0 # Binary match score for simplicity + + if doc_id in all_results: + all_results[doc_id]["similarity"] += (1 - blend) * keyword_score + else: + all_results[doc_id] = {"row": row, "similarity": (1 - blend) * keyword_score} + + # Sort by blended score and take top k + sorted_results = sorted(all_results.values(), key=lambda x: x["similarity"], reverse=True)[:k] + results = pd.DataFrame([r["row"] for r in sorted_results]) + + else: + raise ValueError(f"Unknown search method: {method}") + + # Format results + formatted_results = [] + for _, row in results.iterrows(): + # Convert distance to similarity score based on the metric used + distance = row["_distance"] + + if distance_metric == "cosine": + # Cosine distance is in [0, 2], where 0 = identical, 2 = opposite + # Convert to similarity in [0, 1] + similarity = 1.0 - (distance / 2.0) + elif distance_metric == "l2": + # L2 (Euclidean) distance is in [0, ∞), convert to similarity in [0, 1] + # Use inverse distance: similarity = 1 / (1 + distance) + similarity = 1.0 / (1.0 + distance) + elif distance_metric == "dot": + # Dot product in LanceDB is negative (for minimization) + # Negate it to get the actual dot product similarity + similarity = -distance + else: + # Fallback to simple conversion + similarity = 1.0 / (1.0 + distance) + + formatted_results.append({ + "id": row["id"], + "text": row["text"], + "similarity": float(similarity), + "metadata": pickle.loads(row["metadata"]) # Deserialize metadata + }) + + return formatted_results + + def get(self, doc_id: str) -> Optional[Dict[str, Any]]: + """ + Get a document by ID. + + Args: + doc_id: Document ID + + Returns: + Document dictionary or None if not found + """ + if self.table is None: + return None + + results = self.table.search().where(f"id = '{doc_id}'").to_pandas() + + if len(results) == 0: + print(f"Document with ID {doc_id} not found") + return None + + row = results.iloc[0] + return { + "id": row["id"], + "text": row["text"], + "embedding": row["vector"], + "metadata": pickle.loads(row["metadata"]) # Deserialize metadata + } + + def delete(self, doc_ids: List[str]) -> bool: + """ + Delete documents by ID. + + Args: + doc_ids: List of document IDs to delete + + Returns: + Boolean indicating success + """ + if self.table is None or not doc_ids: + return True + + # Build OR condition for multiple IDs + conditions = " OR ".join([f"id = '{doc_id}'" for doc_id in doc_ids]) + + try: + self.table.delete(conditions) + return True + except Exception as e: + print(f"Error deleting documents: {e}") + return False + + def update(self, doc_id: str, text: Optional[str] = None, + embedding: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None) -> Union[str, None]: + """ + Update a document by ID. + + Args: + doc_id: Document ID + text: New text (if None, text is not updated) + embedding: New embedding (if None, embedding is not updated) + metadata: New metadata (if None, metadata is not updated) + + Returns: + The new document ID if updated successfully, None otherwise. + New ID is generated if text is updated. + """ + if self.table is None: + return False + + # Get the current document + current_doc = self.get(doc_id) + if current_doc is None: + return None + + # Delete the old document if it exists + # NOTE: This is a cheap method of updating the document, but it may not be the most efficient. + if not self.delete([doc_id]): + return None + + # Check if text has changed + text_has_changed = text is not None and text != current_doc["text"] + new_text = text if text is not None else current_doc["text"] + new_embedding = embedding if embedding is not None else (None if text_has_changed else current_doc["embedding"]) + new_metadata = metadata if metadata is not None else current_doc["metadata"] + + # Add the updated document + try: + # Add the new document with updated text and/or embedding + new_ids = self.add([new_text], + embeddings=[new_embedding] if new_embedding is not None else None, + metadata=[new_metadata] if new_metadata is not None else None) + return new_ids[0] + except Exception as e: + print(f"Error updating document: {e}") + return None + + def get_all(self, limit: Optional[int] = None, + offset: int = 0) -> List[Dict[str, Any]]: + """ + Get all documents in the store. + + Args: + limit: Maximum number of documents to return (None for all) + offset: Number of documents to skip + + Returns: + List of document dictionaries + """ + if self.table is None: + return [] + + query = self.table.search() + + if offset > 0: + query = query.offset(offset) + + if limit is not None: + query = query.limit(limit) + + results = query.to_pandas() + + formatted_results = [] + for _, row in results.iterrows(): + formatted_results.append({ + "id": row["id"], + "text": row["text"], + "embedding": row["vector"], + "metadata": pickle.loads(row["metadata"]) # Deserialize metadata + }) + + return formatted_results + + def count(self) -> int: + """ + Get the number of documents in the store. + + Returns: + Number of documents + """ + if self.table is None: + return 0 + + # Convert to pandas and get the count + return len(self.table.to_pandas()) + + def clear(self) -> bool: + """ + Clear all documents from the store. + + Returns: + Boolean indicating success + """ + if self.table is None: + return True + + try: + # Delete the table and set to None - will be recreated on next add() + self.db.drop_table(self.table_name) + self.table = None + return True + except Exception as e: + print(f"Error clearing vector store: {e}") + return False + + +class FaissVectorStore(VectorStore): + """ + Vector store implementation using FAISS for local vector storage. + Stores embeddings in a FAISS index and keeps texts/metadata in a sidecar file. + """ + + def __init__(self, + db_path: str = "faissdb", + embedding_func: Optional[callable] = None, + index_name: str = "index", + metric: str = "l2", + db_mode: str = "create"): + """ + Args: + db_path: Directory where FAISS index and metadata are stored + embedding_func: Optional function to generate embeddings + index_name: Name of the FAISS index file (without extension) + metric: 'l2', 'ip', 'euclidean', 'dot_product', or 'cosine' + """ + super().__init__(embedding_func, db_path, db_mode) + + if faiss is None: + raise ImportError("Faiss is not installed. Please install 'faiss-cpu' or 'faiss-gpu' if you would like to use Faiss vector store methods. You may need to install 'swig' as well; on MacOS this can be done with 'brew install swig'.") + + self.db_path = db_path + self.index_name = index_name + + # Map schema metric names to FAISS metric names + metric_mapping = { + "euclidean": "l2", + "dot_product": "ip", + "cosine": "ip", # Cosine is IP with normalized vectors + # Also support legacy names + "l2": "l2", + "ip": "ip" + } + self.metric = metric_mapping.get(metric.lower(), "l2") + + self.index_file = os.path.join(db_path, f"{index_name}.faiss") + self.meta_file = os.path.join(db_path, f"{index_name}_meta.pkl") + + self.index = None + self.id_to_meta = {} # id -> dict with text, metadata, vector index + self.ids = [] # list of ids in FAISS order + + self._load() + + def _generate_id(self, text: str) -> str: + return hashlib.sha256(text.encode('utf-8')).hexdigest() + + def _save(self): + if self.index is not None: + faiss.write_index(self.index, self.index_file) + with open(self.meta_file, "wb") as f: + pickle.dump({"id_to_meta": self.id_to_meta, "ids": self.ids}, f) + + def _load(self): + if os.path.exists(self.index_file) and os.path.exists(self.meta_file): + self.index = faiss.read_index(self.index_file) + with open(self.meta_file, "rb") as f: + data = pickle.load(f) + self.id_to_meta = data.get("id_to_meta", {}) + self.ids = data.get("ids", []) + else: + self.index = None + self.id_to_meta = {} + self.ids = [] + + def add(self, texts: List[str], embeddings: Optional[List[List[float]]] = None, + metadata: Optional[List[Dict[str, Any]]] = None) -> List[str]: + if not texts: + raise ValueError("No texts provided to add") + if metadata is None: + metadata = [{} for _ in range(len(texts))] + elif len(metadata) != len(texts): + raise ValueError("Number of metadata items must match number of texts") + + doc_ids = [self._generate_id(text) for text in texts] + + # Remove already existing IDs + new_texts, new_embeddings, new_metadata, new_doc_ids = [], [], [], [] + for i, doc_id in enumerate(doc_ids): + if doc_id not in self.id_to_meta: + new_texts.append(texts[i]) + new_metadata.append(metadata[i]) + new_doc_ids.append(doc_id) + if embeddings is not None: + new_embeddings.append(embeddings[i]) + + if not new_texts: + return doc_ids + + # Compute embeddings if not provided + if embeddings is None: + if self.embedding_func is None: + raise ValueError("No embedding function provided and no embeddings given") + new_embeddings = self.embedding_func(new_texts) + if len(new_embeddings) != len(new_texts): + raise ValueError("Number of embeddings and texts must match") + + # Prepare FAISS index + dim = len(new_embeddings[0]) + if self.index is None: + if self.metric == "ip": + self.index = faiss.IndexFlatIP(dim) + else: + self.index = faiss.IndexFlatL2(dim) + + # Add to FAISS + new_embeddings_np = np.array(new_embeddings).astype("float32") + if self.metric == "ip": + faiss.normalize_L2(new_embeddings_np) + self.index.add(new_embeddings_np) + + # Update meta + start_idx = len(self.ids) + for i, doc_id in enumerate(new_doc_ids): + self.ids.append(doc_id) + self.id_to_meta[doc_id] = { + "text": new_texts[i], + "metadata": pickle.dumps(new_metadata[i]), + "vector_index": start_idx + i + } + + self._save() + return doc_ids + + def search(self, query: Union[str, List[float]], k: int = 5, **kwargs) -> List[Dict[str, Any]]: + if self.index is None or not self.ids: + return [] + + similarity_threshold = kwargs.get('similarity_threshold') + # Conversion du seuil en proportion si fourni + if similarity_threshold is not None: + similarity_threshold = float(similarity_threshold) / 100.0 + + # Préparation de l'embedding de la requête + if isinstance(query, str): + if self.embedding_func is None: + raise ValueError("Embedding function not provided for string query") + query_emb = self.embedding_func([query])[0] + else: + query_emb = query + + query_emb = np.array(query_emb, dtype="float32").reshape(1, -1) + if self.metric == "ip": + faiss.normalize_L2(query_emb) + + D, I = self.index.search(query_emb, min(k, len(self.ids))) + results = [] + for idx, dist in zip(I[0], D[0]): + if idx < 0 or idx >= len(self.ids): + continue + doc_id = self.ids[idx] + meta = self.id_to_meta[doc_id] + similarity = 1.0 / (1.0 + dist) if self.metric == "l2" else float(dist) + # Filtrage par similarité si le seuil est défini + if similarity_threshold is not None and similarity < similarity_threshold: + continue + results.append({ + "id": doc_id, + "text": meta["text"], + "similarity": similarity, + "metadata": pickle.loads(meta["metadata"]) + }) + results.sort(key=lambda x: x["similarity"], reverse=True) + return results + + def get(self, doc_id: str) -> Optional[Dict[str, Any]]: + meta = self.id_to_meta.get(doc_id) + if meta is None: + return None + return { + "id": doc_id, + "text": meta["text"], + "embedding": None, # Embedding not stored directly + "metadata": pickle.loads(meta["metadata"]) + } + + def delete(self, doc_ids: List[str]) -> bool: + if not doc_ids or self.index is None: + return True + # Remove from meta and ids + indices_to_remove = [self.ids.index(doc_id) for doc_id in doc_ids if doc_id in self.ids] + if not indices_to_remove: + return True + # Remove from FAISS by rebuilding index (FAISS does not support delete) + keep_indices = [i for i in range(len(self.ids)) if i not in indices_to_remove] + if not keep_indices: + self.index = None + self.ids = [] + self.id_to_meta = {} + self._save() + return True + embeddings = self.index.reconstruct_n(0, len(self.ids)) + new_embeddings = np.array([embeddings[i] for i in keep_indices]).astype("float32") + if self.metric == "ip": + faiss.normalize_L2(new_embeddings) + self.index = faiss.IndexFlatIP(new_embeddings.shape[1]) + else: + self.index = faiss.IndexFlatL2(new_embeddings.shape[1]) + self.index.add(new_embeddings) + # Update ids and meta + new_ids = [self.ids[i] for i in keep_indices] + new_id_to_meta = {doc_id: self.id_to_meta[doc_id] for doc_id in new_ids} + # Update vector_index in meta + for i, doc_id in enumerate(new_ids): + new_id_to_meta[doc_id]["vector_index"] = i + self.ids = new_ids + self.id_to_meta = new_id_to_meta + self._save() + return True + + def update(self, doc_id: str, text: Optional[str] = None, + embedding: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None) -> bool: + # Remove and re-add + current = self.get(doc_id) + if current is None: + return False + self.delete([doc_id]) + new_text = text if text is not None else current["text"] + new_embedding = embedding if embedding is not None else None + new_metadata = metadata if metadata is not None else current["metadata"] + self.add([new_text], embeddings=[new_embedding] if new_embedding is not None else None, + metadata=[new_metadata]) + return True + + def get_all(self, limit: Optional[int] = None, offset: int = 0) -> List[Dict[str, Any]]: + all_ids = self.ids[offset:offset + limit if limit is not None else None] + return [self.get(doc_id) for doc_id in all_ids] + + def count(self) -> int: + return len(self.ids) + + def clear(self) -> bool: + self.index = None + self.ids = [] + self.id_to_meta = {} + if os.path.exists(self.index_file): + os.remove(self.index_file) + if os.path.exists(self.meta_file): + os.remove(self.meta_file) + return True \ No newline at end of file diff --git a/chainforge/react-server/src/AiPopover.tsx b/chainforge/react-server/src/AiPopover.tsx index 54a569def..cf29b251a 100644 --- a/chainforge/react-server/src/AiPopover.tsx +++ b/chainforge/react-server/src/AiPopover.tsx @@ -146,6 +146,7 @@ export function AIPopover({ }: { children: React.ReactNode; }) { + const [opened, setOpened] = useState(false); // API keys const apiKeys = useStore((state) => state.apiKeys); const aiFeaturesProvider = useStore((state) => state.aiFeaturesProvider); @@ -219,9 +220,15 @@ export function AIPopover({ shadow={popoverShadow} withinPortal keepMounted + opened={opened} + onChange={setOpened} + clickOutsideEvents={["click"]} > - diff --git a/chainforge/react-server/src/App.tsx b/chainforge/react-server/src/App.tsx index e1b5a4f54..501743a2a 100644 --- a/chainforge/react-server/src/App.tsx +++ b/chainforge/react-server/src/App.tsx @@ -43,6 +43,7 @@ import { IconHeart, IconCheckbox, IconTransform, + IconSortAscending, } from "@tabler/icons-react"; import RemoveEdge from "./RemoveEdge"; import TextFieldsNode from "./TextFieldsNode"; // Import a custom node @@ -50,6 +51,7 @@ import PromptNode from "./PromptNode"; import CodeEvaluatorNode from "./CodeEvaluatorNode"; import VisNode from "./VisNode"; import InspectNode from "./InspectorNode"; +import SelectVarsNode from "./SelectVarsNode"; import ScriptNode from "./ScriptNode"; import { AlertModalContext } from "./AlertModal"; import ItemsNode from "./ItemsNode"; @@ -57,12 +59,17 @@ import TabularDataNode from "./TabularDataNode"; import JoinNode from "./JoinNode"; import SplitNode from "./SplitNode"; import CommentNode from "./CommentNode"; +import MultiEvalNode from "./MultiEvalNode"; +import RerankNode from "./RerankNode"; import GlobalSettingsModal, { GlobalSettingsModalRef, } from "./GlobalSettingsModal"; import ExampleFlowsModal, { ExampleFlowsModalRef } from "./ExampleFlowsModal"; import LLMEvaluatorNode from "./LLMEvalNode"; import SimpleEvalNode from "./SimpleEvalNode"; +import UploadNode from "./UploadNode"; +import ChunkNode from "./ChunkNode"; +import RetrievalNode from "./RetrievalNode"; import { getDefaultModelFormData, getDefaultModelSettings, @@ -89,6 +96,7 @@ import { APP_IS_RUNNING_LOCALLY, browserTabIsActive, FLASK_BASE_URL, + RAG_AVAILABLE, } from "./backend/utils"; import { Dict, JSONCompatible, LLMSpec } from "./backend/typing"; import { @@ -113,7 +121,6 @@ import { isChromium, isMobileSafari, } from "react-device-detect"; -import MultiEvalNode from "./MultiEvalNode"; import FlowSidebar from "./FlowSidebar"; import NestedMenu, { NestedMenuItemProps } from "./NestedMenu"; import RequestClarificationModal, { @@ -203,6 +210,7 @@ const INITIAL_LLM = () => { const nodeTypes = { textfields: TextFieldsNode, // Register the custom node prompt: PromptNode, + selectvars: SelectVarsNode, chat: PromptNode, simpleval: SimpleEvalNode, evaluator: CodeEvaluatorNode, @@ -217,6 +225,10 @@ const nodeTypes = { join: JoinNode, split: SplitNode, processor: CodeEvaluatorNode, + upload: UploadNode, + chunk: ChunkNode, + retrieval: RetrievalNode, + rerank: RerankNode, media: MediaNode, }; @@ -236,6 +248,10 @@ const nodeEmojis = { comment: "✏️", join: , split: , + upload: "📂", + chunk: "🧩", + retrieval: "🎯", + rerank: , media: "📺", }; @@ -366,8 +382,80 @@ const App = () => { // Add Nodes list const addNodesMenuItems = useMemo(() => { + // RAG-related nodes only if RAG is available + const ragNodes = [ + { + // Menu.Label + key: "RAG", + }, + { + key: "upload", + title: "Upload Docs Node", + icon: nodeEmojis.upload, + tooltip: "Upload documents to the flow, such as text files or PDFs.", + onClick: () => addNode("upload"), + }, + { + key: "chunk", + title: "Chunking Node", + icon: nodeEmojis.chunk, + tooltip: + "Chunk texts into smaller pieces. Compare different chunking methods. Typically used after the Upload Node.", + onClick: () => addNode("chunk"), + }, + { + key: "retrieval", + title: "Retrieval Node", + icon: nodeEmojis.retrieval, + tooltip: + "Given chunks and queries, retrieve relevant chunks for the given query. Compare retrieval methods across queries. Retrieval methods include both classical methods like BM25, and vector stores.", + onClick: () => addNode("retrieval"), + }, + { + key: "rerank", + title: "Rerank Node", + icon: nodeEmojis.rerank, + tooltip: "Reranks retrieval outputs.", + onClick: () => addNode("rerank"), + }, + { + key: "divider", + }, + ] as NestedMenuItemProps[]; + + // Misc nodes + const miscNodes: NestedMenuItemProps[] = [ + { + // Menu.Label + key: "Misc", + }, + { + key: "comment", + title: "Comment Node", + icon: nodeEmojis.comment, + tooltip: "Make a comment about your flow.", + onClick: () => addNode("comment"), + }, + { + key: "script", + title: "Global Python Scripts", + icon: nodeEmojis.script, + tooltip: + "Specify directories to load as local packages, so they can be imported in your Python evaluator nodes (add to sys path).", + onClick: () => addNode("scriptNode", "script"), + }, + { + key: "selectvars", + title: "Filter Variables Node", + icon: , + tooltip: + "Filter which variables and metavariables to keep for the next steps.", + onClick: () => addNode("selectVarsNode", "selectvars"), + }, + ]; + // All initial nodes available in ChainForge - const initNodes = [ + let initNodes = [ { // Menu.Label key: "Input Data", @@ -559,27 +647,12 @@ const App = () => { { key: "divider", }, - { - // Menu.Label - key: "Misc", - }, - { - key: "comment", - title: "Comment Node", - icon: nodeEmojis.comment, - tooltip: "Make a comment about your flow.", - onClick: () => addNode("comment"), - }, - { - key: "script", - title: "Global Python Scripts", - icon: nodeEmojis.script, - tooltip: - "Specify directories to load as local packages, so they can be imported in your Python evaluator nodes (add to sys path).", - onClick: () => addNode("scriptNode", "script"), - }, ] as NestedMenuItemProps[]; + // Add RAG nodes to menu if RAG dependencies are installed on the backend + if (RAG_AVAILABLE) initNodes = [...initNodes, ...ragNodes, ...miscNodes]; + else initNodes = [...initNodes, ...miscNodes]; + // Add favorite nodes to the menu const favoriteNodes = favorites?.nodes?.map(({ name, value, uid }, idx) => { const type = value.type ?? ""; @@ -610,9 +683,6 @@ const App = () => { }); } - // Favorites - // - return initNodes; }, [favorites, addNode]); @@ -1122,26 +1192,51 @@ const App = () => { ], ); - // Load flow from examples modal - const onSelectExampleFlow = (name: string, example_category?: string) => { - // Trigger the 'loading' modal + // cfzip importer + const importFlowZipFromURL = useCallback( + async (url: string) => { + const res = await fetch(url); + if (!res.ok) throw new Error(`Failed to fetch ${url}: ${res.status}`); + const blob = await res.blob(); + + // ensure filename ends with .cfzip + const urlName = + new URL(url, window.location.origin).pathname.split("/").pop() || + "example.cfzip"; + const fileName = /\.cfzip$/i.test(urlName) ? urlName : `${urlName}.cfzip`; + + const file = new File([blob], fileName, { type: "application/zip" }); + + const { flow, flowName } = await importFlowBundle(file); + importFlowFromJSON(flow); + await safeSetFlowFileName(flowName); + }, + [importFlowFromJSON, safeSetFlowFileName], + ); + + // loader for example flows + const onSelectExampleFlow = async (name: string) => { setIsLoading(true); + try { + const base = FLASK_BASE_URL.replace(/\/$/, ""); + + if (/\.cfzip$/i.test(name)) { + const file = name.endsWith(".cfzip") ? name : `${name}.cfzip`; + const url = name.startsWith("http") ? name : `${base}/examples/${file}`; + await importFlowZipFromURL(url); + return; + } - // Detect a special category of the example flow, and use the right loader for it: - if (example_category === "openai-eval") { - importFlowFromOpenAIEval(name); + // treat everything else as .cforge JSON + const baseName = name.replace(/\.cforge$/i, ""); + const flowJSON = await fetchExampleFlow(baseName); + importFlowFromJSON(flowJSON); setFlowFileNameAndCache(`flow-${Date.now()}`); - return; + } catch (err) { + handleError(err as Error); + } finally { + setIsLoading(false); } - - // Fetch the example flow data from the backend - fetchExampleFlow(name) - .then(function (flowJSON) { - // We have the data, import it: - importFlowFromJSON(flowJSON); - setFlowFileNameAndCache(`flow-${Date.now()}`); - }) - .catch(handleError); }; // When the user clicks the 'New Flow' button diff --git a/chainforge/react-server/src/ChunkMethodListComponent.tsx b/chainforge/react-server/src/ChunkMethodListComponent.tsx new file mode 100644 index 000000000..dedf168a8 --- /dev/null +++ b/chainforge/react-server/src/ChunkMethodListComponent.tsx @@ -0,0 +1,277 @@ +import React, { + useState, + useRef, + forwardRef, + useImperativeHandle, + useCallback, + useMemo, +} from "react"; +import { Button, Text, Modal, ScrollArea } from "@mantine/core"; +import { useDisclosure } from "@mantine/hooks"; +import Form from "@rjsf/core"; +import validator from "@rjsf/validator-ajv8"; +import { v4 as uuid } from "uuid"; +import { ChunkMethodSchemas, ChunkMethodGroups } from "./ChunkMethodSchemas"; +import NestedMenu, { NestedMenuItemProps } from "./NestedMenu"; +import LLMItemButtonGroup from "./LLMItemButtonGroup"; +import useStore from "./store"; +import { ensureUniqueName } from "./backend/utils"; + +export interface ChunkMethodSpec { + key: string; + baseMethod: string; + methodType: string; + name: string; + emoji?: string; + settings?: Record; +} + +export interface ChunkMethodListContainerProps { + initMethodItems?: ChunkMethodSpec[]; + onItemsChange?: ( + newItems: ChunkMethodSpec[], + oldItems: ChunkMethodSpec[], + ) => void; +} +export type ChunkMethodListContainerRef = Record; + +const ChunkMethodListItem: React.FC<{ + methodItem: ChunkMethodSpec; + onRemove: (key: string) => void; + onSettingsUpdate: (key: string, newSettings: any) => void; +}> = ({ methodItem, onRemove, onSettingsUpdate }) => { + // Fetch the relevant schema + const schemaEntry = useMemo( + () => + ChunkMethodSchemas[methodItem.baseMethod] || { + schema: {}, + uiSchema: {}, + description: "", + fullName: "", + }, + [methodItem], + ); + const schema = useMemo(() => { + const s = schemaEntry?.schema; + const schemaWithShortname = { + ...s, + properties: { + shortname: { + type: "string", + title: "Short Name", + description: "A nickname for this method.", + default: methodItem.name, + }, + ...s.properties, + }, + } as typeof s; + return schemaWithShortname; + }, [schemaEntry]); + const uiSchema = useMemo(() => schemaEntry?.uiSchema, [schemaEntry]); + + const [settingsModalOpen, { open, close }] = useDisclosure(false); + + return ( +
+
+
+ {methodItem.emoji ? methodItem.emoji + " " : ""} + {methodItem.name} +
+ + onRemove(methodItem.key)} + onClickSettings={open} // from useDisclosure(false) + hideTrashIcon={false} + /> +
+ + + {schema && Object.keys(schema).length > 0 ? ( +
onSettingsUpdate(methodItem.key, evt.formData)} + onSubmit={(evt) => { + onSettingsUpdate(methodItem.key, evt.formData); + close(); + }} + validator={validator as any} + liveValidate + noHtml5Validate + > + +
+
+ ) : ( + + (No custom settings for this method.) + + )} +
+
+ ); +}; + +const ChunkMethodListContainer = forwardRef< + ChunkMethodListContainerRef, + ChunkMethodListContainerProps +>((props, ref) => { + const [methodItems, setMethodItems] = useState( + props.initMethodItems || [], + ); + const oldItemsRef = useRef(methodItems); + + useImperativeHandle(ref, () => ({})); + + // If parent node wants to track changes + const notifyItemsChanged = useCallback( + (newItems: ChunkMethodSpec[]) => { + props.onItemsChange?.(newItems, oldItemsRef.current); + oldItemsRef.current = newItems; + }, + [props.onItemsChange], + ); + + // Remove method + const handleRemoveMethod = useCallback( + (key: string) => { + const newItems = methodItems.filter((m) => m.key !== key); + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + // Update settings + const handleSettingsUpdate = useCallback( + (key: string, newSettings: any) => { + const newItems = methodItems.map((m) => + m.key === key + ? { + ...m, + settings: newSettings, + name: newSettings.shortname ?? m.name, + } + : m, + ); + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + const addMethod = useCallback( + (m: Omit) => { + const uniqueName = ensureUniqueName( + m.name, + methodItems.map((i) => i.settings?.shortname || i.name), + ); + + const newItem: ChunkMethodSpec = { + key: uuid(), + baseMethod: m.baseMethod, + methodType: m.methodType, + name: uniqueName, + emoji: m.emoji, + settings: {}, + }; + const newItems = [...methodItems, newItem]; + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + // Pull in any dropped‑in chunkers + const customChunkers = useStore((s) => s.customChunkers); + + const addMenuItems: NestedMenuItemProps[] = useMemo(() => { + // Built-in groups as top-level submenus + const builtInGroups: NestedMenuItemProps[] = ChunkMethodGroups.map( + (group) => ({ + key: `group-${group.label}`, + title: group.label, + items: group.items.map((m) => ({ + key: `method-${m.baseMethod}`, + title: m.name, + tooltip: m.description, + icon: m.emoji ? {m.emoji} : undefined, + onClick: () => addMethod(m), + })), + }), + ); + + // Custom chunkers as another top-level submenu (if any) + const customGroup: NestedMenuItemProps[] = + customChunkers.length > 0 + ? [ + { + key: "group-custom", + title: "Custom chunkers", + items: customChunkers.map((item) => ({ + key: `custom-${item.baseMethod}`, + title: item.name, + icon: item.emoji ? {item.emoji} : undefined, + onClick: () => addMethod(item), + })), + }, + ] + : []; + + return [...builtInGroups, ...customGroup]; + }, [ChunkMethodGroups, customChunkers, addMethod]); + + return ( +
+
+ Chunking Methods + +
+ } + /> +
+
+ + {methodItems.length === 0 ? ( +
+ + No chunk methods selected. + +
+ ) : ( +
+
+ + {methodItems.map((item) => ( + + ))} + +
+
+ )} +
+ ); +}); + +ChunkMethodListContainer.displayName = "ChunkMethodListContainer"; +export default ChunkMethodListContainer; diff --git a/chainforge/react-server/src/ChunkMethodSchemas.tsx b/chainforge/react-server/src/ChunkMethodSchemas.tsx new file mode 100644 index 000000000..0792df3a2 --- /dev/null +++ b/chainforge/react-server/src/ChunkMethodSchemas.tsx @@ -0,0 +1,552 @@ +import { ModelSettingsDict } from "./backend/typing"; + +/** + * Overlapping + OpenAI tiktoken + */ +export const OverlappingOpenAITiktokenSchema: ModelSettingsDict = { + fullName: "Overlapping + OpenAI tiktoken", + description: "Chunk text using the OpenAI tiktoken library with overlap.", + schema: { + type: "object", + required: ["model", "chunk_size", "chunk_overlap"], + properties: { + model: { + type: "string", + default: "gpt-3.5-turbo", + title: "Model", + description: + "OpenAI model (e.g. gpt-4o) or direct tiktoken tokenizer name (e.g. cl100k_base); OpenAI models auto-map to the correct tokenizer.", + }, + chunk_size: { + type: "number", + default: 200, + title: "Max tokens per chunk", + }, + chunk_overlap: { + type: "number", + default: 50, + title: "Overlap tokens", + }, + }, + }, + uiSchema: {}, + postprocessors: {}, +}; + +/** + * Overlapping + HuggingFace Tokenizers + */ +export const OverlappingHuggingfaceTokenizerSchema: ModelSettingsDict = { + fullName: "Overlapping + HuggingFace Tokenizers", + description: "Chunk text using HuggingFace tokenizer-based segmentation.", + schema: { + type: "object", + required: ["tokenizer", "chunk_size", "chunk_overlap"], + properties: { + tokenizer: { + type: "string", + default: "bert-base-uncased", + title: "Tokenizer Model", + description: + "Tokenizer model to use for chunking. See HuggingFace AutoTokenizer docs for options.", + }, + chunk_size: { + type: "number", + default: 200, + title: "Tokens per chunk", + }, + chunk_overlap: { + type: "number", + default: 50, + title: "Overlap tokens", + }, + }, + }, + uiSchema: { + tokenizer_model: { + "ui:widget": "select", // display as a dropdown + }, + chunk_size: { + "ui:widget": "updown", + "ui:options": { + min: 100, + max: 5000, + step: 50, + }, + }, + }, + postprocessors: {}, +}; + +/** + * Markdown chunker + */ + +export const MarkdownHeaderSchema: ModelSettingsDict = { + fullName: "Markdown Chunker", + description: + "Splits markdown text at #/##/### headings; each section keeps its heading.", + schema: { type: "object", required: [], properties: {} }, + uiSchema: {}, + postprocessors: {}, +}; + +/** + * Syntax-based NLTK + */ +export const SyntaxNltkSchema: ModelSettingsDict = { + fullName: "Syntax-based NLTK", + description: "Splits text into sentences using NLTK's Punkt tokenizer.", + schema: { type: "object", required: [], properties: {} }, + uiSchema: {}, + postprocessors: {}, +}; + +/** + * Syntax-based TextTiling + */ +export const SyntaxTextTilingSchema: ModelSettingsDict = { + fullName: "Syntax-based TextTiling", + description: "Splits text into multi-sentence segments using TextTiling.", + schema: { + type: "object", + required: ["w", "k"], + properties: { + w: { type: "number", default: 20, title: "Window size (w)" }, + k: { type: "number", default: 10, title: "Block comparison size (k)" }, + }, + }, + uiSchema: { + w: { + "ui:widget": "range", + "ui:options": { + min: 5, + max: 50, + step: 5, + }, + }, + k: { + "ui:widget": "range", + "ui:options": { + min: 5, + max: 50, + step: 5, + }, + }, + }, + postprocessors: {}, +}; + +/** + * Chonkie Token Chunker + */ +export const ChonkieTokenSchema: ModelSettingsDict = { + fullName: "Chonkie Token Chunker", + description: "Chunk text using token-based chunking via Chonkie library.", + schema: { + type: "object", + required: ["tokenizer", "chunk_size", "chunk_overlap"], + properties: { + tokenizer: { + type: "string", + default: "gpt2", + title: "Tokenizer", + description: + "Tokenizer or token counter to use. See Chonkie docs for options.", + }, + chunk_size: { + type: "number", + default: 512, + title: "Chunk Size (tokens)", + }, + chunk_overlap: { + type: "number", + default: 0, + title: "Overlap (tokens)", + }, + }, + }, + uiSchema: {}, + postprocessors: {}, +}; + +/** + * Chonkie Sentence Chunker + */ +export const ChonkieSentenceSchema: ModelSettingsDict = { + fullName: "Chonkie Sentence Chunker", + description: + "Chunk text by sentences with token count awareness via Chonkie library.", + schema: { + type: "object", + required: ["tokenizer_or_token_counter", "chunk_size", "chunk_overlap"], + properties: { + tokenizer_or_token_counter: { + type: "string", + default: "gpt2", + title: "Tokenizer", + description: + "Tokenizer or token counter to use. See Chonkie docs for options.", + }, + chunk_size: { + type: "number", + default: 1, + title: "Max tokens per chunk", + description: + "Default 1 keeps each chunk to a single sentence. Increase to group multiple sentences up to the given token count.", + }, + chunk_overlap: { + type: "number", + default: 0, + title: "Overlap (tokens)", + }, + min_sentences_per_chunk: { + type: "number", + default: 1, + title: "Min sentences per chunk", + }, + min_characters_per_sentence: { + type: "number", + default: 12, + title: "Min characters per sentence", + }, + delim: { + type: "string", + default: '[".", "!", "?"]', + title: "Sentence delimiters (JSON array)", + }, + include_delim: { + type: "string", + default: "prev", + title: + "Include delimiters in chunks (prev, next, or leave blank for none)", + }, + }, + }, + uiSchema: { + delim: { + "ui:help": "JSON array of delimiter characters", + }, + }, + postprocessors: { + include_delim: (value: string | number | boolean): string | null => { + if (typeof value !== "string") return null; + if (value !== "prev" && value !== "next") { + return null; + } else { + return value; + } + }, + }, +}; + +/** + * Chonkie Recursive Chunker + */ +export const ChonkieRecursiveSchema: ModelSettingsDict = { + fullName: "Chonkie Recursive Chunker", + description: + "Chunk text recursively with hierarchical splitting via Chonkie library.", + schema: { + type: "object", + required: [ + "tokenizer_or_token_counter", + "chunk_size", + "min_characters_per_chunk", + ], + properties: { + tokenizer_or_token_counter: { + type: "string", + default: "gpt2", + title: "Tokenizer", + description: + "Tokenizer or token counter to use. See Chonkie docs for options.", + }, + chunk_size: { + type: "number", + default: 512, + title: "Max tokens per chunk", + }, + min_characters_per_chunk: { + type: "number", + default: 12, + title: "Min characters per chunk", + }, + use_premade_recipe: { + type: "string", + default: "markdown-en", + title: "Premade recipe (optional)", + description: + "Format: 'name-language' (e.g., 'markdown-en') or just 'language' (e.g., 'en'). Defaults to markdown-en, since ChainForge parses documents by default into markdown. See Chonkie Recipes for available options: https://huggingface.co/datasets/chonkie-ai/recipes/viewer/recipes/train?row=5&views%5B%5D=recipes", + }, + custom_recipe: { + type: "string", + default: "", + title: "Custom recipe JSON (optional)", + description: + "JSON array of recursive chunking rules (RecursiveLevel in the Chonkie API). Overrides premade recipe if provided.", + }, + }, + }, + uiSchema: { + custom_recipe: { + "ui:widget": "textarea", + "ui:help": "JSON array of recursive chunking rules", + }, + }, + postprocessors: {}, +}; + +/** + * Chonkie Semantic Chunker (Consolidated) + */ +export const ChonkieSemanticSchema: ModelSettingsDict = { + fullName: "Chonkie Semantic Chunker", + description: + "Chunk text by semantic similarity. Set 'skip_window' > 0 to enable SDPM (Double-Pass Merging).", + schema: { + type: "object", + required: ["embedding_model", "chunk_size", "threshold"], + properties: { + embedding_model: { + type: "string", + default: "minishlab/potion-base-8M", + title: "Embedding Model", + description: + "Model to use for embeddings. See Chonkie docs for options.", + }, + embedding_local_path: { + type: "string", + default: "", + title: "Embedding Local Path", + description: + "Local path for model to use for embeddings (only needed if cant download through Chonkie).", + }, + chunk_size: { + type: "number", + default: 512, + title: "Max tokens per chunk", + }, + threshold: { + type: "number", + default: 0.8, + title: "Similarity threshold", + description: + "Value between 0-1. Higher values require sentences to be more similar to stay in the same chunk.", + minimum: 0, + maximum: 1, + step: 0.01, + }, + similarity_window: { + type: "number", + default: 1, + title: "Similarity window", + description: + "Number of sentences to consider for similarity threshold calculation", + }, + min_sentences: { + type: "number", + default: 1, + title: "Min sentences per chunk", + }, + min_characters_per_sentence: { + type: "number", + default: 12, + title: "Min characters per sentence", + }, + skip_window: { + type: "number", + default: 0, + title: "Skip window (SDPM)", + description: + "If set to 0, performs standard Semantic chunking. If > 0, performs SDPM (Semantic Double-Pass Merging).", + }, + }, + }, + uiSchema: { + threshold: { + "ui:widget": "updown", + }, + skip_window: { + "ui:widget": "updown", + }, + }, + postprocessors: { + threshold: (value: string | number | boolean): number => { + if (typeof value === "number") return value; + if (typeof value === "string") return parseFloat(value); + return 0.8; + }, + }, +}; + +/** + * Chonkie Late Chunker + */ +export const ChonkieLateSchema: ModelSettingsDict = { + fullName: "Chonkie Late Chunker", + description: + "Chunk text with embedding-guided hierarchical splitting via Chonkie library.", + schema: { + type: "object", + required: ["embedding_model", "chunk_size", "min_characters_per_chunk"], + properties: { + embedding_model: { + type: "string", + default: "sentence-transformers/all-MiniLM-L6-v2", + title: "Embedding Model", + description: + "Model to use for embeddings. See Chonkie docs for options.", + }, + embedding_local_path: { + type: "string", + default: "", + title: "Embedding Local Path", + description: + "Local path for model to use for embeddings (only needed if cant download through Chonkie).", + }, + chunk_size: { + type: "number", + default: 512, + title: "Max tokens per chunk", + }, + min_characters_per_chunk: { + type: "number", + default: 24, + title: "Min characters per chunk", + }, + use_premade_recipe: { + type: "string", + default: "markdown-en", + title: "Premade recipe (optional)", + description: + "Format: 'name-language' (e.g., 'markdown-en') or just 'language' (e.g., 'en'). Defaults to markdown-en, since ChainForge parses documents by default into markdown. See Chonkie Recipes for available options: https://huggingface.co/datasets/chonkie-ai/recipes/viewer/recipes/train?row=5&views%5B%5D=recipes", + }, + custom_recipe: { + type: "string", + default: "", + title: "Custom recipe JSON (optional)", + description: + "JSON array of recursive chunking rules. Overrides premade recipe if provided.", + }, + }, + }, + uiSchema: { + custom_recipe: { + "ui:widget": "textarea", + "ui:help": "JSON array of recursive chunking rules", + }, + }, + postprocessors: {}, +}; + +export const ChunkMethodSchemas: { [baseMethod: string]: ModelSettingsDict } = { + overlapping_openai_tiktoken: OverlappingOpenAITiktokenSchema, + overlapping_huggingface_tokenizers: OverlappingHuggingfaceTokenizerSchema, + markdown_header: MarkdownHeaderSchema, + syntax_nltk: SyntaxNltkSchema, + syntax_texttiling: SyntaxTextTilingSchema, + chonkie_token: ChonkieTokenSchema, + chonkie_sentence: ChonkieSentenceSchema, + chonkie_recursive: ChonkieRecursiveSchema, + chonkie_semantic: ChonkieSemanticSchema, + chonkie_late: ChonkieLateSchema, +}; + +export const ChunkMethodGroups = [ + { + label: "Token-Based", + items: [ + { + baseMethod: "chonkie_token", + methodType: "Chonkie", + name: "Token Chunker", + emoji: "🐿️", + description: + "Split text into fixed-size token chunks with optional overlap. Fastest and cheapest option.", + }, + { + baseMethod: "overlapping_openai_tiktoken", + methodType: "Overlapping Chunking", + name: "OpenAI tiktoken", + emoji: "🤖", + description: + "Use OpenAI’s tiktoken to count tokens for chunk sizes and overlaps.", + }, + { + baseMethod: "overlapping_huggingface_tokenizers", + methodType: "Overlapping Chunking", + name: "HuggingFace Tokenizers", + emoji: "🤗", + description: + "Use a HuggingFace tokenizer to count tokens for chunk sizes and overlaps.", + }, + ], + }, + { + label: "Structure-Based", + items: [ + { + baseMethod: "chonkie_sentence", + methodType: "Chonkie", + name: "Sentence Chunker", + emoji: "✂️", + description: + "Split on sentence boundaries. Nice for QA / summarization where you want readable chunks.", + }, + { + baseMethod: "markdown_header", + methodType: "Markdown", + name: "Markdown Chunker", + emoji: "📝", + description: + "Respect markdown headings when splitting (e.g. #, ##). Great for docs and notebooks.", + }, + { + baseMethod: "syntax_nltk", + methodType: "Syntax-Based Chunking", + name: "NLTK Sentence Splitter", + emoji: "🐍", + description: + "Sentence splitting powered by NLTK. More robust for messy text.", + }, + { + baseMethod: "syntax_texttiling", + methodType: "Syntax-Based Chunking", + name: "Stopword Chunker", + emoji: "📑", + description: + "Topic-based segmentation using TextTiling. Helps break long text into sections based on lexical shifts.", + }, + { + baseMethod: "chonkie_recursive", + methodType: "Chonkie", + name: "Recursive Chunker", + emoji: "🔄", + description: + "Try large chunks first and recursively split until under a token limit. Good when you want big chunks but must respect model limits.", + }, + ], + }, + { + label: "Semantic / Embedding-Based", + items: [ + { + baseMethod: "chonkie_semantic", + methodType: "Chonkie", + name: "Semantic Chunker", + emoji: "🤖", + description: + "Use embeddings to cut at semantically meaningful boundaries (topic changes, sections). More accurate but more expensive.", + }, + { + baseMethod: "chonkie_late", + methodType: "Chonkie", + name: "Late Chunker", + emoji: "⏳", + description: + "Apply length-based chunking at run time instead of precomputing chunks.", + }, + ], + }, +]; diff --git a/chainforge/react-server/src/ChunkNode.tsx b/chainforge/react-server/src/ChunkNode.tsx new file mode 100644 index 000000000..e1ad1eeba --- /dev/null +++ b/chainforge/react-server/src/ChunkNode.tsx @@ -0,0 +1,319 @@ +import React, { + useState, + useEffect, + useCallback, + useRef, + useContext, +} from "react"; +import { Handle, Position } from "reactflow"; +import { Status } from "./StatusIndicatorComponent"; +import { AlertModalContext } from "./AlertModal"; +import BaseNode from "./BaseNode"; +import NodeLabel from "./NodeLabelComponent"; +import useStore from "./store"; + +import LLMResponseInspectorModal, { + LLMResponseInspectorModalRef, +} from "./LLMResponseInspectorModal"; +import InspectFooter from "./InspectFooter"; +import { IconSearch } from "@tabler/icons-react"; + +import ChunkMethodListContainer, { + ChunkMethodSpec, +} from "./ChunkMethodListComponent"; + +import { TemplateVarInfo, LLMResponse } from "./backend/typing"; +import { StringLookup } from "./backend/cache"; +import { FLASK_BASE_URL } from "./backend/utils"; +import { v4 as uuid } from "uuid"; + +interface ChunkNodeProps { + data: { + title?: string; + methods?: ChunkMethodSpec[]; + refresh?: boolean; + }; + id: string; +} + +const ChunkNode: React.FC = ({ data, id }) => { + const nodeDefaultTitle = "Chunk Node"; + const nodeIcon = "🧩"; + + const pullInputData = useStore((s) => s.pullInputData); + const setDataPropsForNode = useStore((s) => s.setDataPropsForNode); + const pingOutputNodes = useStore((s) => s.pingOutputNodes); + + const showAlert = useContext(AlertModalContext); + + const [methodItems, setMethodItems] = useState( + data.methods || [], + ); + const [status, setStatus] = useState(Status.NONE); + const [jsonResponses, setJSONResponses] = useState([]); + + const inspectorRef = useRef(null); + + // On refresh + useEffect(() => { + if (data.refresh) { + setDataPropsForNode(id, { refresh: false, fields: [], output: [] }); + setJSONResponses([]); + setStatus(Status.NONE); + } + }, [data.refresh, id, setDataPropsForNode]); + + // Track changes in chunk methods + const handleMethodItemsChange = useCallback( + (newItems: ChunkMethodSpec[], _oldItems: ChunkMethodSpec[]) => { + setMethodItems(newItems); + setDataPropsForNode(id, { methods: newItems }); + if (status === Status.READY) setStatus(Status.WARNING); + }, + [id, status, setDataPropsForNode], + ); + + // Truncate string helper + const truncateString = (str: string, maxLen = 25): string => { + if (!str) return ""; + if (str.length <= maxLen) return str; + return `${str.slice(0, 12)}...${str.slice(-10)}`; + }; + + // The main chunking function + const runChunking = useCallback(async () => { + const handleError = (msg: string, err?: any) => { + console.error(msg, err); + showAlert?.(msg); + setStatus(Status.ERROR); + }; + + if (methodItems.length === 0) { + handleError("No chunk methods selected!"); + return; + } + + // 1) Pull text from upstream (the UploadNode) + let inputData: { text?: TemplateVarInfo[] } = {}; + try { + inputData = pullInputData(["text"], id) as { text?: TemplateVarInfo[] }; + } catch (error) { + handleError("No input text found. Is UploadNode connected?", error); + return; + } + const fileArr = inputData.text || []; + if (fileArr.length === 0) { + handleError( + "No text found. Please attach an UploadNode or provide text.", + ); + return; + } + + setStatus(Status.LOADING); + setJSONResponses([]); + + // We'll group by library to call your chunker + const allChunksByMethodName: Record = {}; + const allResponsesByMethodName: Record = {}; + + // Group methods by library + const methodsByName = methodItems.reduce( + (acc, method) => { + if (!acc[method.name]) acc[method.name] = []; + acc[method.name].push(method); + return acc; + }, + {} as Record, + ); + + // 2) For each library and each doc + for (const [name, methods] of Object.entries(methodsByName)) { + allChunksByMethodName[name] = []; + allResponsesByMethodName[name] = []; + + for (const fileInfo of fileArr) { + const docTitle = fileInfo?.metavars?.filename || "Untitled"; + + for (const method of methods) { + try { + const formData = new FormData(); + formData.append("baseMethod", method.baseMethod); + + // Get the full text and pack it as a "file" part instead of a plain field + const fullText = StringLookup.get(fileInfo.text) ?? ""; + const textBlob = new Blob([fullText], { type: "text/plain" }); + + formData.append("document", textBlob); + + // Add the user settings + Object.entries(method.settings ?? {}).forEach(([k, v]) => { + formData.append(k, String(v)); + }); + + const res = await fetch(`${FLASK_BASE_URL}chunk`, { + method: "POST", + body: formData, + }); + + if (!res.ok) { + const err = await res.json(); + throw new Error(err.error || "Chunking request failed"); + } + + const json = await res.json(); + const chunks = json.chunks as string[]; + + // We'll build chunk IDs for each doc + const methodSafe = method.methodType.replace(/\W+/g, "_"); + const libSafe = name.replace(/\W+/g, "_"); + + chunks.forEach((cText, index) => { + const cId = uuid(); // `${methodSafe}_${index}_${libSafe}`; + + // Create the chunk object + const chunkVar: TemplateVarInfo = { + text: cText, + prompt: "", + fill_history: { + chunkMethod: `${method.methodType} (${method.name})`, + docTitle, + chunkLibrary: name, + chunkId: index.toString(), + }, + llm: undefined, + metavars: { + docTitle, + chunkLibrary: name, + chunkId: index.toString(), + }, + }; + + allChunksByMethodName[name].push(chunkVar); + + // LLMResponse for inspector + const respObj: LLMResponse = { + uid: cId, + prompt: `Doc: ${docTitle} | Chunk ID: ${truncateString(cId, 25)}`, + vars: { + // chunkMethod: `${method.methodType} (${method.name})`, + chunkId: index.toString(), + docTitle, + // chunkLibrary: name, + }, + responses: [cText], + llm: name, + metavars: chunkVar.metavars || {}, + }; + + allResponsesByMethodName[name].push(respObj); + }); + } catch (err: any) { + handleError( + `Error chunking "${docTitle}" with ${method.name}: ${err.message}`, + err, + ); + return; + } + } + } + } + + // Combine results + const allChunks = Object.values(allChunksByMethodName).flat(); + const allResponses = Object.values(allResponsesByMethodName).flat(); + + // 3) Output data grouped by library + const groupedOutput = Object.entries(allChunksByMethodName).reduce( + (acc, [lib, chunks]) => { + acc[lib] = chunks.map((ch) => ({ + id: ch.metavars?.chunkId, + docTitle: ch.metavars?.docTitle, + method: ch.fill_history?.chunkMethod, + text: ch.text, + })); + return acc; + }, + {} as Record, + ); + + setDataPropsForNode(id, { + fields: allChunks, + output: groupedOutput, + }); + pingOutputNodes(id); + + setJSONResponses(allResponses); + setStatus(Status.READY); + }, [ + id, + methodItems, + pullInputData, + setDataPropsForNode, + showAlert, + pingOutputNodes, + ]); + + // Open inspector + const openInspector = () => { + if (jsonResponses.length > 0 && inspectorRef.current) { + inspectorRef.current.trigger(); + } + }; + + return ( + + + + + + + + {jsonResponses && jsonResponses.length > 0 && ( + { + // Do nothing + }} + isDrawerOpen={false} + label={ + <> + Inspect chunks + + } + /> + )} + + {/* The LLM Response Inspector */} + + + + + ); +}; + +export default ChunkNode; diff --git a/chainforge/react-server/src/ExampleFlowsModal.tsx b/chainforge/react-server/src/ExampleFlowsModal.tsx index ff662843a..a296f091d 100644 --- a/chainforge/react-server/src/ExampleFlowsModal.tsx +++ b/chainforge/react-server/src/ExampleFlowsModal.tsx @@ -576,7 +576,19 @@ const ExampleFlowsModal = forwardRef< + + diff --git a/chainforge/react-server/src/GlobalSettingsModal.tsx b/chainforge/react-server/src/GlobalSettingsModal.tsx index 7593b8ff0..714d90772 100644 --- a/chainforge/react-server/src/GlobalSettingsModal.tsx +++ b/chainforge/react-server/src/GlobalSettingsModal.tsx @@ -32,9 +32,11 @@ import { IconBrandPython, IconX, IconSparkles, + IconBrandGithub, + IconBook, } from "@tabler/icons-react"; import { Dropzone, FileWithPath } from "@mantine/dropzone"; -import useStore, { initLLMProviderMenu } from "./store"; +import useStore, { initLLMProviderMenu, initLLMProviders } from "./store"; import { APP_IS_RUNNING_LOCALLY } from "./backend/utils"; import { setCustomProviders } from "./ModelSettingSchemas"; import { getAIFeaturesModelProviders } from "./backend/ai"; @@ -91,6 +93,24 @@ const read_file = ( reader.readAsText(file); }; +// Normalize backend providers to the store shape for *retrievers* (preserves schema) +function toCustomRetrieverSpecs(providers: any[]) { + return (providers || []) + .filter((p: any) => p.category === "retriever") + .map((p: any) => ({ + key: `__custom/${p.name}`, + baseMethod: `__custom/${p.name}`, + methodName: p.name, + library: p.name, + emoji: p.emoji ?? "✨", + // optional; safe default + needsEmbeddingModel: !!p.needs_embedding_model, + // keep schema + defaults + settings_schema: p.settings_schema ?? undefined, + default_settings: p.default_settings ?? undefined, + })); +} + interface CustomProviderScriptDropzoneProps { onError: (err: string | Error) => void; onSetProviders: (providers: CustomLLMProviderSpec[]) => void; @@ -104,6 +124,9 @@ const CustomProviderScriptDropzone: React.FC< > = ({ onError, onSetProviders }) => { const theme = useMantineTheme(); const [isLoading, setIsLoading] = useState(false); + const setCustomChunkers = useStore((state) => state.setCustomChunkers); + const setCustomRetrievers = useStore((s) => s.setCustomRetrievers); + const setAvailableLLMs = useStore((s) => s.setAvailableLLMs); return ( p.category === "chunker") + .map((p) => ({ + key: `__custom/${p.name}`, + baseMethod: `__custom/${p.name}`, + methodType: "chunker", + name: p.name, + emoji: p.emoji, + settings: {}, + })), + ); + setCustomRetrievers(toCustomRetrieverSpecs(providers)); }) .catch((err) => { setIsLoading(false); @@ -204,6 +242,8 @@ const GlobalSettingsModal = forwardRef( AWS_Region: "us-east-1", AmazonBedrock: JSON.stringify({ credentials: {}, region: "us-east-1" }), Together: "", + DeepSeek: "", + Cohere: "", }, validate: { @@ -373,6 +413,9 @@ const GlobalSettingsModal = forwardRef( [showAlert], ); + const setCustomChunkers = useStore((s) => s.setCustomChunkers); + const setCustomRetrievers = useStore((s) => s.setCustomRetrievers); + const [customProviders, setLocalCustomProviders] = useState< CustomLLMProviderSpec[] >([]); @@ -398,6 +441,24 @@ const GlobalSettingsModal = forwardRef( setLocalCustomProviders( customProviders.filter((p) => p.name !== name), ); + setCustomChunkers( + customProviders + .filter((p) => p.name !== name) // remaining providers + .filter((p) => p.category === "chunker") // only chunkers + .map((p) => ({ + key: `__custom/${p.name}`, + baseMethod: `__custom/${p.name}`, + methodType: "chunker", + name: p.name, + emoji: p.emoji, + settings: {}, + })), + ); + setCustomRetrievers( + toCustomRetrieverSpecs( + customProviders.filter((p) => p.name !== name), + ), + ); refreshLLMProviderLists(); }) .catch(handleError); @@ -420,6 +481,7 @@ const GlobalSettingsModal = forwardRef( // Success; pass custom providers list to store: setCustomProviders(providers); setLocalCustomProviders(providers); + setCustomRetrievers(toCustomRetrieverSpecs(providers)); }) .catch(console.error); } @@ -548,6 +610,13 @@ const GlobalSettingsModal = forwardRef( />
+ +
+ ( see the documentation. - {customProviders.map((p) => ( - - - - {p.emoji} - {p.name} - {p.settings_schema ? ( - - has settings - - ) : ( - <> - )} - - - - - ))} + {["chunker", "retriever", "model"] + .filter((cat) => + customProviders.some((p) => p.category === cat), + ) + .map((cat) => ( + + + {cat[0].toUpperCase() + cat.slice(1)} + + {customProviders + .filter((p) => p.category === cat) + .map((p) => ( + + + + {p.emoji} + {p.name} + {p.settings_schema && ( + + has settings + + )} + + + + + ))} + + ))} { @@ -794,6 +876,30 @@ const GlobalSettingsModal = forwardRef( ); }} /> + + + + + + + + diff --git a/chainforge/react-server/src/LLMResponseInspector.tsx b/chainforge/react-server/src/LLMResponseInspector.tsx index 2ca763f84..6fcc9845d 100644 --- a/chainforge/react-server/src/LLMResponseInspector.tsx +++ b/chainforge/react-server/src/LLMResponseInspector.tsx @@ -353,6 +353,9 @@ export interface LLMResponseInspectorProps { customLLMFieldName?: string; disableBackgroundColor?: boolean; treatLLMFieldAsUnique?: boolean; + ignoreAndHideLLMField?: boolean; // If true, LLM field will not be shown in the table view + ignoreAndHideEvalResField?: boolean; // If true, "Eval Res" column option will not be shown in the table view + defaultTableColVar?: string; } const LLMResponseInspector: React.FC = ({ @@ -362,6 +365,9 @@ const LLMResponseInspector: React.FC = ({ customLLMFieldName, disableBackgroundColor, treatLLMFieldAsUnique, + ignoreAndHideLLMField, + ignoreAndHideEvalResField, + defaultTableColVar, }) => { // Responses const [responseDivs, setResponseDivs] = useState([]); @@ -459,7 +465,7 @@ const LLMResponseInspector: React.FC = ({ }); // The var name to use for columns in the table view - const [tableColVar, setTableColVar] = useState("$LLM"); + const [tableColVar, setTableColVar] = useState(defaultTableColVar ?? "$LLM"); const [userSelectedTableCol, setUserSelectedTableCol] = useState(false); // State of the 'only show scores' toggle when eval results are present @@ -536,38 +542,55 @@ const LLMResponseInspector: React.FC = ({ : false; setShowEvalScoreOptions(contains_eval_res); + let effectiveTableColVar = tableColVar; + + if ( + ignoreAndHideLLMField && + !userSelectedTableCol && + tableColVar === "$LLM" + ) { + effectiveTableColVar = "retrievalMethod"; + setTableColVar("retrievalMethod"); + } + // Set the variables accessible in the MultiSelect for 'group by' - const msvars = found_vars - .map((name: string) => + let msvars = found_vars + .map((name: string) => { + let label = name; + if (name === "retrievalMethod") label = "Retrieval method"; // We add a $ prefix to mark this as a prompt parameter, and so // in the future we can add special types of variables without name collisions - ({ value: name, label: name }), - ) + return { value: name, label }; + }) .concat({ value: "$LLM", label: customLLMFieldName || "LLM" }); - if (contains_eval_res && viewFormat === "table") + + if (ignoreAndHideLLMField) { + // If we are ignoring the LLM field, we need to remove it from the msvars + msvars = msvars.filter((v) => v.value !== "$LLM"); + } + + if ( + contains_eval_res && + viewFormat === "table" && + !ignoreAndHideEvalResField + ) msvars.push({ value: "$EVAL_RES", label: "Eval results" }); + setMultiSelectVars(msvars); // If only one LLM is present, and user hasn't manually selected one to plot, // and there's more than one prompt variable as input, default to plotting the // eval scores, or the first found prompt variable as columns instead: if ( + !ignoreAndHideEvalResField && viewFormat === "table" && !userSelectedTableCol && - tableColVar === "$LLM" + effectiveTableColVar === "$LLM" && + (contains_multi_evals || (found_llms.length === 1 && contains_eval_res)) ) { - if ( - contains_multi_evals || - (found_llms.length === 1 && contains_eval_res) - ) { - // Plot eval scores on columns - setTableColVar("$EVAL_RES"); - return; - } - // else if (found_llms.length === 1 && found_vars.length > 1) { - // setTableColVar(found_vars[0]); - // return; // useEffect will replot with the new values - // } + // Plot eval scores on columns + setTableColVar("$EVAL_RES"); + return; } // If this is the first time receiving responses, set the multiSelectValue to whatever is the first: @@ -719,27 +742,27 @@ const LLMResponseInspector: React.FC = ({ getColVal: (r: LLMResponse) => string | number | undefined, found_sel_var_vals: string[], eval_res_cols: string[]; - let metavar_cols: string[] = []; // found_metavars; -- Disabling this functionality for now, since it is usually annoying. - if (tableColVar === "$LLM") { + const metavar_cols: string[] = []; // found_metavars; -- Disabling this functionality for now, since it is usually annoying. + if (effectiveTableColVar === "$LLM") { var_cols = found_vars; getColVal = getLLMName; found_sel_var_vals = found_llms; colnames = var_cols.concat(metavar_cols).concat(found_llms); } else { - metavar_cols = []; var_cols = found_vars - .filter((v) => v !== tableColVar) + .filter((v) => v !== effectiveTableColVar) .concat(found_llms.length > 1 ? ["LLM"] : []); // only add LLM column if num LLMs > 1 - getColVal = (r) => llmResponseDataToString(r.vars[tableColVar]); - colnames = var_cols; + getColVal = (r) => + llmResponseDataToString(r.vars[effectiveTableColVar]); + colnames = metavar_cols.concat(var_cols); found_sel_var_vals = []; } // If the user wants to plot eval results in separate column, OR there's only a single LLM to show - if (tableColVar === "$EVAL_RES") { + if (effectiveTableColVar === "$EVAL_RES") { // Plot evaluation results on separate column(s): eval_res_cols = getEvalResCols(responses); - // if (tableColVar === "$EVAL_RES") { + // if (effectiveTableColVar === "$EVAL_RES") { // This adds a column, "Response", abusing the way getColVal and found_sel_var_vals is used // below by making a dummy value (one giant group with all responses in it). We then // sort the responses by LLM, to give a nicer view. @@ -747,16 +770,13 @@ const LLMResponseInspector: React.FC = ({ getColVal = () => "_"; found_sel_var_vals = ["_"]; responses.sort((a, b) => getLLMName(a).localeCompare(getLLMName(b))); - // } else { - // colnames = colnames.concat(eval_res_cols); - // } - } else if (tableColVar !== "$LLM") { + } else if (effectiveTableColVar !== "$LLM") { // Get the unique values for the selected variable found_sel_var_vals = Array.from( responses.reduce((acc, res_obj) => { acc.add( - tableColVar in res_obj.vars - ? llmResponseDataToString(res_obj.vars[tableColVar]) + effectiveTableColVar in res_obj.vars + ? llmResponseDataToString(res_obj.vars[effectiveTableColVar]) : "(unspecified)", ); return acc; @@ -773,6 +793,13 @@ const LLMResponseInspector: React.FC = ({ v === "LLM" ? getLLMName(r) : StringLookup.get(r.vars[v]) ?? ""; // Then group responses by prompts. Each prompt will become a separate row of the table (will be treated as unique) + if (ignoreAndHideLLMField) { + // If we are ignoring the LLM field, we need to remove it from the var_cols + var_cols = var_cols.filter((v) => v !== "LLM"); + // Also, we need to remove the LLM column from the colnames + colnames = colnames.filter((c) => c !== "LLM"); + } + const responses_by_prompt = groupResponsesBy(responses, (r) => { const group = var_cols .map((v) => getVar(r, v)) @@ -808,7 +835,7 @@ const LLMResponseInspector: React.FC = ({ EvaluationScore | undefined, ][][] = []; if (eval_res_cols && eval_res_cols.length > 0) { - // We can assume that there's only one response object, since to + // We can assume that there's only one response object, since // if eval_res_cols is set, there must be only one LLM. eval_cols_vals = eval_res_cols.map((metric_name, metric_idx) => { const items = resp_objs[0].eval_res?.items; @@ -1055,7 +1082,12 @@ const LLMResponseInspector: React.FC = ({ }; else return { - style: { lineHeight: 1.2, ...fz }, + style: { + lineHeight: 1.2, + verticalAlign: "top", + textAlign: "left", + ...fz, + }, }; })(), })) as MRT_ColumnDef[]; diff --git a/chainforge/react-server/src/LLMResponseInspectorModal.tsx b/chainforge/react-server/src/LLMResponseInspectorModal.tsx index 1004ea723..2f64c37c4 100644 --- a/chainforge/react-server/src/LLMResponseInspectorModal.tsx +++ b/chainforge/react-server/src/LLMResponseInspectorModal.tsx @@ -27,6 +27,9 @@ export interface LLMResponseInspectorModalProps { customLLMFieldName?: string; disableBackgroundColor?: boolean; treatLLMFieldAsUnique?: boolean; + ignoreAndHideLLMField?: boolean; // If true, LLM field will not be shown in the table view + ignoreAndHideEvalResField?: boolean; // If true, "Eval Res" column option will not be shown in the table view + defaultTableColVar?: string; } const LLMResponseInspectorModal = forwardRef< @@ -105,6 +108,9 @@ const LLMResponseInspectorModal = forwardRef< customLLMFieldName={props.customLLMFieldName} disableBackgroundColor={props.disableBackgroundColor} treatLLMFieldAsUnique={props.treatLLMFieldAsUnique} + ignoreAndHideLLMField={props.ignoreAndHideLLMField} + ignoreAndHideEvalResField={props.ignoreAndHideEvalResField} + defaultTableColVar={props.defaultTableColVar} /> diff --git a/chainforge/react-server/src/ModelSettingSchemas.tsx b/chainforge/react-server/src/ModelSettingSchemas.tsx index 596a84b41..a1808b8cc 100644 --- a/chainforge/react-server/src/ModelSettingSchemas.tsx +++ b/chainforge/react-server/src/ModelSettingSchemas.tsx @@ -2566,8 +2566,12 @@ export function getSettingsSchemaForLLM( [LLMProvider.DeepSeek]: DeepSeekSettings, }; - if (llm_provider === LLMProvider.Custom) return ModelSettings[llm_name]; - else if (llm_provider && llm_provider in provider_to_settings_schema) + if (llm_provider === LLMProvider.Custom) { + return ( + ModelSettings[llm_name] ?? + ModelSettings[llm_name?.endsWith("/") ? llm_name : `${llm_name}/`] + ); + } else if (llm_provider && llm_provider in provider_to_settings_schema) return provider_to_settings_schema[llm_provider]; else if (llm_provider === LLMProvider.Bedrock) { return ModelSettings[llm_name.split("-")[0]]; @@ -2626,21 +2630,25 @@ export const setCustomProvider = ( models?: string[], rate_limit?: number | string, settings_schema?: CustomLLMProviderSpec["settings_schema"], + category?: string, ) => { if (typeof emoji === "string" && (emoji.length === 0 || emoji.length > 2)) throw new Error(`Emoji for a custom provider must have a character.`); + if ((category ?? "model") !== "model") return; const new_provider: Dict = { name }; + new_provider.category = category || "model"; new_provider.emoji = emoji || "✨"; // Each LLM *model* must have a unique name. To avoid name collisions, for custom providers, // the full LLM model name is a path, __custom// // If there's no submodel, it's just __custom/. - const base_model = `__custom/${name}/`; + const base_model = `__custom/${name}`; // no trailing slash new_provider.base_model = base_model; new_provider.model = - base_model + - (Array.isArray(models) && models.length > 0 ? `${models[0]}` : ""); + Array.isArray(models) && models.length > 0 + ? `${base_model}/${models[0]}` + : base_model; // Build the settings form schema for this new custom provider const compiled_schema: ModelSettingsDict = { @@ -2667,6 +2675,16 @@ export const setCustomProvider = ( postprocessors: {}, }; + const canon = (id: string) => id.replace(/\/+$/, ""); + + ModelSettings[canon(base_model)] = compiled_schema; + + if (typeof rate_limit === "number" && rate_limit > 0) { + RATE_LIMIT_BY_MODEL[canon(base_model)] = rate_limit; + } else { + MAX_CONCURRENT[canon(base_model)] = 1; + } + // Add a models selector if there's multiple models if (Array.isArray(models) && models.length > 0) { compiled_schema.schema.properties.model = { @@ -2722,16 +2740,19 @@ export const setCustomProvider = ( }; export const setCustomProviders = (providers: CustomLLMProviderSpec[]) => { - for (const p of providers) - setCustomProvider( - p.name, - p.emoji, - p.models, - p.rate_limit, - p.settings_schema, + (providers || []) + .filter((p) => (p?.category ?? "model") === "model") + .forEach((p) => + setCustomProvider( + p.name, + p.emoji, + p.models, + p.rate_limit, + p.settings_schema, + p.category, + ), ); }; - export const getTemperatureSpecForModel = (modelName: string) => { if (modelName in ModelSettings) { const temperature_property = diff --git a/chainforge/react-server/src/ModelSettingsModal.tsx b/chainforge/react-server/src/ModelSettingsModal.tsx index 21410f43e..5c45337c8 100644 --- a/chainforge/react-server/src/ModelSettingsModal.tsx +++ b/chainforge/react-server/src/ModelSettingsModal.tsx @@ -30,7 +30,7 @@ import { APP_IS_RUNNING_LOCALLY } from "./backend/utils"; const IS_RUNNING_LOCALLY = APP_IS_RUNNING_LOCALLY(); // Custom UI widgets for react-jsonschema-form -const DatalistWidget = (props: WidgetProps) => { +export const DatalistWidget = (props: WidgetProps) => { const [data, setData] = useState( ( props.options.enumOptions?.map((option, index) => ({ diff --git a/chainforge/react-server/src/MultiEvalNode.tsx b/chainforge/react-server/src/MultiEvalNode.tsx index c19343ac2..e66f973ba 100644 --- a/chainforge/react-server/src/MultiEvalNode.tsx +++ b/chainforge/react-server/src/MultiEvalNode.tsx @@ -706,8 +706,9 @@ const MultiEvalNode: React.FC = ({ data, id }) => { } }); }); - const finalResponses = Object.values(merged_res_objs_by_uid); + console.log("Output length:", finalResponses.length); + console.log("MultiEval Output:", finalResponses[0]?.eval_res?.items[0]); // We now have a dict of the form { uid: LLMResponse } // We need return only the values of this dict: setLastResponses(finalResponses); diff --git a/chainforge/react-server/src/NestedMenu.tsx b/chainforge/react-server/src/NestedMenu.tsx index f52e468a2..e745dc79a 100644 --- a/chainforge/react-server/src/NestedMenu.tsx +++ b/chainforge/react-server/src/NestedMenu.tsx @@ -1,4 +1,4 @@ -import React, { ReactNode, useMemo, useState } from "react"; +import React, { ReactNode, useMemo, useState, useEffect, useRef } from "react"; import { Menu, Tooltip, Popover, ActionIcon } from "@mantine/core"; import { IconChevronRight, IconTrash } from "@tabler/icons-react"; import { ContextMenuItemOptions } from "mantine-contextmenu/dist/types"; @@ -45,6 +45,24 @@ export default function NestedMenu({ }) { const [menuOpened, setMenuOpened] = useState(false); const [submenusOpened, setSubmenusOpened] = useState(null); + const menuRef = useRef(null); + + useEffect(() => { + const handleClickOutside = (event: MouseEvent) => { + if (menuRef.current && !menuRef.current.contains(event.target as Node)) { + setMenuOpened(false); + setSubmenusOpened(null); + } + }; + + if (menuOpened) { + document.addEventListener("click", handleClickOutside); + } + + return () => { + document.removeEventListener("click", handleClickOutside); + }; + }, [menuOpened]); const openSubmenu = (key: string) => { if (submenusOpened) { setSubmenusOpened((prev) => [...(prev as string[]), key]); @@ -158,18 +176,19 @@ export default function NestedMenu({ }, [items, submenusOpened]); return ( - - {button(() => setMenuOpened(!menuOpened))} - - {menuItems} - +
+ + {button(() => setMenuOpened(!menuOpened))} + {menuItems} + +
); } diff --git a/chainforge/react-server/src/RerankMethodListComponent.tsx b/chainforge/react-server/src/RerankMethodListComponent.tsx new file mode 100644 index 000000000..8aa0ee625 --- /dev/null +++ b/chainforge/react-server/src/RerankMethodListComponent.tsx @@ -0,0 +1,270 @@ +import React, { + useState, + useRef, + forwardRef, + useImperativeHandle, + useCallback, + useMemo, +} from "react"; +import { Button, Text, Modal, ScrollArea } from "@mantine/core"; +import { useDisclosure } from "@mantine/hooks"; +import Form from "@rjsf/core"; +import validator from "@rjsf/validator-ajv8"; +import { v4 as uuid } from "uuid"; +import { RerankMethodSchemas, rerankMethodGroups } from "./RerankMethodSchemas"; +import NestedMenu, { NestedMenuItemProps } from "./NestedMenu"; +import LLMItemButtonGroup from "./LLMItemButtonGroup"; +import useStore from "./store"; +import { DatalistWidget } from "./ModelSettingsModal"; + +export interface RerankMethodSpec { + key: string; + baseMethod: string; + methodType: string; + name: string; + emoji?: string; + settings?: Record; +} + +export interface RerankMethodListContainerProps { + initMethodItems?: RerankMethodSpec[]; + onItemsChange?: ( + newItems: RerankMethodSpec[], + oldItems: RerankMethodSpec[], + ) => void; +} +export type RerankMethodListContainerRef = Record; + +const RerankMethodListItem: React.FC<{ + methodItem: RerankMethodSpec; + onRemove: (key: string) => void; + onSettingsUpdate: (key: string, newSettings: any) => void; +}> = ({ methodItem, onRemove, onSettingsUpdate }) => { + // Fetch the relevant schema + const schemaEntry = useMemo( + () => + RerankMethodSchemas[methodItem.baseMethod] || { + schema: {}, + uiSchema: {}, + description: "", + fullName: "", + }, + [methodItem], + ); + const schema = useMemo(() => { + return schemaEntry?.schema; + }, [schemaEntry]); + const uiSchema = useMemo(() => schemaEntry?.uiSchema, [schemaEntry]); + + const [settingsModalOpen, { open, close }] = useDisclosure(false); + + return ( +
+
+
+ {methodItem.emoji ? methodItem.emoji + " " : ""} + {methodItem.name} +
+ + onRemove(methodItem.key)} + onClickSettings={open} // from useDisclosure(false) + hideTrashIcon={false} + /> +
+ + + {schema && Object.keys(schema).length > 0 ? ( +
onSettingsUpdate(methodItem.key, evt.formData)} + onSubmit={(evt) => { + onSettingsUpdate(methodItem.key, evt.formData); + close(); + }} + validator={validator as any} + widgets={{ datalist: DatalistWidget } as any} + liveValidate + noHtml5Validate + > + +
+
+ ) : ( + + (No custom settings for this method.) + + )} +
+
+ ); +}; + +const RerankMethodListContainer = forwardRef< + RerankMethodListContainerRef, + RerankMethodListContainerProps +>((props, ref) => { + const [methodItems, setMethodItems] = useState( + props.initMethodItems || [], + ); + const oldItemsRef = useRef(methodItems); + + useImperativeHandle(ref, () => ({})); + + // If parent node wants to track changes + const notifyItemsChanged = useCallback( + (newItems: RerankMethodSpec[]) => { + props.onItemsChange?.(newItems, oldItemsRef.current); + oldItemsRef.current = newItems; + }, + [props.onItemsChange], + ); + + // Remove method + const handleRemoveMethod = useCallback( + (key: string) => { + const newItems = methodItems.filter((m) => m.key !== key); + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + // Update method settings + const handleUpdateMethodSettings = useCallback( + (key: string, newSettings: any) => { + const newItems = methodItems.map((m) => + m.key === key ? { ...m, settings: newSettings } : m, + ); + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + // Add method + const handleAddMethod = useCallback( + ( + baseMethod: string, + name: string, + emoji: string, + methodType: string, + customDefaults?: Record, + ) => { + const key = uuid(); + const schemaEntry = RerankMethodSchemas[baseMethod]; + + // Pull defaults from schema + let defaultSettings: Record = {}; + if (schemaEntry?.schema?.properties) { + const schemaProps = schemaEntry.schema.properties; + defaultSettings = Object.entries(schemaProps).reduce( + (acc, [propKey, propDef]) => { + if ( + propDef && + typeof propDef === "object" && + "default" in propDef + ) { + acc[propKey] = (propDef as any).default; + } + return acc; + }, + {} as Record, + ); + } + + // Override with custom defaults if provided + if (customDefaults) { + defaultSettings = { ...defaultSettings, ...customDefaults }; + } + + const newMethod: RerankMethodSpec = { + key, + baseMethod, + methodType, + name, + emoji, + settings: defaultSettings, + }; + + const newItems = [...methodItems, newMethod]; + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + // Build nested menu items + const menuItems = useMemo((): NestedMenuItemProps[] => { + return rerankMethodGroups.map((group) => ({ + key: `group-${group.label}`, + title: group.label, + items: group.items.map((item) => ({ + key: `method-${item.baseMethod}-${item.name}`, + title: `${item.emoji} ${item.name}`, + onClick: () => + handleAddMethod( + item.baseMethod, + item.name, + item.emoji, + item.library, + (item as any).defaultSettings, + ), + })), + })); + }, [handleAddMethod]); + + return ( +
+
+ Reranking Methods + +
+ } + /> +
+
+ + {methodItems.length === 0 ? ( +
+ + No reranking methods selected. + +
+ ) : ( +
+ {/* List of Selected Methods */} + + {methodItems.map((method) => ( + + ))} + +
+ )} +
+ ); +}); + +RerankMethodListContainer.displayName = "RerankMethodListContainer"; + +export default RerankMethodListContainer; diff --git a/chainforge/react-server/src/RerankMethodSchemas.tsx b/chainforge/react-server/src/RerankMethodSchemas.tsx new file mode 100644 index 000000000..92d1afe48 --- /dev/null +++ b/chainforge/react-server/src/RerankMethodSchemas.tsx @@ -0,0 +1,287 @@ +import { ModelSettingsDict } from "./backend/typing"; + +/** + * Cross-encoder Reranking + */ +export const CrossEncoderRerankSchema: ModelSettingsDict = { + fullName: "Cross-encoder Reranker", + description: + "Rerank documents using a cross-encoder model for query-document pairs", + schema: { + type: "object", + required: ["model", "top_k"], + properties: { + shortName: { + type: "string", + default: "Cross-encoder Reranker", + title: "Nickname", + description: + "Unique identifier to appear in ChainForge. Keep it short.", + }, + model: { + type: "string", + default: "cross-encoder/ms-marco-MiniLM-L-6-v2", + title: "Cross-encoder Model", + enum: [ + "cross-encoder/ms-marco-MiniLM-L-6-v2", + "cross-encoder/ms-marco-MiniLM-L-12-v2", + "cross-encoder/ms-marco-TinyBERT-L-2-v2", + "cross-encoder/ms-marco-electra-base", + "BAAI/bge-reranker-base", + "BAAI/bge-reranker-large", + ], + description: "Pre-trained cross-encoder model for reranking", + }, + top_k: { + type: "number", + default: 5, + title: "Top K Results", + minimum: 1, + maximum: 50, + description: "Number of top documents to return after reranking", + }, + batch_size: { + type: "number", + default: 32, + title: "Batch Size", + minimum: 1, + maximum: 128, + description: "Batch size for model inference", + }, + }, + }, + uiSchema: { + shortName: { + "ui:widget": "text", + "ui:options": { + placeholder: "Custom name for your reranking method", + }, + }, + model: { + "ui:widget": "datalist", + }, + top_k: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 50, + step: 1, + }, + }, + batch_size: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 128, + step: 1, + }, + }, + }, + postprocessors: {}, +}; + +/** + * Cohere Rerank API + */ +export const CohereRerankSchema: ModelSettingsDict = { + fullName: "Cohere Rerank API", + description: "Rerank documents using Cohere's reranking API", + schema: { + type: "object", + required: ["model", "top_k"], + properties: { + shortName: { + type: "string", + default: "Cohere Rerank", + title: "Nickname", + description: + "Unique identifier to appear in ChainForge. Keep it short.", + }, + model: { + type: "string", + default: "rerank-v3.5", + title: "Cohere Model", + enum: [ + "rerank-v3.5", + "rerank-english-v3.0", + "rerank-multilingual-v3.0", + ], + description: "Cohere reranking model to use", + }, + top_k: { + type: "number", + default: 5, + title: "Top K Results", + minimum: 1, + maximum: 100, + description: "Number of top documents to return after reranking", + }, + max_chunks_per_doc: { + type: "number", + default: 10, + title: "Max Chunks per Document", + minimum: 1, + maximum: 100, + description: "Maximum number of chunks to consider per document", + }, + }, + }, + uiSchema: { + shortName: { + "ui:widget": "text", + "ui:options": { + placeholder: "Custom name for your reranking method", + }, + }, + model: { + "ui:widget": "datalist", + }, + top_k: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 100, + step: 1, + }, + }, + max_chunks_per_doc: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 100, + step: 1, + }, + }, + }, + postprocessors: {}, +}; + +// Combined schema object for all reranking methods +export const RerankMethodSchemas: { + [baseMethod: string]: ModelSettingsDict; +} = { + cross_encoder: CrossEncoderRerankSchema, + cohere_rerank: CohereRerankSchema, +}; + +// Method groupings for the menu +export const rerankMethodGroups = [ + { + label: "Cross Encoder", + items: [ + { + baseMethod: "cross_encoder", + name: "MiniLM-L-6-v2", + library: "CrossEncoder", + emoji: "🧠", + group: "Cross Encoder", + needsEmbeddingModel: false, + defaultSettings: { + model: "cross-encoder/ms-marco-MiniLM-L-6-v2", + shortName: "MiniLM-L-6-v2", + }, + }, + { + baseMethod: "cross_encoder", + name: "MiniLM-L-12-v2", + library: "CrossEncoder", + emoji: "🧠", + group: "Cross Encoder", + needsEmbeddingModel: false, + defaultSettings: { + model: "cross-encoder/ms-marco-MiniLM-L-12-v2", + shortName: "MiniLM-L-12-v2", + }, + }, + { + baseMethod: "cross_encoder", + name: "TinyBERT-L-2-v2", + library: "CrossEncoder", + emoji: "🧠", + group: "Cross Encoder", + needsEmbeddingModel: false, + defaultSettings: { + model: "cross-encoder/ms-marco-TinyBERT-L-2-v2", + shortName: "TinyBERT-L-2-v2", + }, + }, + { + baseMethod: "cross_encoder", + name: "Electra-Base", + library: "CrossEncoder", + emoji: "🧠", + group: "Cross Encoder", + needsEmbeddingModel: false, + defaultSettings: { + model: "cross-encoder/ms-marco-electra-base", + shortName: "Electra-Base", + }, + }, + { + baseMethod: "cross_encoder", + name: "BGE Reranker Base", + library: "CrossEncoder", + emoji: "🧠", + group: "Cross Encoder", + needsEmbeddingModel: false, + defaultSettings: { + model: "BAAI/bge-reranker-base", + shortName: "BGE Reranker Base", + }, + }, + { + baseMethod: "cross_encoder", + name: "BGE Reranker Large", + library: "CrossEncoder", + emoji: "🧠", + group: "Cross Encoder", + needsEmbeddingModel: false, + defaultSettings: { + model: "BAAI/bge-reranker-large", + shortName: "BGE Reranker Large", + }, + }, + ], + }, + { + label: "Cohere API", + items: [ + { + baseMethod: "cohere_rerank", + name: "Rerank v3.5 (Latest)", + library: "Cohere", + emoji: "💬", + group: "Cohere API", + needsEmbeddingModel: false, + defaultSettings: { + model: "rerank-v3.5", + shortName: "Rerank v3.5", + }, + }, + { + baseMethod: "cohere_rerank", + name: "Rerank English v3.0", + library: "Cohere", + emoji: "💬", + group: "Cohere API", + needsEmbeddingModel: false, + defaultSettings: { + model: "rerank-english-v3.0", + shortName: "Rerank English v3.0", + }, + }, + { + baseMethod: "cohere_rerank", + name: "Rerank Multilingual v3.0", + library: "Cohere", + emoji: "💬", + group: "Cohere API", + needsEmbeddingModel: false, + defaultSettings: { + model: "rerank-multilingual-v3.0", + shortName: "Rerank Multilingual v3.0", + }, + }, + ], + }, +]; diff --git a/chainforge/react-server/src/RerankNode.tsx b/chainforge/react-server/src/RerankNode.tsx new file mode 100644 index 000000000..26cec3251 --- /dev/null +++ b/chainforge/react-server/src/RerankNode.tsx @@ -0,0 +1,442 @@ +import React, { + useState, + useEffect, + useCallback, + useRef, + useContext, +} from "react"; +import { Handle, Position } from "reactflow"; +import { Badge } from "@mantine/core"; +import { Status } from "./StatusIndicatorComponent"; +import { AlertModalContext } from "./AlertModal"; +import BaseNode from "./BaseNode"; +import NodeLabel from "./NodeLabelComponent"; +import useStore from "./store"; + +import LLMResponseInspectorModal, { + LLMResponseInspectorModalRef, +} from "./LLMResponseInspectorModal"; +import InspectFooter from "./InspectFooter"; +import { IconSearch, IconSortAscending } from "@tabler/icons-react"; + +import RerankMethodListContainer, { + RerankMethodSpec, +} from "./RerankMethodListComponent"; + +import { TemplateVarInfo, LLMResponse } from "./backend/typing"; +import { StringLookup } from "./backend/cache"; +import { FLASK_BASE_URL } from "./backend/utils"; +import { v4 as uuid } from "uuid"; + +// Constants for handle positioning and styling +const HANDLE_Y_START = 60; // Adjust this value to move the first handle up/down +const HANDLE_Y_GAP = 30; // Adjust this value for spacing between handles +const HANDLE_X_OFFSET = "-14px"; // Nudge handle horizontally if needed (ReactFlow default is centered) + +const handleStyle: React.CSSProperties = { + background: "#555", + position: "absolute", // Necessary for precise positioning relative to wrapper + left: HANDLE_X_OFFSET, +}; +const badgeStyle: React.CSSProperties = { textTransform: "none" }; +const handleWrapperBaseStyle: React.CSSProperties = { + // Common style for the div wrapping Badge + Handle + position: "absolute", + left: "10px", // Padding from the node's left edge + display: "flex", + alignItems: "center", // Vertically align Badge and Handle dot + height: "20px", // Define height for alignment reference +}; +const badgeWrapperStyle: React.CSSProperties = { + // Style for the div specifically containing the Badge + marginRight: "8px", // Space between Badge and Handle dot +}; + +interface RerankNodeProps { + data: { + title?: string; + methods?: RerankMethodSpec[]; + refresh?: boolean; + }; + id: string; +} + +const RerankNode: React.FC = ({ data, id }) => { + const nodeDefaultTitle = "Rerank Node"; + const nodeIcon = ; + + const pullInputData = useStore((s) => s.pullInputData); + const setDataPropsForNode = useStore((s) => s.setDataPropsForNode); + const pingOutputNodes = useStore((s) => s.pingOutputNodes); + const apiKeys = useStore((s) => s.apiKeys); + + const showAlert = useContext(AlertModalContext); + + const [methodItems, setMethodItems] = useState( + data.methods || [], + ); + const [status, setStatus] = useState(Status.NONE); + const [jsonResponses, setJSONResponses] = useState([]); + + const inspectorRef = useRef(null); + + // On refresh + useEffect(() => { + if (data.refresh) { + setDataPropsForNode(id, { refresh: false, fields: [], output: [] }); + setJSONResponses([]); + setStatus(Status.NONE); + } + }, [data.refresh, id, setDataPropsForNode]); + + // Track changes in rerank methods + const handleMethodItemsChange = useCallback( + (newItems: RerankMethodSpec[], _oldItems: RerankMethodSpec[]) => { + setMethodItems(newItems); + setDataPropsForNode(id, { methods: newItems }); + if (status === Status.READY) setStatus(Status.WARNING); + }, + [id, status, setDataPropsForNode], + ); + + // The main reranking function + const runReranking = useCallback(async () => { + const handleError = (msg: string, err?: any) => { + console.error(msg, err); + showAlert?.(msg); + setStatus(Status.ERROR); + }; + + if (methodItems.length === 0) { + handleError("No reranking methods selected!"); + return; + } + + // 1) Pull data from upstream (chunks from ChunkNode or RetrievalNode, and query) + let inputData: { + chunks?: TemplateVarInfo[]; + query?: TemplateVarInfo[]; + text?: TemplateVarInfo[]; + } = {}; + + try { + inputData = pullInputData(["chunks", "query", "text"], id) as { + chunks?: TemplateVarInfo[]; + query?: TemplateVarInfo[]; + text?: TemplateVarInfo[]; + }; + } catch (error) { + handleError( + "No input data found. Is a ChunkNode or RetrievalNode connected?", + error, + ); + return; + } + + // Use chunks if available, otherwise fall back to text + const documentsArr = inputData.chunks || inputData.text || []; + const queryArr = inputData.query || []; + + if (documentsArr.length === 0) { + handleError( + "No documents found. Please attach a ChunkNode, RetrievalNode, or provide text.", + ); + return; + } + + // Validate that documents have valid text + const validDocuments = documentsArr.filter( + (doc) => doc && doc.text && StringLookup.get(doc.text), + ); + + if (validDocuments.length === 0) { + handleError( + "No valid documents with text found. Please check your input data.", + ); + return; + } + + setStatus(Status.LOADING); + setJSONResponses([]); + + // We'll group by method name to call the reranker + const allReranksByMethodName: Record = {}; + const allResponsesByMethodName: Record = {}; + + // Group methods by name + const methodsByName = methodItems.reduce( + (acc, method) => { + if (!acc[method.name]) acc[method.name] = []; + acc[method.name].push(method); + return acc; + }, + {} as Record, + ); + + // 2) For each method and each query (if available) + for (const [name, methods] of Object.entries(methodsByName)) { + allReranksByMethodName[name] = []; + allResponsesByMethodName[name] = []; + + // If we have queries, rerank for each query + // Otherwise, rerank all documents together + const queriesToProcess = queryArr.length > 0 ? queryArr : [null]; + + for (const queryInfo of queriesToProcess) { + const query = + queryInfo && queryInfo.text + ? StringLookup.get(queryInfo.text) || "" + : ""; + + for (const method of methods) { + try { + const formData = new FormData(); + formData.append("baseMethod", method.baseMethod); + + // Add documents as a JSON array + const documents = validDocuments.map( + (doc) => StringLookup.get(doc.text) || "", + ); + formData.append("documents", JSON.stringify(documents)); + + // Add query if available + if (query) { + formData.append("query", query); + } else { + console.warn( + `Warning: No query found when preparing payload for reranking with method ${method.name}. Proceeding without 'query' component. Results will be suboptimal.`, + ); + } + + // Add the user settings + Object.entries(method.settings ?? {}).forEach(([k, v]) => { + formData.append(k, String(v)); + }); + + // Add API keys + if (apiKeys) { + formData.append("api_keys", JSON.stringify(apiKeys)); + } + + const res = await fetch(`${FLASK_BASE_URL}rerank`, { + method: "POST", + body: formData, + }); + + if (!res.ok) { + const err = await res.json(); + throw new Error(err.error || "Reranking request failed"); + } + + const json = await res.json(); + const rerankedResults = + json.reranked_documents || json.results || []; + + // Process reranked results + const methodSafe = method.methodType.replace(/\W+/g, "_"); + const querySafe = query + ? query.slice(0, 20).replace(/\W+/g, "_") + : "no_query"; + + rerankedResults.forEach((result: any, index: number) => { + const rId = uuid(); + + // Extract text and score from result + let resultText = ""; + let score = 0; + + if (typeof result === "string") { + resultText = result; + score = 1.0 - index / rerankedResults.length; // Synthetic score based on rank + } else if (result.document || result.text) { + resultText = result.document || result.text; + score = + result.score || + result.relevance_score || + 1.0 - index / rerankedResults.length; + } else { + resultText = String(result); + score = 1.0 - index / rerankedResults.length; + } + + // Create the reranked document object + const rerankVar: TemplateVarInfo = { + text: resultText, + prompt: query || "N/A", + fill_history: { + rerankMethod: `${method.methodType} (${method.name})`, + query: query || "N/A", + originalRank: index, + score: String(score), + }, + llm: method.name, + metavars: { + query: query || "N/A", + rerankMethod: method.methodType, + originalRank: index, + score: score, + }, + }; + + allReranksByMethodName[name].push(rerankVar); + + // LLMResponse for inspector + const respObj: LLMResponse = { + uid: rId, + prompt: `Query: ${query || "N/A"} | Rank: ${index + 1} | Score: ${score.toFixed(3)}`, + vars: { + query: query || "N/A", + rank: String(index + 1), + score: String(score.toFixed(3)), + }, + responses: [resultText], + llm: method.name, + metavars: rerankVar.metavars || {}, + }; + + allResponsesByMethodName[name].push(respObj); + }); + } catch (err: any) { + handleError( + `Error reranking with ${method.name}: ${err.message}`, + err, + ); + return; + } + } + } + } + + // Combine results + const allReranks = Object.values(allReranksByMethodName).flat(); + const allResponses = Object.values(allResponsesByMethodName).flat(); + + // 3) Output data grouped by method + const groupedOutput = Object.entries(allReranksByMethodName).reduce( + (acc, [method, reranks]) => { + acc[method] = reranks.map((rr) => ({ + rank: rr.metavars?.originalRank, + query: rr.metavars?.query, + score: rr.metavars?.score, + method: rr.fill_history?.rerankMethod, + text: rr.text, + })); + return acc; + }, + {} as Record, + ); + + setDataPropsForNode(id, { + fields: allReranks, + output: groupedOutput, + }); + pingOutputNodes(id); + + setJSONResponses(allResponses); + setStatus(Status.READY); + }, [ + id, + methodItems, + pullInputData, + setDataPropsForNode, + showAlert, + pingOutputNodes, + ]); + + // Open inspector + const openInspector = () => { + if (jsonResponses.length > 0 && inspectorRef.current) { + inspectorRef.current.trigger(); + } + }; + + return ( + + + +
+ {/* Labeled Handle for 'chunks' */} +
+
+ + chunks + +
+ +
+ + {/* Labeled Handle for 'query' */} +
+
+ + query + +
+ +
+ + {/* Add margin top to push list below handles */} +
+ +
+
+ + {jsonResponses && jsonResponses.length > 0 && ( + { + // Do nothing + }} + isDrawerOpen={false} + label={ + <> + Inspect reranked docs + + } + /> + )} + + {/* The LLM Response Inspector */} + + + +
+ ); +}; + +export default RerankNode; diff --git a/chainforge/react-server/src/RetrievalMethodListComponent.tsx b/chainforge/react-server/src/RetrievalMethodListComponent.tsx new file mode 100644 index 000000000..abf68728d --- /dev/null +++ b/chainforge/react-server/src/RetrievalMethodListComponent.tsx @@ -0,0 +1,931 @@ +import React, { + useState, + useRef, + forwardRef, + useImperativeHandle, + useCallback, + useMemo, + useEffect, +} from "react"; +import { + Menu, + Button, + Card, + Group, + Text, + ActionIcon, + Modal, + Divider, + Box, + Badge, + Stack, + ScrollArea, +} from "@mantine/core"; +import { useDisclosure } from "@mantine/hooks"; +import { + IconPlus, + IconTrash, + IconSettings, + IconChevronRight, + IconLink, + IconUnlink, + IconGitMerge, +} from "@tabler/icons-react"; +import Form from "@rjsf/core"; +import { RJSFSchema, UiSchema } from "@rjsf/utils"; +import validator from "@rjsf/validator-ajv8"; +import { v4 as uuid } from "uuid"; +import { + RetrievalMethodSchemas, + retrievalMethodGroups, + embeddingProviders, + rankFusionMethods, +} from "./RetrievalMethodSchemas"; +import useStore from "./store"; +import { DatalistWidget } from "./ModelSettingsModal"; +import NestedMenu, { NestedMenuItemProps } from "./NestedMenu"; +import { ensureUniqueName } from "./backend/utils"; + +/** Linked group of methods with fusion settings */ +export interface LinkedMethodGroup { + id: string; + methodKeys: string[]; + fusionMethod: string; + fusionSettings: Record; + groupName?: string; +} + +/** Enhanced method spec to track group membership */ +export interface RetrievalMethodSpec { + key: string; + baseMethod: string; + methodName: string; + library: string; + emoji?: string; + needsEmbeddingModel?: boolean; + embeddingProvider?: string; + settings?: Record; + source?: "builtin" | "custom"; + settingsSchema?: { settings?: Record; ui?: Record }; + groupId?: string; // Links to a LinkedMethodGroup +} + +/** Settings modal */ +interface SettingsModalProps { + opened: boolean; + onClose: () => void; + methodItem: RetrievalMethodSpec; + onSettingsUpdate: (settings: any) => void; +} + +const SettingsModal: React.FC = ({ + opened, + onClose, + methodItem, + onSettingsUpdate, +}) => { + const [currentSettings, setCurrentSettings] = React.useState( + methodItem.settings, + ); + + // Update local state when methodItem changes + React.useEffect(() => { + setCurrentSettings(methodItem.settings); + }, [methodItem.settings]); + + const builtin = RetrievalMethodSchemas[methodItem.baseMethod]; + + // With normalized store data, a single custom check suffices. + const isCustom = methodItem.source === "custom"; + + // Prefer custom schema if present; otherwise built-in schema. + const customSchema = methodItem.settingsSchema + ? { + schema: { + type: "object", + properties: methodItem.settingsSchema.settings ?? {}, + }, + uiSchema: methodItem.settingsSchema.ui ?? {}, + } + : null; + + let finalSchema = (customSchema?.schema ?? builtin?.schema) as + | RJSFSchema + | undefined; + const finalUiSchema = (customSchema?.uiSchema ?? builtin?.uiSchema) as + | UiSchema + | undefined; + + // Ensure customs always have a Nickname field + if (isCustom) { + const props = (finalSchema?.properties ?? {}) as Record; + if (!("shortName" in props)) { + finalSchema = { + type: "object", + properties: { + shortName: { + type: "string", + title: "Nickname", + default: methodItem.settings?.shortName ?? methodItem.methodName, + description: + "Unique identifier to appear in ChainForge. Keep it short.", + }, + ...props, + }, + } as RJSFSchema; + } + } + + // Built-ins: Update embedding model enum based on selected provider + if ( + !isCustom && + methodItem.needsEmbeddingModel && + builtin && + finalSchema?.properties + ) { + // Get the currently selected provider from current settings + const selectedProvider = + currentSettings?.embeddingProvider || + (finalSchema.properties as any).embeddingProvider?.default || + "huggingface"; + + const provider = embeddingProviders.find( + (p) => p.value === selectedProvider, + ); + + if (provider && provider.models && provider.models.length > 0) { + // Update the embeddingModel enum to match the selected provider's models + finalSchema = { + ...finalSchema, + properties: { + ...finalSchema.properties, + embeddingModel: { + ...(finalSchema.properties as any).embeddingModel, + enum: provider.models, + }, + }, + } as RJSFSchema; + } + } + + const handleSettingsChange = (e: any) => { + const newSettings = e.formData; + + // Check if embeddingProvider changed + if ( + methodItem.needsEmbeddingModel && + currentSettings?.embeddingProvider && + newSettings.embeddingProvider !== currentSettings.embeddingProvider + ) { + // Provider changed - update model to first model of new provider + const newProvider = embeddingProviders.find( + (p) => p.value === newSettings.embeddingProvider, + ); + if (newProvider && newProvider.models && newProvider.models.length > 0) { + newSettings.embeddingModel = newProvider.models[0]; + } + } + + setCurrentSettings(newSettings); + onSettingsUpdate(newSettings); + }; + + const hasProps = + !!finalSchema && Object.keys(finalSchema.properties ?? {}).length > 0; + if (!hasProps) { + return ( + +
+ This method has no configurable settings. +
+
+ ); + } + + return ( + + + schema={finalSchema as RJSFSchema} + uiSchema={(finalUiSchema || {}) as UiSchema} + validator={validator} + formData={currentSettings} + onChange={handleSettingsChange} + widgets={{ datalist: DatalistWidget }} + > + {/* live update via onChange */} + + + + + + )} + + + ); +}; + +/** One row in the list */ +interface RetrievalMethodListItemProps { + methodItem: RetrievalMethodSpec; + onRemove: (key: string) => void; + onSettingsUpdate: (key: string, settings: any) => void; + latency?: string; +} + +const RetrievalMethodListItem: React.FC< + RetrievalMethodListItemProps & { + isLinked?: boolean; + isFirstInGroup?: boolean; + isLastInGroup?: boolean; + onLink?: () => void; + onUnlink?: () => void; + onFusionSettings?: () => void; + } +> = ({ + methodItem, + onRemove, + onSettingsUpdate, + latency, + isLinked = false, + isFirstInGroup = false, + isLastInGroup = false, + onLink, + onUnlink, + onFusionSettings, +}) => { + const [opened, { open, close }] = useDisclosure(false); + + return ( + <> +
+ {/* Title (left) */} +
+ {methodItem.emoji && `${methodItem.emoji} `} + {methodItem.settings?.shortName || methodItem.methodName} + {latency && ( + + {latency} + + )} +
+ + {/* Actions (right) */} +
+ {isFirstInGroup && ( + + + + )} + {!isLinked && onLink && ( + + + + )} + {isLinked && onUnlink && ( + + + + )} + onRemove(methodItem.key)} + title="Remove" + > + + + { + e.preventDefault(); + e.stopPropagation(); + open(); + }} + title="Settings" + > + + +
+
+ + {/* Keep modal mounted alongside the row so open()/close() works */} + + onSettingsUpdate(methodItem.key, settings) + } + /> + + ); +}; + +/** Main container */ +export interface RetrievalMethodListContainerProps { + initMethodItems?: RetrievalMethodSpec[]; + initLinkedGroups?: LinkedMethodGroup[]; + onGroupsChange?: (groups: LinkedMethodGroup[]) => void; + onItemsChange?: ( + newItems: RetrievalMethodSpec[], + oldItems: RetrievalMethodSpec[], + ) => void; + methodResults?: Record; +} + +export const RetrievalMethodListContainer = forwardRef< + any, + RetrievalMethodListContainerProps +>((props, ref) => { + const [methodItems, setMethodItems] = useState( + props.initMethodItems || [], + ); + const linkedGroups: LinkedMethodGroup[] = props.initLinkedGroups ?? []; + + const [fusionModalGroup, setFusionModalGroup] = + useState(null); + const oldItemsRef = useRef(methodItems); + + useImperativeHandle(ref, () => ({ + getMethodItems: () => methodItems, + })); + + const notifyItemsChanged = useCallback( + (newItems: RetrievalMethodSpec[]) => { + props.onItemsChange?.(newItems, oldItemsRef.current); + oldItemsRef.current = newItems; + }, + [props.onItemsChange], + ); + + const handleRemoveMethod = useCallback( + (key: string) => { + const newItems = methodItems.filter((m) => m.key !== key); + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + const handleSettingsUpdate = useCallback( + (key: string, newSettings: any) => { + const newItems = methodItems.map((m) => + m.key === key ? { ...m, settings: newSettings } : m, + ); + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + function defaultsFromCustomSchema(s?: { settings?: Record }) { + const props = s?.settings || {}; + const out: Record = {}; + for (const [k, v] of Object.entries(props)) { + if (v && typeof v === "object" && "default" in (v as any)) { + out[k] = (v as any).default; + } + } + return out; + } + + const addMethod = useCallback( + ( + m: Omit, + embeddingProviderValue?: string, + ) => { + const isCustom = m.source === "custom"; + + // Find selected embedding provider (built-ins only) + const provider = embeddingProviderValue + ? embeddingProviders.find((p) => p.value === embeddingProviderValue) + : undefined; + + let defaultSettings: Record = {}; + + const uniqueName = ensureUniqueName( + m.methodName, + methodItems.map((i) => i.settings?.shortName || i.methodName), + ); + + if (isCustom) { + // Pull defaults from normalized custom schema + defaultSettings = defaultsFromCustomSchema(m.settingsSchema); + } else { + const methodSchema = RetrievalMethodSchemas[m.baseMethod]; + if (methodSchema?.schema?.properties) { + const schemaProps = methodSchema.schema.properties; + defaultSettings = Object.entries(schemaProps).reduce( + (acc, [key, prop]) => { + if ("default" in prop) acc[key] = (prop as any).default; + return acc; + }, + {} as Record, + ); + } + // Override embedding provider and model if passed explicitly + if (m.needsEmbeddingModel && embeddingProviderValue) { + defaultSettings.embeddingProvider = embeddingProviderValue; + if (provider?.models?.length) { + defaultSettings.embeddingModel = provider.models[0]; + } + } + } + defaultSettings.shortName = uniqueName; + + const newItem: RetrievalMethodSpec = { + key: uuid(), + baseMethod: m.baseMethod, + methodName: uniqueName, + library: m.library, + emoji: m.emoji, + needsEmbeddingModel: m.needsEmbeddingModel, + ...(m.needsEmbeddingModel && embeddingProviderValue + ? { embeddingProvider: provider?.value || "" } + : {}), + source: isCustom ? "custom" : "builtin", + settingsSchema: isCustom ? m.settingsSchema : undefined, + settings: defaultSettings, + }; + + const newItems = [...methodItems, newItem]; + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + // Thanks to the unified store normalizer, these are already consistent. + const customRetrievers = useStore((s) => s.customRetrievers || []); + + const handleLinkMethods = useCallback( + (methodKey: string) => { + const currentIndex = methodItems.findIndex((m) => m.key === methodKey); + if (currentIndex === -1 || currentIndex === methodItems.length - 1) + return; + + const nextMethod = methodItems[currentIndex + 1]; + if (nextMethod.groupId) return; + + const newGroupId = uuid(); + const newGroup: LinkedMethodGroup = { + id: newGroupId, + methodKeys: [methodKey, nextMethod.key], + fusionMethod: "reciprocal_rank_fusion", + fusionSettings: { k: 60 }, + }; + + props.onGroupsChange?.([...(linkedGroups || []), newGroup]); + + const newItems = methodItems.map((m) => + m.key === methodKey || m.key === nextMethod.key + ? { ...m, groupId: newGroupId } + : m, + ); + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + const handleUnlinkMethods = useCallback( + (groupId: string) => { + props.onGroupsChange?.( + (linkedGroups || []).filter((g) => g.id !== groupId), + ); + + const newItems = methodItems.map((m) => + m.groupId === groupId ? { ...m, groupId: undefined } : m, + ); + setMethodItems(newItems); + notifyItemsChanged(newItems); + }, + [methodItems, notifyItemsChanged], + ); + + const handleFusionSettingsUpdate = useCallback( + (groupId: string, settings: any) => { + // Update the source of truth + props.onGroupsChange?.( + (linkedGroups || []).map((g) => + g.id === groupId ? { ...g, fusionSettings: settings } : g, + ), + ); + + // Keep the modal's local state in lockstep so RJSF stays editable + setFusionModalGroup((prev) => + prev && prev.id === groupId + ? { ...prev, fusionSettings: settings } + : prev, + ); + }, + [linkedGroups, props.onGroupsChange], + ); + + const handleFusionMethodChange = useCallback( + (groupId: string, fusionMethod: string, defaultSettings: any) => { + props.onGroupsChange?.( + (linkedGroups || []).map((g) => + g.id === groupId + ? { ...g, fusionMethod, fusionSettings: defaultSettings } + : g, + ), + ); + setFusionModalGroup((prev) => + prev + ? { + ...prev, + fusionMethod, + fusionSettings: defaultSettings, + } + : null, + ); + }, + [linkedGroups, props.onGroupsChange], + ); + + const addMenuItems: NestedMenuItemProps[] = useMemo(() => { + // Built-in retrieval groups + const builtInGroups: NestedMenuItemProps[] = retrievalMethodGroups.map( + (group) => ({ + key: `group-${group.label}`, + title: group.label, + items: group.items.map((m) => ({ + key: `method-${m.baseMethod}-${m.embeddingProvider || "default"}`, + title: m.methodName, + tooltip: m.description, + icon: m.emoji ? {m.emoji} : undefined, + onClick: () => addMethod(m, m.embeddingProvider), + })), + }), + ); + + // Custom retrievers group (if any) + const customGroup: NestedMenuItemProps[] = + customRetrievers.length > 0 + ? [ + { + key: "group-custom", + title: "Custom Providers", + items: customRetrievers.map((prov) => ({ + key: `custom-${prov.key}`, + title: prov.methodName, + icon: prov.emoji ? {prov.emoji} : undefined, + onClick: () => + addMethod({ + baseMethod: prov.baseMethod, + methodName: prov.methodName, + library: prov.library, + emoji: prov.emoji, + needsEmbeddingModel: prov.needsEmbeddingModel, + source: "custom", + settingsSchema: + prov.settingsSchema ?? (prov as any).settings_schema, + } as any), + })), + }, + ] + : []; + + return [...builtInGroups, ...customGroup]; + }, [retrievalMethodGroups, customRetrievers, embeddingProviders, addMethod]); + + return ( +
+
+ Retrieval Methods +
+ } + /> +
+
+ +
+ + {methodItems.length === 0 ? ( + + No retrieval methods selected. + + ) : ( + methodItems.map((item) => { + const group = linkedGroups.find((g) => g.id === item.groupId); + const members = methodItems.filter( + (m) => m.groupId === group?.id, + ); + const isLinked = !!group; + const isFirstInGroup = isLinked && members[0]?.key === item.key; + const isLastInGroup = + isLinked && members[members.length - 1]?.key === item.key; + const latency = + props.methodResults?.[item.key]?.metavars?.latency; + + return ( + handleLinkMethods(item.key)} + onUnlink={ + group ? () => handleUnlinkMethods(group.id) : undefined + } + onFusionSettings={ + isFirstInGroup && group + ? () => setFusionModalGroup(group) + : undefined + } + /> + ); + }) + )} + +
+ + {fusionModalGroup && ( + setFusionModalGroup(null)} + group={fusionModalGroup} + methodItems={methodItems} + onSettingsUpdate={(settings) => + handleFusionSettingsUpdate(fusionModalGroup.id, settings) + } + onFusionMethodChange={handleFusionMethodChange} + /> + )} +
+ ); +}); + +RetrievalMethodListContainer.displayName = "RetrievalMethodListContainer"; +export default RetrievalMethodListContainer; diff --git a/chainforge/react-server/src/RetrievalMethodSchemas.tsx b/chainforge/react-server/src/RetrievalMethodSchemas.tsx new file mode 100644 index 000000000..a9da88da0 --- /dev/null +++ b/chainforge/react-server/src/RetrievalMethodSchemas.tsx @@ -0,0 +1,713 @@ +import { ModelSettingsDict } from "./backend/typing"; + +// Available embedding models +export const embeddingProviders = [ + { + label: "🤗 HuggingFace Transformers", + value: "huggingface", + models: [ + "sentence-transformers/all-MiniLM-L6-v2", + "sentence-transformers/all-mpnet-base-v2", + "thenlper/gte-large", + "BAAI/bge-large-en-v1.5", + ], + }, + { + label: "🤖 OpenAI Embeddings", + value: "openai", + models: [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ], + }, + { + label: "🔷 Azure OpenAI Embeddings", + value: "azure-openai", + models: [], + }, + { + label: "💬 Cohere Embeddings", + value: "cohere", + models: [ + "embed-english-v2.0", + "embed-multilingual-v2.0", + "embed-english-light-v2.0", + ], + }, + { + label: "🧠 Sentence Transformers", + value: "sentence-transformers", + models: [ + "all-MiniLM-L6-v2", + "all-mpnet-base-v2", + "paraphrase-MiniLM-L3-v2", + "all-distilroberta-v1", + ], + }, +]; + +/** + * BM25 Retrieval + */ +export const BM25Schema: ModelSettingsDict = { + fullName: "BM25 Retrieval", + description: "Retrieves documents using the BM25 ranking algorithm", + schema: { + type: "object", + required: ["top_k", "bm25_k1", "bm25_b"], + properties: { + shortName: { + type: "string", + default: "BM25 Retrieval", + title: "Nickname", + description: + "Unique identifier to appear in ChainForge. Keep it short.", + }, + top_k: { + type: "number", + default: 5, + title: "Top K Results", + }, + bm25_k1: { + type: "number", + default: 1.5, + title: "k1 Parameter", + }, + bm25_b: { + type: "number", + default: 0.75, + title: "b Parameter", + }, + }, + }, + uiSchema: { + shortName: { + "ui:widget": "text", + "ui:options": { + placeholder: "Custom name for your retrieval method", + }, + }, + top_k: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 20, + step: 1, + }, + }, + bm25_k1: { + "ui:widget": "range", + "ui:options": { + min: 0.5, + max: 3.0, + step: 0.1, + }, + }, + bm25_b: { + "ui:widget": "range", + "ui:options": { + min: 0, + max: 1, + step: 0.05, + }, + }, + }, + postprocessors: {}, +}; + +/** + * TF-IDF Retrieval + */ +export const TFIDFSchema: ModelSettingsDict = { + fullName: "TF-IDF Retrieval", + description: "Retrieves documents using TF-IDF scoring", + schema: { + type: "object", + required: ["top_k", "max_features"], + properties: { + shortName: { + type: "string", + default: "TF-IDF Retrieval", + title: "Nickname", + description: + "Unique identifier to appear in ChainForge. Keep it short.", + }, + top_k: { + type: "number", + title: "Top K Results", + }, + max_features: { + type: "number", + title: "Max Features (Vocabulary Size)", // Clarified title + default: 500, + }, + }, + }, + uiSchema: { + shortName: { + "ui:widget": "text", + "ui:options": { + placeholder: "Custom name for your retrieval method", + }, + }, + top_k: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 50, // Increased max slightly? Adjust as needed + step: 1, + }, + }, + max_features: { + "ui:widget": "range", + "ui:options": { + min: 100, + max: 10000, // Increased max? Adjust as needed + step: 100, + }, + }, + }, + postprocessors: {}, +}; + +/** + * Boolean Search + */ +export const BooleanSearchSchema: ModelSettingsDict = { + fullName: "Boolean Search", + description: "Simple boolean keyword matching", + schema: { + type: "object", + required: ["top_k", "required_match_count"], + properties: { + shortName: { + type: "string", + default: "Boolean Search", + title: "Nickname", + description: + "Unique identifier to appear in ChainForge. Keep it short.", + }, + top_k: { + type: "number", + default: 5, + title: "Top K Results", + }, + required_match_count: { + type: "number", + default: 1, + title: "Required Matches", + }, + }, + }, + uiSchema: { + shortName: { + "ui:widget": "text", + "ui:options": { + placeholder: "Custom name for your retrieval method", + }, + }, + top_k: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 20, + step: 1, + }, + }, + required_match_count: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 10, + step: 1, + }, + }, + }, + postprocessors: {}, +}; + +/** + * Keyword Overlap + */ +export const KeywordOverlapSchema: ModelSettingsDict = { + fullName: "Keyword Overlap", + description: "Retrieves documents based on keyword overlap ratio", + schema: { + type: "object", + required: ["top_k", "normalization_factor"], + properties: { + shortName: { + type: "string", + default: "Keyword Overlap", + title: "Nickname", + description: + "Unique identifier to appear in ChainForge. Keep it short.", + }, + top_k: { + type: "number", + default: 5, + title: "Top K Results", + }, + normalization_factor: { + type: "number", + default: 0.75, + title: "Normalization Factor", + }, + }, + }, + uiSchema: { + shortName: { + "ui:widget": "text", + "ui:options": { + placeholder: "Custom name for your retrieval method", + }, + }, + top_k: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 20, + step: 1, + }, + }, + normalization_factor: { + "ui:widget": "range", + "ui:options": { + min: 0, + max: 1, + step: 0.05, + }, + }, + }, + postprocessors: {}, +}; + +/** + * Unified Embedding-based Similarity Schema + * Consolidates cosine, manhattan, euclidean, and vector store approaches + */ +export const EmbeddingSimilaritySchema: ModelSettingsDict = { + fullName: "Embedding-based Similarity", + description: + "Retrieves documents using semantic similarity between embeddings", + schema: { + type: "object", + required: [ + "top_k", + "similarity_threshold", + "similarity_metric", + "storage_backend", + ], + properties: { + shortName: { + type: "string", + default: "Embedding Similarity", + title: "Nickname", + description: + "Unique identifier to appear in ChainForge. Keep it short.", + }, + embeddingProvider: { + type: "string", + title: "Embedding Provider", + enum: embeddingProviders.map((p) => p.value), + default: "huggingface", + description: "Select the embedding provider to use", + }, + embeddingModel: { + type: "string", + title: "Embedding Model", + default: "sentence-transformers/all-MiniLM-L6-v2", + description: "Select or enter a custom embedding model name", + }, + embeddingLocalPath: { + type: "string", + title: "Local Model Path (optional)", + default: "", + description: + "Only needed if you prefer local files instead of downloading the model automatically.", + }, + top_k: { + type: "number", + default: 5, + title: "Top K Results", + description: "Number of most similar documents to retrieve", + }, + similarity_threshold: { + type: "number", + default: 50, + title: "Similarity Threshold (%)", + minimum: 0, + maximum: 100, + step: 1, + description: + "Minimum similarity percentage (0-100) required for a result to be considered relevant.", + }, + similarity_metric: { + type: "string", + default: "cosine", + title: "Similarity Metric", + enum: ["cosine", "euclidean", "dot_product"], + description: "How to measure similarity between embeddings", + }, + storage_backend: { + type: "string", + default: "lancedb", + title: "Storage Backend", + enum: ["lancedb", "faiss"], + description: + "Where to store and search embeddings. LanceDB is simplest for persistence; FAISS for large-scale (requires separate installation) and possibly connecting to a pre-computed FAISS vector store on your local disk.", + }, + // Disable clustering method for now, too complex + // use_clustering: { + // type: "boolean", + // default: false, + // title: "Enable Clustering", + // description: "Pre-cluster documents to improve retrieval on large, diverse corpora", + // }, + // n_clusters: { + // type: "number", + // default: 5, + // title: "Number of Clusters", + // description: "How many clusters to create (only used if clustering is enabled)", + // }, + // LanceDB-specific settings + lancedb_path: { + type: "string", + default: "", + title: "LanceDB Path", + description: + "File path for LanceDB database (required if using LanceDB backend)", + }, + lancedb_table: { + type: "string", + default: "embeddings", + title: "LanceDB Table Name", + description: "Table name within LanceDB", + }, + lancedb_search_method: { + type: "string", + default: "similarity", + title: "LanceDB Search Method", + enum: ["similarity", "mmr", "hybrid"], + description: + "Search strategy: standard similarity, MMR (diverse results), or hybrid (vector + keyword)", + }, + // FAISS-specific settings + faiss_path: { + type: "string", + default: "", + title: "FAISS Index Path", + description: + "File path to save/load FAISS index (required if using FAISS backend)", + }, + faiss_mode: { + type: "string", + default: "create", + title: "FAISS Mode", + enum: ["create", "load"], + description: "Create new FAISS index or load existing one", + }, + }, + }, + uiSchema: { + shortName: { + "ui:widget": "text", + "ui:options": { + placeholder: "Custom name for your retrieval method", + }, + }, + embeddingProvider: { + "ui:widget": "select", + "ui:options": { + enumOptions: embeddingProviders.map((p) => ({ + label: p.label, + value: p.value, + })), + }, + "ui:help": "Choose the embedding provider", + }, + embeddingModel: { + "ui:widget": "datalist", + "ui:help": "Select a model or enter a custom model name", + }, + embeddingLocalPath: { + "ui:widget": "text", + "ui:options": { + placeholder: "e.g., ./my_model_directory", + }, + }, + top_k: { + "ui:widget": "range", + "ui:options": { + min: 1, + max: 20, + step: 1, + }, + }, + similarity_threshold: { + "ui:widget": "range", + "ui:options": { + min: 0, + max: 100, + step: 1, + }, + }, + similarity_metric: { + "ui:widget": "select", + "ui:options": { + enumOptions: [ + { label: "Cosine Similarity (standard for RAG)", value: "cosine" }, + { label: "Euclidean Distance (L2)", value: "euclidean" }, + { label: "Dot Product (Inner Product)", value: "dot_product" }, + ], + }, + }, + storage_backend: { + "ui:widget": "select", + "ui:options": { + enumOptions: [ + { label: "In-Memory (simple, no persistence)", value: "memory" }, + { label: "LanceDB (persistent, recommended)", value: "lancedb" }, + { + label: "FAISS (high-performance, requires installation)", + value: "faiss", + }, + ], + }, + }, + // use_clustering: { + // "ui:widget": "checkbox", + // }, + // n_clusters: { + // "ui:widget": "range", + // "ui:options": { + // min: 2, + // max: 20, + // step: 1, + // }, + // }, + lancedb_path: { + "ui:widget": "text", + "ui:options": { + placeholder: "e.g., ./my_lancedb", + }, + }, + lancedb_table: { + "ui:widget": "text", + "ui:options": { + placeholder: "embeddings", + }, + }, + lancedb_search_method: { + "ui:widget": "select", + "ui:options": { + enumOptions: [ + { label: "Standard Similarity", value: "similarity" }, + { + label: "Maximum Marginal Relevance (diverse results)", + value: "mmr", + }, + { label: "Hybrid (vector + keyword)", value: "hybrid" }, + ], + }, + }, + faiss_path: { + "ui:widget": "text", + "ui:options": { + placeholder: "e.g., ./my_index.faiss", + }, + }, + faiss_mode: { + "ui:widget": "select", + "ui:options": { + enumOptions: [ + { label: "Create New Index", value: "create" }, + { label: "Load Existing Index", value: "load" }, + ], + }, + }, + }, + postprocessors: {}, +}; + +// Add rank fusion methods +export const rankFusionMethods = [ + { + value: "reciprocal_rank_fusion", + label: "Reciprocal Rank Fusion (RRF)", + description: "Combines rankings using reciprocal rank formula", + schema: { + type: "object", + properties: { + k: { + type: "number", + title: "K Parameter", + default: 60, + description: "Parameter for RRF formula (higher = more democratic)", + }, + weights: { + type: "array", + title: "Method Weights", + items: { type: "number" }, + description: + "Optional weights for each method (leave empty for equal weights)", + }, + }, + }, + }, + { + value: "weighted_average", + label: "Weighted Average", + description: "Simple weighted average of scores", + schema: { + type: "object", + properties: { + normalize_scores: { + type: "boolean", + title: "Normalize Scores", + default: true, + description: "Normalize scores before combining", + }, + }, + }, + }, +]; + +// Combined schema object for all retrieval methods +export const RetrievalMethodSchemas: { + [baseMethod: string]: ModelSettingsDict; +} = { + bm25: BM25Schema, + tfidf: TFIDFSchema, + boolean: BooleanSearchSchema, + overlap: KeywordOverlapSchema, + embedding: EmbeddingSimilaritySchema, + // Deprecated methods (kept for backwards compatibility) + cosine: EmbeddingSimilaritySchema, + euclidean: EmbeddingSimilaritySchema, + clustered: EmbeddingSimilaritySchema, +}; + +// Method groupings for the menu +export const retrievalMethodGroups = [ + { + label: "Keyword-based Retrieval", + items: [ + { + baseMethod: "bm25", + methodName: "BM25 Retrieval", + library: "BM25", + emoji: "📊", + group: "Keyword-based Retrieval", + needsEmbeddingModel: false, + embeddingProvider: undefined, + description: + "Classic keyword ranking using term frequency and document length normalization. Great default for keyword-heavy queries.", + }, + { + baseMethod: "tfidf", + methodName: "TF-IDF Retrieval", + library: "TF-IDF", + emoji: "📈", + group: "Keyword-based Retrieval", + needsEmbeddingModel: false, + embeddingProvider: undefined, + description: + "Vector-space retrieval based on term frequency–inverse document frequency. Good for exact words and rare terms.", + }, + { + baseMethod: "boolean", + methodName: "Boolean Search", + library: "Boolean Search", + emoji: "🔍", + group: "Keyword-based Retrieval", + needsEmbeddingModel: false, + embeddingProvider: undefined, + description: + "Keyword retrieval based on minimum token overlap with the query, ranked by how many words they share.", + }, + { + baseMethod: "overlap", + methodName: "Keyword Overlap", + library: "KeywordOverlap", + emoji: "🎯", + group: "Keyword-based Retrieval", + needsEmbeddingModel: false, + embeddingProvider: undefined, + description: + "Score documents by how many query keywords they share. Simple and fast when term overlap is what matters.", + }, + ], + }, + { + label: "Embedding-based Retrieval", + items: [ + { + baseMethod: "embedding", + methodName: "HuggingFace Embedding", + library: "EmbeddingSimilarity", + emoji: "🤗", + group: "Embedding-based Retrieval", + needsEmbeddingModel: true, + embeddingProvider: "huggingface", + description: + "Retrieve documents using HuggingFace transformer embeddings. Fast, open-source, and runs locally.", + }, + { + baseMethod: "embedding", + methodName: "OpenAI Embedding", + library: "EmbeddingSimilarity", + emoji: "🤖", + group: "Embedding-based Retrieval", + needsEmbeddingModel: true, + embeddingProvider: "openai", + description: + "Retrieve documents using OpenAI embeddings (ada-002, text-embedding-3). High quality, requires API key.", + }, + { + baseMethod: "embedding", + methodName: "Azure OpenAI Embedding", + library: "EmbeddingSimilarity", + emoji: "🔷", + group: "Embedding-based Retrieval", + needsEmbeddingModel: true, + embeddingProvider: "azure-openai", + description: + "Retrieve documents using Azure OpenAI embeddings. Enterprise-ready with Azure compliance.", + }, + { + baseMethod: "embedding", + methodName: "Cohere Embedding", + library: "EmbeddingSimilarity", + emoji: "💬", + group: "Embedding-based Retrieval", + needsEmbeddingModel: true, + embeddingProvider: "cohere", + description: + "Retrieve documents using Cohere embeddings. Multilingual support and optimized for search.", + }, + { + baseMethod: "embedding", + methodName: "Sentence Transformers Embedding", + library: "EmbeddingSimilarity", + emoji: "🧠", + group: "Embedding-based Retrieval", + needsEmbeddingModel: true, + embeddingProvider: "sentence-transformers", + description: + "Retrieve documents using Sentence Transformers. Optimized for semantic similarity tasks.", + }, + // { + // baseMethod: "clustered", + // methodName: "Clustered Embedding", + // library: "Clustered", + // emoji: "🎲", + // group: "Embedding-based Retrieval", + // needsEmbeddingModel: true, + // description: + // "Cluster documents in embedding space, then retrieve from the most relevant clusters. Good for large, heterogeneous corpora.", + // }, + ], + }, +]; diff --git a/chainforge/react-server/src/RetrievalNode.tsx b/chainforge/react-server/src/RetrievalNode.tsx new file mode 100644 index 000000000..efe6a1618 --- /dev/null +++ b/chainforge/react-server/src/RetrievalNode.tsx @@ -0,0 +1,568 @@ +import React, { + useState, + useEffect, + useCallback, + useRef, + useContext, +} from "react"; +import { Handle, Position } from "reactflow"; +import { Badge, Progress } from "@mantine/core"; +import { IconSearch } from "@tabler/icons-react"; +import BaseNode from "./BaseNode"; +import NodeLabel from "./NodeLabelComponent"; +import useStore from "./store"; +import InspectFooter from "./InspectFooter"; +import { AlertModalContext } from "./AlertModal"; +import AreYouSureModal, { AreYouSureModalRef } from "./AreYouSureModal"; +import LLMResponseInspectorModal, { + LLMResponseInspectorModalRef, +} from "./LLMResponseInspectorModal"; +import RetrievalMethodListContainer, { + RetrievalMethodSpec, +} from "./RetrievalMethodListComponent"; +import { LLMResponse, TemplateVarInfo } from "./backend/typing"; +import { FLASK_BASE_URL } from "./backend/utils"; +import type { LinkedMethodGroup } from "./RetrievalMethodListComponent"; +import { Status } from "./StatusIndicatorComponent"; + +interface RetrievalNodeProps { + id: string; + data: { + title?: string; + methods?: RetrievalMethodSpec[]; + results?: Record; + refresh?: boolean; + linked_groups?: LinkedMethodGroup[]; + }; +} + +// Constants for handle positioning and styling +const HANDLE_Y_START = 60; // Adjust this value to move the first handle up/down +const HANDLE_Y_GAP = 30; // Adjust this value for spacing between handles +const HANDLE_X_OFFSET = "-14px"; // Nudge handle horizontally if needed (ReactFlow default is centered) + +const handleStyle: React.CSSProperties = { + background: "#555", + position: "absolute", // Necessary for precise positioning relative to wrapper + left: HANDLE_X_OFFSET, +}; +const badgeStyle: React.CSSProperties = { textTransform: "none" }; +const handleWrapperBaseStyle: React.CSSProperties = { + // Common style for the div wrapping Badge + Handle + position: "absolute", + left: "10px", // Padding from the node's left edge + display: "flex", + alignItems: "center", // Vertically align Badge and Handle dot + height: "20px", // Define height for alignment reference +}; +const badgeWrapperStyle: React.CSSProperties = { + // Style for the div specifically containing the Badge + marginRight: "8px", // Space between Badge and Handle dot +}; + +const RetrievalNode: React.FC = ({ id, data }) => { + const nodeDefaultTitle = "Retrieval Node"; + const nodeIcon = "🎯"; + + // Store hooks + const pullInputData = useStore((s) => s.pullInputData); + const setDataPropsForNode = useStore((s) => s.setDataPropsForNode); + const pingOutputNodes = useStore((s) => s.pingOutputNodes); + const apiKeys = useStore((s) => s.apiKeys); + + // Context + const showAlert = useContext(AlertModalContext); + + // State + const [methodItems, setMethodItems] = useState( + data.methods || [], + ); + const [status, setStatus] = useState(Status.NONE); + const [runTooltip, setRunTooltip] = useState("Run Retrieval"); + const [confirmMessage, setConfirmMessage] = useState(""); + const [results, setResults] = useState>( + data.results || {}, + ); + const [jsonResponses, setJsonResponses] = useState([]); + const [progress, setProgress] = useState(undefined); + const [progressAnimated, setProgressAnimated] = useState(true); + const pollIntervalRef = useRef(null); + + // Fusion // wire to the Fusion button + const [linkedGroups, setLinkedGroups] = useState([]); + + // Refs + const inspectorModalRef = useRef(null); + const retrievalConfirmModalRef = useRef(null); + + // Every time we click run, this increments. If we click stop, we increment it + // (invalidating the previous run) and reset the UI. + const runIdRef = useRef(0); + + const handleStopClick = useCallback(() => { + // Invalidate the current run by incrementing the ID + runIdRef.current += 1; + + // Stop the progress polling immediately + if (pollIntervalRef.current) { + clearInterval(pollIntervalRef.current); + pollIntervalRef.current = null; + } + + // Reset UI State immediately + setStatus(Status.NONE); + setProgress(undefined); + setProgressAnimated(false); + }, []); + + // Reset on refresh + useEffect(() => { + if (data.refresh) { + setDataPropsForNode(id, { + refresh: false, + results: {}, + output: [], + }); + setResults({}); + setJsonResponses([]); + } + }, [data.refresh, id, setDataPropsForNode]); + + // Handle method changes + const handleMethodsChange = useCallback( + (newItems: RetrievalMethodSpec[]) => { + setMethodItems(newItems); + setDataPropsForNode(id, { methods: newItems }); + if (status === Status.READY) { + setStatus(Status.WARNING); + } + }, + [id, setDataPropsForNode, status], + ); + + // Confirmation modal for running retrieval + const confirmAndRunRetrieval = () => { + // Pull current input data to check counts + const inputData = pullInputData(["chunks"], id) as { chunks?: any[] }; + const numChunks = inputData.chunks?.length || 0; + + // Check if an embedding model is active + // We check if the baseMethod is 'vector' or if an embedding provider is set + const hasEmbeddingModel = methodItems.some( + (m) => m.baseMethod === "vector" || !!m.embeddingProvider, + ); + + // Construct the base message + let msg = + "⚠️ You're about to run all configured retrieval methods. This may create, load, or modify vector stores."; + + if (hasEmbeddingModel && numChunks > 100) { + msg += + ` (🛑 High Volume Warning: You are running an embedding model on ${numChunks} ` + + "chunks. This will generate embeddings for all chunks that haven't already been embedded in previous runs of the " + + "retriever, which may be slow and incur costs.)"; + } + + setConfirmMessage(msg); + retrievalConfirmModalRef.current?.trigger(); + }; + + // Main retrieval function + const runRetrieval = useCallback(async () => { + if (methodItems.length === 0) { + showAlert?.("Please add at least one retrieval method"); + return; + } + const currentRunId = runIdRef.current; + + // Setup UI for loading + setStatus(Status.LOADING); + setProgress(5); // Start at 5% + setProgressAnimated(true); + + // Start Polling the "Faked" Endpoint + pollIntervalRef.current = window.setInterval(async () => { + try { + const resp = await fetch(`${FLASK_BASE_URL}getRetrieveProgress`); + if (currentRunId !== runIdRef.current) return; + const data = await resp.json(); + + let currentProgress = 0; + if (typeof data === "number") { + currentProgress = data; + } else if (data && typeof data === "object") { + // Sum all values in the object (assuming they are numbers representing % completion) + const values = Object.values(data) as number[]; + currentProgress = values.reduce( + (acc, val) => acc + (typeof val === "number" ? val : 0), + 0, + ); + } + + // Clamp between 5 and 95 so it doesn't look finished until it actually is + setProgress(Math.min(95, Math.max(5, currentProgress))); + } catch (e) { + console.warn("Could not fetch progress", e); + } + }, 500); + + try { + // Get input data from connected nodes + const inputData = pullInputData(["chunks", "queries"], id) as { + chunks?: any[]; + queries?: any[]; + }; + + // Format methods for the API request + const formattedMethods = methodItems.map((method) => ({ + id: method.key, + baseMethod: method.baseMethod, + methodName: method.methodName, + library: method.library, + embeddingProvider: method.embeddingProvider, + settings: method.settings || {}, + })); + + // Updated error checks for clarity + if (!inputData.chunks || inputData.chunks.length === 0) { + throw new Error("Input 'chunks' is missing or empty."); + } + if (!inputData.queries || inputData.queries.length === 0) { + throw new Error("Input 'queries' is missing or empty."); + } + + console.log("Chunks:", inputData.chunks); + + // Make the API request + const response = await fetch(`${FLASK_BASE_URL}retrieve`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + methods: formattedMethods, + chunks: inputData.chunks, + queries: inputData.queries, + api_keys: apiKeys, + fusion_enabled: linkedGroups.length > 0, + linked_groups: linkedGroups.length > 0 ? linkedGroups : [], + }), + }); + + if (currentRunId !== runIdRef.current) { + console.log("Retrieval result ignored (stopped by user)."); + return; + } + + if (!response.ok) { + const body = await response.json(); + const message = + body && typeof body.error === "string" + ? body.error + : `Retrieval failed: ${response.statusText}`; + + throw new Error(message); + } + + // The response is now a flat array of objects + const retrievalResults = await response.json(); + if (currentRunId !== runIdRef.current) return; + + // --- Hide individual members of fused groups; keep only the fused column --- + const fusedMemberIds = new Set( + (linkedGroups || []).flatMap((g) => g.methodKeys || []), + ); + + const filteredResults = + linkedGroups.length > 0 + ? retrievalResults.filter((r: any) => { + const mid = r?.metavars?.methodId; + if (!mid) return true; + if (typeof mid === "string" && mid.startsWith("group:")) + return true; // fused rows + return !fusedMemberIds.has(mid); // drop members of fused groups + }) + : retrievalResults; + + console.warn("Retrieval results:", filteredResults); + + // Convert to proper LLMResponse objects + const llmResponses: LLMResponse[] = filteredResults.map( + (result: any) => ({ + uid: result.uid || `retrieval-${Date.now()}-${Math.random()}`, + prompt: result.prompt, + vars: result.vars || {}, + metavars: result.metavars || {}, + responses: [result.text], + eval_res: result.eval_res || [], + llm: result.vars.retrievalMethod || "Unknown", // We are abusing 'llm' to store the retrieval method + // llm: result.llm || "Unknown Method", + }), + ); + + // Set the responses for the inspector + setJsonResponses(llmResponses); + + // Group results by method for the node's internal state + const resultsByMethod: Record = {}; + + // Process each result to organize by method + filteredResults.forEach((result: any) => { + // Extract method info using nullish coalescing for safety + const methodId = result.metavars?.methodId ?? "unknown_method"; + + if (!resultsByMethod[methodId]) { + resultsByMethod[methodId] = { + retrieved: {}, + metavars: { + retrievalMethod: result.vars?.retrievalMethod ?? "Unknown Method", + retrievalMethodSignature: + result.metavars?.retrievalMethodSignature, + embeddingModel: result.metavars?.embeddingModel, + latency: result.metavars?.latency_ms, + }, + }; + } + + // Group by query + const query = result.prompt; + if (!resultsByMethod[methodId].retrieved[query]) { + resultsByMethod[methodId].retrieved[query] = []; + } + + // Add this result to the appropriate query group + resultsByMethod[methodId].retrieved[query].push({ + text: result.text, + similarity: result.eval_res?.items[0]?.similarity, + docTitle: result.metavars?.docTitle, + chunkId: result.metavars?.chunkId, + }); + }); + + // Update results state + setResults(resultsByMethod); + + const outputForDownstream: TemplateVarInfo[] = filteredResults.map( + (result: any) => ({ + text: result.text, + prompt: result.prompt, + fill_history: result.vars || {}, + metavars: result.metavars || {}, + llm: result.llm, // Should we call this 'method' instead? + uid: result.uid || `chunk-${Date.now()}-${Math.random()}`, + }), + ); + + // Update node data + setDataPropsForNode(id, { + methods: methodItems, + results: resultsByMethod, + output: outputForDownstream, + }); + + // Notify downstream nodes + pingOutputNodes(id); + setStatus(Status.READY); + } catch (error) { + // Only show error if we weren't stopped + if (currentRunId === runIdRef.current) { + console.error("Detailed error:", error); + showAlert?.( + error instanceof Error ? error.message : "Retrieval failed", + ); + setStatus(Status.ERROR); + } + } finally { + // Only run cleanup if this is still the active run + // (If we stopped, handleStopClick already cleaned up) + if (currentRunId === runIdRef.current) { + if (pollIntervalRef.current) clearInterval(pollIntervalRef.current); + setProgress(100); + setProgressAnimated(false); + setTimeout(() => { + // Check one last time before clearing UI + if (currentRunId === runIdRef.current) { + setProgress(undefined); + } + }, 2000); + } + } + }, [ + methodItems, + id, + pullInputData, + setDataPropsForNode, + pingOutputNodes, + showAlert, + apiKeys, + linkedGroups, + ]); + + // Update stored data when methods change + useEffect(() => { + setDataPropsForNode(id, { + methods: methodItems, + results, + }); + }, [id, methodItems, results, setDataPropsForNode]); + + const handleRunHover = useCallback(() => { + if (status === Status.LOADING) return; + + try { + // Pull data from inputs without processing (just to count) + const inputData = pullInputData(["chunks", "queries"], id) as { + chunks?: any[]; + queries?: any[]; + }; + + const numChunks = inputData.chunks?.length || 0; + const numQueries = inputData.queries?.length || 0; + const numMethods = methodItems.length; + + if (numMethods === 0) { + setRunTooltip("Please add a retrieval method first."); + } else if (numChunks === 0 || numQueries === 0) { + setRunTooltip("Connect 'chunks' and 'queries' inputs."); + } else { + setRunTooltip( + `Will run ${numMethods} method(s) for ${numQueries} queries against ${numChunks} chunks.`, + ); + } + } catch (err) { + console.error(err); + setRunTooltip("Error checking inputs."); + } + }, [pullInputData, id, status, methodItems]); + + return ( + + + +
+ {/* Labeled Handle for 'queries' */} +
+
+ + queries + +
+ +
+ + {/* Labeled Handle for 'chunks' */} +
+
+ + chunks + +
+ +
+ + {/* Add margin top to push list below handles */} +
+ { + setLinkedGroups(groups); + setDataPropsForNode(id, { linked_groups: groups }); + }} + methodResults={results} + /> +
+
+ + {progress !== undefined && ( +
+ + {/* Add a small text label below */} +
+ {Math.round(progress)}% +
+
+ )} + + inspectorModalRef.current?.trigger()} + showDrawerButton={false} + onDrawerClick={() => undefined} + isDrawerOpen={false} + label={ + <> + Inspect results + + } + /> + + + + + + + +
+ ); +}; +export default RetrievalNode; diff --git a/chainforge/react-server/src/SelectVarsNode.tsx b/chainforge/react-server/src/SelectVarsNode.tsx new file mode 100644 index 000000000..2f613ceec --- /dev/null +++ b/chainforge/react-server/src/SelectVarsNode.tsx @@ -0,0 +1,304 @@ +import React, { useCallback, useEffect, useMemo, useState } from "react"; +import { Handle, Position } from "reactflow"; +import { v4 as uuid } from "uuid"; +import { + Box, + Button, + Checkbox, + Flex, + Group, + ScrollArea, + Text, +} from "@mantine/core"; +import { IconFilter } from "@tabler/icons-react"; + +import useStore from "./store"; +import BaseNode from "./BaseNode"; +import NodeLabel from "./NodeLabelComponent"; + +import { generatePrompts } from "./backend/backend"; +import { + extractLLMLookup, + tagMetadataWithLLM, + removeLLMTagFromMetadata, +} from "./backend/utils"; +import { + Dict, + JSONCompatible, + TemplateVarInfo, + LLMResponsesByVarDict, +} from "./backend/typing"; + +const ALWAYS_INCLUDED_KEYS = ["__pt", "id", "signature"]; + +interface SelectVarsNodeProps { + data: { + input?: JSONCompatible; + title?: string; + refresh?: boolean; + selectedKeys?: string[]; + }; + id: string; +} + +function toPlainObject(maybe: any): Record { + if (!maybe) return {}; + if (maybe instanceof Map) return Object.fromEntries(maybe.entries()); + if ( + Array.isArray(maybe) && + maybe.length > 0 && + Array.isArray(maybe[0]) && + maybe[0].length === 2 + ) { + try { + return Object.fromEntries(maybe as any); + } catch { + /* ignore */ + } + } + if (typeof maybe === "object") { + const out: Record = {}; + Object.getOwnPropertyNames(maybe).forEach((k) => { + out[k] = (maybe as any)[k]; + }); + return out; + } + return {}; +} + +const SelectVarsNode: React.FC = ({ data, id }) => { + const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + const pullInputData = useStore((state) => state.pullInputData); + + const [pastInputs, setPastInputs] = useState([]); + const [inputItems, setInputItems] = useState([]); + const [availableKeys, setAvailableKeys] = useState([]); + const [selectedKeys, setSelectedKeys] = useState( + data.selectedKeys ?? [], + ); + + // Toujours inclure les clés spéciales + const selectedSet = useMemo( + () => new Set([...(selectedKeys ?? []), ...ALWAYS_INCLUDED_KEYS]), + [selectedKeys], + ); + + const handleSetAndSave = ( + value: T, + setter: React.Dispatch>, + propName: string, + ) => { + setter(value); + setDataPropsForNode(id, { [propName]: value as any }); + }; + + const handleOnConnect = useCallback(() => { + let input_data: LLMResponsesByVarDict = pullInputData(["__input"], id); + if (!input_data || !input_data.__input) { + setInputItems([]); + setAvailableKeys([]); + setDataPropsForNode(id, { fields: [], output: [] }); + return; + } + + const llm_lookup = extractLLMLookup(input_data); + input_data = tagMetadataWithLLM(input_data); + + generatePrompts( + "{__input}", + input_data as Dict<(TemplateVarInfo | string)[]>, + ) + .then((promptTemplates) => { + const resp_objs = promptTemplates.map((p: any) => { + const sourceVars = + "fill_history" in p && p.fill_history ? p.fill_history : p.vars; + const normalized = toPlainObject(sourceVars); + + const obj: any = { + text: p.toString(), + llm: + p.metavars && "__LLM_key" in p.metavars + ? llm_lookup[p.metavars.__LLM_key] + : undefined, + metavars: removeLLMTagFromMetadata(p.metavars), + uid: uuid(), + fill_history: normalized, // on ne passe jamais 'vars' + }; + + // Toujours conserver les clés spéciales si présentes + ALWAYS_INCLUDED_KEYS.forEach((key) => { + if (p.metavars && key in p.metavars) { + if (!obj.metavars) obj.metavars = {}; + obj.metavars[key] = p.metavars[key]; + } + }); + + return obj; + }); + + setInputItems(resp_objs); + + // Clés disponibles: union des clés de metavars + fill_history (sans ALWAYS_INCLUDED_KEYS) + const keys = new Set(); + resp_objs.forEach((r) => { + Object.keys(r.metavars ?? {}).forEach((k) => keys.add(k)); + Object.keys(r.fill_history ?? {}).forEach((k) => keys.add(k)); + }); + ALWAYS_INCLUDED_KEYS.forEach((k) => keys.delete(k)); + const nextAvailable = Array.from(keys).sort(); + setAvailableKeys(nextAvailable); + + // Initialiser/mettre à jour selectedKeys + if ((selectedKeys ?? []).length === 0) { + handleSetAndSave(nextAvailable, setSelectedKeys, "selectedKeys"); + } else { + const merged = Array.from( + new Set([...(selectedKeys ?? []), ...nextAvailable]), + ); + if (merged.length !== (selectedKeys ?? []).length) { + handleSetAndSave(merged, setSelectedKeys, "selectedKeys"); + } + } + }) + .catch((e) => { + console.error(e); + setInputItems([]); + setAvailableKeys([]); + setDataPropsForNode(id, { fields: [], output: [] }); + }); + }, [id, pullInputData, selectedKeys, setDataPropsForNode]); + + if (data.input && data.input !== pastInputs) { + setPastInputs(data.input); + handleOnConnect(); + } + + useEffect(() => { + if (data.refresh) { + setDataPropsForNode(id, { refresh: false }); + handleOnConnect(); + } + }, [data.refresh, id, handleOnConnect, setDataPropsForNode]); + + // Sortie: filtrer metavars ET appliquer le filtre à fill_history + useEffect(() => { + if (inputItems.length === 0) { + setDataPropsForNode(id, { fields: [], output: [] }); + return; + } + + const out = inputItems.map((f) => { + const mv = f.metavars || {}; + const fhAll = toPlainObject(f.fill_history ?? {}); + + // 1) Filtrer metavars selon la sélection + ALWAYS_INCLUDED_KEYS + const filteredMetavars: Record = {}; + Object.keys(mv).forEach((k) => { + if (selectedSet.has(k)) filteredMetavars[k] = mv[k]; + }); + ALWAYS_INCLUDED_KEYS.forEach((key) => { + if (key in mv) filteredMetavars[key] = mv[key]; + }); + + // 2) Appliquer le filtre à fill_history (ne garder que selectedSet + ALWAYS_INCLUDED_KEYS) + const filteredFillHistory: Record = {}; + Object.keys(fhAll).forEach((k) => { + if (selectedSet.has(k)) filteredFillHistory[k] = fhAll[k]; + }); + ALWAYS_INCLUDED_KEYS.forEach((key) => { + if (key in fhAll) filteredFillHistory[key] = fhAll[key]; + }); + + // Ne jamais renvoyer 'vars' + const { vars: _omitVars, ...rest } = f; + + return { + ...rest, + metavars: filteredMetavars, + fill_history: filteredFillHistory, + }; + }); + + setDataPropsForNode(id, { fields: out, output: out }); + }, [inputItems, selectedSet, id, setDataPropsForNode]); + + const toggleKey = (k: string) => { + const next = selectedSet.has(k) + ? selectedKeys.filter((x) => x !== k) + : [...selectedKeys, k]; + handleSetAndSave(next, setSelectedKeys, "selectedKeys"); + }; + + const selectAll = () => + handleSetAndSave(availableKeys, setSelectedKeys, "selectedKeys"); + const unselectAll = () => + handleSetAndSave([], setSelectedKeys, "selectedKeys"); + + return ( + + } + /> + + + + + + + + Variable keys: {availableKeys.length} + + + + + + + + + {availableKeys.length === 0 ? ( + + No variables found in input. + + ) : ( + + {availableKeys.map((k) => ( + toggleKey(k)} + /> + ))} + + )} + + + + + + ); +}; + +export default SelectVarsNode; diff --git a/chainforge/react-server/src/SplitNode.tsx b/chainforge/react-server/src/SplitNode.tsx index 37380a3a5..bcd5a4635 100644 --- a/chainforge/react-server/src/SplitNode.tsx +++ b/chainforge/react-server/src/SplitNode.tsx @@ -41,6 +41,7 @@ const formattingOptions = [ { value: ",", label: "commas (,)" }, { value: "code", label: "code blocks" }, { value: "paragraph", label: "paragraphs (md)" }, + { value: ";", label: "semicon (;)" }, ]; /** Flattens markdown AST as dict to text (string) */ @@ -69,6 +70,11 @@ export const splitText = ( return processCSV(s) .map((s) => _escapeBraces(s)) .filter((s) => s.length > 0); + else if (format === ";") + return s + .split(";") + .map((s) => _escapeBraces(s.trim())) + .filter((s) => s.length > 0); // Other formatting rules require markdown parsing: // Parse string as markdown diff --git a/chainforge/react-server/src/UploadFileModal.tsx b/chainforge/react-server/src/UploadFileModal.tsx index d30e352ba..64affd7fe 100644 --- a/chainforge/react-server/src/UploadFileModal.tsx +++ b/chainforge/react-server/src/UploadFileModal.tsx @@ -227,7 +227,7 @@ const UploadFileModal = forwardRef( setIsFetching(true); try { - const proxyUrl = `${FLASK_BASE_URL}/api/proxyImage?url=${encodeURIComponent(url)}`; + const proxyUrl = `${FLASK_BASE_URL}api/proxyImage?url=${encodeURIComponent(url)}`; const response = await fetch(proxyUrl); if (!response.ok) { diff --git a/chainforge/react-server/src/UploadNode.tsx b/chainforge/react-server/src/UploadNode.tsx new file mode 100644 index 000000000..37a45b542 --- /dev/null +++ b/chainforge/react-server/src/UploadNode.tsx @@ -0,0 +1,289 @@ +import React, { + useEffect, + useState, + useRef, + useCallback, + useMemo, + useContext, +} from "react"; +import { Handle, Position } from "reactflow"; +import { + Button, + Group, + Text, + Box, + List, + ThemeIcon, + Flex, + ScrollArea, +} from "@mantine/core"; +import { IconUpload, IconTrash } from "@tabler/icons-react"; +import useStore from "./store"; +import BaseNode from "./BaseNode"; +import NodeLabel from "./NodeLabelComponent"; +import { AlertModalContext } from "./AlertModal"; +import { Status } from "./StatusIndicatorComponent"; +import { TemplateVarInfo } from "./backend/typing"; +import { MediaLookup } from "./backend/cache"; + +interface UploadNodeProps { + data: { + title: string; + fields: TemplateVarInfo[]; + refresh: boolean; + }; + id: string; +} + +const UploadNode: React.FC = ({ data, id }) => { + const nodeIcon = useMemo(() => "📁", []); + const nodeDefaultTitle = useMemo(() => "Upload Node", []); + const setDataPropsForNode = useStore((state) => state.setDataPropsForNode); + + const [fields, setFields] = useState(data.fields || []); + const [status, setStatus] = useState(Status.READY); + + const [fileListCollapsed, setFileListCollapsed] = useState( + !(data.fields && data.fields.length > 0), + ); + const toggleFileList = () => setFileListCollapsed((prev) => !prev); + + const showAlert = useContext(AlertModalContext); + const fileInputRef = useRef(null); + + // Handle file uploads + const handleFilesUpload = useCallback( + async (files: FileList) => { + if (files.length === 0) return; + + setStatus(Status.LOADING); + const updatedFields = [...fields]; + + for (const file of Array.from(files)) { + console.log("Uploading file:", file, file.name); + + try { + // Upload the file to the lookup and get its UID + const uid = await MediaLookup.upload(file); + + // Grab the content of the file, in plain text + // TODO: Make this work on the front-end if backend is not available + const text = await MediaLookup.getAsText(uid); + + console.log("File content:", text); + + // Add filename + text content as a new TemplateVarInfo + updatedFields.push({ + text: text, + prompt: "", + fill_history: {}, + llm: undefined, + metavars: { + size: file.size.toString(), + type: file.type, + filename: file.name, // important: store doc name + id: uid, + }, + }); + } catch (error: any) { + console.error("Error uploading file:", error); + showAlert?.(`Error uploading ${file.name}: ${error.message}`); + setStatus(Status.ERROR); + } + } + + setFields(updatedFields); + + // Also set the node's output for the flow + setDataPropsForNode(id, { fields: updatedFields, output: updatedFields }); + setStatus(Status.READY); + }, + [fields, id, setDataPropsForNode, showAlert], + ); + + // On file input change + const handleFileInputChange = ( + event: React.ChangeEvent, + ) => { + if (event.target.files) { + handleFilesUpload(event.target.files); + event.target.value = ""; + } + }; + + // Drag & drop + const handleDrop = (event: React.DragEvent) => { + event.preventDefault(); + if (event.dataTransfer.files) { + handleFilesUpload(event.dataTransfer.files); + } + }; + const handleDragOver = (event: React.DragEvent) => { + event.preventDefault(); + }; + + // Remove a file + const handleRemoveFile = (index: number) => { + const fieldToRemove = fields[index]; + + // UID is stored in metavars.id, set when uploading + const uid = + typeof fieldToRemove?.metavars?.id === "string" + ? fieldToRemove.metavars.id + : undefined; + + // Remove from MediaLookup (same idea as MediaNode.handleRemoveMedia) + if (uid) { + try { + MediaLookup.remove(uid); + } catch (error) { + console.error("Error removing file from MediaLookup:", error); + } + } + + // 2) Update local fields + node output + const updatedFields = fields.filter((_, i) => i !== index); + setFields(updatedFields); + setDataPropsForNode(id, { fields: updatedFields, output: updatedFields }); + }; + + // Clear all + const handleClearUploads = useCallback(() => { + // Collect all UIDs before clearing + const uidsToRemove = fields + .map((field) => + typeof field.metavars?.id === "string" ? field.metavars.id : undefined, + ) + .filter((x): x is string => !!x); + + // Remove each file from MediaLookup + for (const uid of uidsToRemove) { + try { + MediaLookup.remove(uid); + } catch (error) { + console.error("Error removing file from MediaLookup:", error); + } + } + + setFields([]); + setDataPropsForNode(id, { fields: [], output: [] }); + setStatus(Status.READY); + }, [fields, id, setDataPropsForNode]); + + // Refresh logic + useEffect(() => { + if (data.refresh) { + handleClearUploads(); + setDataPropsForNode(id, { refresh: false }); + } + }, [data.refresh, handleClearUploads, id, setDataPropsForNode]); + + return ( + + + +
fileInputRef.current?.click()} + > + + + Drag & drop files here or click to upload (.pdf, .docx, .txt, .md) + + +
+ + + + + + + Uploaded Files ({fields.length}) + + {fields.length > 0 && ( + + )} + + + {!fileListCollapsed && fields.length > 0 && ( + + + {fields.map((field, index) => ( + + 📄 + + } + > + + + + {typeof field.metavars?.filename === "string" + ? field.metavars.filename + : "Untitled file"} + + {field.text && typeof field.text === "string" && ( + + {field.text.slice(0, 50)} + {field.text.length > 50 ? "..." : ""} + + )} + + + + + ))} + + + )} + +
+ ); +}; + +export default UploadNode; diff --git a/chainforge/react-server/src/VisNode.tsx b/chainforge/react-server/src/VisNode.tsx index 6542f0098..bbea23001 100644 --- a/chainforge/react-server/src/VisNode.tsx +++ b/chainforge/react-server/src/VisNode.tsx @@ -433,14 +433,26 @@ export const VisView = forwardRef( })), ); - // Find all the special 'LLM group' metavars and put them in the 'group by' dropdown: - const available_llm_groups = [{ value: "LLM", label: "LLM" }].concat( - metavars.filter(cleanMetavarsFilterFunc).map((name) => ({ - value: name, - label: `LLMs #${parseInt(name.slice(4)) + 1}`, - })), - ); - if (available_llm_groups.length > 1) + // Find all the special metavars and vars and put them in the 'group by' dropdown: + const available_llm_groups = [{ value: "LLM", label: "LLM" }] + .concat(varnames.map((name) => ({ value: name, label: name }))) + .concat( + metavars.filter(cleanMetavarsFilterFunc).map((name) => { + let label = `${name} (meta)`; + if (name.startsWith("llm_")) { + label = `LLMs #${parseInt(name.slice(4)) + 1}`; + } else if (name === "retriever" || name === "retrieval_method") { + label = "Retrieval methods"; + } else if (name === "chunk") { + label = "Chunks"; + } + return { + value: `__meta_${name}`, + label, + }; + }), + ); + if (available_llm_groups.some((g) => g.value.startsWith("__meta_llm_"))) available_llm_groups[0] = { value: "LLM", label: "LLMs (last)" }; setAvailableLLMGroups(available_llm_groups); @@ -518,7 +530,12 @@ export const VisView = forwardRef( typeof resp_obj.llm === "number" ? StringLookup.get(resp_obj.llm) ?? "(LLM lookup failed)" : resp_obj.llm?.name; - else return resp_obj.metavars[selectedLLMGroup] as string; + else if (selectedLLMGroup?.startsWith("__meta_")) { + const meta_key = selectedLLMGroup.slice("__meta_".length); + return resp_obj.metavars[meta_key] as string; + } else { + return resp_obj.vars[selectedLLMGroup as string] as string; + } }; const getLLMsInResponses = (responses: LLMResponse[]) => getUniqueKeysInResponses(responses, get_llm); @@ -776,6 +793,9 @@ export const VisView = forwardRef( } const shortnames = genUniqueShortnames(names); + const yLabelShortnames = genUniqueShortnames( + new Set(responses.map(resp_to_x)), + ); for (const name of names) { let x_items: EvaluationScore[] = []; let text_items: string[] = []; @@ -786,8 +806,10 @@ export const VisView = forwardRef( const eval_res = get_items(r.eval_res).filter( (i) => i === name, ); + const rawLabel = resp_to_x(r); + const yLabel = yLabelShortnames[rawLabel] ?? rawLabel; x_items = x_items.concat( - new Array(eval_res.length).fill(resp_to_x(r)), + new Array(eval_res.length).fill(yLabel), ); }); } else { diff --git a/chainforge/react-server/src/backend/backend.ts b/chainforge/react-server/src/backend/backend.ts index 3633e37b3..6d0831098 100644 --- a/chainforge/react-server/src/backend/backend.ts +++ b/chainforge/react-server/src/backend/backend.ts @@ -467,18 +467,38 @@ async function run_over_responses( } } - // If type is just a processor + // Nouveau traitement pour format {text: ..., metavars: {...}} if (process_type === "processor") { - // Replace response texts in resp_obj with the transformed ones: - resp_obj.responses = processed; + if ( + Array.isArray(processed) && + processed.length > 0 && + processed[0] && + typeof processed[0] === "object" && + "text" in processed[0] && + "metavars" in processed[0] + ) { + // On prend le texte comme réponse, et le second élément comme metavars + resp_obj.responses = [processed[0].text]; + resp_obj.metavars = processed[0].metavars; + } else if ( + Array.isArray(processed) && + processed.length > 0 && + Array.isArray(processed[0]) && + processed[0].length === 2 && + typeof processed[0][1] === "object" + ) { + // Ancien cas tuple (texte, metavars) + resp_obj.responses = [processed[0][0]]; + resp_obj.metavars = processed[0][1]; + } else { + // Cas standard + resp_obj.responses = processed; + } } else { // If type is an evaluator - // Check the type of evaluation results - // NOTE: We assume this is consistent across all evaluations, but it may not be. const eval_res_type = check_typeof_vals(processed); if (eval_res_type === MetricType.Numeric) { - // Store items with summary of mean, median, etc resp_obj.eval_res = { items: processed, dtype: (getEnumName(MetricType, eval_res_type) ?? @@ -491,7 +511,6 @@ async function run_over_responses( "Unsupported types found in evaluation results. Only supported types for metrics are: int, float, bool, str.", ); } else { - // Categorical, KeyValue, etc, we just store the items: resp_obj.eval_res = { items: processed, dtype: (getEnumName(MetricType, eval_res_type) ?? diff --git a/chainforge/react-server/src/backend/pyodide/exec-py.worker.js b/chainforge/react-server/src/backend/pyodide/exec-py.worker.js index 8fbfee4ad..af1c86091 100644 --- a/chainforge/react-server/src/backend/pyodide/exec-py.worker.js +++ b/chainforge/react-server/src/backend/pyodide/exec-py.worker.js @@ -24,6 +24,11 @@ self.onmessage = async function (event) { try { await self.pyodide.loadPackagesFromImports(python); let results = await self.pyodide.runPythonAsync(python); + // Conversion si possible + if (results && typeof results.toJs === "function") { + results = results.toJs({ dict_converter: Object.fromEntries }); + } + console.log("About to postMessage:", results); self.postMessage({ results, id }); } catch (error) { self.postMessage({ error: error.message, id }); diff --git a/chainforge/react-server/src/backend/typing.ts b/chainforge/react-server/src/backend/typing.ts index 4b10e7280..da75168f1 100644 --- a/chainforge/react-server/src/backend/typing.ts +++ b/chainforge/react-server/src/backend/typing.ts @@ -200,6 +200,7 @@ export type CustomLLMProviderSpec = { settings: Dict>; ui: Dict>; }; + category?: string; }; /** Internal description of model settings, passed to react-json-schema */ diff --git a/chainforge/react-server/src/backend/utils.ts b/chainforge/react-server/src/backend/utils.ts index 01b4bfb77..6298bba2b 100644 --- a/chainforge/react-server/src/backend/utils.ts +++ b/chainforge/react-server/src/backend/utils.ts @@ -63,7 +63,6 @@ const ANTHROPIC_HUMAN_PROMPT = "\n\nHuman:"; const ANTHROPIC_AI_PROMPT = "\n\nAssistant:"; /** Where the ChainForge Flask server is being hosted, if any. */ - export const FLASK_BASE_URL = // @ts-expect-error undefined window.__CF_HOSTNAME !== undefined && window.__CF_PORT !== undefined @@ -114,6 +113,48 @@ export function APP_IS_RUNNING_LOCALLY(): boolean { return _APP_IS_RUNNING_LOCALLY; } +// We cache the RAG availability check to avoid repeated backend calls +export let RAG_AVAILABLE: boolean | undefined; +let _RAG_CHECK_PROMISE: Promise | undefined; + +async function checkRagAvailabilityFromBackend(): Promise { + try { + const response = await call_flask_backend("checkRagAvailable", {}); + return response.rag_available === true; + } catch (error) { + console.warn("Failed to check RAG availability from backend:", error); + return false; + } +} + +export async function isRagAvailable(): Promise { + // First check if the window flag is set + // @ts-expect-error undefined + if (window.__RAG_AVAILABLE !== undefined) { + RAG_AVAILABLE = (window as any).__RAG_AVAILABLE as boolean; + } else if (window?.location.port === "3000") { + // Dev mode -- assume RAG is available for easier testing + RAG_AVAILABLE = true; + } + // If not cached, check with the backend + else if (RAG_AVAILABLE === undefined) { + // Avoid multiple concurrent requests + if (!_RAG_CHECK_PROMISE) { + _RAG_CHECK_PROMISE = checkRagAvailabilityFromBackend(); + } + RAG_AVAILABLE = await _RAG_CHECK_PROMISE; + _RAG_CHECK_PROMISE = undefined; + } + + return RAG_AVAILABLE; +} + +// Check RAG availability immediately upon load +// :: Start the async check but don't block +isRagAvailable().catch((err) => + console.warn("Background RAG availability check failed:", err), +); + /** * Equivalent to a 'fetch' call, but routes it to the backend Flask server in * case we are running a local server and prefer to not deal with CORS issues making API calls client-side. diff --git a/chainforge/react-server/src/store.tsx b/chainforge/react-server/src/store.tsx index 5c1d47dfe..89208f5d2 100644 --- a/chainforge/react-server/src/store.tsx +++ b/chainforge/react-server/src/store.tsx @@ -34,6 +34,9 @@ import { TogetherChatSettings } from "./ModelSettingSchemas"; import { NativeLLM } from "./backend/models"; import { StringLookup } from "./backend/cache"; import { saveGlobalConfig } from "./backend/backend"; +import { remove } from "jszip"; +import { ChunkMethodSpec } from "./ChunkMethodListComponent"; +import type { RetrievalMethodSpec } from "./RetrievalMethodListComponent"; const IS_RUNNING_LOCALLY = APP_IS_RUNNING_LOCALLY(); // Initial project settings @@ -94,6 +97,7 @@ const refreshableOutputNodeTypes = new Set([ "simpleval", "join", "split", + "selectvars", ]); export const initLLMProviderMenu: (LLMSpec | LLMGroup)[] = [ @@ -461,6 +465,14 @@ export interface StoreHandles { AvailableLLMs: LLMSpec[]; setAvailableLLMs: (specs: LLMSpec[]) => void; + // Custom chunkers loaded via the Custom Providers dropzone + customChunkers: ChunkMethodSpec[]; + setCustomChunkers: (chunkers: ChunkMethodSpec[]) => void; + + // Custom retrievers loaded via the Custom Providers dropzone + customRetrievers: RetrievalMethodSpec[]; + setCustomRetrievers: (retrievers: RetrievalMethodSpec[]) => void; + // API keys to LLM providers apiKeys: Dict; setAPIKeys: (apiKeys: Dict) => void; @@ -552,6 +564,38 @@ const useStore = create((set, get) => ({ set({ AvailableLLMs: llmProviderList }); }, + customChunkers: [], + setCustomChunkers: (chunkers) => { + set({ customChunkers: chunkers }); + }, + + customRetrievers: [], + setCustomRetrievers: (retrievers: any[]) => { + const items = (retrievers ?? []) + .filter((p) => (p?.category ?? "retriever") === "retriever") + .map((p) => { + const name = + p?.name ?? p?.methodName ?? p?.library ?? "Custom Provider"; + const baseMethod = `__custom/${name}`; // enforce canonical + + return { + // whitelist normalized store shape + key: p?.key ?? uuid(), + methodName: name, + library: name, + baseMethod, + emoji: p?.emoji ?? "✨", + needsEmbeddingModel: !!p?.needs_embedding_model, + models: Array.isArray(p?.models) ? p.models : [], + settings_schema: p?.settings_schema ?? undefined, + default_settings: p?.default_settings ?? undefined, + source: "custom" as const, + }; + }); + + set({ customRetrievers: items as any }); + }, + aiFeaturesProvider: "OpenAI", setAIFeaturesProvider: (llmProvider) => { set({ aiFeaturesProvider: llmProvider }); diff --git a/chainforge/react-server/src/styles.css b/chainforge/react-server/src/styles.css index 98a53d806..ddc57b18d 100644 --- a/chainforge/react-server/src/styles.css +++ b/chainforge/react-server/src/styles.css @@ -130,6 +130,7 @@ hr { padding: 2px 8px; overflow-y: auto; max-height: 205px; + width: 100%; } .llm-list-backdrop { margin: 6px 0px 6px 6px; @@ -137,6 +138,9 @@ hr { text-align: left; font-size: 10pt; color: #777; + display: flex; + align-items: center; + justify-content: space-between; } .llm-scorer-container { width: 290px; @@ -154,7 +158,44 @@ html[data-mantine-color-scheme="dark"] .llm-list-backdrop { .llm-list-item { background: white; + padding: 6px 8px; + border-radius: 6px; + box-shadow: + 0 1px 3px rgba(0, 0, 0, 0.12), + 0 1px 2px rgba(0, 0, 0, 0.24); + margin: 0 0 8px 0; + display: grid; + /* grid-gap: 20px; */ + /* grid-template-columns: 1fr auto; */ + align-items: center; + column-gap: 12px; +} + +.llm-list-item > * { + min-width: 0; +} + +.llm-card-header { + font-weight: 500; + font-size: 10pt; + font-family: -apple-system, "Segoe UI", "Roboto", "Oxygen", "Ubuntu", + "Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif; + text-align: start; + float: left; + margin: 0; + white-space: nowrap; /* keep on one line */ + overflow: hidden; + text-overflow: ellipsis; + line-height: 1.2; +} + +.llm-row-actions { + display: flex; + align-items: center; + gap: 6px; + flex-wrap: nowrap; } + html[data-mantine-color-scheme="dark"] .llm-list-item { background: #222; color: #ccc; @@ -1631,7 +1672,7 @@ html[data-mantine-color-scheme="dark"] .saved-flows-footer { .icl { white-space: pre-wrap; - word-break: break-all; + word-break: break-word; } .upload-node { @@ -1666,6 +1707,14 @@ html[data-mantine-color-scheme="dark"] .upload-node-list { .chunk-node { width: 290px; + min-width: 290px; + max-width: 290px; +} + +.retrieval-node { + width: 330px; + min-width: 330px; + max-width: 330px; } /* Classes relative to the Media Node */ @@ -1772,3 +1821,7 @@ html[data-mantine-color-scheme="dark"] .carousel-nav-button:hover { 0.2 ); /* Brighter hover background for dark mode */ } + +.rerank-node { + width: 290px; +} diff --git a/chainforge/requirements-lock.txt b/chainforge/requirements-lock.txt new file mode 100644 index 000000000..a4054b539 --- /dev/null +++ b/chainforge/requirements-lock.txt @@ -0,0 +1,142 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.11.18 +aiosignal==1.3.2 +annotated-types==0.7.0 +anyio==4.9.0 +asgiref==3.8.1 +async-timeout==4.0.3 +attrs==25.3.0 +beautifulsoup4==4.13.4 +blinker==1.9.0 +blis==1.2.1 +build==1.2.2.post1 +certifi==2025.4.26 +cffi==1.17.1 +charset-normalizer==3.4.2 +chonkie==1.0.7 +click==8.1.8 +cloudpathlib==0.21.0 +cobble==0.1.4 +cohere==5.15.0 +coloredlogs==15.0.1 +confection==0.1.5 +cryptography==44.0.3 +dataclasses-json==0.6.7 +deprecation==2.1.0 +distro==1.9.0 +et_xmlfile==2.0.0 +exceptiongroup==1.2.2 +fastavro==1.10.0 +filelock==3.18.0 +Flask==3.1.0 +flask-cors==5.0.1 +flatbuffers==25.2.10 +frozenlist==1.6.0 +fsspec==2025.3.2 +greenlet==3.2.1 +grpcio==1.44.0 +h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 +httpx-sse==0.4.0 +huggingface-hub==0.30.2 +humanfriendly==10.0 +idna==3.10 +iniconfig==2.1.0 +itsdangerous==2.2.0 +Jinja2==3.1.6 +jiter==0.9.0 +joblib==1.5.0 +jsonpatch==1.33 +jsonpointer==3.0.0 +lancedb==0.18.0 +langcodes==3.5.0 +language_data==1.3.0 +lxml==5.4.0 +magika==0.6.2 +mammoth==1.9.0 +marisa-trie==1.2.1 +markdown-it-py==3.0.0 +markdownify==1.1.0 +markitdown==0.1.1 +MarkupSafe==3.0.2 +marshmallow==3.26.1 +mdurl==0.1.2 +mistune==3.1.3 +model2vec==0.5.0 +mpmath==1.3.0 +multidict==6.4.3 +mypy_extensions==1.1.0 +networkx==3.4.2 +nltk==3.9.1 +numpy==1.26.4 +onnxruntime==1.19.2 +openai==1.77.0 +openpyxl==3.1.5 +orjson==3.10.18 +overrides==7.7.0 +packaging==24.2 +pandas==2.2.3 +pdfminer.six==20250416 +pillow==11.2.1 +pip-tools==7.4.1 +platformdirs==4.3.7 +pluggy==1.5.0 +preshed==3.0.9 +propcache==0.3.1 +protobuf==6.30.2 +pyarrow==16.0.0 +pycparser==2.22 +pydantic==2.11.4 +pydantic-settings==2.9.1 +pydantic_core==2.33.2 +Pygments==2.19.1 +pylance==0.22.0 +PyMuPDF==1.25.5 +pyproject_hooks==1.2.0 +pytest==8.3.5 +python-dateutil==2.9.0.post0 +python-docx==1.1.2 +python-dotenv==1.1.0 +python-pptx==1.0.2 +pytz==2025.2 +PyYAML==6.0.2 +rank-bm25==0.2.2 +regex==2024.11.6 +requests==2.32.3 +requests-toolbelt==1.0.0 +rich==14.0.0 +safetensors==0.5.3 +scikit-learn==1.6.1 +scipy==1.15.2 +sentence-transformers==4.1.0 +shellingham==1.5.4 +six==1.17.0 +smart-open==7.1.0 +sniffio==1.3.1 +soupsieve==2.7 +SQLAlchemy==2.0.40 +sympy==1.14.0 +tenacity==9.1.2 +threadpoolctl==3.6.0 +tiktoken==0.9.0 +tokenizers==0.21.1 +tomli==2.2.1 +torch==2.2.2 +tqdm==4.67.1 +transformers==4.51.3 +typer==0.15.3 +types-requests==2.31.0.6 +types-urllib3==1.26.25.14 +typing-inspect==0.9.0 +typing-inspection==0.4.0 +typing_extensions==4.13.2 +tzdata==2025.2 +urllib3==1.26.6 +Werkzeug==3.1.3 +Whoosh==2.7.4 +wrapt==1.17.2 +xlrd==2.0.1 +XlsxWriter==3.2.3 +yarl==1.20.0 +zstandard==0.23.0 diff --git a/chainforge/requirements-rag-lock.txt b/chainforge/requirements-rag-lock.txt new file mode 100644 index 000000000..87489a118 --- /dev/null +++ b/chainforge/requirements-rag-lock.txt @@ -0,0 +1,104 @@ +annotated-types==0.7.0 +anyio==4.11.0 +asgiref==3.10.0 +beautifulsoup4==4.14.2 +blinker==1.9.0 +certifi==2025.11.12 +cffi==2.0.0 +charset-normalizer==3.4.4 +chonkie==1.3.1 +click==8.3.1 +cobble==0.1.4 +cohere==5.20.0 +coloredlogs==15.0.1 +cryptography==46.0.3 +defusedxml==0.7.1 +deprecation==2.1.0 +distro==1.9.0 +et_xmlfile==2.0.0 +fastavro==1.12.1 +filelock==3.20.0 +Flask==3.1.2 +flask-cors==6.0.1 +flatbuffers==25.9.23 +fsspec==2025.10.0 +grpcio==1.76.0 +h11==0.16.0 +hf-xet==1.2.0 +httpcore==1.0.9 +httpx==0.28.1 +httpx-sse==0.4.0 +huggingface-hub==0.36.0 +humanfriendly==10.0 +idna==3.11 +itsdangerous==2.2.0 +Jinja2==3.1.6 +jiter==0.12.0 +joblib==1.5.2 +lancedb==0.17.0 +lxml==6.0.2 +magika==0.6.3 +mammoth==1.11.0 +markdown-it-py==4.0.0 +markdownify==1.2.2 +markitdown==0.1.3 +MarkupSafe==3.0.3 +mdurl==0.1.2 +mistune==3.1.4 +model2vec==0.7.0 +mpmath==1.3.0 +networkx==3.5 +nltk==3.9.2 +numpy==1.26.4 +onnxruntime==1.19.2 +openai==2.8.0 +openpyxl==3.1.5 +overrides==7.7.0 +packaging==25.0 +pandas==2.3.3 +pdfminer.six==20251107 +pillow==12.0.0 +platformdirs==4.5.0 +protobuf==6.33.1 +pyarrow==16.0.0 +pycparser==2.23 +pydantic==2.12.4 +pydantic_core==2.41.5 +Pygments==2.19.2 +pylance==0.20.0 +PyMuPDF==1.26.6 +python-dateutil==2.9.0.post0 +python-docx==1.2.0 +python-dotenv==1.2.1 +python-pptx==1.0.2 +pytz==2025.2 +PyYAML==6.0.3 +rank-bm25==0.2.2 +regex==2025.11.3 +requests==2.32.5 +rich==14.2.0 +safetensors==0.6.2 +scikit-learn==1.7.2 +scipy==1.16.3 +sentence-transformers==5.1.2 +setuptools==80.9.0 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.8 +sympy==1.14.0 +threadpoolctl==3.6.0 +tiktoken==0.12.0 +tokenizers==0.22.1 +torch==2.2.2 +tqdm==4.67.1 +transformers==4.57.1 +types-requests==2.31.0.6 +types-urllib3==1.26.25.14 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2025.2 +urllib3==1.26.6 +Werkzeug==3.1.3 +Whoosh==2.7.4 +xlrd==2.0.2 +xlsxwriter==3.2.9 diff --git a/chainforge/requirements.txt b/chainforge/requirements.txt index d944ff584..cbba8f54f 100644 --- a/chainforge/requirements.txt +++ b/chainforge/requirements.txt @@ -1,10 +1,29 @@ flask>=2.2.3 flask[async] flask_cors +grpcio +numpy<2.0 requests openai urllib3==1.26.6 mistune>=2.0 platformdirs cryptography -markitdown[pdf, docx, xlsx, xls, pptx] \ No newline at end of file +pymupdf +python-docx +tiktoken +nltk>=3.8 +transformers +scikit-learn>=1.4.0 +sentence-transformers +rank-bm25 +whoosh +cohere +markitdown[pdf, docx, xlsx, xls, pptx] +chonkie>=1.0 +model2vec>=0.5.0 # required by chonkie +pyarrow>=14.0,<=16.0.0 +lancedb<=0.18.0 +pandas +accelerate +tqdm \ No newline at end of file diff --git a/docker-compose.gpu.yml b/docker-compose.gpu.yml new file mode 100644 index 000000000..97524d222 --- /dev/null +++ b/docker-compose.gpu.yml @@ -0,0 +1,41 @@ +services: + chainforge-gpu: + build: + context: . + dockerfile: Dockerfile.gpu + image: chainforge/chainforge:gpu + container_name: chainforge-gpu + ports: + - "8000:8000" + volumes: + # Mount data directory for persistent storage + - chainforge-data:/home/chainforge/.local/share/chainforge + environment: + - NODE_ENV=production + - FLASK_ENV=production + # Add API keys as environment variables if needed + # - OPENAI_API_KEY=${OPENAI_API_KEY} + # - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} + # - COHERE_API_KEY=${COHERE_API_KEY} + # - GOOGLE_API_KEY=${GOOGLE_API_KEY} + # - DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY} + # - HUGGINGFACE_API_KEY=${HUGGINGFACE_API_KEY} + restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "curl -f http://localhost:8000 || exit 1"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + +# Named volume for persistent data +volumes: + chainforge-data: + driver: local diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 000000000..25bfa5aa6 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,34 @@ +services: + chainforge: + build: + context: . + dockerfile: Dockerfile + image: chainforge/chainforge:latest + container_name: chainforge + ports: + - "8000:8000" + volumes: + # Mount data directory for persistent storage + - chainforge-data:/home/chainforge/.local/share/chainforge + environment: + - NODE_ENV=production + - FLASK_ENV=production + # Add API keys as environment variables if needed + # - OPENAI_API_KEY=${OPENAI_API_KEY} + # - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} + # - COHERE_API_KEY=${COHERE_API_KEY} + # - GOOGLE_API_KEY=${GOOGLE_API_KEY} + # - DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY} + # - HUGGINGFACE_API_KEY=${HUGGINGFACE_API_KEY} + restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "curl -f http://localhost:8000 || exit 1"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + +# Named volume for persistent data +volumes: + chainforge-data: + driver: local \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 000000000..418d9b342 --- /dev/null +++ b/environment.yml @@ -0,0 +1,211 @@ +name: ragforge +channels: + - conda-forge +dependencies: + - aws-c-auth=0.7.20 + - aws-c-cal=0.6.12 + - aws-c-common=0.9.17 + - aws-c-compression=0.2.18 + - aws-c-event-stream=0.4.2 + - aws-c-http=0.8.1 + - aws-c-io=0.14.8 + - aws-c-mqtt=0.10.4 + - aws-c-s3=0.5.9 + - aws-c-sdkutils=0.1.16 + - aws-checksums=0.1.18 + - aws-crt-cpp=0.26.8 + - aws-sdk-cpp=1.11.267 + - bzip2=1.0.8 + - c-ares=1.34.5 + - ca-certificates=2025.10.5 + - gflags=2.2.2 + - glog=0.7.1 + - icu=75.1 + - krb5=1.21.3 + - libabseil=20240116.2 + - libarrow=16.0.0 + - libarrow-acero=16.0.0 + - libarrow-dataset=16.0.0 + - libarrow-substrait=16.0.0 + - libblas=3.9.0 + - libbrotlicommon=1.1.0 + - libbrotlidec=1.1.0 + - libbrotlienc=1.1.0 + - libcblas=3.9.0 + - libcrc32c=1.1.2 + - libcurl=8.17.0 + - libcxx=21.1.5 + - libedit=3.1.20250104 + - libev=4.33 + - libevent=2.1.12 + - libexpat=2.7.1 + - libffi=3.5.2 + - libgfortran=15.2.0 + - libgfortran5=15.2.0 + - libgoogle-cloud=2.23.0 + - libgoogle-cloud-storage=2.23.0 + - libgrpc=1.62.2 + - liblapack=3.9.0 + - liblzma=5.8.1 + - libnghttp2=1.67.0 + - libopenblas=0.3.30 + - libparquet=16.0.0 + - libprotobuf=4.25.3 + - libre2-11=2023.09.01 + - libsqlite=3.51.0 + - libssh2=1.11.1 + - libthrift=0.19.0 + - libutf8proc=2.8.0 + - libzlib=1.3.1 + - llvm-openmp=21.1.5 + - lz4-c=1.9.4 + - ncurses=6.5 + - numpy=1.26.4 + - openssl=3.5.4 + - orc=2.0.0 + - pip=25.2 + - pyarrow=16.0.0 + - pyarrow-core=16.0.0 + - python=3.12.12 + - python_abi=3.12 + - re2=2023.09.01 + - readline=8.2 + - setuptools=80.9.0 + - snappy=1.2.2 + - tk=8.6.13 + - tzdata=2025b + - wheel=0.45.1 + - zstd=1.5.7 + - pip: + - Flask==3.1.2 + - Jinja2==3.1.6 + - MarkupSafe==3.0.3 + - PyMuPDF==1.26.6 + - PyYAML==6.0.3 + - Pygments==2.19.2 + - SQLAlchemy==2.0.44 + - Werkzeug==3.1.3 + - Whoosh==2.7.4 + - accelerate==1.11.0 + - aiohappyeyeballs==2.6.1 + - aiohttp==3.13.2 + - aiosignal==1.4.0 + - annotated-types==0.7.0 + - anyio==4.11.0 + - asgiref==3.10.0 + - attrs==25.4.0 + - beautifulsoup4==4.14.2 + - blinker==1.9.0 + - blis==1.3.0 + - certifi==2025.10.5 + - cffi==2.0.0 + - charset-normalizer==3.4.4 + - chonkie==1.3.1 + - click==8.3.0 + - cloudpathlib==0.23.0 + - cobble==0.1.4 + - cohere==5.20.0 + - coloredlogs==15.0.1 + - confection==0.1.5 + - cryptography==46.0.3 + - dataclasses-json==0.6.7 + - defusedxml==0.7.1 + - deprecation==2.1.0 + - distro==1.9.0 + - et_xmlfile==2.0.0 + - fastavro==1.12.1 + - filelock==3.20.0 + - flask-cors==6.0.1 + - flatbuffers==25.9.23 + - frozenlist==1.8.0 + - fsspec==2025.10.0 + - grpcio==1.76.0 + - h11==0.16.0 + - hf-xet==1.2.0 + - httpcore==1.0.9 + - httpx==0.28.1 + - httpx-sse==0.4.0 + - huggingface-hub==0.36.0 + - humanfriendly==10.0 + - idna==3.11 + - itsdangerous==2.2.0 + - jiter==0.12.0 + - joblib==1.5.2 + - jsonpatch==1.33 + - jsonpointer==3.0.0 + - lancedb==0.18.0 + - lxml==6.0.2 + - magika==0.6.3 + - mammoth==1.11.0 + - markdown-it-py==4.0.0 + - markdownify==1.2.0 + - markitdown==0.1.3 + - marshmallow==3.26.1 + - mdurl==0.1.2 + - mistune==3.1.4 + - model2vec==0.7.0 + - mpmath==1.3.0 + - multidict==6.7.0 + - mypy_extensions==1.1.0 + - networkx==3.5 + - nltk==3.9.2 + - onnxruntime==1.23.2 + - openai==2.7.1 + - openpyxl==3.1.5 + - orjson==3.11.4 + - ormsgpack==1.12.0 + - overrides==7.7.0 + - packaging==25.0 + - pandas==2.3.3 + - pdfminer.six==20251107 + - pillow==12.0.0 + - platformdirs==4.5.0 + - preshed==3.0.10 + - propcache==0.4.1 + - protobuf==6.33.0 + - psutil==7.1.3 + - pycparser==2.23 + - pydantic==2.12.4 + - pydantic-settings==2.11.0 + - pydantic_core==2.41.5 + - pylance==0.22.0 + - python-dateutil==2.9.0.post0 + - python-docx==1.2.0 + - python-dotenv==1.2.1 + - python-pptx==1.0.2 + - pytz==2025.2 + - rank-bm25==0.2.2 + - regex==2025.11.3 + - requests==2.32.5 + - requests-toolbelt==1.0.0 + - rich==14.2.0 + - safetensors==0.6.2 + - scikit-learn==1.7.2 + - scipy==1.16.3 + - sentence-transformers==5.1.2 + - six==1.17.0 + - smart_open==7.5.0 + - sniffio==1.3.1 + - soupsieve==2.8 + - sympy==1.14.0 + - tenacity==9.1.2 + - threadpoolctl==3.6.0 + - tiktoken==0.12.0 + - tokenizers==0.22.1 + - torch==2.9.0 + - tqdm==4.67.1 + - transformers==4.57.1 + - typer-slim==0.20.0 + - types-requests==2.31.0.6 + - types-urllib3==1.26.25.14 + - typing-inspect==0.9.0 + - typing-inspection==0.4.2 + - typing_extensions==4.15.0 + - tzdata==2025.2 + - urllib3==1.26.6 + - wrapt==2.0.1 + - xlrd==2.0.2 + - xlsxwriter==3.2.9 + - xxhash==3.6.0 + - yarl==1.22.0 + - zstandard==0.25.0 diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..b8757d170 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +# pytest.ini +[pytest] +pythonpath = . diff --git a/setup.py b/setup.py index aea90c5d4..2d790ab57 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,33 @@ from setuptools import setup, find_packages +# Dependency groups +rag_deps = [ + # RAGForge dependencies + "grpcio", + "numpy<2.0", # numpy>=2.0 is not compatible with libraries like torch + "pymupdf", + "python-docx", + "tiktoken", + "nltk>=3.8", + "transformers", + "scikit-learn>=1.4.0", + "sentence-transformers", + "rank-bm25", + "whoosh", + "cohere", + "chonkie>=1.0", + "model2vec>=0.5.0", # required by chonkie + "pyarrow>=14.0,<=16.0.0", # newer versions of pyarrow require CMake 3.25 or higher, which is not compatible with all systems + "lancedb<0.18.0" # pylance requires pyarrow 14 or higher. Later versions of LanceDB give strange errors with pyarrow<=16.0.0. +] + def readme(): with open('README.md', encoding='utf-8') as f: return f.read() setup( name="chainforge", - version="0.3.6.4", + version="0.3.7.0", packages=find_packages(), author="Ian Arawjo", description="A Visual Programming Environment for Prompt Engineering", @@ -16,7 +37,7 @@ def readme(): license="MIT", url="https://github.com/ianarawjo/ChainForge/", install_requires=[ - # Package dependencies + # Core package dependencies (pre-RAGForge) "flask>=2.2.3", "flask[async]", "flask_cors", @@ -28,6 +49,12 @@ def readme(): "mistune>=2.0", # for LLM response markdown parsing "markitdown[pdf, docx, xlsx, xls, pptx]", ], + extras_require={ + # Extra dependencies for functionality like RAGForge, + # which may not be needed by all users + "rag": rag_deps, + "all": rag_deps, + }, entry_points={ "console_scripts": [ "chainforge = chainforge.app:main", diff --git a/tests/test_chunking.py b/tests/test_chunking.py new file mode 100644 index 000000000..ae4ff9137 --- /dev/null +++ b/tests/test_chunking.py @@ -0,0 +1,136 @@ +import pytest +from chainforge.rag.chunkers import ( + chonkie_token, chonkie_sentence, chonkie_recursive, chonkie_semantic, + chonkie_late, chonkie_neural, + overlapping_openai_tiktoken, overlapping_huggingface_tokenizers, + syntax_nltk, syntax_texttiling +) + +class TestChonkieChunking: + + @pytest.fixture(autouse=True) + def setup(self): + # A dummy document to use for all tests + self.dummy_document = """ + This is a test document. It contains several sentences. + Each sentence should be treated as a potential chunk boundary. + We have different paragraphs too! + And some more text to make it a bit longer. + This way we can test chunking methods properly. + Let's add even more text to ensure we have enough tokens for meaningful chunking. + Machine learning models often need a certain amount of text to work with. + The quick brown fox jumps over the lazy dog. + Now is the time for all good men to come to the aid of their country. + """ + + def test_chonkie_token(self): + chunker = chonkie_token + chunks = chunker(self.dummy_document) + assert isinstance(chunks, list) + assert len(chunks) > 0 + for chunk in chunks: + assert isinstance(chunk, str) + + def test_chonkie_token_with_parameters(self): + chunker = chonkie_token + chunks = chunker(self.dummy_document, chunk_size=100, chunk_overlap=10, tokenizer="gpt2") + assert isinstance(chunks, list) + assert len(chunks) > 0 + + def test_chonkie_sentence(self): + chunker = chonkie_sentence + chunks = chunker(self.dummy_document) + assert isinstance(chunks, list) + assert len(chunks) > 0 + for chunk in chunks: + assert isinstance(chunk, str) + + def test_chonkie_sentence_with_parameters(self): + chunker = chonkie_sentence + chunks = chunker(self.dummy_document, chunk_size=100, chunk_overlap=10, + min_sentences_per_chunk=2, min_characters_per_sentence=5) + assert isinstance(chunks, list) + assert len(chunks) > 0 + + def test_chonkie_recursive(self): + chunker = chonkie_recursive + chunks = chunker(self.dummy_document) + assert isinstance(chunks, list) + assert len(chunks) > 0 + for chunk in chunks: + assert isinstance(chunk, str) + + def test_chonkie_recursive_with_parameters(self): + chunker = chonkie_recursive + chunks = chunker(self.dummy_document, chunk_size=100, + min_characters_per_chunk=10) + assert isinstance(chunks, list) + assert len(chunks) > 0 + + def test_chonkie_semantic(self): + chunker = chonkie_semantic + chunks = chunker(self.dummy_document) + assert isinstance(chunks, list) + assert len(chunks) > 0 + for chunk in chunks: + assert isinstance(chunk, str) + + def test_chonkie_semantic_with_parameters(self): + chunker = chonkie_semantic + chunks = chunker(self.dummy_document, chunk_size=100, threshold=0.5, + min_sentences=2, similarity_window=2) + assert isinstance(chunks, list) + assert len(chunks) > 0 + + + # Late chunker may pose problems because + # its dependencies require numpy>=2.0 yet other libraries + # require numpy<2.0. + def test_chonkie_late(self): + chunker = chonkie_late + if chunker is None or not callable(chunker): + pytest.skip("chonkie_late chunker not fully implemented") + chunks = chunker(self.dummy_document) + assert isinstance(chunks, list) + + def test_overlapping_tiktoken(self): + chunker = overlapping_openai_tiktoken + chunks = chunker(self.dummy_document) + assert isinstance(chunks, list) + assert len(chunks) > 0 + for chunk in chunks: + assert isinstance(chunk, str) + + def test_overlapping_tiktoken_with_parameters(self): + chunker = overlapping_openai_tiktoken + chunks = chunker(self.dummy_document, model="gpt-3.5-turbo", + chunk_size=100, chunk_overlap=25) + assert isinstance(chunks, list) + assert len(chunks) > 0 + + def test_overlapping_huggingface(self): + chunker = overlapping_huggingface_tokenizers + chunks = chunker(self.dummy_document) + assert isinstance(chunks, list) + assert len(chunks) > 0 + for chunk in chunks: + assert isinstance(chunk, str) + + def test_syntax_nltk(self): + chunker = syntax_nltk + chunks = chunker(self.dummy_document) + assert isinstance(chunks, list) + assert len(chunks) > 0 + for chunk in chunks: + assert isinstance(chunk, str) + + def test_syntax_texttiling(self): + chunker = syntax_texttiling + # Note: The texttiling chunker may not work well with short texts, so + # we are using a longer dummy document for testing. + chunks = chunker(self.dummy_document + "\n\n" + self.dummy_document) + assert isinstance(chunks, list) + assert len(chunks) > 0 + for chunk in chunks: + assert isinstance(chunk, str) + diff --git a/tests/test_rerankers.py b/tests/test_rerankers.py new file mode 100644 index 000000000..3f31cb9de --- /dev/null +++ b/tests/test_rerankers.py @@ -0,0 +1,114 @@ +import sys +import os +import pytest + +# Add the parent directory to sys.path to import modules +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +# Import after path setup +from chainforge.rag.rerankers import RerankingMethodRegistry + + +def _has_sentence_transformers(): + """Check if sentence-transformers library is available.""" + try: + import sentence_transformers # noqa: F401 + return True + except ImportError: + return False + + +class TestRerankers: + + @pytest.mark.skipif( + not _has_sentence_transformers(), + reason="sentence-transformers library not available" + ) + def test_cross_encoder_rerank(self): + """Test cross-encoder reranking with a small model.""" + # Sample documents and query + documents = [ + "Python is a programming language", + "The weather is nice today", + "Machine learning uses algorithms", + "Cats are cute animals" + ] + query = "programming language" + + # Get the cross-encoder handler + cross_encoder_handler = RerankingMethodRegistry.get_handler("cross_encoder") + assert cross_encoder_handler is not None + + # Test with the smallest available model + results = cross_encoder_handler( + documents=documents, + query=query, + model="cross-encoder/ms-marco-MiniLM-L-6-v2", # Smallest available model + top_k=2 + ) + + # Verify results structure + assert isinstance(results, list) + assert len(results) <= 2 # Should return top_k results + + for result in results: + assert "document" in result + assert "score" in result + assert "index" in result + assert isinstance(result["score"], float) + assert isinstance(result["index"], int) + assert result["document"] in documents + + # Results should be sorted by score (descending) + if len(results) > 1: + assert results[0]["score"] >= results[1]["score"] + + @pytest.mark.skipif( + not os.getenv("COHERE_API_KEY"), + reason="COHERE_API_KEY environment variable not set" + ) + def test_cohere_rerank(self): + """Test Cohere reranking with actual API call.""" + # Sample documents and query + documents = [ + "Python is a programming language", + "The weather is nice today", + "Machine learning uses algorithms", + "Cats are cute animals", + ] + query = "programming language" + + # Get the cohere rerank handler + cohere_handler = RerankingMethodRegistry.get_handler("cohere_rerank") + assert cohere_handler is not None + + # Test reranking with actual API call + results = cohere_handler( + documents=documents, + query=query, + model="rerank-v3.5", + top_k=2 + ) + + # Verify results structure + assert isinstance(results, list) + assert len(results) <= 2 # Should return top_k results or fewer + assert len(results) > 0 # Should return at least one result + + for result in results: + assert "document" in result + assert "score" in result + assert "index" in result + assert isinstance(result["score"], float) + assert isinstance(result["index"], int) + assert result["document"] in documents + assert 0.0 <= result["score"] <= 1.0 # Cohere scores are typically between 0 and 1 + + # Results should be sorted by score (descending) + if len(results) > 1: + assert results[0]["score"] >= results[1]["score"] + + # The most relevant document should be about programming + # (given our query "programming language") + most_relevant = results[0]["document"] + assert "Python" in most_relevant or "programming" in most_relevant diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py new file mode 100644 index 000000000..e33a9ea41 --- /dev/null +++ b/tests/test_retrievers.py @@ -0,0 +1,890 @@ +import pytest +import json +import os +import tempfile +import shutil +import numpy as np +from unittest.mock import patch, MagicMock +import sys + +from chainforge.flask_app import app +from chainforge.rag.embeddings import EmbeddingMethodRegistry +from chainforge.rag.retrievers import RetrievalMethodRegistry +from chainforge.rag.vector_stores import LancedbVectorStore, FaissVectorStore + +# Add the parent directory to sys.path to import flask_app +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +@pytest.fixture +def client(): + """Create a test client for the app.""" + app.config['TESTING'] = True + with app.test_client() as client: + yield client + +@pytest.fixture +def temp_db_dir(): + """Create a temporary directory for database storage.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + # Cleanup after test + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + +@pytest.fixture +def sample_chunks(): + """Sample chunks for testing.""" + return [ + { + "text": "Python is a high-level programming language known for its simplicity and readability.", + "docTitle": "Python Basics", + "chunkId": "chunk1" + }, + { + "text": "Machine learning is a subset of artificial intelligence that enables systems to learn from data.", + "docTitle": "ML Introduction", + "chunkId": "chunk2" + }, + { + "text": "Deep learning uses neural networks with multiple layers to process complex patterns in data.", + "docTitle": "Deep Learning", + "chunkId": "chunk3" + }, + { + "text": "Natural language processing allows computers to understand and generate human language.", + "docTitle": "NLP Overview", + "chunkId": "chunk4" + }, + { + "text": "Data science combines statistics, programming, and domain expertise to extract insights from data.", + "docTitle": "Data Science", + "chunkId": "chunk5" + } + ] + +@pytest.fixture +def sample_queries(): + """Sample queries for testing.""" + return [ + {"text": "What is Python?", "metavars": {}}, + {"text": "Tell me about machine learning", "metavars": {}}, + {"text": "How does deep learning work?", "metavars": {}} + ] + +class TestRetrieveEndpoint: + + def test_retrieve_missing_methods(self, client): + """Test the /retrieve endpoint with missing methods.""" + response = client.post('/retrieve', json={ + "chunks": [{"text": "sample text"}], + "queries": [{"text": "sample query"}], + }) + assert response.status_code == 400 + + def test_retrieve_missing_chunks(self, client): + """Test the /retrieve endpoint with missing chunks.""" + response = client.post('/retrieve', json={ + "methods": [{"id": "method1", "baseMethod": "bm25", "methodName": "BM25", "library": "BM25"}], + "queries": [{"text": "sample query"}], + }) + assert response.status_code == 400 + + def test_retrieve_missing_queries(self, client): + """Test the /retrieve endpoint with missing queries.""" + response = client.post('/retrieve', json={ + "methods": [{"id": "method1", "baseMethod": "bm25", "methodName": "BM25", "library": "BM25"}], + "chunks": [{"text": "sample text"}], + }) + assert response.status_code == 400 + + @patch('chainforge.rag.retrievers.RetrievalMethodRegistry.get_handler') + def test_retrieve_bm25(self, mock_get_handler, client): + """Test the /retrieve endpoint with BM25.""" + # Set up mock handler + mock_handler = MagicMock() + mock_response = [ + { + 'query_object': {'text': 'What is Python?'}, + 'retrieved_chunks': [ + { + 'text': 'Python is a programming language.', + 'similarity': 0.95, + 'docTitle': 'Programming Languages', + 'chunkId': 'chunk1' + } + ] + } + ] + mock_handler.return_value = mock_response + mock_get_handler.return_value = mock_handler + + # Make request + request_data = { + "methods": [ + { + "id": "method1", + "baseMethod": "bm25", + "methodName": "BM25", + "library": "BM25", + "settings": {"top_k": 3} + } + ], + "chunks": [ + { + "text": "Python is a programming language.", + "prompt": "original query", + "metavars": {"docTitle": "Programming Languages", "chunkId": "chunk1"}, + "fill_history": {"chunkMethod": "test_method"} + } + ], + "queries": [ + { + "text": "What is Python?", + "metavars": {"docTitle": "Questions"} + } + ] + } + + response = client.post('/retrieve', json=request_data) + assert response.status_code == 200 + + # Verify handler was retrieved and called + mock_get_handler.assert_called_with("bm25") + assert mock_handler.called + + # Check response format + result = json.loads(response.data) + assert isinstance(result, list) + assert len(result) > 0 + assert "text" in result[0] + assert "metavars" in result[0] + assert "methodId" in result[0]["metavars"] + assert "retrievalMethodSignature" in result[0]["metavars"] + assert "docTitle" in result[0]["metavars"] + assert "chunkId" in result[0]["metavars"] + + @patch('chainforge.rag.retrievers.RetrievalMethodRegistry.get_handler') + def test_retrieve_tfidf(self, mock_get_handler, client): + """Test the /retrieve endpoint with TF-IDF.""" + # Set up mock handler + mock_handler = MagicMock() + mock_response = [ + { + 'query_object': {'text': 'How to install Python?'}, + 'retrieved_chunks': [ + { + 'text': 'To install Python, download it from python.org.', + 'similarity': 0.88, + 'docTitle': 'Installation Guide', + 'chunkId': 'chunk2' + } + ] + } + ] + mock_handler.return_value = mock_response + mock_get_handler.return_value = mock_handler + + # Make request + request_data = { + "methods": [ + { + "id": "method2", + "baseMethod": "tfidf", + "methodName": "TF-IDF", + "library": "sklearn", + "settings": {"top_k": 3} + } + ], + "chunks": [ + { + "text": "To install Python, download it from python.org.", + "prompt": "original query", + "metavars": {"docTitle": "Installation Guide", "chunkId": "chunk2"}, + "fill_history": {"chunkMethod": "test_method"} + } + ], + "queries": [ + { + "text": "How to install Python?", + "metavars": {"docTitle": "Questions"} + } + ] + } + + response = client.post('/retrieve', json=request_data) + assert response.status_code == 200 + + # Verify handler was retrieved and called + mock_get_handler.assert_called_with("tfidf") + assert mock_handler.called + + +# ============================================================================ +# EMBEDDING RETRIEVAL TESTS - COMPREHENSIVE COVERAGE +# ============================================================================ + +class TestEmbeddingRetrievalWithLanceDB: + """Test embedding-based retrieval with LanceDB backend across all providers.""" + + @pytest.mark.parametrize("similarity_metric", ["cosine", "euclidean", "dot_product"]) + def test_lancedb_openai_embeddings(self, temp_db_dir, sample_chunks, sample_queries, similarity_metric): + """Test LanceDB retrieval with OpenAI embeddings and different similarity metrics.""" + # Skip if OpenAI API key not available + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + # Get embedder + embedder = EmbeddingMethodRegistry.get_embedder("openai") + api_keys = {"OpenAI": os.environ.get("OPENAI_API_KEY")} + + # Generate embeddings + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="text-embedding-3-small", api_keys=api_keys) + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="text-embedding-3-small", api_keys=api_keys) + + # Test retrieval + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = { + "top_k": 3, + "metric": similarity_metric, + } + + results = handler( + sample_chunks, + chunk_embeddings, + sample_queries, + query_embeddings, + settings, + temp_db_dir + ) + + # Assertions + assert len(results) == len(sample_queries) + for result in results: + assert "query_object" in result + assert "retrieved_chunks" in result + assert len(result["retrieved_chunks"]) <= 3 + + # Check structure of retrieved chunks + for chunk in result["retrieved_chunks"]: + assert "text" in chunk + assert "similarity" in chunk + assert "id" in chunk + assert isinstance(chunk["similarity"], (float, int)) + assert 0 <= chunk["similarity"] <= 1.1 # Allow slight numerical imprecision + + @pytest.mark.parametrize("top_k", [1, 3, 5]) + def test_lancedb_openai_different_top_k(self, temp_db_dir, sample_chunks, sample_queries, top_k): + """Test LanceDB retrieval with different top_k values.""" + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + embedder = EmbeddingMethodRegistry.get_embedder("openai") + api_keys = {"OpenAI": os.environ.get("OPENAI_API_KEY")} + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="text-embedding-3-small", api_keys=api_keys) + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="text-embedding-3-small", api_keys=api_keys) + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": top_k, "metric": "cosine"} + + results = handler(sample_chunks, chunk_embeddings, sample_queries, query_embeddings, settings, temp_db_dir) + + for result in results: + assert len(result["retrieved_chunks"]) <= top_k + # Check that results are ordered by similarity (descending) + similarities = [chunk["similarity"] for chunk in result["retrieved_chunks"]] + assert similarities == sorted(similarities, reverse=True) + + # Commented out Cohere tests for now due to API availability issues + # def test_lancedb_cohere_embeddings(self, temp_db_dir, sample_chunks, sample_queries): + # """Test LanceDB retrieval with Cohere embeddings.""" + # if not os.environ.get("COHERE_API_KEY"): + # pytest.skip("COHERE_API_KEY not set") + + # embedder = EmbeddingMethodRegistry.get_embedder("cohere") + # api_keys = {"Cohere": os.environ.get("COHERE_API_KEY")} + + # chunk_texts = [c["text"] for c in sample_chunks] + # chunk_embeddings = embedder(chunk_texts, model_name="embed-english-v3.0", api_keys=api_keys) + + # query_texts = [q["text"] for q in sample_queries] + # query_embeddings = embedder(query_texts, model_name="embed-english-v3.0", api_keys=api_keys) + + # handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + # settings = {"top_k": 3, "metric": "cosine"} + + # results = handler(sample_chunks, chunk_embeddings, sample_queries, query_embeddings, settings, temp_db_dir) + + # assert len(results) == len(sample_queries) + # for result in results: + # assert len(result["retrieved_chunks"]) <= 3 + # assert all("similarity" in chunk for chunk in result["retrieved_chunks"]) + + def test_lancedb_sentence_transformers(self, temp_db_dir, sample_chunks, sample_queries): + """Test LanceDB retrieval with Sentence Transformers embeddings.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": 3, "metric": "cosine"} + + results = handler(sample_chunks, chunk_embeddings, sample_queries, query_embeddings, settings, temp_db_dir) + + assert len(results) == len(sample_queries) + for result in results: + assert "query_object" in result + assert "retrieved_chunks" in result + assert len(result["retrieved_chunks"]) <= 3 + + def test_lancedb_huggingface_embeddings(self, temp_db_dir, sample_chunks, sample_queries): + """Test LanceDB retrieval with HuggingFace embeddings.""" + embedder = EmbeddingMethodRegistry.get_embedder("huggingface") + + # Use a small model for testing + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="sentence-transformers/all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="sentence-transformers/all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": 3, "metric": "cosine"} + + results = handler(sample_chunks, chunk_embeddings, sample_queries, query_embeddings, settings, temp_db_dir) + + assert len(results) == len(sample_queries) + for result in results: + assert len(result["retrieved_chunks"]) <= 3 + + def test_lancedb_persistence(self, temp_db_dir, sample_chunks, sample_queries): + """Test that LanceDB properly persists and reloads data.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries[:1]] + query_embeddings = embedder(query_texts, model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": 3, "metric": "cosine"} + + # First retrieval - creates database + results1 = handler(sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, settings, temp_db_dir) + + # Second retrieval - should reuse existing database + results2 = handler(sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, settings, temp_db_dir) + + # Results should be identical + assert len(results1) == len(results2) + assert results1[0]["retrieved_chunks"][0]["text"] == results2[0]["retrieved_chunks"][0]["text"] + + def test_lancedb_relevance_ordering(self, temp_db_dir, sample_chunks): + """Test that LanceDB returns results in order of relevance.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + # Query that should strongly match first chunk + query = [{"text": "Python programming language", "metavars": {}}] + query_embeddings = embedder([query[0]["text"]], model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": 5, "metric": "cosine"} + + results = handler(sample_chunks, chunk_embeddings, query, query_embeddings, settings, temp_db_dir) + + # First result should be the Python chunk + assert "Python" in results[0]["retrieved_chunks"][0]["text"] + + # Verify descending order + similarities = [chunk["similarity"] for chunk in results[0]["retrieved_chunks"]] + assert similarities == sorted(similarities, reverse=True) + + +class TestEmbeddingRetrievalWithFAISS: + """Test embedding-based retrieval with FAISS backend across all providers.""" + + @pytest.mark.parametrize("similarity_metric", ["l2", "cosine", "dot"]) + def test_faiss_openai_embeddings(self, temp_db_dir, sample_chunks, sample_queries, similarity_metric): + """Test FAISS retrieval with OpenAI embeddings and different similarity metrics.""" + # Skip if FAISS not installed + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + embedder = EmbeddingMethodRegistry.get_embedder("openai") + api_keys = {"OpenAI": os.environ.get("OPENAI_API_KEY")} + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="text-embedding-3-small", api_keys=api_keys) + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="text-embedding-3-small", api_keys=api_keys) + + handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + settings = {"top_k": 3, "metric": similarity_metric} + + results = handler(sample_chunks, chunk_embeddings, sample_queries, query_embeddings, settings, temp_db_dir) + + assert len(results) == len(sample_queries) + for result in results: + assert "query_object" in result + assert "retrieved_chunks" in result + assert len(result["retrieved_chunks"]) <= 3 + + for chunk in result["retrieved_chunks"]: + assert "text" in chunk + assert "similarity" in chunk + assert "id" in chunk + assert isinstance(chunk["similarity"], (float, int)) + + @pytest.mark.parametrize("top_k", [1, 3, 5]) + def test_faiss_different_top_k(self, temp_db_dir, sample_chunks, sample_queries, top_k): + """Test FAISS retrieval with different top_k values.""" + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + embedder = EmbeddingMethodRegistry.get_embedder("openai") + api_keys = {"OpenAI": os.environ.get("OPENAI_API_KEY")} + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="text-embedding-3-small", api_keys=api_keys) + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="text-embedding-3-small", api_keys=api_keys) + + handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + settings = {"top_k": top_k, "metric": "l2"} + + results = handler(sample_chunks, chunk_embeddings, sample_queries, query_embeddings, settings, temp_db_dir) + + for result in results: + assert len(result["retrieved_chunks"]) <= top_k + + def test_faiss_cohere_embeddings(self, temp_db_dir, sample_chunks, sample_queries): + """Test FAISS retrieval with Cohere embeddings.""" + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + if not os.environ.get("COHERE_API_KEY"): + pytest.skip("COHERE_API_KEY not set") + + embedder = EmbeddingMethodRegistry.get_embedder("cohere") + api_keys = {"Cohere": os.environ.get("COHERE_API_KEY")} + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="embed-english-v3.0", api_keys=api_keys) + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="embed-english-v3.0", api_keys=api_keys) + + handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + settings = {"top_k": 3, "metric": "cosine"} + + results = handler(sample_chunks, chunk_embeddings, sample_queries, query_embeddings, settings, temp_db_dir) + + assert len(results) == len(sample_queries) + for result in results: + assert len(result["retrieved_chunks"]) <= 3 + + def test_faiss_sentence_transformers(self, temp_db_dir, sample_chunks, sample_queries): + """Test FAISS retrieval with Sentence Transformers embeddings.""" + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + settings = {"top_k": 3, "metric": "cosine"} + + results = handler(sample_chunks, chunk_embeddings, sample_queries, query_embeddings, settings, temp_db_dir) + + assert len(results) == len(sample_queries) + for result in results: + assert len(result["retrieved_chunks"]) <= 3 + + def test_faiss_huggingface_embeddings(self, temp_db_dir, sample_chunks, sample_queries): + """Test FAISS retrieval with HuggingFace embeddings.""" + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + embedder = EmbeddingMethodRegistry.get_embedder("huggingface") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="sentence-transformers/all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="sentence-transformers/all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + settings = {"top_k": 3, "metric": "l2"} + + results = handler(sample_chunks, chunk_embeddings, sample_queries, query_embeddings, settings, temp_db_dir) + + assert len(results) == len(sample_queries) + for result in results: + assert len(result["retrieved_chunks"]) <= 3 + + def test_faiss_persistence(self, temp_db_dir, sample_chunks, sample_queries): + """Test that FAISS properly persists and reloads data.""" + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries[:1]] + query_embeddings = embedder(query_texts, model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + settings = {"top_k": 3, "metric": "l2"} + + # First retrieval - creates database + results1 = handler(sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, settings, temp_db_dir) + + # Second retrieval - should reuse existing database + results2 = handler(sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, settings, temp_db_dir) + + # Results should be consistent + assert len(results1) == len(results2) + assert results1[0]["retrieved_chunks"][0]["text"] == results2[0]["retrieved_chunks"][0]["text"] + + def test_faiss_relevance_ordering(self, temp_db_dir, sample_chunks): + """Test that FAISS returns results in order of relevance.""" + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + # Query that should strongly match first chunk + query = [{"text": "Python programming language", "metavars": {}}] + query_embeddings = embedder([query[0]["text"]], model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + settings = {"top_k": 5, "metric": "cosine"} + + results = handler(sample_chunks, chunk_embeddings, query, query_embeddings, settings, temp_db_dir) + + # First result should be the Python chunk + assert "Python" in results[0]["retrieved_chunks"][0]["text"] + + # Verify descending order + similarities = [chunk["similarity"] for chunk in results[0]["retrieved_chunks"]] + assert similarities == sorted(similarities, reverse=True) + + +class TestEmbeddingRetrievalCrossBackend: + """Test consistency across LanceDB and FAISS backends.""" + + def test_backends_return_similar_results(self, temp_db_dir, sample_chunks, sample_queries): + """Test that LanceDB and FAISS return similar top results.""" + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries[:1]] + query_embeddings = embedder(query_texts, model_name="all-MiniLM-L6-v2") + + # Test with LanceDB + lancedb_dir = os.path.join(temp_db_dir, "lancedb") + os.makedirs(lancedb_dir, exist_ok=True) + lance_handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + lance_results = lance_handler( + sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, + {"top_k": 3, "metric": "cosine"}, lancedb_dir + ) + + # Test with FAISS + faiss_dir = os.path.join(temp_db_dir, "faiss") + os.makedirs(faiss_dir, exist_ok=True) + faiss_handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + faiss_results = faiss_handler( + sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, + {"top_k": 3, "metric": "cosine"}, faiss_dir + ) + + # Both should return the same top result (most relevant chunk) + assert lance_results[0]["retrieved_chunks"][0]["text"] == faiss_results[0]["retrieved_chunks"][0]["text"] + + @pytest.mark.parametrize("embedding_provider", [ + "sentence-transformers", + pytest.param("openai", marks=pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")), + pytest.param("cohere", marks=pytest.mark.skipif(not os.environ.get("COHERE_API_KEY"), reason="COHERE_API_KEY not set")) + ]) + def test_all_providers_with_both_backends(self, temp_db_dir, sample_chunks, sample_queries, embedding_provider): + """Test that all embedding providers work with both backends.""" + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + embedder = EmbeddingMethodRegistry.get_embedder(embedding_provider) + + # Configure model and API keys based on provider + if embedding_provider == "openai": + model_name = "text-embedding-3-small" + api_keys = {"OpenAI": os.environ.get("OPENAI_API_KEY")} + elif embedding_provider == "cohere": + model_name = "embed-english-v3.0" + api_keys = {"Cohere": os.environ.get("COHERE_API_KEY")} + else: + model_name = "all-MiniLM-L6-v2" + api_keys = None + + chunk_texts = [c["text"] for c in sample_chunks] + if api_keys: + chunk_embeddings = embedder(chunk_texts, model_name=model_name, api_keys=api_keys) + else: + chunk_embeddings = embedder(chunk_texts, model_name=model_name) + + query_texts = [q["text"] for q in sample_queries[:1]] + if api_keys: + query_embeddings = embedder(query_texts, model_name=model_name, api_keys=api_keys) + else: + query_embeddings = embedder(query_texts, model_name=model_name) + + # Test LanceDB + lancedb_dir = os.path.join(temp_db_dir, f"lancedb_{embedding_provider}") + os.makedirs(lancedb_dir, exist_ok=True) + lance_handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + lance_results = lance_handler( + sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, + {"top_k": 3, "metric": "cosine"}, lancedb_dir + ) + assert len(lance_results) == 1 + assert len(lance_results[0]["retrieved_chunks"]) <= 3 + + # Test FAISS + faiss_dir = os.path.join(temp_db_dir, f"faiss_{embedding_provider}") + os.makedirs(faiss_dir, exist_ok=True) + faiss_handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + faiss_results = faiss_handler( + sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, + {"top_k": 3, "metric": "cosine"}, faiss_dir + ) + assert len(faiss_results) == 1 + assert len(faiss_results[0]["retrieved_chunks"]) <= 3 + + +class TestEmbeddingRetrievalEdgeCases: + """Test edge cases and error handling for embedding retrieval.""" + + def test_empty_queries(self, temp_db_dir, sample_chunks): + """Test handling of empty query list.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + + with pytest.raises(Exception): + handler(sample_chunks, chunk_embeddings, [], [], {"top_k": 3, "metric": "cosine"}, temp_db_dir) + + def test_empty_chunks(self, temp_db_dir, sample_queries): + """Test handling of empty chunk list.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + query_texts = [q["text"] for q in sample_queries] + query_embeddings = embedder(query_texts, model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + + with pytest.raises(Exception): + handler([], [], sample_queries, query_embeddings, {"top_k": 3, "metric": "cosine"}, temp_db_dir) + + def test_top_k_larger_than_chunks(self, temp_db_dir, sample_chunks, sample_queries): + """Test that top_k larger than number of chunks works correctly.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries[:1]] + query_embeddings = embedder(query_texts, model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": 100, "metric": "cosine"} # Much larger than 5 chunks + + results = handler(sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, settings, temp_db_dir) + + # Should return all chunks + assert len(results[0]["retrieved_chunks"]) == len(sample_chunks) + + def test_single_chunk_single_query(self, temp_db_dir): + """Test with minimal input: one chunk and one query.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunks = [{"text": "Test chunk", "docTitle": "Test", "chunkId": "1"}] + chunk_embeddings = embedder([chunks[0]["text"]], model_name="all-MiniLM-L6-v2") + + queries = [{"text": "Test query", "metavars": {}}] + query_embeddings = embedder([queries[0]["text"]], model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": 1, "metric": "cosine"} + + results = handler(chunks, chunk_embeddings, queries, query_embeddings, settings, temp_db_dir) + + assert len(results) == 1 + assert len(results[0]["retrieved_chunks"]) == 1 + assert results[0]["retrieved_chunks"][0]["text"] == "Test chunk" + + def test_duplicate_chunks(self, temp_db_dir, sample_queries): + """Test handling of duplicate chunks.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + # Create chunks with duplicates + chunks = [ + {"text": "Duplicate text", "docTitle": "Doc1", "chunkId": "1"}, + {"text": "Duplicate text", "docTitle": "Doc2", "chunkId": "2"}, + {"text": "Unique text", "docTitle": "Doc3", "chunkId": "3"}, + ] + + chunk_texts = [c["text"] for c in chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries[:1]] + query_embeddings = embedder(query_texts, model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": 3, "metric": "cosine"} + + # Should handle duplicates gracefully + results = handler(chunks, chunk_embeddings, sample_queries[:1], query_embeddings, settings, temp_db_dir) + + assert len(results) == 1 + assert len(results[0]["retrieved_chunks"]) >= 1 + + @pytest.mark.parametrize("invalid_metric", ["invalid", "unknown", ""]) + def test_invalid_similarity_metric_lancedb(self, temp_db_dir, sample_chunks, sample_queries, invalid_metric): + """Test handling of invalid similarity metrics.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunk_texts = [c["text"] for c in sample_chunks] + chunk_embeddings = embedder(chunk_texts, model_name="all-MiniLM-L6-v2") + + query_texts = [q["text"] for q in sample_queries[:1]] + query_embeddings = embedder(query_texts, model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": 3, "metric": invalid_metric} + + # Should either handle gracefully or raise appropriate error + try: + results = handler(sample_chunks, chunk_embeddings, sample_queries[:1], query_embeddings, settings, temp_db_dir) + # If it doesn't raise, it should still return valid results + assert len(results) == 1 + except (ValueError, Exception): + # Expected for invalid metrics + pass + + +class TestEmbeddingRetrievalMetadata: + """Test that metadata is properly preserved through retrieval.""" + + def test_metadata_preservation_lancedb(self, temp_db_dir): + """Test that chunk metadata is preserved in LanceDB retrieval.""" + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunks = [ + { + "text": "Test chunk with metadata", + "docTitle": "Important Doc", + "chunkId": "special-123", + "customField": "custom_value" + } + ] + + chunk_embeddings = embedder([chunks[0]["text"]], model_name="all-MiniLM-L6-v2") + + query = [{"text": "test query", "metavars": {}}] + query_embeddings = embedder([query[0]["text"]], model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("lancedb_vector_store") + settings = {"top_k": 1, "metric": "cosine"} + + results = handler(chunks, chunk_embeddings, query, query_embeddings, settings, temp_db_dir) + + # Check that the retrieved chunk has correct text + assert results[0]["retrieved_chunks"][0]["text"] == "Test chunk with metadata" + + def test_metadata_preservation_faiss(self, temp_db_dir): + """Test that chunk metadata is preserved in FAISS retrieval.""" + try: + import faiss + except ImportError: + pytest.skip("FAISS not installed") + + embedder = EmbeddingMethodRegistry.get_embedder("sentence-transformers") + + chunks = [ + { + "text": "Test chunk with metadata", + "docTitle": "Important Doc", + "chunkId": "special-456" + } + ] + + chunk_embeddings = embedder([chunks[0]["text"]], model_name="all-MiniLM-L6-v2") + + query = [{"text": "test query", "metavars": {}}] + query_embeddings = embedder([query[0]["text"]], model_name="all-MiniLM-L6-v2") + + handler = RetrievalMethodRegistry.get_handler("faiss_vector_store") + settings = {"top_k": 1, "metric": "l2"} + + results = handler(chunks, chunk_embeddings, query, query_embeddings, settings, temp_db_dir) + + # Check that the retrieved chunk has correct text + assert results[0]["retrieved_chunks"][0]["text"] == "Test chunk with metadata" + \ No newline at end of file diff --git a/tests/test_vector_store.py b/tests/test_vector_store.py new file mode 100644 index 000000000..4573bf5f5 --- /dev/null +++ b/tests/test_vector_store.py @@ -0,0 +1,224 @@ +import pytest +import tempfile +import os +import numpy as np +from chainforge.rag.vector_stores import LancedbVectorStore + +class TestLocalVectorStore: + + @pytest.fixture + def dummy_embedder(self): + """A simple mock embedding function that generates random vectors""" + def _embed(texts): + return [np.random.randn(384).tolist() for _ in texts] + return _embed + + @pytest.fixture + def dummy_documents(self): + """Return dummy documents about programming languages""" + return [ + "Python is an interpreted high-level general-purpose programming language. Its design philosophy emphasizes code readability with its use of significant indentation.", + "JavaScript is a programming language that conforms to the ECMAScript specification. JavaScript is high-level, often just-in-time compiled, and multi-paradigm.", + "Java is a high-level, class-based, object-oriented programming language that is designed to have as few implementation dependencies as possible.", + "C++ is a general-purpose programming language created by Bjarne Stroustrup as an extension of the C programming language.", + "Ruby is an interpreted, high-level, general-purpose programming language. It was designed and developed in the mid-1990s by Yukihiro Matsumoto in Japan.", + "Go is a statically typed, compiled programming language designed at Google by Robert Griesemer, Rob Pike, and Ken Thompson.", + "Rust is a multi-paradigm, high-level, general-purpose programming language designed for performance and safety, especially safe concurrency.", + "Swift is a general-purpose, multi-paradigm, compiled programming language developed by Apple Inc. for iOS, iPadOS, macOS, watchOS, tvOS, and Linux." + ] + + @pytest.fixture + def dummy_metadata(self): + """Return metadata for dummy documents""" + return [ + {"language": "Python", "year": 1991, "creator": "Guido van Rossum", "paradigm": "multi-paradigm"}, + {"language": "JavaScript", "year": 1995, "creator": "Brendan Eich", "paradigm": "multi-paradigm"}, + {"language": "Java", "year": 1995, "creator": "James Gosling", "paradigm": "object-oriented"}, + {"language": "C++", "year": 1985, "creator": "Bjarne Stroustrup", "paradigm": "multi-paradigm"}, + {"language": "Ruby", "year": 1995, "creator": "Yukihiro Matsumoto", "paradigm": "multi-paradigm"}, + {"language": "Go", "year": 2009, "creator": "Google", "paradigm": "concurrent"}, + {"language": "Rust", "year": 2010, "creator": "Mozilla", "paradigm": "multi-paradigm"}, + {"language": "Swift", "year": 2014, "creator": "Apple", "paradigm": "multi-paradigm"} + ] + + @pytest.fixture + def vector_store(self, tmp_path, dummy_embedder): + """Create a temporary vector store for testing""" + temp_dir = tmp_path / 'chainforge_test' + temp_dir.mkdir(exist_ok=True) + db_path = os.path.join(temp_dir, 'test_vector_store.db') + store = LancedbVectorStore(db_path=db_path, embedding_func=dummy_embedder) + yield store + # Cleanup + # if os.path.exists(db_path): + # os.remove(db_path) + # if os.path.exists(temp_dir): + # os.rmdir(temp_dir) + + def test_add_and_count(self, vector_store, dummy_documents, dummy_metadata): + """Test adding documents to the vector store and counting them""" + # Add documents + ids = vector_store.add(dummy_documents, metadata=dummy_metadata) + + # Verify correct number of IDs returned + assert len(ids) == len(dummy_documents) + + # Verify count method + assert vector_store.count() == len(dummy_documents) + + # Verify adding same documents again doesn't increase count + ids2 = vector_store.add(dummy_documents, metadata=dummy_metadata) + assert vector_store.count() == len(dummy_documents) + assert set(ids) == set(ids2) # IDs should be the same + + def test_get_by_id(self, vector_store, dummy_documents, dummy_metadata): + """Test retrieving documents by ID""" + ids = vector_store.add(dummy_documents, metadata=dummy_metadata) + + # Get Python document + python_doc = vector_store.get(ids[0]) + assert python_doc is not None + assert python_doc["text"] == dummy_documents[0] + assert python_doc["metadata"] == dummy_metadata[0] + assert python_doc["id"] == ids[0] + + # Test non-existent ID + non_existent = vector_store.get("this_id_does_not_exist") + assert non_existent is None + + def test_get_all(self, vector_store, dummy_documents, dummy_metadata): + """Test retrieving all documents""" + vector_store.add(dummy_documents, metadata=dummy_metadata) + + # Get all documents + all_docs = vector_store.get_all() + assert len(all_docs) == len(dummy_documents) + + # Verify data of all documents match the original + for i, doc in enumerate(all_docs): + assert doc["text"] == dummy_documents[i] + assert doc["metadata"] == dummy_metadata[i] + assert doc["embedding"] is not None # Ensure embedding exists + + # Get with limit + limited_docs = vector_store.get_all(limit=3) + assert len(limited_docs) == 3 + + # Get with offset + offset_docs = vector_store.get_all(offset=2) + assert len(offset_docs) == len(dummy_documents) - 2 + + # Combine limit and offset + paged_docs = vector_store.get_all(limit=2, offset=2) + assert len(paged_docs) == 2 + + def test_delete(self, vector_store, dummy_documents, dummy_metadata): + """Test deleting documents""" + ids = vector_store.add(dummy_documents, metadata=dummy_metadata) + + # Delete first two documents + success = vector_store.delete(ids[:2]) + assert success + + # Verify count decreased + assert vector_store.count() == len(dummy_documents) - 2 + + # Verify first document no longer exists + assert vector_store.get(ids[0]) is None + + # But third document still exists + assert vector_store.get(ids[2]) is not None + + def test_update(self, vector_store, dummy_documents, dummy_metadata): + """Test updating documents""" + ids = vector_store.add(dummy_documents, metadata=dummy_metadata) + assert vector_store.count() == len(dummy_documents) + + # Update text + updated_text = "Python is an amazing language for data science and machine learning!" + new_id = vector_store.update(ids[0], text=updated_text) + assert new_id is not None # New ID should be returned + + # Verify text was updated + doc = vector_store.get(new_id) + assert doc is not None + assert doc["text"] == updated_text + assert doc["metadata"] == dummy_metadata[0] # Metadata unchanged + + # Update metadata + updated_metadata = {"language": "Python", "year": 1991, + "creator": "Guido van Rossum", + "paradigm": "multi-paradigm", + "usage": "data science, web development, automation"} + new_new_id = vector_store.update(new_id, metadata=updated_metadata) + assert new_new_id is not None # New ID should be returned + assert new_new_id == new_id # ID should remain the same, because we didn't change text + + # Verify metadata was updated + doc = vector_store.get(new_id) + assert doc["text"] == updated_text # Text remains the same from previous update + assert doc["metadata"] == updated_metadata + + def test_search_with_embedding(self, vector_store, dummy_documents, dummy_metadata, dummy_embedder): + """Test searching with a query embedding""" + vector_store.add(dummy_documents, metadata=dummy_metadata) + + # Generate a random query embedding + query_embedding = np.random.randn(384).tolist() + + # Simple similarity search + results = vector_store.search(query_embedding, k=3) + assert len(results) <= 3 # May be fewer if there are fewer documents + for result in results: + assert "id" in result + assert "text" in result + assert "score" in result + assert "metadata" in result + + # Try different distance_metrics + results_l2 = vector_store.search(query_embedding, k=3, distance_metric="l2") + assert len(results_l2) <= 3 + results_cosine = vector_store.search(query_embedding, k=3, distance_metric="cosine") + assert len(results_cosine) <= 3 + results_dot = vector_store.search(query_embedding, k=3, distance_metric="dot") + assert len(results_dot) <= 3 + + # Test MMR search + mmr_results = vector_store.search(query_embedding, k=3, method="mmr") + assert len(mmr_results) <= 3 + + # Test hybrid search + hybrid_results = vector_store.search(query_embedding, k=3, method="hybrid") + assert len(hybrid_results) <= 3 + + def test_clear(self, vector_store, dummy_documents, dummy_metadata): + """Test clearing all documents""" + vector_store.add(dummy_documents, metadata=dummy_metadata) + assert vector_store.count() == len(dummy_documents) + + # Clear the store + success = vector_store.clear() + assert success + + # Verify store is empty + assert vector_store.count() == 0 + assert len(vector_store.get_all()) == 0 + + def test_search_with_embedding_function(self, vector_store, dummy_documents, dummy_metadata): + """Test searching using the embedding function""" + # Add documents using the embedding function + vector_store.add(dummy_documents, metadata=dummy_metadata) + + # Use the same embedding function for a query + query = "Which programming language is best for web development?" + query_embedding = vector_store.embedding_func([query])[0] + + results = vector_store.search(query_embedding, k=3) + assert len(results) <= 3 + + # Results should have all required fields + for result in results: + assert "id" in result + assert "text" in result + assert "score" in result + assert "metadata" in result \ No newline at end of file