diff --git a/.gitignore b/.gitignore
index d7ae2881..f33db3f3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -63,6 +63,7 @@ claude_agent_settings.json
.coverage
htmlcov/
.pytest_cache/
+.test-results/
# Skill test run results (detailed per-task logs with full responses)
.test/skills/*/runs/
diff --git a/.test/tests/integration/test_compute.py b/.test/tests/integration/test_compute.py
new file mode 100644
index 00000000..688d1536
--- /dev/null
+++ b/.test/tests/integration/test_compute.py
@@ -0,0 +1,245 @@
+"""Integration tests for compute.py CLI script.
+
+Tests actual subprocess execution of the compute CLI script.
+"""
+import json
+import subprocess
+import sys
+from pathlib import Path
+
+import pytest
+
+# Get repo root for running scripts
+_repo_root = Path(__file__).resolve().parents[3]
+_compute_script = _repo_root / "databricks-skills" / "databricks-execution-compute" / "scripts" / "compute.py"
+
+
+class TestComputeScriptHelp:
+ """Test compute.py help and basic CLI structure."""
+
+ def test_script_shows_help(self):
+ """Verify script has help output."""
+ result = subprocess.run(
+ [sys.executable, str(_compute_script), "--help"],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=10
+ )
+
+ assert result.returncode == 0
+ assert "execute-code" in result.stdout
+ assert "list-compute" in result.stdout
+ assert "manage-cluster" in result.stdout
+
+ def test_execute_code_help(self):
+ """Verify execute-code subcommand help."""
+ result = subprocess.run(
+ [sys.executable, str(_compute_script), "execute-code", "--help"],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=10
+ )
+
+ assert result.returncode == 0
+ assert "--code" in result.stdout
+ assert "--compute-type" in result.stdout
+
+ def test_list_compute_help(self):
+ """Verify list-compute subcommand help."""
+ result = subprocess.run(
+ [sys.executable, str(_compute_script), "list-compute", "--help"],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=10
+ )
+
+ assert result.returncode == 0
+ assert "--resource" in result.stdout
+
+ def test_manage_cluster_help(self):
+ """Verify manage-cluster subcommand help."""
+ result = subprocess.run(
+ [sys.executable, str(_compute_script), "manage-cluster", "--help"],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=10
+ )
+
+ assert result.returncode == 0
+ assert "--action" in result.stdout
+
+
+@pytest.mark.integration
+class TestListCompute:
+ """Tests for list-compute command."""
+
+ def test_list_clusters(self):
+ """Should list all clusters."""
+ result = subprocess.run(
+ [sys.executable, str(_compute_script), "list-compute", "--resource", "clusters"],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=60
+ )
+
+ try:
+ output = json.loads(result.stdout)
+ assert "clusters" in output
+ assert isinstance(output["clusters"], list)
+ except json.JSONDecodeError:
+ pytest.fail(f"Invalid JSON: {result.stdout}\nStderr: {result.stderr}")
+
+ def test_list_node_types(self):
+ """Should list available node types."""
+ result = subprocess.run(
+ [sys.executable, str(_compute_script), "list-compute", "--resource", "node_types"],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=60
+ )
+
+ try:
+ output = json.loads(result.stdout)
+ assert "node_types" in output
+ assert isinstance(output["node_types"], list)
+ assert len(output["node_types"]) > 0
+ except json.JSONDecodeError:
+ pytest.fail(f"Invalid JSON: {result.stdout}\nStderr: {result.stderr}")
+
+ def test_list_spark_versions(self):
+ """Should list available Spark versions."""
+ result = subprocess.run(
+ [sys.executable, str(_compute_script), "list-compute", "--resource", "spark_versions"],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=60
+ )
+
+ try:
+ output = json.loads(result.stdout)
+ assert "spark_versions" in output
+ assert isinstance(output["spark_versions"], list)
+ assert len(output["spark_versions"]) > 0
+ except json.JSONDecodeError:
+ pytest.fail(f"Invalid JSON: {result.stdout}\nStderr: {result.stderr}")
+
+
+@pytest.mark.integration
+class TestExecuteCode:
+ """Tests for execute-code command."""
+
+ def test_execute_serverless_simple(self):
+ """Test simple Python execution on serverless."""
+ code = 'print("Hello from compute test"); dbutils.notebook.exit("success")'
+
+ result = subprocess.run(
+ [
+ sys.executable, str(_compute_script),
+ "execute-code",
+ "--code", code,
+ "--compute-type", "serverless",
+ "--timeout", "180"
+ ],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=300 # 5 min for cold start
+ )
+
+ try:
+ output = json.loads(result.stdout)
+ assert output.get("success", False), f"Execution failed: {output}"
+ except json.JSONDecodeError:
+ pytest.fail(f"Invalid JSON: {result.stdout}\nStderr: {result.stderr}")
+
+ def test_execute_requires_code_or_file(self):
+ """Should return error when neither code nor file provided."""
+ result = subprocess.run(
+ [
+ sys.executable, str(_compute_script),
+ "execute-code",
+ "--compute-type", "serverless"
+ ],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=30
+ )
+
+ try:
+ output = json.loads(result.stdout)
+ assert output.get("success") is False
+ assert "error" in output
+ except json.JSONDecodeError:
+ pytest.fail(f"Invalid JSON: {result.stdout}\nStderr: {result.stderr}")
+
+
+@pytest.mark.integration
+class TestManageCluster:
+ """Tests for manage-cluster command (read-only operations)."""
+
+ def test_invalid_action(self):
+ """Should return error for invalid action."""
+ result = subprocess.run(
+ [
+ sys.executable, str(_compute_script),
+ "manage-cluster",
+ "--action", "invalid_action"
+ ],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=30
+ )
+
+ # argparse will fail with invalid choice
+ assert result.returncode != 0 or "error" in result.stdout.lower()
+
+ def test_get_requires_cluster_id(self):
+ """Should return error when cluster_id not provided for get."""
+ result = subprocess.run(
+ [
+ sys.executable, str(_compute_script),
+ "manage-cluster",
+ "--action", "get"
+ ],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=30
+ )
+
+ try:
+ output = json.loads(result.stdout)
+ assert output.get("success") is False
+ assert "error" in output
+ except json.JSONDecodeError:
+ pytest.fail(f"Invalid JSON: {result.stdout}\nStderr: {result.stderr}")
+
+ def test_create_requires_name(self):
+ """Should return error when name not provided for create."""
+ result = subprocess.run(
+ [
+ sys.executable, str(_compute_script),
+ "manage-cluster",
+ "--action", "create"
+ ],
+ capture_output=True,
+ text=True,
+ cwd=str(_repo_root),
+ timeout=30
+ )
+
+ try:
+ output = json.loads(result.stdout)
+ assert output.get("success") is False
+ assert "error" in output
+ except json.JSONDecodeError:
+ pytest.fail(f"Invalid JSON: {result.stdout}\nStderr: {result.stderr}")
diff --git a/DECOMMISSION_PLAN.md b/DECOMMISSION_PLAN.md
new file mode 100644
index 00000000..93bfa905
--- /dev/null
+++ b/DECOMMISSION_PLAN.md
@@ -0,0 +1,270 @@
+# MCP Server Decommissioning Plan
+
+## Executive Summary
+
+This plan outlines removing `databricks-tools-core` and `databricks-mcp-server` from the main AI Dev Kit project, simplifying the installation to focus on **standalone skills only**.
+
+## Current State Analysis
+
+### What Exists Today
+
+| Component | Purpose | Dependencies |
+|-----------|---------|--------------|
+| `databricks-tools-core/` | Python library with high-level Databricks functions | None (standalone) |
+| `databricks-mcp-server/` | MCP server exposing 50+ tools | Depends on databricks-tools-core |
+| `databricks-skills/` | Markdown skills + self-contained Python scripts | **None** (already standalone) |
+| `databricks-builder-app/` | Full-stack web application | **Depends on BOTH** tools-core and mcp-server |
+
+### Files Referencing MCP/Core
+
+**Shell scripts:**
+- `install.sh` (main installer) - lines 1071, 251, 657, etc.
+- `databricks-mcp-server/setup.sh`
+- `.claude-plugin/setup.sh`
+- `databricks-builder-app/scripts/deploy.sh` (lines 193-195)
+- `databricks-builder-app/scripts/start_local.sh` (lines 205-206)
+
+**Documentation:**
+- `README.md` - references both packages in "What's Included" and "Core Library" sections
+- `SECURITY.md` - mentions packages in installation flow
+- `CONTRIBUTING.md` - setup instructions reference mcp-server
+- `databricks-builder-app/README.md` - architecture diagram includes mcp-server
+
+## builder-app Refactoring (Much Simpler Than Expected!)
+
+### Reference Implementation
+
+A cleaner solution exists in `industry-demo-prompts/app/src/demo_prompt_generator/backend/services/agent.py`.
+
+**Key insight:** MCP tools are NOT needed. Skills + standard SDK tools provide everything:
+
+```python
+# Note: MCP tools removed - ai-dev-kit now uses CLI tools via skills
+allowed_tools = ["Read", "Write", "Edit", "Glob", "Grep", "Bash", "Skill"]
+```
+
+### Current builder-app Dependencies
+
+| File | Import | Can Be Removed? |
+|------|--------|----------------|
+| `server/services/agent.py` | `databricks_tools_core.auth` | Yes - use `databricks.sdk.WorkspaceClient()` directly |
+| `server/services/databricks_tools.py` | `databricks_mcp_server.*` | **DELETE ENTIRE FILE** |
+| `server/services/clusters.py` | `databricks_tools_core.auth` | Yes - use SDK directly |
+| `server/services/warehouses.py` | `databricks_tools_core.auth` | Yes - use SDK directly |
+| `server/services/user.py` | `databricks_tools_core.identity` | Yes - inline constants |
+| `server/db/database.py` | `databricks_tools_core.identity` | Yes - inline constants |
+| `alembic/env.py` | `databricks_tools_core.identity` | Yes - inline constants |
+
+### Refactoring Steps
+
+1. **Delete `databricks_tools.py`** (433 lines) - No longer needed
+2. **Simplify `agent.py`**:
+ - Remove MCP server loading
+ - Use standard SDK tools: `["Read", "Write", "Edit", "Glob", "Grep", "Bash", "Skill"]`
+ - Add `setting_sources=["project"]` to enable skill discovery
+ - Copy client pooling pattern from reference implementation
+3. **Replace auth imports** - Use `databricks.sdk.WorkspaceClient()` directly
+4. **Inline identity constants**:
+ ```python
+ # Instead of: from databricks_tools_core.identity import PRODUCT_NAME, PRODUCT_VERSION
+ PRODUCT_NAME = "databricks-builder-app"
+ PRODUCT_VERSION = "0.1.0"
+ ```
+5. **Update deploy.sh** - Remove package copying steps
+6. **Update pyproject.toml** - Remove `databricks_tools_core*` and `databricks_mcp_server*` from includes
+
+### Code Reduction
+
+| File | Before | After |
+|------|--------|-------|
+| `databricks_tools.py` | 433 lines | **DELETED** |
+| `agent.py` | ~400 lines | ~300 lines |
+| `deploy.sh` | Complex pkg copy | Simple |
+
+**Total: ~500+ lines removed, simpler architecture**
+
+### Phase 2: Simplify Main Project
+
+Once builder-app is self-contained:
+
+#### 2.1 Delete Folders
+```bash
+rm -rf databricks-tools-core/
+rm -rf databricks-mcp-server/
+```
+
+#### 2.2 Simplify install.sh
+
+**Option A: Remove MCP entirely (Recommended)**
+
+Replace the 1790-line `install.sh` with a simplified version that:
+- Only installs skills (like `install_skills.sh` does)
+- Removes all MCP configuration code
+- Removes the Python venv creation for MCP
+
+**Option B: Keep MCP as optional**
+
+Keep `--skills-only` as default, make MCP opt-in via `--with-mcp`:
+- Default behavior = skills only
+- `--with-mcp` = old behavior
+
+#### 2.3 Update Documentation
+
+**README.md changes:**
+- Remove "Core Library" section
+- Remove "MCP Tools Only" from table
+- Remove databricks-tools-core from "What's Included"
+- Update architecture diagram (remove MCP layer)
+
+**Files to update:**
+- `README.md`
+- `SECURITY.md`
+- `CONTRIBUTING.md`
+- `databricks-builder-app/README.md`
+
+#### 2.4 Update Other Files
+
+- `.mcp.json` - Delete or update
+- `.claude-plugin/setup.sh` - Remove core/mcp references
+- `pyproject.toml` (if any) - Update dependencies
+
+## Installation Flow Comparison
+
+### Current Flow (install.sh)
+```
+1. Clone repo to ~/.ai-dev-kit
+2. Create Python venv
+3. pip install databricks-tools-core + databricks-mcp-server
+4. Install skills to .claude/skills/
+5. Write MCP config to claude_desktop_config.json, etc.
+```
+
+### Simplified Flow (after decommissioning)
+```
+1. Install skills to .claude/skills/ (directly from GitHub)
+2. Done!
+```
+
+## Migration Guide for Users
+
+Users who want MCP tools after decommissioning:
+
+1. **Use databricks CLI directly** - Skills now guide users to use CLI commands
+2. **Use databricks SDK** - Skills include Python SDK examples
+3. **Fork the MCP server** - If they really need it, they can fork the repo at the commit before removal
+
+## Risks and Mitigations
+
+| Risk | Mitigation |
+|------|------------|
+| builder-app breaks | Phase 1 must complete before Phase 2 |
+| Users depend on MCP | Document migration path; skills cover same functionality |
+| Lost test coverage | Move relevant tests to databricks-skills/.tests/ |
+
+## File Deletion Summary
+
+**Folders to delete:**
+- `databricks-tools-core/` (~20 Python files, ~15K lines)
+- `databricks-mcp-server/` (~15 Python files, ~10K lines)
+
+**Files to heavily modify:**
+- `install.sh` (reduce from 1790 lines to ~500)
+- `README.md` (remove 4+ sections)
+- `CONTRIBUTING.md` (remove MCP setup)
+- `SECURITY.md` (update installation flow)
+
+**Files to delete:**
+- `.mcp.json` (MCP config example)
+
+## Pre-requisite: Fix Skills Integration Tests
+
+Before proceeding with decommissioning, fix the broken integration tests in `databricks-skills/.tests/`:
+
+### Current Test Status
+
+| Test File | Unit Tests | Integration Tests | Status |
+|-----------|------------|-------------------|--------|
+| `test_agent_bricks_manager.py` | 5 pass | 3 skip (no workspace) | OK |
+| `test_pdf_generator.py` | 13 pass | 3 fail | **NEEDS FIX** |
+
+### Failing Tests (test_pdf_generator.py)
+
+```
+FAILED test_pdf_generator.py::TestPDFGenerationIntegration::test_generate_and_upload_pdf
+FAILED test_pdf_generator.py::TestPDFGenerationIntegration::test_generate_and_upload_pdf_with_folder
+FAILED test_pdf_generator.py::TestPDFGenerationIntegration::test_generate_complex_pdf
+```
+
+**Root cause:** Test volume `ai_dev_kit.test_pdf_generation.raw_data` doesn't exist.
+
+### Fix Required
+
+Update `test_pdf_generator.py` to skip gracefully when test volume is unavailable:
+
+```python
+@pytest.fixture(autouse=True)
+def skip_if_volume_missing(self, test_config):
+ """Skip tests if the required volume doesn't exist."""
+ error = _validate_volume_exists(
+ test_config["catalog"],
+ test_config["schema"],
+ test_config["volume"]
+ )
+ if error:
+ pytest.skip(f"Test volume not available: {error}")
+```
+
+### Additional Integration Tests Needed
+
+For complete coverage, add integration tests for remaining skills with Python files:
+
+| Skill | Python File | Test Status |
+|-------|-------------|-------------|
+| `databricks-agent-bricks` | `mas_manager.py` | Has tests |
+| `databricks-unstructured-pdf-generation` | `pdf_generator.py` | Has tests (needs fix) |
+| Other skills with .py files | Various | Need tests |
+
+## Recommended Execution Order
+
+### Phase 0: Fix Skills Tests
+1. [ ] **Fix broken integration tests** (test_pdf_generator.py skip when volume missing)
+2. [ ] Add integration tests for remaining skills with Python files
+
+### Phase 1: Refactor builder-app (Much Simpler Now!)
+
+**Reference implementation:** `../industry-demo-prompts/app/src/demo_prompt_generator/backend/services/agent.py`
+
+3. [ ] Update `pyproject.toml`:
+ - Bump `claude-agent-sdk>=0.1.50` (from 0.1.19)
+ - Remove `databricks_tools_core*` and `databricks_mcp_server*` from includes
+4. [ ] Delete `server/services/databricks_tools.py` entirely
+5. [ ] Simplify `server/services/agent.py`:
+ - Remove MCP imports and loading
+ - Use standard tools: `["Read", "Write", "Edit", "Glob", "Grep", "Bash", "Skill"]`
+ - Add `setting_sources=["project"]` for skill discovery
+ - Adopt client pooling pattern from reference implementation
+6. [ ] Replace `databricks_tools_core.auth` → `databricks.sdk.WorkspaceClient()`
+7. [ ] Inline `PRODUCT_NAME`, `PRODUCT_VERSION` constants
+8. [ ] Update `deploy.sh` - remove package copying
+9. [ ] Test builder-app locally and deployed
+
+### Phase 2: Simplify Main Project
+10. [ ] Simplify `install.sh` to skills-only (remove MCP setup)
+11. [ ] Update `install.ps1` (Windows) similarly
+12. [ ] Update `README.md`
+13. [ ] Update `CONTRIBUTING.md`
+14. [ ] Update `SECURITY.md`
+
+### Phase 3: Delete and Verify
+15. [ ] Delete `databricks-tools-core/`
+16. [ ] Delete `databricks-mcp-server/`
+17. [ ] Delete `.mcp.json`
+18. [ ] Delete `.claude-plugin/` (or update if needed)
+19. [ ] Test full installation flow (skills-only)
+20. [ ] Test builder-app deployment
+
+## Questions to Resolve
+
+1. **Should we archive MCP in a separate branch?** - For users who want to fork it
+2. **What about install.ps1 (Windows)?** - Same changes needed
+3. **Keep .claude-plugin/ ?** - This also references MCP
diff --git a/README.md b/README.md
index 75a3da18..4fe8e6bd 100644
--- a/README.md
+++ b/README.md
@@ -14,9 +14,6 @@
AI-Driven Development (vibe coding) on Databricks just got a whole lot better. The **AI Dev Kit** gives your AI coding assistant (Claude Code, Cursor, Antigravity, Windsurf, etc.) the trusted sources it needs to build faster and smarter on Databricks.
-
-
-
---
@@ -41,10 +38,9 @@ AI-Driven Development (vibe coding) on Databricks just got a whole lot better. T
|----------------------------------|----------|------------|
| :star: [**Install AI Dev Kit**](#install-in-existing-project) | **Start here!** Follow quick install instructions to add to your existing project folder | [Quick Start (install)](#install-in-existing-project)
| [**Visual Builder App**](#visual-builder-app) | Web-based UI for Databricks development | `databricks-builder-app/` |
-| [**Core Library**](#core-library) | Building custom integrations (LangChain, OpenAI, etc.) | `pip install` |
-| [**Skills Only**](databricks-skills/) | Provide Databricks patterns and best practices (without MCP functions) | Install skills |
+| [**Skills Only**](databricks-skills/) | Provide Databricks patterns and best practices | Install skills |
| [**Genie Code Skills**](databricks-skills/install_skills.sh) | Install skills into your workspace for Genie Code (`--install-to-genie`) | [Genie Code skills (install)](#genie-code-skills) |
-| [**MCP Tools Only**](databricks-mcp-server/) | Just executable actions (no guidance) | Register MCP server |
+| [**Embed in Your App**](#embedding-in-other-apps) | Integrate the agent into your own application | Integration example |
---
## Quick Start
@@ -155,25 +151,25 @@ cd ai-dev-kit/databricks-builder-app
For local development:
```bash
-./scripts/setup.sh # Install dependencies
-# Edit .env.local with your credentials
-./scripts/start_dev.sh # Start locally at http://localhost:3000
+# One command does everything: provisions Lakebase, installs deps, starts servers
+./scripts/start_local.sh --profile
```
See [`databricks-builder-app/`](databricks-builder-app/) for full documentation.
+### Embedding in Other Apps
-### Core Library
-
-Use `databricks-tools-core` directly in your Python projects:
+To embed the Databricks agent into your own application:
-```python
-from databricks_tools_core.sql import execute_sql
-
-results = execute_sql("SELECT * FROM my_catalog.schema.table LIMIT 10")
+```bash
+cd ai-dev-kit/databricks-builder-app/scripts/_integration-example
+./setup.sh
+# Edit .env with your credentials
+python example_integration.py
```
-Works with LangChain, OpenAI Agents SDK, or any Python framework. See [databricks-tools-core/](databricks-tools-core/) for details.
+See [`scripts/_integration-example/`](databricks-builder-app/scripts/_integration-example/) for full integration guide.
+
---
## Genie Code Skills
@@ -213,9 +209,7 @@ This directory is customizable if you wish to only use certain skills or even cr
| Component | Description |
|-----------|-------------|
-| [`databricks-tools-core/`](databricks-tools-core/) | Python library with high-level Databricks functions |
-| [`databricks-mcp-server/`](databricks-mcp-server/) | MCP server exposing 50+ tools for AI assistants |
-| [`databricks-skills/`](databricks-skills/) | 20 markdown skills teaching Databricks patterns |
+| [`databricks-skills/`](databricks-skills/) | 20+ markdown skills teaching Databricks patterns |
| [`databricks-builder-app/`](databricks-builder-app/) | Full-stack web app with Claude Code integration |
---
@@ -243,12 +237,7 @@ The source in this project is provided subject to the [Databricks License](https
| Package | Version | License | Project URL |
|---------|---------|---------|-------------|
-| [fastmcp](https://github.com/jlowin/fastmcp) | ≥0.1.0 | MIT | https://github.com/jlowin/fastmcp |
-| [mcp](https://github.com/modelcontextprotocol/python-sdk) | ≥1.0.0 | MIT | https://github.com/modelcontextprotocol/python-sdk |
-| [sqlglot](https://github.com/tobymao/sqlglot) | ≥20.0.0 | MIT | https://github.com/tobymao/sqlglot |
-| [sqlfluff](https://github.com/sqlfluff/sqlfluff) | ≥3.0.0 | MIT | https://github.com/sqlfluff/sqlfluff |
-| [plutoprint](https://github.com/nicvagn/plutoprint) | ==0.19.0 | MIT | https://github.com/plutoprint/plutoprint |
-| [claude-agent-sdk](https://github.com/anthropics/claude-code) | ≥0.1.19 | MIT | https://github.com/anthropics/claude-code |
+| [claude-agent-sdk](https://github.com/anthropics/claude-code) | ≥0.1.50 | MIT | https://github.com/anthropics/claude-code |
| [fastapi](https://github.com/fastapi/fastapi) | ≥0.115.8 | MIT | https://github.com/fastapi/fastapi |
| [uvicorn](https://github.com/encode/uvicorn) | ≥0.34.0 | BSD-3-Clause | https://github.com/encode/uvicorn |
| [httpx](https://github.com/encode/httpx) | ≥0.28.0 | BSD-3-Clause | https://github.com/encode/httpx |
diff --git a/databricks-builder-app/README.md b/databricks-builder-app/README.md
index ccbbdb74..f90d4335 100644
--- a/databricks-builder-app/README.md
+++ b/databricks-builder-app/README.md
@@ -32,20 +32,14 @@ A web application that provides a Claude Code agent interface with integrated Da
├─────────────────────────────────────────────────────────────────────────────┤
│ Each user message spawns a Claude Code agent session via claude-agent-sdk │
│ │
-│ Built-in Tools: MCP Tools (Databricks): Skills: │
-│ ┌──────────────────┐ ┌─────────────────────────┐ ┌───────────┐ │
-│ │ Read, Write, Edit│ │ execute_sql │ │ sdp │ │
-│ │ Glob, Grep, Skill│ │ create_or_update_pipeline │ dabs │ │
-│ └──────────────────┘ │ upload_folder │ │ sdk │ │
-│ │ execute_code │ │ ... │ │
-│ │ ... │ └───────────┘ │
-│ └─────────────────────────┘ │
-│ │ │
-│ ▼ │
-│ ┌─────────────────────────┐ │
-│ │ databricks-mcp-server │ │
-│ │ (in-process SDK tools) │ │
-│ └─────────────────────────┘ │
+│ Built-in Tools: Skills: │
+│ ┌──────────────────┐ ┌───────────┐ │
+│ │ Read, Write, Edit│ │ sdp │ │
+│ │ Glob, Grep │ Skills provide Databricks │ dabs │ │
+│ │ Bash, Skill │ CLI/SDK guidance and examples │ sdk │ │
+│ └──────────────────┘ │ ... │ │
+│ └───────────┘ │
+│ │
└─────────────────────────────────────────────────────────────────────────────┘
│
▼
@@ -63,20 +57,21 @@ A web application that provides a Claude Code agent interface with integrated Da
When a user sends a message, the backend creates a Claude Code session using the `claude-agent-sdk`:
```python
-from claude_agent_sdk import ClaudeAgentOptions, query
+from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
options = ClaudeAgentOptions(
cwd=str(project_dir), # Project working directory
- allowed_tools=allowed_tools, # Built-in + MCP tools
- permission_mode='bypassPermissions', # Auto-accept all tools including MCP
+ allowed_tools=allowed_tools, # Built-in tools (Read, Write, Edit, Glob, Grep, Bash, Skill)
+ permission_mode='bypassPermissions', # Auto-accept all tools
resume=session_id, # Resume previous conversation
- mcp_servers=mcp_servers, # Databricks MCP server config
system_prompt=system_prompt, # Databricks-focused prompt
setting_sources=['user', 'project'], # Load skills from .claude/skills
)
-async for msg in query(prompt=message, options=options):
- yield msg # Stream to frontend
+async with ClaudeSDKClient(options=options) as client:
+ await client.query(message)
+ async for msg in client.receive_response():
+ yield msg # Stream to frontend
```
Key features:
@@ -125,7 +120,7 @@ The app supports multi-user authentication using per-request credentials:
2. **Auth context set** - Before invoking the agent:
```python
- from databricks_tools_core.auth import set_databricks_auth, clear_databricks_auth
+ from server.services.auth import set_databricks_auth, clear_databricks_auth
set_databricks_auth(workspace_url, user_token)
try:
@@ -142,28 +137,25 @@ The app supports multi-user authentication using per-request credentials:
This ensures each user's requests use their own Databricks credentials, enabling proper access control and audit logging.
-### 3. MCP Integration (Databricks Tools)
-
-Databricks tools are loaded in-process using the Claude Agent SDK's MCP server feature:
-
-```python
-from claude_agent_sdk import tool, create_sdk_mcp_server
+### 3. Databricks Integration (Skills-based)
-# Tools are dynamically loaded from databricks-mcp-server
-server = create_sdk_mcp_server(name='databricks', tools=sdk_tools)
+Databricks capabilities are provided through the Skills system, which gives Claude detailed guidance on using the Databricks CLI and SDK. This approach is simpler and more maintainable than running a separate MCP server.
-options = ClaudeAgentOptions(
- mcp_servers={'databricks': server},
- allowed_tools=['mcp__databricks__execute_sql', ...],
-)
-```
+**How skills work:**
+- Skills are markdown files with CLI/SDK instructions and examples
+- Claude loads skills on-demand using the `Skill` tool
+- Skills provide patterns for common Databricks operations:
+ - SQL execution via `databricks sql` CLI or SDK
+ - Cluster operations and code execution
+ - Pipeline management with SDP (Spark Declarative Pipelines)
+ - File operations with Unity Catalog volumes
+ - Job and workflow management
-Tools are exposed as `mcp__databricks__` and include:
-- SQL execution (`execute_sql`, `execute_sql_multi`)
-- Warehouse management (`list_warehouses`, `get_best_warehouse`)
-- Cluster execution (`execute_code`)
-- Pipeline management (`create_or_update_pipeline`, `start_update`, etc.)
-- File operations (`upload_to_workspace`)
+**Benefits of skills-based approach:**
+- No additional server process needed
+- Skills stay up-to-date with CLI/SDK changes
+- Claude can adapt commands to user context
+- Easier to debug and maintain
### 4. Skills System
@@ -404,7 +396,7 @@ databricks-builder-app/
│ │ └── conversations.py
│ └── services/ # Business logic
│ ├── agent.py # Claude Code session management
-│ ├── databricks_tools.py # MCP tool loading from SDK
+│ ├── auth.py # Databricks auth utilities
│ ├── user.py # User auth (headers/env vars)
│ ├── skills_manager.py
│ ├── backup_manager.py
@@ -597,6 +589,4 @@ This provides a minimal working example with setup instructions for integrating
## Related Packages
-- **databricks-tools-core**: Core MCP functionality and SQL operations
-- **databricks-mcp-server**: MCP server exposing Databricks tools
- **databricks-skills**: Skill definitions for Databricks development
diff --git a/databricks-builder-app/alembic/env.py b/databricks-builder-app/alembic/env.py
index fb2ad67d..036c0f5b 100644
--- a/databricks-builder-app/alembic/env.py
+++ b/databricks-builder-app/alembic/env.py
@@ -64,7 +64,7 @@ def get_url_and_connect_args():
# Generate token using Databricks SDK
import uuid
from databricks.sdk import WorkspaceClient
- from databricks_tools_core.identity import PRODUCT_NAME, PRODUCT_VERSION
+ from server.services.auth import PRODUCT_NAME, PRODUCT_VERSION
w = WorkspaceClient(product=PRODUCT_NAME, product_version=PRODUCT_VERSION)
diff --git a/databricks-builder-app/pyproject.toml b/databricks-builder-app/pyproject.toml
index ec0c7ef1..a0d61936 100644
--- a/databricks-builder-app/pyproject.toml
+++ b/databricks-builder-app/pyproject.toml
@@ -3,14 +3,14 @@ requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
-where = [".", "packages"]
-include = ["server*", "databricks_tools_core*", "databricks_mcp_server*"]
-exclude = ["client*", "tests*", "scripts*"]
+where = ["."]
+include = ["server*"]
+exclude = ["client*", "tests*", "scripts*", "packages*"]
[project]
name = "databricks-builder-app"
version = "0.1.0"
-description = "Claude Code MCP Application"
+description = "Databricks Builder App - AI-powered development assistant"
readme = "README.md"
requires-python = ">=3.11,<3.14"
dependencies = [
@@ -27,12 +27,8 @@ dependencies = [
"greenlet>=3.0.0",
"psycopg2-binary>=2.9.11", # For alembic migrations (sync)
# Claude Agent SDK (successor to claude-code-sdk)
- "claude-agent-sdk>=0.1.19",
+ "claude-agent-sdk>=0.1.50",
"anthropic>=0.42.0",
- # Databricks MCP tools (databricks-mcp-server installed from sibling dir in dev)
- "mcp>=1.0.0",
- "fastmcp==3.1.1",
- # databricks-tools-core and databricks-mcp-server are bundled in packages/ directory
"requests>=2.31.0",
# MLflow 3+ for Claude Code tracing (not mlflow[databricks] to avoid litellm)
"mlflow>=3.9.0",
diff --git a/databricks-builder-app/scripts/_integration-example/README.md b/databricks-builder-app/scripts/_integration-example/README.md
index c64e6bc6..237e0b1b 100644
--- a/databricks-builder-app/scripts/_integration-example/README.md
+++ b/databricks-builder-app/scripts/_integration-example/README.md
@@ -3,8 +3,8 @@
This directory shows how to embed `ai-dev-kit` into your own application.
**What you get:** The same Claude Agent SDK-based agent used by `databricks-builder-app`, with:
-- Databricks MCP tools (SQL, clusters, jobs, pipelines, Unity Catalog, etc.)
-- Skills for guided development (SDP, SDK patterns, MLflow, etc.)
+- Skills for guided Databricks development (SDP, SDK patterns, MLflow, etc.)
+- Built-in tools (Read, Write, Edit, Glob, Grep, Bash, Skill)
- Multi-user auth support via contextvars
## Prerequisites
@@ -64,8 +64,8 @@ python example_integration.py
This integration uses:
- **Claude Agent SDK** (`claude-agent-sdk`) - Anthropic's SDK for running Claude as an agentic assistant
-- **Databricks MCP Tools** (`databricks-mcp-server`) - Tools loaded in-process via MCP protocol
-- **Skills** - Markdown files in `.claude/skills/` that provide domain-specific guidance
+- **Skills** - Markdown files in `.claude/skills/` that provide Databricks CLI/SDK guidance
+- **Built-in tools** - Read, Write, Edit, Glob, Grep, Bash, Skill
The `stream_agent_response` function is the same one used by `databricks-builder-app`.
@@ -76,7 +76,7 @@ The `stream_agent_response` function is the same one used by `databricks-builder
from server.services.agent import stream_agent_response
# Auth utilities for per-request Databricks credentials
-from databricks_tools_core import set_databricks_auth, clear_databricks_auth
+from server.services.auth import set_databricks_auth, clear_databricks_auth
```
### Basic Usage
@@ -85,7 +85,7 @@ from databricks_tools_core import set_databricks_auth, clear_databricks_auth
import asyncio
async def run_agent(message: str):
- # Set Databricks auth for this request (passed to MCP tools)
+ # Set Databricks auth for this request (used by skills for CLI/SDK operations)
set_databricks_auth(
host="https://your-workspace.cloud.databricks.com",
token="dapi..."
diff --git a/databricks-builder-app/scripts/_integration-example/example_integration.py b/databricks-builder-app/scripts/_integration-example/example_integration.py
index 5fc9460f..3832ddf5 100644
--- a/databricks-builder-app/scripts/_integration-example/example_integration.py
+++ b/databricks-builder-app/scripts/_integration-example/example_integration.py
@@ -3,8 +3,8 @@
Minimal example of embedding ai-dev-kit into a custom app.
This runs the same Claude Agent SDK-based agent used by databricks-builder-app,
-with Databricks MCP tools and skills loaded. The agent can:
-- Execute SQL queries via warehouses
+with Databricks skills loaded. The agent uses skills + standard SDK tools to:
+- Execute SQL queries via warehouses (using CLI/SDK)
- Run Python/PySpark on clusters
- Manage Unity Catalog objects
- Create and run jobs/pipelines
@@ -23,9 +23,8 @@
load_dotenv()
# Import the same agent service used by databricks-builder-app
-# This uses claude-agent-sdk with Databricks MCP tools loaded in-process
from server.services.agent import stream_agent_response
-from databricks_tools_core import set_databricks_auth, clear_databricks_auth
+from server.services.auth import set_databricks_auth, clear_databricks_auth
async def run_agent(message: str, project_id: str = "demo") -> None:
@@ -41,7 +40,7 @@ async def run_agent(message: str, project_id: str = "demo") -> None:
print(f"Databricks workspace: {host}")
print(f"Message: {message}")
- print("Running Claude agent with Databricks MCP tools...")
+ print("Running Claude agent with Databricks skills...")
print("-" * 50)
# Set auth context for this request
diff --git a/databricks-builder-app/scripts/_integration-example/requirements.txt b/databricks-builder-app/scripts/_integration-example/requirements.txt
index b4e7d6ee..0c3ab1ce 100644
--- a/databricks-builder-app/scripts/_integration-example/requirements.txt
+++ b/databricks-builder-app/scripts/_integration-example/requirements.txt
@@ -1,13 +1,11 @@
# AI Dev Kit Integration Dependencies
# Install from local ai-dev-kit repository
-# Core packages (from ai-dev-kit repo)
--e ../../../databricks-tools-core
--e ../../../databricks-mcp-server
+# Builder app (includes all dependencies)
-e ../../ # databricks-builder-app
# Claude Agent SDK
-claude-agent-sdk>=0.1.0
+claude-agent-sdk>=0.1.50
# For loading environment variables
python-dotenv>=1.0.0
diff --git a/databricks-builder-app/scripts/deploy.sh b/databricks-builder-app/scripts/deploy.sh
index b2f19647..becad80a 100755
--- a/databricks-builder-app/scripts/deploy.sh
+++ b/databricks-builder-app/scripts/deploy.sh
@@ -188,12 +188,6 @@ echo " Copying frontend build..."
mkdir -p "$STAGING_DIR/client"
cp -r client/out "$STAGING_DIR/client/"
-echo " Copying Databricks packages..."
-mkdir -p "$STAGING_DIR/packages/databricks_tools_core"
-cp -r "$REPO_ROOT/databricks-tools-core/databricks_tools_core/"* "$STAGING_DIR/packages/databricks_tools_core/"
-mkdir -p "$STAGING_DIR/packages/databricks_mcp_server"
-cp -r "$REPO_ROOT/databricks-mcp-server/databricks_mcp_server/"* "$STAGING_DIR/packages/databricks_mcp_server/"
-
if [ "$SKIP_SKILLS" = true ] && [ -d "$SKILLS_CACHE_DIR" ] && [ "$(ls -A "$SKILLS_CACHE_DIR" 2>/dev/null)" ]; then
mkdir -p "$STAGING_DIR/skills"
echo -e " ${GREEN}✓${NC} Reusing cached skills from ${SKILLS_CACHE_DIR} (--skip-skills)"
diff --git a/databricks-builder-app/scripts/start_local.sh b/databricks-builder-app/scripts/start_local.sh
index 7b50baa0..78cc2d5f 100755
--- a/databricks-builder-app/scripts/start_local.sh
+++ b/databricks-builder-app/scripts/start_local.sh
@@ -201,13 +201,6 @@ else
uv sync --quiet
echo -e " ${GREEN}✓${NC} Backend dependencies installed"
fi
-
-if [ -d "$REPO_ROOT/databricks-tools-core" ] && [ -d "$REPO_ROOT/databricks-mcp-server" ]; then
- uv pip install -e "$REPO_ROOT/databricks-tools-core" -e "$REPO_ROOT/databricks-mcp-server" --quiet 2>/dev/null
- echo -e " ${GREEN}✓${NC} Sibling packages installed"
-else
- echo -e " ${YELLOW}⚠${NC} Sibling packages not found at repo root"
-fi
echo ""
# ─────────────────────────────────────────────────────────────────────────────
diff --git a/databricks-builder-app/server/db/database.py b/databricks-builder-app/server/db/database.py
index bcd75a85..1d50c8cd 100644
--- a/databricks-builder-app/server/db/database.py
+++ b/databricks-builder-app/server/db/database.py
@@ -103,7 +103,7 @@ def _get_workspace_client():
try:
import os
from databricks.sdk import WorkspaceClient
- from databricks_tools_core.identity import PRODUCT_NAME, PRODUCT_VERSION
+ from server.services.auth import PRODUCT_NAME, PRODUCT_VERSION
product_kwargs = dict(product=PRODUCT_NAME, product_version=PRODUCT_VERSION)
if _has_oauth_credentials():
diff --git a/databricks-builder-app/server/routers/clusters.py b/databricks-builder-app/server/routers/clusters.py
index 4184fd89..4e522481 100644
--- a/databricks-builder-app/server/routers/clusters.py
+++ b/databricks-builder-app/server/routers/clusters.py
@@ -3,8 +3,8 @@
import logging
from fastapi import APIRouter, Request
-from databricks_tools_core.auth import set_databricks_auth, clear_databricks_auth
+from ..services.auth import set_databricks_auth, clear_databricks_auth
from ..services.clusters import list_clusters_async
from ..services.user import get_current_user, get_current_token, get_workspace_url
diff --git a/databricks-builder-app/server/routers/warehouses.py b/databricks-builder-app/server/routers/warehouses.py
index 961a717c..748b6323 100644
--- a/databricks-builder-app/server/routers/warehouses.py
+++ b/databricks-builder-app/server/routers/warehouses.py
@@ -3,8 +3,8 @@
import logging
from fastapi import APIRouter, Request
-from databricks_tools_core.auth import set_databricks_auth, clear_databricks_auth
+from ..services.auth import set_databricks_auth, clear_databricks_auth
from ..services.warehouses import list_warehouses_async
from ..services.user import get_current_user, get_current_token, get_workspace_url
diff --git a/databricks-builder-app/server/services/agent.py b/databricks-builder-app/server/services/agent.py
index 0238472c..18d7363b 100644
--- a/databricks-builder-app/server/services/agent.py
+++ b/databricks-builder-app/server/services/agent.py
@@ -1,10 +1,10 @@
"""Claude Code Agent service for managing agent sessions.
Uses the claude-agent-sdk to create and manage Claude Code agent sessions
-with directory-scoped file permissions and Databricks tools.
+with directory-scoped file permissions and Databricks skills.
-Databricks tools are loaded in-process from databricks-mcp-server using
-the SDK tool wrapper. Auth is handled via contextvars for multi-user support.
+Skills are loaded from .claude/skills/ directory and provide Databricks
+CLI/SDK guidance. No MCP server needed - skills + standard tools are sufficient.
MLflow Tracing:
Uses ClaudeSDKClient with mlflow.anthropic.autolog() for automatic tracing.
@@ -17,7 +17,6 @@
"""
import asyncio
-import json
import logging
import os
import queue
@@ -44,28 +43,29 @@
ToolUseBlock,
UserMessage,
)
-from databricks_tools_core.auth import set_databricks_auth, clear_databricks_auth
+from .auth import (
+ PRODUCT_NAME,
+ PRODUCT_VERSION,
+ clear_databricks_auth,
+ set_databricks_auth,
+)
from .backup_manager import ensure_project_directory as _ensure_project_directory
-from .databricks_tools import load_databricks_tools, create_filtered_databricks_server
from .system_prompt import get_system_prompt
logger = logging.getLogger(__name__)
-# Built-in Claude Code tools
+# Built-in Claude Code tools (no MCP - skills provide Databricks CLI/SDK guidance)
BUILTIN_TOOLS = [
'Read',
'Write',
'Edit',
-# 'Bash',
'Glob',
'Grep',
+ 'Bash',
+ 'Skill',
]
-# Cached Databricks tools (loaded once)
-_databricks_server = None
-_databricks_tool_names = None
-
# Cached Claude settings (loaded once)
_claude_settings = None
@@ -90,23 +90,6 @@ def _load_claude_settings() -> dict:
return _claude_settings
-def get_databricks_tools(force_reload: bool = False):
- """Get Databricks tools, optionally forcing a reload.
-
- Args:
- force_reload: If True, recreate the MCP server to clear any corrupted state
-
- Returns:
- Tuple of (server, tool_names)
- """
- global _databricks_server, _databricks_tool_names
- if _databricks_server is None or force_reload:
- if force_reload:
- logger.info('Force reloading Databricks MCP server')
- _databricks_server, _databricks_tool_names = load_databricks_tools()
- return _databricks_server, _databricks_tool_names
-
-
def get_project_directory(project_id: str) -> Path:
"""Get the directory path for a project.
@@ -312,33 +295,19 @@ async def stream_agent_response(
set_databricks_auth(databricks_host, databricks_token, force_token=is_cross_workspace)
try:
- # Build allowed tools list
+ # Build allowed tools list (skills provide Databricks CLI/SDK guidance - no MCP needed)
allowed_tools = BUILTIN_TOOLS.copy()
# Sync project skills directory before running agent
- from .skills_manager import sync_project_skills, get_available_skills, get_allowed_mcp_tools
+ from .skills_manager import sync_project_skills, get_available_skills
sync_project_skills(project_dir, enabled_skills=enabled_skills)
- # Get Databricks tools and filter based on enabled skills.
- # We must create a filtered MCP server (not just filter allowed_tools)
- # because bypassPermissions mode exposes all tools in registered MCP servers.
- databricks_server, databricks_tool_names = get_databricks_tools()
- filtered_tool_names = get_allowed_mcp_tools(databricks_tool_names, enabled_skills=enabled_skills)
-
- if len(filtered_tool_names) < len(databricks_tool_names):
- # Some tools are blocked — create a filtered MCP server with only allowed tools
- databricks_server, filtered_tool_names = create_filtered_databricks_server(filtered_tool_names)
- blocked_count = len(databricks_tool_names) - len(filtered_tool_names)
- logger.info(f'Databricks MCP server: {len(filtered_tool_names)} tools allowed, {blocked_count} blocked by disabled skills')
- else:
- logger.info(f'Databricks MCP server configured with {len(filtered_tool_names)} tools')
-
- allowed_tools.extend(filtered_tool_names)
-
- # Only add the Skill tool if there are enabled skills for the agent to use
+ # Log available skills
available = get_available_skills(enabled_skills=enabled_skills)
if available:
- allowed_tools.append('Skill')
+ logger.info(f'Skills available: {len(available)} skills enabled')
+ else:
+ logger.info('No skills enabled')
# Generate system prompt with available skills, cluster, warehouse, and catalog/schema context
system_prompt = get_system_prompt(
@@ -388,8 +357,16 @@ async def stream_agent_response(
logger.info(f'Configured Databricks model serving: {anthropic_base_url} with model {anthropic_model}')
logger.info(f'Claude env vars: BASE_URL={claude_env.get("ANTHROPIC_BASE_URL")}, MODEL={claude_env.get("ANTHROPIC_MODEL")}')
+ # Pass Databricks credentials to Claude subprocess for CLI/SDK commands
+ # This enables `databricks` CLI and Python SDK to authenticate automatically
+ # when Claude runs them via the Bash tool (skills guide Claude to use these)
+ if databricks_host:
+ claude_env['DATABRICKS_HOST'] = databricks_host
+ if databricks_token:
+ claude_env['DATABRICKS_TOKEN'] = databricks_token
+ logger.info(f'Databricks CLI auth configured for: {databricks_host}')
+
# Databricks SDK upstream tracking for subprocess user-agent attribution
- from databricks_tools_core.identity import PRODUCT_NAME, PRODUCT_VERSION
claude_env['DATABRICKS_SDK_UPSTREAM'] = PRODUCT_NAME
claude_env['DATABRICKS_SDK_UPSTREAM_VERSION'] = PRODUCT_VERSION
@@ -431,11 +408,10 @@ async def _keepalive_hook(_input_data, _tool_use_id, _context):
options = ClaudeAgentOptions(
cwd=str(project_dir),
allowed_tools=allowed_tools,
- permission_mode='bypassPermissions', # Auto-accept all tools including MCP
+ permission_mode='bypassPermissions', # Auto-accept all tools
can_use_tool=can_use_tool, # Handle AskUserQuestion gracefully
hooks={"PreToolUse": [HookMatcher(matcher=None, hooks=[_keepalive_hook])]},
resume=session_id, # Resume from previous session if provided
- mcp_servers={'databricks': databricks_server}, # In-process SDK tools
system_prompt=system_prompt, # Databricks-focused system prompt
setting_sources=["user", "project"], # Load Skills from filesystem
env=claude_env, # Pass Databricks auth settings (ANTHROPIC_AUTH_TOKEN, etc.)
diff --git a/databricks-builder-app/server/services/auth.py b/databricks-builder-app/server/services/auth.py
new file mode 100644
index 00000000..02045338
--- /dev/null
+++ b/databricks-builder-app/server/services/auth.py
@@ -0,0 +1,106 @@
+"""Centralized Databricks authentication utilities.
+
+Provides auth context management for multi-user support and a WorkspaceClient
+factory that respects the auth context.
+
+Authentication flows in the Builder App:
+
+1. **Backend SDK calls** (clusters.py, warehouses.py, user.py):
+ - Use `set_databricks_auth()` to set context vars at request start
+ - Call `get_workspace_client()` which reads from context vars
+ - Context vars are per-request (via contextvars)
+
+2. **Claude subprocess CLI commands** (via Bash tool):
+ - Credentials are passed via environment variables (DATABRICKS_HOST, DATABRICKS_TOKEN)
+ - Set in agent.py when building `claude_env` dict
+ - The `databricks` CLI and Python SDK both respect these env vars
+
+The context var approach (set_databricks_auth) works within the same Python process.
+The env var approach works for subprocesses (like Claude running CLI commands).
+"""
+
+import os
+from contextvars import ContextVar
+from typing import Optional
+
+from databricks.sdk import WorkspaceClient
+
+# Product identity for SDK user-agent tracking
+PRODUCT_NAME = 'databricks-builder-app'
+PRODUCT_VERSION = '0.1.0'
+
+# Context variables for Databricks auth (multi-user support)
+_databricks_host_var: ContextVar[str | None] = ContextVar('databricks_host', default=None)
+_databricks_token_var: ContextVar[str | None] = ContextVar('databricks_token', default=None)
+_force_token_var: ContextVar[bool] = ContextVar('force_token', default=False)
+
+
+def set_databricks_auth(host: str | None, token: str | None, force_token: bool = False):
+ """Set Databricks auth credentials in context for the current request.
+
+ Args:
+ host: Databricks workspace URL (e.g., https://company.cloud.databricks.com)
+ token: Access token for authentication
+ force_token: If True, use the provided token even when OAuth M2M is available.
+ Use this for cross-workspace operations where the target workspace
+ differs from the app's workspace.
+ """
+ _databricks_host_var.set(host)
+ _databricks_token_var.set(token)
+ _force_token_var.set(force_token)
+
+
+def clear_databricks_auth():
+ """Clear Databricks auth credentials from context."""
+ _databricks_host_var.set(None)
+ _databricks_token_var.set(None)
+ _force_token_var.set(False)
+
+
+def _has_oauth_credentials() -> bool:
+ """Check if OAuth M2M credentials (Service Principal) are configured."""
+ return bool(
+ os.environ.get('DATABRICKS_CLIENT_ID') and os.environ.get('DATABRICKS_CLIENT_SECRET')
+ )
+
+
+def get_workspace_client() -> WorkspaceClient:
+ """Get a WorkspaceClient with proper auth handling.
+
+ Auth precedence:
+ 1. If auth context is set (via set_databricks_auth) and force_token is True,
+ use the context credentials (for cross-workspace operations)
+ 2. If OAuth M2M credentials exist in environment, use OAuth M2M
+ 3. Fall back to SDK default auth (PAT, Azure CLI, etc.)
+
+ Returns:
+ Configured WorkspaceClient instance
+ """
+ product_kwargs = dict(product=PRODUCT_NAME, product_version=PRODUCT_VERSION)
+
+ # Check for auth context (set by request handler for multi-user support)
+ ctx_host = _databricks_host_var.get()
+ ctx_token = _databricks_token_var.get()
+ force_token = _force_token_var.get()
+
+ # If force_token is set, use the context credentials directly
+ # This is for cross-workspace operations where OAuth M2M won't work
+ if force_token and ctx_host and ctx_token:
+ return WorkspaceClient(host=ctx_host, token=ctx_token, **product_kwargs)
+
+ # In Databricks Apps, prefer OAuth M2M for same-workspace operations
+ if _has_oauth_credentials():
+ return WorkspaceClient(
+ host=os.environ.get('DATABRICKS_HOST', ''),
+ client_id=os.environ.get('DATABRICKS_CLIENT_ID', ''),
+ client_secret=os.environ.get('DATABRICKS_CLIENT_SECRET', ''),
+ auth_type='oauth-m2m',
+ **product_kwargs,
+ )
+
+ # Development mode or no OAuth - use context credentials if available
+ if ctx_host and ctx_token:
+ return WorkspaceClient(host=ctx_host, token=ctx_token, **product_kwargs)
+
+ # Fall back to SDK default auth
+ return WorkspaceClient(**product_kwargs)
diff --git a/databricks-builder-app/server/services/clusters.py b/databricks-builder-app/server/services/clusters.py
index 7e3f24b4..2a00c120 100644
--- a/databricks-builder-app/server/services/clusters.py
+++ b/databricks-builder-app/server/services/clusters.py
@@ -7,9 +7,9 @@
from threading import Lock
from typing import Optional
-from databricks.sdk.config import Config
from databricks.sdk.service.compute import State
-from databricks_tools_core.auth import get_workspace_client
+
+from .auth import get_workspace_client
logger = logging.getLogger(__name__)
diff --git a/databricks-builder-app/server/services/databricks_tools.py b/databricks-builder-app/server/services/databricks_tools.py
deleted file mode 100644
index 11a83e4d..00000000
--- a/databricks-builder-app/server/services/databricks_tools.py
+++ /dev/null
@@ -1,432 +0,0 @@
-"""Dynamic tool loader for Databricks tools.
-
-Scans FastMCP tools from databricks-mcp-server and creates
-in-process SDK tools for the Claude Code Agent SDK.
-
-Includes async handoff for long-running operations to prevent
-Claude connection timeouts. When a tool exceeds SAFE_EXECUTION_THRESHOLD,
-execution continues in background and returns an operation ID for polling.
-"""
-
-import asyncio
-import json
-import logging
-import threading
-import time
-from contextvars import copy_context
-from typing import Any
-
-from claude_agent_sdk import tool, create_sdk_mcp_server
-
-from .operation_tracker import (
- create_operation,
- complete_operation,
- get_operation,
- list_operations,
-)
-
-logger = logging.getLogger(__name__)
-
-# Seconds before switching to async mode to avoid connection timeout
-# Anthropic API has ~50s stream idle timeout, we switch early to keep messages flowing
-# Lower threshold ensures tool results return quickly, preventing cumulative timeout
-SAFE_EXECUTION_THRESHOLD = 10
-
-
-def load_databricks_tools():
- """Dynamically scan FastMCP tools and create in-process SDK MCP server.
-
- Returns:
- Tuple of (server_config, tool_names) where:
- - server_config: McpSdkServerConfig for ClaudeAgentOptions.mcp_servers
- - tool_names: List of tool names in mcp__databricks__* format
- """
- sdk_tools, tool_names = _get_all_sdk_tools()
-
- logger.info(f'Loaded {len(sdk_tools)} Databricks tools: {[n.split("__")[-1] for n in tool_names]}')
-
- server = create_sdk_mcp_server(name='databricks', tools=sdk_tools)
- return server, tool_names
-
-
-# Cached SDK tools (loaded once, reused for filtered server creation)
-_all_sdk_tools = None
-_all_tool_names = None
-
-
-def _get_all_sdk_tools():
- """Load and cache all SDK tool wrappers.
-
- Returns:
- Tuple of (sdk_tools, tool_names)
- """
- global _all_sdk_tools, _all_tool_names
-
- if _all_sdk_tools is not None:
- return _all_sdk_tools, _all_tool_names
-
- # Import triggers @mcp.tool registration
- from databricks_mcp_server.server import mcp
- from databricks_mcp_server.tools import sql, compute, file, pipelines # noqa: F401
-
- sdk_tools = []
- tool_names = []
-
- # Get registered tools from FastMCP (handle different API versions)
- registered_tools = None
-
- # Attempt 1: FastMCP 3.1.1+ with _tool_manager._tools (sync, local dev)
- if hasattr(mcp, '_tool_manager') and hasattr(getattr(mcp, '_tool_manager'), '_tools'):
- registered_tools = mcp._tool_manager._tools
- logger.info('Loaded tools via _tool_manager._tools')
-
- # Attempt 2: Async list_tools() (deployed FastMCP version)
- if registered_tools is None and hasattr(mcp, 'list_tools'):
- import concurrent.futures
- with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
- tools_list = executor.submit(lambda: asyncio.run(mcp.list_tools())).result()
- registered_tools = {t.name: t for t in tools_list}
- logger.info(f'Loaded tools via list_tools(): {list(registered_tools.keys())}')
-
- # Wrap all Databricks MCP tools
- for name, mcp_tool in registered_tools.items():
- input_schema = _convert_schema(mcp_tool.parameters)
- sdk_tool = _make_wrapper(name, mcp_tool.description, input_schema, mcp_tool.fn)
- sdk_tools.append(sdk_tool)
- tool_names.append(f'mcp__databricks__{name}')
-
- # Add operation tracking tools (for async handoff pattern)
- sdk_tools.append(_create_check_operation_status_tool())
- tool_names.append('mcp__databricks__check_operation_status')
-
- sdk_tools.append(_create_list_operations_tool())
- tool_names.append('mcp__databricks__list_operations')
-
- _all_sdk_tools = sdk_tools
- _all_tool_names = tool_names
- return sdk_tools, tool_names
-
-
-def create_filtered_databricks_server(allowed_tool_names: list[str]):
- """Create an MCP server with only the specified tools.
-
- Used to restrict which Databricks tools the agent can access based on
- which skills are enabled.
-
- Args:
- allowed_tool_names: List of tool names in mcp__databricks__* format
-
- Returns:
- Tuple of (server_config, filtered_tool_names)
- """
- all_sdk_tools, all_tool_names = _get_all_sdk_tools()
-
- allowed_set = set(allowed_tool_names)
- filtered_tools = []
- filtered_names = []
-
- for sdk_tool, tool_name in zip(all_sdk_tools, all_tool_names):
- if tool_name in allowed_set:
- filtered_tools.append(sdk_tool)
- filtered_names.append(tool_name)
-
- logger.info(
- f'Created filtered Databricks server: {len(filtered_names)}/{len(all_tool_names)} tools '
- f'({len(all_tool_names) - len(filtered_names)} blocked)'
- )
-
- server = create_sdk_mcp_server(name='databricks', tools=filtered_tools)
- return server, filtered_names
-
-
-def _create_check_operation_status_tool():
- """Create the check_operation_status tool for polling async operations."""
-
- @tool(
- "check_operation_status",
- """Check status of an async operation.
-
-Use this to get results of long-running operations that were moved to
-background execution. When a tool takes longer than 30 seconds, it returns
-an operation_id instead of blocking. Use this tool to poll for the result.
-
-Args:
- operation_id: The operation ID returned by the long-running tool
-
-Returns:
- - status: 'running', 'completed', or 'failed'
- - tool_name: Name of the original tool
- - result: The operation result (if completed)
- - error: Error message (if failed)
- - elapsed_seconds: Time since operation started
-""",
- {"operation_id": str},
- )
- async def check_operation_status(args: dict[str, Any]) -> dict[str, Any]:
- operation_id = args.get("operation_id", "")
-
- op = get_operation(operation_id)
- if not op:
- return {
- "content": [
- {
- "type": "text",
- "text": json.dumps(
- {
- "status": "not_found",
- "error": f"Operation {operation_id} not found. It may have expired (TTL: 1 hour) or never existed.",
- }
- ),
- }
- ]
- }
-
- result = {
- "status": op.status,
- "operation_id": op.operation_id,
- "tool_name": op.tool_name,
- "elapsed_seconds": round(time.time() - op.started_at, 1),
- }
-
- if op.status == "completed":
- result["result"] = op.result
- elif op.status == "failed":
- result["error"] = op.error
-
- return {"content": [{"type": "text", "text": json.dumps(result, default=str)}]}
-
- return check_operation_status
-
-
-def _create_list_operations_tool():
- """Create the list_operations tool for viewing all tracked operations."""
-
- @tool(
- "list_operations",
- """List all tracked async operations.
-
-Use this to see all operations that are running or recently completed.
-Useful for checking what's in progress or finding an operation ID.
-
-Args:
- status: Optional filter - 'running', 'completed', or 'failed'
-
-Returns:
- List of operations with their status and elapsed time
-""",
- {"status": str},
- )
- async def list_ops(args: dict[str, Any]) -> dict[str, Any]:
- status_filter = args.get("status")
- if status_filter == "":
- status_filter = None
-
- ops = list_operations(status_filter)
- return {"content": [{"type": "text", "text": json.dumps(ops, default=str)}]}
-
- return list_ops
-
-
-def _convert_schema(json_schema: dict) -> dict[str, type]:
- """Convert JSON schema to SDK simple format: {"param": type}"""
- type_map = {
- 'string': str,
- 'integer': int,
- 'number': float,
- 'boolean': bool,
- 'array': list,
- 'object': dict,
- }
- result = {}
-
- for param, spec in json_schema.get('properties', {}).items():
- # Handle anyOf (optional types like "string | null")
- if 'anyOf' in spec:
- for opt in spec['anyOf']:
- if opt.get('type') != 'null':
- result[param] = type_map.get(opt.get('type'), str)
- break
- else:
- result[param] = type_map.get(spec.get('type'), str)
-
- return result
-
-
-def _make_wrapper(name: str, description: str, schema: dict, fn):
- """Create SDK tool wrapper for a FastMCP function.
-
- The wrapper runs the sync function in a thread pool to avoid
- blocking the async event loop. It also handles JSON string parsing
- for complex types (lists, dicts) that the Claude agent may pass as strings.
-
- Includes async handoff for long-running operations:
- - Operations completing within SAFE_EXECUTION_THRESHOLD return normally
- - Operations exceeding the threshold switch to background execution
- and return an operation_id for polling via check_operation_status
- """
-
- @tool(name, description, schema)
- async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
- import sys
- import traceback
- import concurrent.futures
-
- start_time = time.time()
- print(f'[MCP TOOL] {name} called with args: {args}', file=sys.stderr, flush=True)
- logger.info(f'[MCP] Tool {name} called with args: {args}')
- try:
- # Parse JSON strings for complex types (Claude agent sometimes sends these as strings)
- parsed_args = {}
- for key, value in args.items():
- if isinstance(value, str) and value.strip().startswith(('[', '{')):
- # Try to parse as JSON if it looks like a list or dict
- try:
- parsed_args[key] = json.loads(value)
- print(f'[MCP TOOL] Parsed {key} from JSON string', file=sys.stderr, flush=True)
- except json.JSONDecodeError:
- # Not valid JSON, keep as string
- parsed_args[key] = value
- else:
- parsed_args[key] = value
-
- # FastMCP tools are sync - run in thread pool with heartbeat
- print(f'[MCP TOOL] Running {name} in thread pool with heartbeat...', file=sys.stderr, flush=True)
-
- # Copy context to propagate Databricks auth contextvars to the thread
- ctx = copy_context()
-
- def run_in_context():
- """Run the tool function within the copied context."""
- return ctx.run(fn, **parsed_args)
-
- # Run tool in executor so we can poll for completion with heartbeat
- # Use executor.submit() to get a concurrent.futures.Future (thread-safe)
- # instead of loop.run_in_executor() which returns an asyncio.Future
- loop = asyncio.get_event_loop()
- executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
- cf_future = executor.submit(run_in_context) # concurrent.futures.Future
- # Wrap in asyncio.Future for async waiting
- future = asyncio.wrap_future(cf_future, loop=loop)
-
- # Heartbeat every 10 seconds while waiting for the tool to complete
- HEARTBEAT_INTERVAL = 10
- heartbeat_count = 0
- while True:
- try:
- # Wait for result with timeout
- result = await asyncio.wait_for(
- asyncio.shield(future),
- timeout=HEARTBEAT_INTERVAL
- )
- # Tool completed successfully
- break
- except asyncio.TimeoutError:
- # Tool still running - emit heartbeat
- heartbeat_count += 1
- elapsed = time.time() - start_time
- print(f'[MCP HEARTBEAT] {name} still running... ({elapsed:.0f}s elapsed, heartbeat #{heartbeat_count})', file=sys.stderr, flush=True)
- logger.debug(f'[MCP] Heartbeat for {name}: {elapsed:.0f}s elapsed')
-
- # Check if we should switch to async mode to avoid connection timeout
- if elapsed > SAFE_EXECUTION_THRESHOLD:
- op_id = create_operation(name, parsed_args)
- print(
- f'[MCP ASYNC] {name} exceeded {SAFE_EXECUTION_THRESHOLD}s, '
- f'switching to async mode (operation_id: {op_id})',
- file=sys.stderr,
- flush=True,
- )
- logger.info(
- f'[MCP] Tool {name} switched to async mode after {elapsed:.0f}s '
- f'(operation_id: {op_id})'
- )
-
- # Start background thread to complete the operation
- # We use threading.Thread instead of asyncio.create_task because
- # the fresh event loop pattern may not keep tasks alive
- def complete_in_background(op_id, cf_future, executor):
- """Background thread to wait for completion and store result."""
- try:
- # Block until the future completes (it's already running)
- # cf_future is a concurrent.futures.Future which is thread-safe
- result = cf_future.result() # This blocks
- complete_operation(op_id, result=result)
- print(
- f'[MCP ASYNC] Operation {op_id} completed successfully',
- file=sys.stderr,
- flush=True,
- )
- except Exception as e:
- import traceback
- error_details = traceback.format_exc()
- complete_operation(op_id, error=str(e))
- print(
- f'[MCP ASYNC] Operation {op_id} failed: {e}',
- file=sys.stderr,
- flush=True,
- )
- print(
- f'[MCP ASYNC] Traceback:\n{error_details}',
- file=sys.stderr,
- flush=True,
- )
- finally:
- executor.shutdown(wait=False)
-
- bg_thread = threading.Thread(
- target=complete_in_background,
- args=(op_id, cf_future, executor),
- daemon=True,
- )
- bg_thread.start()
-
- # Return immediately with operation info
- return {
- 'content': [
- {
- 'type': 'text',
- 'text': json.dumps({
- 'status': 'async',
- 'operation_id': op_id,
- 'tool_name': name,
- 'message': (
- f'Operation is taking longer than {SAFE_EXECUTION_THRESHOLD}s '
- f'and has been moved to background execution. '
- f'Use check_operation_status("{op_id}") to poll for results.'
- ),
- 'elapsed_seconds': round(elapsed, 1),
- }),
- }
- ]
- }
-
- # Continue waiting
- continue
-
- elapsed = time.time() - start_time
- result_str = json.dumps(result, default=str)
- print(f'[MCP TOOL] {name} completed in {elapsed:.2f}s, result length: {len(result_str)}', file=sys.stderr, flush=True)
- logger.info(f'[MCP] Tool {name} completed in {elapsed:.2f}s')
- return {'content': [{'type': 'text', 'text': result_str}]}
- except asyncio.CancelledError:
- elapsed = time.time() - start_time
- error_msg = f'Tool execution cancelled after {elapsed:.2f}s (likely due to stream timeout)'
- print(f'[MCP TOOL] {name} CANCELLED: {error_msg}', file=sys.stderr, flush=True)
- logger.error(f'[MCP] Tool {name} cancelled: {error_msg}')
- return {'content': [{'type': 'text', 'text': f'Error: {error_msg}'}], 'is_error': True}
- except TimeoutError as e:
- elapsed = time.time() - start_time
- error_msg = f'Tool execution timed out after {elapsed:.2f}s: {e}'
- print(f'[MCP TOOL] {name} TIMEOUT: {error_msg}', file=sys.stderr, flush=True)
- logger.error(f'[MCP] Tool {name} timeout: {error_msg}')
- return {'content': [{'type': 'text', 'text': f'Error: {error_msg}'}], 'is_error': True}
- except Exception as e:
- elapsed = time.time() - start_time
- error_details = traceback.format_exc()
- error_msg = f'{type(e).__name__}: {str(e)}'
- print(f'[MCP TOOL] {name} FAILED after {elapsed:.2f}s: {error_msg}', file=sys.stderr, flush=True)
- print(f'[MCP TOOL] Stack trace:\n{error_details}', file=sys.stderr, flush=True)
- logger.exception(f'[MCP] Tool {name} failed after {elapsed:.2f}s: {error_msg}')
- return {'content': [{'type': 'text', 'text': f'Error ({type(e).__name__}): {str(e)}\n\nThis error occurred after {elapsed:.2f}s. If this is a long-running operation, it may have exceeded the stream timeout (50s).'}], 'is_error': True}
-
- return wrapper
diff --git a/databricks-builder-app/server/services/user.py b/databricks-builder-app/server/services/user.py
index 4d3a0709..f2b737d5 100644
--- a/databricks-builder-app/server/services/user.py
+++ b/databricks-builder-app/server/services/user.py
@@ -17,9 +17,10 @@
from typing import Optional
from databricks.sdk import WorkspaceClient
-from databricks_tools_core.identity import PRODUCT_NAME, PRODUCT_VERSION
from fastapi import Request
+from .auth import PRODUCT_NAME, PRODUCT_VERSION
+
logger = logging.getLogger(__name__)
# Cache for dev user to avoid repeated API calls
diff --git a/databricks-builder-app/server/services/warehouses.py b/databricks-builder-app/server/services/warehouses.py
index 7922a7cf..7bdb352f 100644
--- a/databricks-builder-app/server/services/warehouses.py
+++ b/databricks-builder-app/server/services/warehouses.py
@@ -7,7 +7,7 @@
from threading import Lock
from typing import Optional
-from databricks_tools_core.auth import get_workspace_client
+from .auth import get_workspace_client
logger = logging.getLogger(__name__)
diff --git a/databricks-mcp-server/.gitignore b/databricks-mcp-server/.gitignore
deleted file mode 100644
index 6e32d7cd..00000000
--- a/databricks-mcp-server/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-.test-results/
diff --git a/databricks-mcp-server/README.md b/databricks-mcp-server/README.md
deleted file mode 100644
index f3d7b459..00000000
--- a/databricks-mcp-server/README.md
+++ /dev/null
@@ -1,269 +0,0 @@
-# Databricks MCP Server
-
-A simple [FastMCP](https://github.com/jlowin/fastmcp) server that exposes Databricks operations as MCP tools for AI assistants like Claude Code.
-
-## Quick Start
-
-### Step 1: Clone the repository
-
-```bash
-git clone https://github.com/databricks-solutions/ai-dev-kit.git
-cd ai-dev-kit
-```
-
-### Step 2: Install the packages
-
-```bash
-# Install the core library
-uv pip install -e ./databricks-tools-core
-
-# Install the MCP server
-uv pip install -e ./databricks-mcp-server
-```
-
-### Step 3: Configure Databricks authentication
-
-```bash
-# Option 1: Environment variables
-export DATABRICKS_HOST="https://your-workspace.cloud.databricks.com"
-export DATABRICKS_TOKEN="your-token"
-
-# Option 2: Use a profile from ~/.databrickscfg
-export DATABRICKS_CONFIG_PROFILE="your-profile"
-```
-
-### Step 4: Add MCP server to Claude Code
-
-For Claude Code, add to your project's `.mcp.json` (create the file if it doesn't exist).
-For Cursor, add to your project's `.cursor/mcp.json` (create the file if it doesn't exist).
-
-```json
-{
- "mcpServers": {
- "databricks": {
- "command": "uv",
- "args": ["run", "--directory", "/path/to/ai-dev-kit", "python", "databricks-mcp-server/run_server.py"],
- "defer_loading": true
- }
- }
-}
-```
-
-**Replace `/path/to/ai-dev-kit`** with the actual path where you cloned the repo.
-
-**Note:** `"defer_loading": true` improves startup time by not loading all tools upfront.
-
-### Step 5 (Recommended): Install Databricks skills
-
-The MCP server works best with **Databricks skills** that teach Claude best practices:
-
-```bash
-# In your project directory (not ai-dev-kit)
-cd /path/to/your/project
-curl -sSL https://raw.githubusercontent.com/databricks-solutions/ai-dev-kit/main/databricks-skills/install_skills.sh | bash
-```
-
-### Step 6: Start Claude Code
-
-```bash
-cd /path/to/your/project
-claude
-```
-
-Claude now has both:
-- **Skills** (knowledge) - patterns and best practices in `.claude/skills/`
-- **MCP Tools** (actions) - Databricks operations via the MCP server
-
-## Available Tools
-
-### SQL Operations
-
-| Tool | Description |
-|------|-------------|
-| `execute_sql` | Execute a SQL query on a Databricks SQL Warehouse |
-| `execute_sql_multi` | Execute multiple SQL statements with parallel execution |
-| `list_warehouses` | List all SQL warehouses in the workspace |
-| `get_best_warehouse` | Get the ID of the best available warehouse |
-| `get_table_stats_and_schema` | Get table schema and statistics |
-
-### Compute
-
-| Tool | Description |
-|------|-------------|
-| `execute_code` | Execute code on Databricks (serverless or cluster), or run a local file |
-| `manage_cluster` | Create, modify, start, terminate, or delete clusters |
-| `manage_sql_warehouse` | Create, modify, or delete SQL warehouses |
-| `list_compute` | List clusters, node types, or spark versions |
-
-### File Operations
-
-| Tool | Description |
-|------|-------------|
-| `upload_to_workspace` | Upload files/folders to workspace (works like `cp` - handles files, folders, globs) |
-
-### Jobs
-
-| Tool | Description |
-|------|-------------|
-| `create_job` | Create a new job with tasks (serverless by default) |
-| `get_job` | Get detailed job configuration |
-| `list_jobs` | List jobs with optional name filter |
-| `find_job_by_name` | Find job by exact name, returns job ID |
-| `update_job` | Update job configuration |
-| `delete_job` | Delete a job |
-| `run_job_now` | Trigger a job run, returns run ID |
-| `get_run` | Get run status and details |
-| `get_run_output` | Get run output and logs |
-| `list_runs` | List runs with filters |
-| `cancel_run` | Cancel a running job |
-| `wait_for_run` | Wait for run completion |
-
-### Spark Declarative Pipelines (SDP)
-
-| Tool | Description |
-|------|-------------|
-| `create_or_update_pipeline` | Create or update pipeline by name (auto-detects existing) |
-| `get_pipeline` | Get pipeline details by ID or name; enriched with latest update status and events. Omit args to list all. |
-| `delete_pipeline` | Delete a pipeline |
-| `run_pipeline` | Start, stop, or wait for pipeline runs |
-
-### Knowledge Assistants (KA)
-
-| Tool | Description |
-|------|-------------|
-| `manage_ka` | Manage Knowledge Assistants (create/update, get, find by name, delete) |
-
-### Genie Spaces
-
-| Tool | Description |
-|------|-------------|
-| `create_or_update_genie` | Create or update a Genie Space for SQL-based data exploration |
-| `get_genie` | Get Genie Space details by space ID |
-| `find_genie_by_name` | Find Genie Space by name, returns space ID |
-| `delete_genie` | Delete a Genie Space |
-
-### Supervisor Agent (MAS)
-
-| Tool | Description |
-|------|-------------|
-| `manage_mas` | Manage Supervisor Agents (create/update, get, find by name, delete) |
-
-### AI/BI Dashboards
-
-| Tool | Description |
-|------|-------------|
-| `create_or_update_dashboard` | Create or update an AI/BI dashboard from JSON content |
-| `get_dashboard` | Get dashboard details by ID, or list all dashboards (omit dashboard_id) |
-| `delete_dashboard` | Soft-delete a dashboard (moves to trash) |
-| `publish_dashboard` | Publish or unpublish a dashboard (`publish=True/False`) |
-
-### Model Serving
-
-| Tool | Description |
-|------|-------------|
-| `get_serving_endpoint_status` | Get the status of a Model Serving endpoint |
-| `query_serving_endpoint` | Query a Model Serving endpoint with chat or ML model inputs |
-| `list_serving_endpoints` | List all Model Serving endpoints in the workspace |
-
-## Architecture
-
-```
-┌─────────────────────────────────────────────────────────────┐
-│ Claude Code │
-│ │
-│ Skills (knowledge) MCP Tools (actions) │
-│ └── .claude/skills/ └── .claude/mcp.json │
-│ ├── sdp-writer └── databricks server │
-│ ├── databricks-bundles │
-│ └── ... │
-└──────────────────────────────┬──────────────────────────────┘
- │ MCP Protocol (stdio)
- ▼
-┌─────────────────────────────────────────────────────────────┐
-│ databricks-mcp-server (FastMCP) │
-│ │
-│ tools/sql.py ──────────────┐ │
-│ tools/compute.py ──────────┤ │
-│ tools/file.py ─────────────┤ │
-│ tools/jobs.py ─────────────┼──► @mcp.tool decorators │
-│ tools/pipelines.py ────────┤ │
-│ tools/agent_bricks.py ─────┤ │
-│ tools/aibi_dashboards.py ──┤ │
-│ tools/serving.py ──────────┘ │
-└──────────────────────────────┬──────────────────────────────┘
- │ Python imports
- ▼
-┌─────────────────────────────────────────────────────────────┐
-│ databricks-tools-core │
-│ │
-│ sql/ compute/ jobs/ pipelines/ │
-│ └── execute └── run_code └── run/wait └── create/run │
-└──────────────────────────────┬──────────────────────────────┘
- │ Databricks SDK
- ▼
- ┌─────────────────────┐
- │ Databricks │
- │ Workspace │
- └─────────────────────┘
-```
-
-## Development
-
-The server is intentionally simple - each tool file just imports functions from `databricks-tools-core` and decorates them with `@mcp.tool`.
-
-### Running Integration Tests
-
-Integration tests run against a real Databricks workspace. Configure authentication first (see Step 3 above).
-
-```bash
-# Run all tests (excluding slow tests like cluster creation)
-python tests/integration/run_tests.py
-
-# Run all tests including slow tests
-python tests/integration/run_tests.py --all
-
-# Show report from the latest run
-python tests/integration/run_tests.py --report
-
-# Run with fewer parallel workers (default: 8)
-python tests/integration/run_tests.py -j 4
-```
-
-Results are saved to `tests/integration/.test-results//` with logs for each test folder.
-
-See [tests/integration/README.md](tests/integration/README.md) for more details.
-
-To add a new tool:
-
-1. Add the function to `databricks-tools-core`
-2. Create a wrapper in `databricks_mcp_server/tools/`
-3. Import it in `server.py`
-
-Example:
-
-```python
-# tools/my_module.py
-from databricks_tools_core.my_module import my_function as _my_function
-from ..server import mcp
-
-@mcp.tool
-def my_function(arg1: str, arg2: int = 10) -> dict:
- """Tool description shown to the AI."""
- return _my_function(arg1=arg1, arg2=arg2)
-```
-
-## Usage Tracking via Audit Logs
-
-All API calls made through the MCP server are tagged with a custom `User-Agent` header:
-
-```
-databricks-ai-dev-kit/0.1.0 databricks-sdk-py/... project/
-```
-
-The project name is auto-detected from the git remote URL (no configuration needed). This makes every call filterable in the `system.access.audit` system table.
-
-> **Note:** Audit log entries may take 2–10 minutes to appear. The workspace must have Unity Catalog enabled to query `system.access.audit`.
-
-## License
-
-© Databricks, Inc. See [LICENSE.md](../LICENSE.md).
diff --git a/databricks-mcp-server/databricks_mcp_server/__init__.py b/databricks-mcp-server/databricks_mcp_server/__init__.py
deleted file mode 100644
index e4e12c14..00000000
--- a/databricks-mcp-server/databricks_mcp_server/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Databricks MCP Server - FastMCP-based tools for Databricks operations."""
-
-__version__ = "0.1.0"
diff --git a/databricks-mcp-server/databricks_mcp_server/manifest.py b/databricks-mcp-server/databricks_mcp_server/manifest.py
deleted file mode 100644
index daab9ffd..00000000
--- a/databricks-mcp-server/databricks_mcp_server/manifest.py
+++ /dev/null
@@ -1,180 +0,0 @@
-"""Resource tracking manifest for cross-session continuity.
-
-Tracks Databricks resources created through the MCP server in a local
-`.databricks-resources.json` file. This allows agents to see what was
-created in previous sessions and avoid duplicates.
-"""
-
-import json
-import logging
-import os
-import tempfile
-from datetime import datetime, timezone
-from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional
-
-logger = logging.getLogger(__name__)
-
-# ---------------------------------------------------------------------------
-# Resource deleter registry
-# ---------------------------------------------------------------------------
-# Each tool module registers a callable that deletes a resource given its ID.
-# This avoids hard-coding every resource type inside the manifest tool layer.
-_RESOURCE_DELETERS: Dict[str, Callable[[str], None]] = {}
-
-
-def register_deleter(resource_type: str, fn: Callable[[str], None]) -> None:
- """Register a delete function for a resource type.
-
- Tool modules call this at import time so the manifest tool layer can
- delete any tracked resource without knowing implementation details.
-
- Args:
- resource_type: The manifest resource type key (e.g. ``"job"``).
- fn: A callable that takes a ``resource_id`` string and deletes
- the corresponding Databricks resource. Should raise on failure.
- """
- _RESOURCE_DELETERS[resource_type] = fn
-
-
-MANIFEST_FILENAME = ".databricks-resources.json"
-MANIFEST_VERSION = 1
-
-
-def _get_manifest_path() -> Path:
- """Get the path to the manifest file.
-
- Looks for ``MANIFEST_FILENAME`` relative to CWD (MCP servers are
- launched from the project root).
- """
- return Path(os.getcwd()) / MANIFEST_FILENAME
-
-
-def _read_manifest() -> Dict[str, Any]:
- """Read the manifest file, returning an empty structure if missing."""
- path = _get_manifest_path()
- if not path.exists():
- return {"version": MANIFEST_VERSION, "resources": []}
- try:
- with open(path, "r") as f:
- data = json.load(f)
- if not isinstance(data, dict) or "resources" not in data:
- return {"version": MANIFEST_VERSION, "resources": []}
- return data
- except (json.JSONDecodeError, OSError) as exc:
- logger.warning("Failed to read manifest %s: %s", path, exc)
- return {"version": MANIFEST_VERSION, "resources": []}
-
-
-def _write_manifest(data: Dict[str, Any]) -> None:
- """Atomically write the manifest file."""
- path = _get_manifest_path()
- try:
- # Write to a temp file in the same directory, then rename
- fd, tmp_path = tempfile.mkstemp(dir=path.parent, prefix=".manifest-tmp-", suffix=".json")
- try:
- with os.fdopen(fd, "w") as f:
- json.dump(data, f, indent=2)
- f.write("\n")
- os.replace(tmp_path, path)
- except Exception:
- # Clean up temp file on failure
- try:
- os.unlink(tmp_path)
- except OSError:
- pass
- raise
- except OSError as exc:
- logger.warning("Failed to write manifest %s: %s", path, exc)
-
-
-def _now_iso() -> str:
- """Return the current UTC time as an ISO 8601 string."""
- return datetime.now(timezone.utc).isoformat()
-
-
-def track_resource(
- resource_type: str,
- name: str,
- resource_id: str,
- url: Optional[str] = None,
-) -> None:
- """Track a created/updated resource in the manifest.
-
- Upsert logic:
- - If a resource with the same type+id exists, update name/url/updated_at.
- - If a resource with the same type+name exists but different id, update the id.
- - Otherwise, append a new entry.
-
- This is best-effort: failures are logged but never raised.
- """
- try:
- data = _read_manifest()
- resources: List[Dict[str, Any]] = data.get("resources", [])
- now = _now_iso()
-
- # Try to find by type+id
- for r in resources:
- if r.get("type") == resource_type and r.get("id") == resource_id:
- r["name"] = name
- if url:
- r["url"] = url
- r["updated_at"] = now
- _write_manifest(data)
- return
-
- # Try to find by type+name (handles ID changes across sessions)
- for r in resources:
- if r.get("type") == resource_type and r.get("name") == name:
- r["id"] = resource_id
- if url:
- r["url"] = url
- r["updated_at"] = now
- _write_manifest(data)
- return
-
- # New resource
- entry: Dict[str, Any] = {
- "type": resource_type,
- "name": name,
- "id": resource_id,
- "created_at": now,
- "updated_at": now,
- }
- if url:
- entry["url"] = url
- resources.append(entry)
- data["resources"] = resources
- _write_manifest(data)
- except Exception as exc:
- logger.warning("Failed to track resource %s/%s: %s", resource_type, name, exc)
-
-
-def remove_resource(resource_type: str, resource_id: str) -> bool:
- """Remove a resource from the manifest by type+id.
-
- Returns True if the resource was found and removed.
- """
- try:
- data = _read_manifest()
- resources = data.get("resources", [])
- original_count = len(resources)
- data["resources"] = [
- r for r in resources if not (r.get("type") == resource_type and r.get("id") == resource_id)
- ]
- if len(data["resources"]) < original_count:
- _write_manifest(data)
- return True
- return False
- except Exception as exc:
- logger.warning("Failed to remove resource %s/%s: %s", resource_type, resource_id, exc)
- return False
-
-
-def list_resources(resource_type: Optional[str] = None) -> List[Dict[str, Any]]:
- """Return tracked resources, optionally filtered by type."""
- data = _read_manifest()
- resources = data.get("resources", [])
- if resource_type:
- resources = [r for r in resources if r.get("type") == resource_type]
- return resources
diff --git a/databricks-mcp-server/databricks_mcp_server/middleware.py b/databricks-mcp-server/databricks_mcp_server/middleware.py
deleted file mode 100644
index 129b26ff..00000000
--- a/databricks-mcp-server/databricks_mcp_server/middleware.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""
-Middleware for the Databricks MCP Server.
-
-Provides cross-cutting concerns like timeout and error handling for all MCP tool calls.
-"""
-
-import anyio
-import json
-import logging
-import traceback
-
-from fastmcp.exceptions import ToolError
-from fastmcp.server.middleware import Middleware, MiddlewareContext, CallNext
-from fastmcp.tools.tool import ToolResult
-from mcp.types import CallToolRequestParams, TextContent
-
-logger = logging.getLogger(__name__)
-
-
-class TimeoutHandlingMiddleware(Middleware):
- """Catches errors from any tool and returns structured results.
-
- This middleware provides two key functions:
-
- 1. **Timeout handling**: When async operations (job runs, pipeline updates,
- resource provisioning) exceed their timeout, converts the exception into
- a JSON response that tells the agent the operation is still in progress
- and should NOT be retried blindly. Without this, agents interpret timeout
- errors as failures and retry — potentially creating duplicate resources.
-
- 2. **Error handling**: Catches all other exceptions and returns them as
- structured JSON responses instead of crashing the MCP server. This ensures
- the server stays up and the agent gets actionable error information.
-
- Note: For timeouts to work on sync tools, the server must wrap sync functions
- in asyncio.to_thread() (see server.py _patch_tool_decorator_for_async).
- """
-
- async def on_call_tool(
- self,
- context: MiddlewareContext[CallToolRequestParams],
- call_next: CallNext[CallToolRequestParams, ToolResult],
- ) -> ToolResult:
- tool_name = context.message.name
- arguments = context.message.arguments
-
- try:
- result = await call_next(context)
-
- # Fix for FastMCP not populating structured_content automatically.
- # When a tool has a return type annotation (e.g., -> Dict[str, Any]),
- # FastMCP generates an outputSchema but doesn't set structured_content.
- # MCP SDK then fails validation: "outputSchema defined but no structured output"
- # We fix this by parsing the JSON text content and setting structured_content.
- if result and not result.structured_content and result.content:
- if len(result.content) == 1 and isinstance(result.content[0], TextContent):
- try:
- parsed = json.loads(result.content[0].text)
- if isinstance(parsed, dict):
- # Create new ToolResult with structured_content populated
- result = ToolResult(
- content=result.content,
- structured_content=parsed,
- )
- except (json.JSONDecodeError, TypeError):
- pass # Not valid JSON, leave as-is
-
- return result
-
- except TimeoutError as e:
- # In Python 3.11+, asyncio.TimeoutError is an alias for TimeoutError,
- # so this single handler catches both
- logger.warning(
- "Tool '%s' timed out. Raising ToolError.",
- tool_name,
- )
- # Raise ToolError so the MCP SDK sets isError=True on the response,
- # which bypasses outputSchema validation. Returning a ToolResult here
- # would be treated as a success and fail validation when outputSchema
- # is defined (e.g., tools with -> Dict[str, Any] return type).
- raise ToolError(json.dumps({
- "error": True,
- "error_type": "timeout",
- "tool": tool_name,
- "message": str(e) or "Operation timed out",
- "action_required": (
- "Operation may still be in progress. "
- "Do NOT retry the same call. "
- "Use the appropriate get/status tool to check current state."
- ),
- })) from e
-
- except anyio.get_cancelled_exc_class():
- # Re-raise CancelledError so MCP SDK's handler catches it and skips
- # calling message.respond(). If we return a result here, the SDK will
- # try to respond, but the request may already be marked as responded
- # by the cancellation handler, causing an AssertionError crash.
- # See: https://github.com/modelcontextprotocol/python-sdk/pull/1153
- logger.warning(
- "Tool '%s' was cancelled. Re-raising to let MCP SDK handle cleanup.",
- tool_name,
- )
- raise
-
- except Exception as e:
- # Log the full traceback for debugging
- logger.error(
- "Tool '%s' raised an exception: %s\n%s",
- tool_name,
- str(e),
- traceback.format_exc(),
- )
-
- # Raise ToolError so the MCP SDK sets isError=True on the response,
- # which bypasses outputSchema validation. Returning a ToolResult here
- # would be treated as a success and fail validation when outputSchema
- # is defined (e.g., tools with -> Dict[str, Any] return type).
- raise ToolError(json.dumps({
- "error": True,
- "error_type": type(e).__name__,
- "tool": tool_name,
- "message": str(e),
- })) from e
diff --git a/databricks-mcp-server/databricks_mcp_server/server.py b/databricks-mcp-server/databricks_mcp_server/server.py
deleted file mode 100644
index 4ee150aa..00000000
--- a/databricks-mcp-server/databricks_mcp_server/server.py
+++ /dev/null
@@ -1,172 +0,0 @@
-"""
-Databricks MCP Server
-
-A FastMCP server that exposes Databricks operations as MCP tools.
-Simply wraps functions from databricks-tools-core.
-"""
-
-import asyncio
-import functools
-import inspect
-import subprocess
-import sys
-from contextlib import asynccontextmanager
-
-from fastmcp import FastMCP
-
-from .middleware import TimeoutHandlingMiddleware
-
-
-# ---------------------------------------------------------------------------
-# Windows fixes — must run BEFORE FastMCP init and tool registration
-# ---------------------------------------------------------------------------
-
-
-def _patch_subprocess_stdin():
- """Monkey-patch subprocess so stdin defaults to DEVNULL on Windows.
-
- When the MCP server runs in stdio mode, stdin IS the JSON-RPC pipe.
- Any subprocess call without explicit stdin lets child processes inherit
- this pipe handle. On Windows the Databricks SDK refreshes auth tokens
- via ``subprocess.run(["databricks", "auth", "token", ...], shell=True)``
- without setting stdin — the spawned ``databricks.exe`` blocks reading
- from the shared pipe, hanging every MCP tool call.
-
- Fix: default stdin to DEVNULL so child processes never touch the pipe.
-
- See: https://github.com/modelcontextprotocol/python-sdk/issues/671
- """
- _original_run = subprocess.run
-
- @functools.wraps(_original_run)
- def _patched_run(*args, **kwargs):
- kwargs.setdefault("stdin", subprocess.DEVNULL)
- return _original_run(*args, **kwargs)
-
- subprocess.run = _patched_run
-
- _OriginalPopen = subprocess.Popen
-
- class _PatchedPopen(_OriginalPopen):
- def __init__(self, *args, **kwargs):
- kwargs.setdefault("stdin", subprocess.DEVNULL)
- super().__init__(*args, **kwargs)
-
- subprocess.Popen = _PatchedPopen
-
-
-
-
-def _wrap_sync_in_thread(fn):
- """Wrap a sync function to run in asyncio.to_thread(), preserving metadata."""
-
- @functools.wraps(fn)
- async def async_wrapper(**kwargs):
- return await asyncio.to_thread(fn, **kwargs)
-
- return async_wrapper
-
-
-# Apply subprocess patch early — before any Databricks SDK import (Windows only)
-if sys.platform == "win32":
- _patch_subprocess_stdin()
-
-
-def _patch_tool_decorator_for_async():
- """Wrap sync tool functions in asyncio.to_thread() on all platforms.
-
- FastMCP's FunctionTool.run() calls sync functions directly on the asyncio
- event loop thread, which blocks the stdio transport's I/O tasks. This causes:
-
- 1. On Windows with ProactorEventLoop: deadlock where all MCP tools hang.
-
- 2. On ALL platforms: cancellation race conditions. When the MCP client
- cancels a request (e.g., timeout), the event loop can't propagate the
- CancelledError to blocking sync code. The sync function eventually
- returns, but the MCP SDK has already responded to the cancellation,
- causing "Request already responded to" assertion errors and crashes.
-
- This patch intercepts @mcp.tool registration to wrap sync functions so they
- run in a thread pool, yielding control back to the event loop for I/O and
- enabling proper cancellation handling via anyio's task cancellation.
- """
- original_tool = mcp.tool
-
- @functools.wraps(original_tool)
- def patched_tool(fn=None, *args, **kwargs):
- # Handle @mcp.tool("name") — returns a decorator
- if fn is None or isinstance(fn, str):
- decorator = original_tool(fn, *args, **kwargs)
-
- @functools.wraps(decorator)
- def wrapper(func):
- if not inspect.iscoroutinefunction(func):
- func = _wrap_sync_in_thread(func)
- return decorator(func)
-
- return wrapper
-
- # Handle @mcp.tool (bare decorator, fn is the function)
- if not inspect.iscoroutinefunction(fn):
- fn = _wrap_sync_in_thread(fn)
- return original_tool(fn, *args, **kwargs)
-
- mcp.tool = patched_tool
-
-# ---------------------------------------------------------------------------
-# Server initialisation
-# ---------------------------------------------------------------------------
-
-# Disable FastMCP's built-in task worker on Windows.
-# The docket worker uses fakeredis XREADGROUP BLOCK which deadlocks
-# the ProactorEventLoop, preventing asyncio.to_thread() callbacks.
-# Belt-and-suspenders: pass tasks=False AND override _docket_lifespan,
-# because tasks=False alone does not prevent the worker from starting.
-_fastmcp_kwargs = {}
-if sys.platform == "win32":
- _fastmcp_kwargs["tasks"] = False
-
-mcp = FastMCP("Databricks MCP Server", **_fastmcp_kwargs)
-
-if sys.platform == "win32":
-
- @asynccontextmanager
- async def _noop_lifespan(*args, **kwargs):
- yield
-
- if hasattr(mcp, "_docket_lifespan"):
- mcp._docket_lifespan = _noop_lifespan
-
-# Register middleware (see middleware.py for details on each)
-mcp.add_middleware(TimeoutHandlingMiddleware())
-
-# Apply async wrapper on ALL platforms to:
-# 1. Prevent event loop deadlocks (critical on Windows)
-# 2. Enable proper cancellation handling (critical on all platforms)
-# Without this, sync tools block the event loop, preventing CancelledError
-# propagation and causing "Request already responded to" crashes.
-# TODO: FastMCP 3.x automatically wraps sync functions in asyncio.to_thread().
-# Test if this patch is still needed with FastMCP 3.x.
-_patch_tool_decorator_for_async()
-
-# Import and register all tools (side-effect imports: each module registers @mcp.tool decorators)
-from .tools import ( # noqa: F401, E402
- sql,
- compute,
- file,
- pipelines,
- jobs,
- agent_bricks,
- aibi_dashboards,
- serving,
- unity_catalog,
- volume_files,
- genie,
- manifest,
- vector_search,
- lakebase,
- user,
- apps,
- workspace,
- pdf,
-)
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/__init__.py b/databricks-mcp-server/databricks_mcp_server/tools/__init__.py
deleted file mode 100644
index 1630e01e..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Tool modules for Databricks MCP Server."""
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/agent_bricks.py b/databricks-mcp-server/databricks_mcp_server/tools/agent_bricks.py
deleted file mode 100644
index 7b8336d2..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/agent_bricks.py
+++ /dev/null
@@ -1,553 +0,0 @@
-"""Agent Bricks tools - Manage Knowledge Assistants (KA) and Supervisor Agents (MAS).
-
-For Genie Space tools, see genie.py
-"""
-
-from typing import Any, Dict, List, Optional
-
-from databricks_tools_core.agent_bricks import (
- AgentBricksManager,
- EndpointStatus,
- get_tile_example_queue,
-)
-from databricks_tools_core.identity import with_description_footer
-
-from ..manifest import register_deleter
-from ..server import mcp
-
-# Singleton manager instance
-_manager: Optional[AgentBricksManager] = None
-
-
-def _get_manager() -> AgentBricksManager:
- """Get or create the singleton AgentBricksManager instance."""
- global _manager
- if _manager is None:
- _manager = AgentBricksManager()
- return _manager
-
-
-def _delete_ka_resource(resource_id: str) -> None:
- _get_manager().delete(resource_id)
-
-
-def _delete_mas_resource(resource_id: str) -> None:
- _get_manager().delete(resource_id)
-
-
-register_deleter("knowledge_assistant", _delete_ka_resource)
-register_deleter("multi_agent_supervisor", _delete_mas_resource)
-
-
-# ============================================================================
-# Internal action handlers
-# ============================================================================
-
-
-def _ka_create_or_update(
- name: str,
- volume_path: str,
- description: str = None,
- instructions: str = None,
- tile_id: str = None,
- add_examples_from_volume: bool = True,
-) -> Dict[str, Any]:
- """Create or update a Knowledge Assistant."""
- if not name:
- return {"error": "Missing required parameter 'name' for create_or_update action"}
- if not volume_path:
- return {"error": "Missing required parameter 'volume_path' for create_or_update action"}
-
- description = with_description_footer(description)
- manager = _get_manager()
-
- # Build knowledge source from volume path
- knowledge_sources = [
- {
- "files_source": {
- "name": f"source_{name.replace(' ', '_').lower()}",
- "type": "files",
- "files": {"path": volume_path},
- }
- }
- ]
-
- # Create or update the KA
- result = manager.ka_create_or_update(
- name=name,
- knowledge_sources=knowledge_sources,
- description=description,
- instructions=instructions,
- tile_id=tile_id,
- )
-
- # Extract info from new flat format
- response_tile_id = result.get("tile_id", "")
- # Map SDK state to endpoint status for backward compatibility
- state = result.get("state", "UNKNOWN")
- endpoint_status = "ONLINE" if state == "ACTIVE" else ("PROVISIONING" if state == "CREATING" else state)
-
- response = {
- "tile_id": response_tile_id,
- "name": result.get("name", name),
- "operation": result.get("operation", "created"),
- "endpoint_status": endpoint_status,
- "examples_queued": 0,
- }
-
- # Scan volume for examples if requested
- if add_examples_from_volume and response_tile_id:
- examples = manager.scan_volume_for_examples(volume_path)
- if examples:
- # If endpoint is ACTIVE, add examples directly
- if state == "ACTIVE":
- created = manager.ka_add_examples_batch(response_tile_id, examples)
- response["examples_added"] = len(created)
- else:
- # Queue examples for when endpoint becomes ready
- queue = get_tile_example_queue()
- queue.enqueue(response_tile_id, manager, examples, tile_type="KA")
- response["examples_queued"] = len(examples)
-
- # Track resource on successful create/update
- try:
- if response_tile_id:
- from ..manifest import track_resource
-
- track_resource(
- resource_type="knowledge_assistant",
- name=response.get("name", name),
- resource_id=response_tile_id,
- )
- except Exception:
- pass # best-effort tracking
-
- return response
-
-
-def _ka_get(tile_id: str) -> Dict[str, Any]:
- """Get a Knowledge Assistant by tile ID."""
- if not tile_id:
- return {"error": "Missing required parameter 'tile_id' for get action"}
-
- manager = _get_manager()
- result = manager.ka_get(tile_id)
-
- if not result:
- return {"error": f"Knowledge Assistant {tile_id} not found"}
-
- # Get examples count (handle failures gracefully)
- try:
- examples_response = manager.ka_list_examples(tile_id)
- examples_count = len(examples_response.get("examples", []))
- except Exception:
- examples_count = 0
-
- # Map SDK state to endpoint status for backward compatibility
- state = result.get("state", "UNKNOWN")
- endpoint_status = "ONLINE" if state == "ACTIVE" else ("PROVISIONING" if state == "CREATING" else state)
-
- return {
- "tile_id": result.get("tile_id", tile_id),
- "name": result.get("name", ""),
- "description": result.get("description", ""),
- "endpoint_status": endpoint_status,
- "endpoint_name": result.get("endpoint_name", ""),
- "knowledge_sources": result.get("sources", []),
- "examples_count": examples_count,
- "instructions": result.get("instructions", ""),
- }
-
-
-def _ka_find_by_name(name: str) -> Dict[str, Any]:
- """Find a Knowledge Assistant by name."""
- if not name:
- return {"error": "Missing required parameter 'name' for find_by_name action"}
-
- manager = _get_manager()
- result = manager.find_by_name(name)
-
- if result is None:
- return {"found": False, "name": name}
-
- # Fetch full details to get endpoint status and name
- full_details = manager.ka_get(result.tile_id)
- endpoint_status = "UNKNOWN"
- endpoint_name = ""
- if full_details:
- state = full_details.get("state", "UNKNOWN")
- endpoint_status = "ONLINE" if state == "ACTIVE" else ("PROVISIONING" if state == "CREATING" else state)
- endpoint_name = full_details.get("endpoint_name", "")
-
- return {
- "found": True,
- "tile_id": result.tile_id,
- "name": result.name,
- "endpoint_name": endpoint_name,
- "endpoint_status": endpoint_status,
- }
-
-
-def _ka_delete(tile_id: str) -> Dict[str, Any]:
- """Delete a Knowledge Assistant."""
- if not tile_id:
- return {"error": "Missing required parameter 'tile_id' for delete action"}
-
- manager = _get_manager()
- try:
- manager.delete(tile_id)
- try:
- from ..manifest import remove_resource
-
- remove_resource(resource_type="knowledge_assistant", resource_id=tile_id)
- except Exception:
- pass
- return {"success": True, "tile_id": tile_id}
- except Exception as e:
- return {"success": False, "tile_id": tile_id, "error": str(e)}
-
-
-def _mas_create_or_update(
- name: str,
- agents: List[Dict[str, str]],
- description: str = None,
- instructions: str = None,
- tile_id: str = None,
- examples: List[Dict[str, str]] = None,
-) -> Dict[str, Any]:
- """Create or update a Supervisor Agent."""
- if not name:
- return {"error": "Missing required parameter 'name' for create_or_update action"}
- if not agents:
- return {"error": "Missing required parameter 'agents' for create_or_update action"}
-
- description = with_description_footer(description)
- manager = _get_manager()
-
- # Validate and build agent list for API
- agent_list = []
- for i, agent in enumerate(agents):
- agent_name = agent.get("name", "")
- if not agent_name:
- return {"error": f"Agent at index {i} is missing required 'name' field"}
-
- agent_description = agent.get("description", "")
- if not agent_description:
- return {"error": f"Agent '{agent_name}' is missing required 'description' field"}
-
- has_endpoint = bool(agent.get("endpoint_name"))
- has_genie = bool(agent.get("genie_space_id"))
- has_ka = bool(agent.get("ka_tile_id"))
- has_uc_function = bool(agent.get("uc_function_name"))
- has_connection = bool(agent.get("connection_name"))
-
- # Count how many agent types are specified
- agent_type_count = sum([has_endpoint, has_genie, has_ka, has_uc_function, has_connection])
- if agent_type_count > 1:
- return {
- "error": (
- f"Agent '{agent_name}' has multiple agent types. "
- "Provide only one of: 'endpoint_name', 'genie_space_id', "
- "'ka_tile_id', 'uc_function_name', or 'connection_name'."
- )
- }
- if agent_type_count == 0:
- return {
- "error": (
- f"Agent '{agent_name}' must have one of: 'endpoint_name', "
- "'genie_space_id', 'ka_tile_id', 'uc_function_name', or 'connection_name'"
- )
- }
-
- agent_config = {
- "name": agent_name,
- "description": agent_description,
- }
-
- if has_genie:
- agent_config["agent_type"] = "genie"
- agent_config["genie_space"] = {"id": agent.get("genie_space_id")}
- elif has_ka:
- # KA tiles are referenced via their serving endpoint
- # Endpoint name uses the first segment of the tile_id
- ka_tile_id = agent.get("ka_tile_id")
- tile_id_prefix = ka_tile_id.split("-")[0]
- agent_config["agent_type"] = "serving_endpoint"
- agent_config["serving_endpoint"] = {"name": f"ka-{tile_id_prefix}-endpoint"}
- elif has_uc_function:
- uc_function_name = agent.get("uc_function_name")
- uc_parts = uc_function_name.split(".")
- if len(uc_parts) != 3:
- return {
- "error": (
- f"Agent '{agent_name}': uc_function_name must be in format "
- f"'catalog.schema.function_name', got '{uc_function_name}'"
- )
- }
- agent_config["agent_type"] = "unity_catalog_function"
- agent_config["unity_catalog_function"] = {
- "uc_path": {
- "catalog": uc_parts[0],
- "schema": uc_parts[1],
- "name": uc_parts[2],
- }
- }
- elif has_connection:
- agent_config["agent_type"] = "external_mcp_server"
- agent_config["external_mcp_server"] = {"connection_name": agent.get("connection_name")}
- else:
- agent_config["agent_type"] = "serving_endpoint"
- agent_config["serving_endpoint"] = {"name": agent.get("endpoint_name")}
-
- agent_list.append(agent_config)
-
- operation = "created"
- response_tile_id = tile_id
-
- if tile_id:
- # Check if exists
- existing = manager.mas_get(tile_id)
- if existing:
- operation = "updated"
- result = manager.mas_update(
- tile_id=tile_id,
- name=name,
- description=description,
- instructions=instructions,
- agents=agent_list,
- )
- else:
- return {"error": f"MAS {tile_id} not found"}
- else:
- # Check if exists by name
- existing = manager.mas_find_by_name(name)
- if existing:
- operation = "updated"
- response_tile_id = existing.tile_id
- result = manager.mas_update(
- tile_id=existing.tile_id,
- name=name,
- description=description,
- instructions=instructions,
- agents=agent_list,
- )
- else:
- # Create new
- result = manager.mas_create(
- name=name,
- agents=agent_list,
- description=description,
- instructions=instructions,
- )
- response_tile_id = result.get("multi_agent_supervisor", {}).get("tile", {}).get("tile_id", "")
-
- # Extract status
- mas_data = result.get("multi_agent_supervisor", {})
- tile_data = mas_data.get("tile", {})
- status_data = mas_data.get("status", {})
- endpoint_status = status_data.get("endpoint_status", "UNKNOWN")
-
- response = {
- "tile_id": response_tile_id or tile_data.get("tile_id", ""),
- "name": tile_data.get("name", name),
- "operation": operation,
- "endpoint_status": endpoint_status,
- "agents_count": len(agents),
- }
-
- # Add examples if provided
- if examples and response["tile_id"]:
- if endpoint_status == EndpointStatus.ONLINE.value:
- created = manager.mas_add_examples_batch(response["tile_id"], examples)
- response["examples_added"] = len(created)
- else:
- # Queue examples for when endpoint becomes ready
- queue = get_tile_example_queue()
- queue.enqueue(response["tile_id"], manager, examples, tile_type="MAS")
- response["examples_queued"] = len(examples)
-
- # Track resource on successful create/update
- try:
- mas_tile_id = response.get("tile_id")
- if mas_tile_id:
- from ..manifest import track_resource
-
- track_resource(
- resource_type="multi_agent_supervisor",
- name=response.get("name", name),
- resource_id=mas_tile_id,
- )
- except Exception:
- pass # best-effort tracking
-
- return response
-
-
-def _mas_get(tile_id: str) -> Dict[str, Any]:
- """Get a Supervisor Agent by tile ID."""
- if not tile_id:
- return {"error": "Missing required parameter 'tile_id' for get action"}
-
- manager = _get_manager()
- result = manager.mas_get(tile_id)
-
- if not result:
- return {"error": f"Supervisor Agent {tile_id} not found"}
-
- mas_data = result.get("multi_agent_supervisor", {})
- tile_data = mas_data.get("tile", {})
- status_data = mas_data.get("status", {})
-
- # Get examples count (handle failures gracefully)
- try:
- examples_response = manager.mas_list_examples(tile_id)
- examples_count = len(examples_response.get("examples", []))
- except Exception:
- examples_count = 0
-
- return {
- "tile_id": tile_data.get("tile_id", tile_id),
- "name": tile_data.get("name", ""),
- "description": tile_data.get("description", ""),
- "endpoint_status": status_data.get("endpoint_status", "UNKNOWN"),
- "agents": mas_data.get("agents", []),
- "examples_count": examples_count,
- "instructions": mas_data.get("instructions", ""),
- }
-
-
-def _mas_find_by_name(name: str) -> Dict[str, Any]:
- """Find a Supervisor Agent by name."""
- if not name:
- return {"error": "Missing required parameter 'name' for find_by_name action"}
-
- manager = _get_manager()
- result = manager.mas_find_by_name(name)
-
- if result is None:
- return {"found": False, "name": name}
-
- # Fetch full details to get endpoint status and agents
- full_details = manager.mas_get(result.tile_id)
- if full_details:
- mas_data = full_details.get("multi_agent_supervisor", {})
- status_data = mas_data.get("status", {})
- return {
- "found": True,
- "tile_id": result.tile_id,
- "name": result.name,
- "endpoint_status": status_data.get("endpoint_status", "UNKNOWN"),
- "agents_count": len(mas_data.get("agents", [])),
- }
-
- return {
- "found": True,
- "tile_id": result.tile_id,
- "name": result.name,
- "endpoint_status": "UNKNOWN",
- "agents_count": 0,
- }
-
-
-def _mas_delete(tile_id: str) -> Dict[str, Any]:
- """Delete a Supervisor Agent."""
- if not tile_id:
- return {"error": "Missing required parameter 'tile_id' for delete action"}
-
- manager = _get_manager()
- try:
- manager.delete(tile_id)
- try:
- from ..manifest import remove_resource
-
- remove_resource(resource_type="multi_agent_supervisor", resource_id=tile_id)
- except Exception:
- pass
- return {"success": True, "tile_id": tile_id}
- except Exception as e:
- return {"success": False, "tile_id": tile_id, "error": str(e)}
-
-
-# ============================================================================
-# Consolidated MCP Tools
-# ============================================================================
-
-
-@mcp.tool(timeout=180)
-def manage_ka(
- action: str,
- name: str = None,
- volume_path: str = None,
- description: str = None,
- instructions: str = None,
- tile_id: str = None,
- add_examples_from_volume: bool = True,
-) -> Dict[str, Any]:
- """Manage Knowledge Assistant (KA) - RAG-based document Q&A.
-
- Actions: create_or_update (name+volume_path), get (tile_id), find_by_name (name), delete (tile_id).
- volume_path: UC Volume path with documents (e.g., /Volumes/catalog/schema/vol/docs).
- description: What this KA does (shown to users). instructions: How KA should answer queries.
- add_examples_from_volume: scan volume for JSON example files with question/guideline pairs.
- See agent-bricks skill for full details.
- Returns: create_or_update={tile_id, operation, endpoint_status}, get={tile_id, knowledge_sources, examples_count},
- find_by_name={found, tile_id, endpoint_name}, delete={success}."""
- action = action.lower()
-
- if action == "create_or_update":
- return _ka_create_or_update(
- name=name,
- volume_path=volume_path,
- description=description,
- instructions=instructions,
- tile_id=tile_id,
- add_examples_from_volume=add_examples_from_volume,
- )
- elif action == "get":
- return _ka_get(tile_id=tile_id)
- elif action == "find_by_name":
- return _ka_find_by_name(name=name)
- elif action == "delete":
- return _ka_delete(tile_id=tile_id)
- else:
- return {"error": f"Invalid action '{action}'. Must be one of: create_or_update, get, find_by_name, delete"}
-
-
-@mcp.tool(timeout=180)
-def manage_mas(
- action: str,
- name: str = None,
- agents: List[Dict[str, str]] = None,
- description: str = None,
- instructions: str = None,
- tile_id: str = None,
- examples: List[Dict[str, str]] = None,
-) -> Dict[str, Any]:
- """Manage Supervisor Agent (MAS) - orchestrates multiple agents for query routing.
-
- Actions: create_or_update (name+agents), get (tile_id), find_by_name (name), delete (tile_id).
- agents: [{name, description (critical for routing), ONE OF: endpoint_name|genie_space_id|ka_tile_id|uc_function_name|connection_name}].
- description: What this MAS does. instructions: Routing rules for the supervisor.
- examples: [{question, guideline}] to train routing behavior.
- See agent-bricks skill for full agent configuration details.
- Returns: create_or_update={tile_id, operation, endpoint_status, agents_count}, get={tile_id, agents, examples_count},
- find_by_name={found, tile_id, agents_count}, delete={success}."""
- action = action.lower()
-
- if action == "create_or_update":
- return _mas_create_or_update(
- name=name,
- agents=agents,
- description=description,
- instructions=instructions,
- tile_id=tile_id,
- examples=examples,
- )
- elif action == "get":
- return _mas_get(tile_id=tile_id)
- elif action == "find_by_name":
- return _mas_find_by_name(name=name)
- elif action == "delete":
- return _mas_delete(tile_id=tile_id)
- else:
- return {"error": f"Invalid action '{action}'. Must be one of: create_or_update, get, find_by_name, delete"}
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/aibi_dashboards.py b/databricks-mcp-server/databricks_mcp_server/tools/aibi_dashboards.py
deleted file mode 100644
index b8b3002e..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/aibi_dashboards.py
+++ /dev/null
@@ -1,154 +0,0 @@
-"""AI/BI Dashboard tools - Create and manage AI/BI dashboards.
-
-Note: AI/BI dashboards were previously known as Lakeview dashboards.
-The SDK/API still uses the 'lakeview' name internally.
-
-Consolidated into 1 tool:
-- manage_dashboard: create_or_update, get, list, delete, publish, unpublish
-"""
-
-import json
-from typing import Any, Dict, Optional, Union
-
-from databricks_tools_core.aibi_dashboards import (
- create_or_update_dashboard as _create_or_update_dashboard,
- get_dashboard as _get_dashboard,
- list_dashboards as _list_dashboards,
- publish_dashboard as _publish_dashboard,
- trash_dashboard as _trash_dashboard,
- unpublish_dashboard as _unpublish_dashboard,
-)
-
-from ..manifest import register_deleter
-from ..server import mcp
-
-
-def _delete_dashboard_resource(resource_id: str) -> None:
- _trash_dashboard(dashboard_id=resource_id)
-
-
-register_deleter("dashboard", _delete_dashboard_resource)
-
-
-@mcp.tool(timeout=120)
-def manage_dashboard(
- action: str,
- # For create_or_update:
- display_name: Optional[str] = None,
- parent_path: Optional[str] = None,
- serialized_dashboard: Optional[Union[str, dict]] = None,
- warehouse_id: Optional[str] = None,
- # For create_or_update publish option:
- publish: bool = True,
- # For get/delete/publish/unpublish:
- dashboard_id: Optional[str] = None,
- # For publish:
- embed_credentials: bool = True,
-) -> Dict[str, Any]:
- """Manage AI/BI dashboards: create, update, get, list, delete, publish.
-
- CRITICAL: Before calling this tool to create or edit a dashboard, you MUST:
- 0. Review the databricks-aibi-dashboards skill to understand widget definitions.
- You must EXACTLY follow the JSON structure detailed in the skill.
- 1. Call get_table_stats_and_schema() to get table schemas for your queries.
- 2. Call execute_sql() to TEST EVERY dataset query before using in dashboard.
- If you skip validation, widgets WILL show errors!
-
- Actions:
- - create_or_update: Create/update dashboard from JSON.
- Requires display_name, parent_path, serialized_dashboard, warehouse_id.
- publish=True (default) auto-publishes after create.
- Returns: {success, dashboard_id, path, url, published, error}.
- - get: Get dashboard details. Requires dashboard_id.
- Returns: dashboard config and metadata.
- - list: List all dashboards.
- Returns: {dashboards: [...]}.
- - delete: Soft-delete (moves to trash). Requires dashboard_id.
- Returns: {status, message}.
- - publish: Publish dashboard. Requires dashboard_id, warehouse_id.
- embed_credentials=True allows users without data access to view.
- Returns: {status, dashboard_id}.
- - unpublish: Unpublish dashboard. Requires dashboard_id.
- Returns: {status, dashboard_id}.
-
- Widget structure rules (for create_or_update):
- - queries is TOP-LEVEL SIBLING of spec (NOT inside spec, NOT named_queries)
- - fields[].name MUST match encodings fieldName exactly
- - Use datasetName (camelCase, not dataSetName)
- - Versions: counter/table/filter=2, bar/line/pie=3
- - Layout: 6-column grid
- - Filter types: filter-multi-select, filter-single-select, filter-date-range-picker
- - Text widget uses textbox_spec (no spec block)"""
- act = action.lower()
-
- if act == "create_or_update":
- if not all([display_name, parent_path, serialized_dashboard, warehouse_id]):
- return {"error": "create_or_update requires: display_name, parent_path, serialized_dashboard, warehouse_id"}
-
- # MCP deserializes JSON params, so serialized_dashboard may arrive as a dict
- if isinstance(serialized_dashboard, dict):
- serialized_dashboard = json.dumps(serialized_dashboard)
-
- result = _create_or_update_dashboard(
- display_name=display_name,
- parent_path=parent_path,
- serialized_dashboard=serialized_dashboard,
- warehouse_id=warehouse_id,
- publish=publish,
- )
-
- # Track resource on successful create/update
- try:
- if result.get("success") and result.get("dashboard_id"):
- from ..manifest import track_resource
-
- track_resource(
- resource_type="dashboard",
- name=display_name,
- resource_id=result["dashboard_id"],
- url=result.get("url"),
- )
- except Exception:
- pass
-
- return result
-
- elif act == "get":
- if not dashboard_id:
- return {"error": "get requires: dashboard_id"}
- return _get_dashboard(dashboard_id=dashboard_id)
-
- elif act == "list":
- return _list_dashboards(page_size=200)
-
- elif act == "delete":
- if not dashboard_id:
- return {"error": "delete requires: dashboard_id"}
- result = _trash_dashboard(dashboard_id=dashboard_id)
- try:
- from ..manifest import remove_resource
- remove_resource(resource_type="dashboard", resource_id=dashboard_id)
- except Exception:
- pass
- return result
-
- elif act == "publish":
- if not dashboard_id:
- return {"error": "publish requires: dashboard_id"}
- if not warehouse_id:
- return {"error": "publish requires: warehouse_id"}
- return _publish_dashboard(
- dashboard_id=dashboard_id,
- warehouse_id=warehouse_id,
- embed_credentials=embed_credentials,
- )
-
- elif act == "unpublish":
- if not dashboard_id:
- return {"error": "unpublish requires: dashboard_id"}
- return _unpublish_dashboard(dashboard_id=dashboard_id)
-
- else:
- return {
- "error": f"Invalid action '{action}'. Valid actions: create_or_update, get, list, delete, publish, unpublish"
- }
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/apps.py b/databricks-mcp-server/databricks_mcp_server/tools/apps.py
deleted file mode 100644
index 34c7474d..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/apps.py
+++ /dev/null
@@ -1,201 +0,0 @@
-"""App tools - Manage Databricks Apps lifecycle.
-
-Consolidated into 1 tool:
-- manage_app: create_or_update, get, list, delete
-"""
-
-import logging
-from typing import Any, Dict, Optional
-
-from databricks_tools_core.apps.apps import (
- create_app as _create_app,
- get_app as _get_app,
- list_apps as _list_apps,
- deploy_app as _deploy_app,
- delete_app as _delete_app,
- get_app_logs as _get_app_logs,
-)
-from databricks_tools_core.identity import with_description_footer
-
-from ..manifest import register_deleter
-from ..server import mcp
-
-logger = logging.getLogger(__name__)
-
-
-def _delete_app_resource(resource_id: str) -> None:
- _delete_app(name=resource_id)
-
-
-register_deleter("app", _delete_app_resource)
-
-
-# ============================================================================
-# Helpers
-# ============================================================================
-
-
-def _find_app_by_name(name: str) -> Optional[Dict[str, Any]]:
- """Find an app by name, returns None if not found."""
- try:
- result = _get_app(name=name)
- if result.get("error"):
- return None
- return result
- except Exception:
- return None
-
-
-# ============================================================================
-# Tool: manage_app
-# ============================================================================
-
-
-@mcp.tool(timeout=180)
-def manage_app(
- action: str,
- # For create_or_update/get/delete:
- name: Optional[str] = None,
- # For create_or_update:
- source_code_path: Optional[str] = None,
- description: Optional[str] = None,
- mode: Optional[str] = None,
- # For get:
- include_logs: bool = False,
- deployment_id: Optional[str] = None,
- # For list:
- name_contains: Optional[str] = None,
-) -> Dict[str, Any]:
- """Manage Databricks Apps: create, deploy, get, list, delete.
-
- Actions:
- - create_or_update: Idempotent create. Deploys if source_code_path provided. Requires name.
- source_code_path: Volume or workspace path to deploy from.
- description: App description. mode: Deployment mode.
- Returns: {name, created: bool, url, status, deployment}.
- - get: Get app details. Requires name.
- include_logs=True for deployment logs. deployment_id for specific deployment.
- Returns: {name, url, status, logs}.
- - list: List all apps. Optional name_contains filter.
- Returns: {apps: [{name, url, status}, ...]}.
- - delete: Delete an app. Requires name.
- Returns: {name, status}.
-
- See databricks-app-python skill for app development guidance."""
- act = action.lower()
-
- if act == "create_or_update":
- if not name:
- return {"error": "create_or_update requires: name"}
- return _create_or_update_app(
- name=name,
- source_code_path=source_code_path,
- description=description,
- mode=mode,
- )
-
- elif act == "get":
- if not name:
- return {"error": "get requires: name"}
- return _get_app_details(
- name=name,
- include_logs=include_logs,
- deployment_id=deployment_id,
- )
-
- elif act == "list":
- return {"apps": _list_apps(name_contains=name_contains)}
-
- elif act == "delete":
- if not name:
- return {"error": "delete requires: name"}
- return _delete_app_by_name(name=name)
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: create_or_update, get, list, delete"}
-
-
-# ============================================================================
-# Helper Functions
-# ============================================================================
-
-
-def _create_or_update_app(
- name: str,
- source_code_path: Optional[str],
- description: Optional[str],
- mode: Optional[str],
-) -> Dict[str, Any]:
- """Create app if not exists, optionally deploy."""
- existing = _find_app_by_name(name)
-
- if existing:
- result = {**existing, "created": False}
- else:
- app_result = _create_app(name=name, description=with_description_footer(description))
- result = {**app_result, "created": True}
-
- # Track resource on successful create
- try:
- if result.get("name"):
- from ..manifest import track_resource
-
- track_resource(
- resource_type="app",
- name=result["name"],
- resource_id=result["name"],
- )
- except Exception:
- pass # best-effort tracking
-
- # Deploy if source_code_path provided
- if source_code_path:
- try:
- deployment = _deploy_app(
- app_name=name,
- source_code_path=source_code_path,
- mode=mode,
- )
- result["deployment"] = deployment
- except Exception as e:
- logger.warning("Failed to deploy app '%s': %s", name, e)
- result["deployment_error"] = str(e)
-
- return result
-
-
-def _get_app_details(
- name: str,
- include_logs: bool,
- deployment_id: Optional[str],
-) -> Dict[str, Any]:
- """Get app details with optional logs."""
- result = _get_app(name=name)
-
- if include_logs:
- try:
- logs = _get_app_logs(
- app_name=name,
- deployment_id=deployment_id,
- )
- result["logs"] = logs.get("logs", "")
- result["logs_deployment_id"] = logs.get("deployment_id")
- except Exception as e:
- result["logs_error"] = str(e)
-
- return result
-
-
-def _delete_app_by_name(name: str) -> Dict[str, str]:
- """Delete a Databricks App."""
- result = _delete_app(name=name)
-
- # Remove from tracked resources
- try:
- from ..manifest import remove_resource
-
- remove_resource(resource_type="app", resource_id=name)
- except Exception:
- pass # best-effort tracking
-
- return result
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/compute.py b/databricks-mcp-server/databricks_mcp_server/tools/compute.py
deleted file mode 100644
index 712e01ce..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/compute.py
+++ /dev/null
@@ -1,421 +0,0 @@
-"""Compute tools - Execute code and manage compute resources on Databricks.
-
-Consolidated into 4 tools (down from 19) to reduce LLM parsing overhead:
-- execute_code: Run code on serverless or cluster compute
-- manage_cluster: Create, modify, start, terminate, or delete clusters
-- manage_sql_warehouse: Create, modify, or delete SQL warehouses
-- list_compute: List/inspect clusters, node types, and spark versions
-"""
-
-import json
-from typing import Dict, Any, List, Optional
-
-from databricks_tools_core.compute import (
- # Execution
- list_clusters as _list_clusters,
- get_best_cluster as _get_best_cluster,
- start_cluster as _start_cluster,
- get_cluster_status as _get_cluster_status,
- execute_databricks_command as _execute_databricks_command,
- run_file_on_databricks as _run_file_on_databricks,
- run_code_on_serverless as _run_code_on_serverless,
- NoRunningClusterError,
- # Cluster management
- create_cluster as _create_cluster,
- modify_cluster as _modify_cluster,
- terminate_cluster as _terminate_cluster,
- delete_cluster as _delete_cluster,
- list_node_types as _list_node_types,
- list_spark_versions as _list_spark_versions,
- # SQL warehouse management
- create_sql_warehouse as _create_sql_warehouse,
- modify_sql_warehouse as _modify_sql_warehouse,
- delete_sql_warehouse as _delete_sql_warehouse,
-)
-
-from ..server import mcp
-
-
-def _none_if_empty(value):
- """Convert empty strings to None (Claude agent sometimes passes '' instead of null)."""
- return None if value == "" else value
-
-
-# ---------------------------------------------------------------------------
-# Tool 1: execute_code
-# ---------------------------------------------------------------------------
-
-
-@mcp.tool
-def execute_code(
- code: str = None,
- file_path: str = None,
- compute_type: str = "auto",
- cluster_id: str = None,
- context_id: str = None,
- language: str = "python",
- timeout: int = None,
- destroy_context_on_completion: bool = False,
- workspace_path: str = None,
- run_name: str = None,
- job_extra_params: Dict[str, Any] = None,
-) -> Dict[str, Any]:
- """Execute code on Databricks via serverless or cluster compute.
-
- Modes:
- - auto (default): Serverless unless cluster_id/context_id given or language is scala/r
- - serverless: No cluster needed, ~30s cold start, best for batch/one-off tasks
- - cluster: State persists via context_id, best for interactive work (but slow ~2min one-off cluster startup)
-
- - Cluster mode returns context_id. REUSE IT for subsequent calls to skip context creation (Variables/imports persist across calls).
- - Serverless has no context reuse (~30s cold start each time).
-
- file_path: Run local file (.py/.scala/.sql/.r), auto-detects language.
- workspace_path: Save as notebook in workspace (omit for ephemeral).
- .ipynb: Pass raw JSON with serverless, auto-detected.
- job_extra_params: Extra job params (serverless only). For dependencies:
- {"environments": [{"environment_key": "env", "spec": {"client": "4", "dependencies": ["pandas", "sklearn"]}}]}
-
- Timeouts: serverless=1800s, cluster=120s, file=600s.
- Returns: {success, output, error, cluster_id, context_id} or {run_id, run_url}."""
- # Normalize empty strings to None
- code = _none_if_empty(code)
- file_path = _none_if_empty(file_path)
- cluster_id = _none_if_empty(cluster_id)
- context_id = _none_if_empty(context_id)
- language = _none_if_empty(language) or "python"
- workspace_path = _none_if_empty(workspace_path)
- run_name = _none_if_empty(run_name)
-
- if not code and not file_path:
- return {"success": False, "error": "Either 'code' or 'file_path' must be provided."}
-
- # Resolve "auto" compute type
- if compute_type == "auto":
- if cluster_id or context_id:
- compute_type = "cluster"
- elif file_path and language and language.lower() in ("scala", "r"):
- compute_type = "cluster"
- elif language and language.lower() in ("scala", "r"):
- compute_type = "cluster"
- else:
- compute_type = "serverless"
-
- # --- File-based execution ---
- if file_path:
- if compute_type == "serverless":
- # Read file and run on serverless
- try:
- with open(file_path, "r", encoding="utf-8") as f:
- code = f.read()
- except FileNotFoundError:
- return {"success": False, "error": f"File not found: {file_path}"}
- except Exception as e:
- return {"success": False, "error": f"Failed to read file: {e}"}
- # Fall through to serverless execution below
- else:
- # Run file on cluster
- default_timeout = timeout if timeout is not None else 600
- try:
- result = _run_file_on_databricks(
- file_path=file_path,
- cluster_id=cluster_id,
- context_id=context_id,
- language=language if language != "python" else None, # let it auto-detect
- timeout=default_timeout,
- destroy_context_on_completion=destroy_context_on_completion,
- workspace_path=workspace_path,
- )
- return result.to_dict()
- except NoRunningClusterError as e:
- return _no_cluster_error_response(e)
-
- # --- Serverless execution ---
- if compute_type == "serverless":
- default_timeout = timeout if timeout is not None else 1800
- result = _run_code_on_serverless(
- code=code,
- language=language,
- timeout=default_timeout,
- run_name=run_name,
- cleanup=workspace_path is None,
- workspace_path=workspace_path,
- job_extra_params=job_extra_params,
- )
- return result.to_dict()
-
- # --- Cluster execution ---
- default_timeout = timeout if timeout is not None else 120
- try:
- result = _execute_databricks_command(
- code=code,
- cluster_id=cluster_id,
- context_id=context_id,
- language=language,
- timeout=default_timeout,
- destroy_context_on_completion=destroy_context_on_completion,
- )
- return result.to_dict()
- except NoRunningClusterError as e:
- return _no_cluster_error_response(e)
-
-
-def _no_cluster_error_response(e: NoRunningClusterError) -> Dict[str, Any]:
- """Build a structured error response when no running cluster is available."""
- return {
- "success": False,
- "output": None,
- "error": str(e),
- "cluster_id": None,
- "context_id": None,
- "context_destroyed": True,
- "message": None,
- "suggestions": e.suggestions,
- "startable_clusters": e.startable_clusters,
- "skipped_clusters": e.skipped_clusters,
- "available_clusters": e.available_clusters,
- }
-
-
-# ---------------------------------------------------------------------------
-# Tool 2: manage_cluster
-# ---------------------------------------------------------------------------
-
-
-@mcp.tool
-def manage_cluster(
- action: str,
- cluster_id: str = None,
- name: str = None,
- num_workers: int = None,
- spark_version: str = None,
- node_type_id: str = None,
- autotermination_minutes: int = None,
- data_security_mode: str = None,
- spark_conf: str = None,
- autoscale_min_workers: int = None,
- autoscale_max_workers: int = None,
-) -> Dict[str, Any]:
- """Create, modify, start, terminate, or delete a cluster.
-
- Actions:
- - create: Requires name. Auto-picks DBR, node type, SINGLE_USER, 120min auto-stop.
- - modify: Requires cluster_id. Only specified params change. Running clusters restart.
- - start: Requires cluster_id. ASK USER FIRST (costs money, 3-8min startup).
- - terminate: Reversible stop. Requires cluster_id.
- - get: returns cluster details. Requires cluster_id.
- - delete: PERMANENT. CONFIRM WITH USER. Requires cluster_id.
-
- num_workers default 1, ignored if autoscale set. spark_conf: JSON string.
- Returns: {cluster_id, cluster_name, state, message}."""
- action = action.lower().strip()
-
- # Normalize empty strings
- cluster_id = _none_if_empty(cluster_id)
- name = _none_if_empty(name)
- spark_version = _none_if_empty(spark_version)
- node_type_id = _none_if_empty(node_type_id)
- data_security_mode = _none_if_empty(data_security_mode)
-
- if action == "create":
- if not name:
- return {"success": False, "error": "name is required for create action."}
-
- # Parse spark_conf JSON
- parsed_spark_conf = None
- if spark_conf and spark_conf.strip():
- parsed_spark_conf = json.loads(spark_conf)
-
- kwargs = {}
- if spark_version:
- kwargs["spark_version"] = spark_version
- if node_type_id:
- kwargs["node_type_id"] = node_type_id
- if data_security_mode:
- kwargs["data_security_mode"] = data_security_mode
- if parsed_spark_conf:
- kwargs["spark_conf"] = parsed_spark_conf
- if autoscale_min_workers is not None:
- kwargs["autoscale_min_workers"] = autoscale_min_workers
- if autoscale_max_workers is not None:
- kwargs["autoscale_max_workers"] = autoscale_max_workers
-
- return _create_cluster(
- name=name,
- num_workers=num_workers if num_workers is not None else 1,
- autotermination_minutes=autotermination_minutes if autotermination_minutes is not None else 120,
- **kwargs,
- )
-
- elif action == "modify":
- if not cluster_id:
- return {"success": False, "error": "cluster_id is required for modify action."}
-
- kwargs = {}
- if name:
- kwargs["name"] = name
- if num_workers is not None:
- kwargs["num_workers"] = num_workers
- if spark_version:
- kwargs["spark_version"] = spark_version
- if node_type_id:
- kwargs["node_type_id"] = node_type_id
- if autotermination_minutes is not None:
- kwargs["autotermination_minutes"] = autotermination_minutes
- if autoscale_min_workers is not None:
- kwargs["autoscale_min_workers"] = autoscale_min_workers
- if autoscale_max_workers is not None:
- kwargs["autoscale_max_workers"] = autoscale_max_workers
- if spark_conf and spark_conf.strip():
- kwargs["spark_conf"] = json.loads(spark_conf)
-
- return _modify_cluster(cluster_id=cluster_id, **kwargs)
-
- elif action == "start":
- if not cluster_id:
- return {"success": False, "error": "cluster_id is required for start action."}
- return _start_cluster(cluster_id)
-
- elif action == "terminate":
- if not cluster_id:
- return {"success": False, "error": "cluster_id is required for terminate action."}
- return _terminate_cluster(cluster_id)
-
- elif action == "delete":
- if not cluster_id:
- return {"success": False, "error": "cluster_id is required for delete action."}
- return _delete_cluster(cluster_id)
-
- elif action == "get":
- if not cluster_id:
- return {"success": False, "error": "cluster_id is required for get action."}
- try:
- return _get_cluster_status(cluster_id)
- except Exception as e:
- # Handle case where cluster doesn't exist (e.g., after deletion)
- if "does not exist" in str(e).lower():
- return {"success": True, "cluster_id": cluster_id, "state": "DELETED", "exists": False}
- return {"success": False, "error": str(e)}
-
- else:
- return {
- "success": False,
- "error": f"Unknown action: {action!r}. Must be one of: create, modify, start, terminate, delete, get.",
- }
-
-
-# ---------------------------------------------------------------------------
-# Tool 3: manage_sql_warehouse
-# ---------------------------------------------------------------------------
-
-
-@mcp.tool
-def manage_sql_warehouse(
- action: str,
- warehouse_id: str = None,
- name: str = None,
- size: str = None,
- min_num_clusters: int = None,
- max_num_clusters: int = None,
- auto_stop_mins: int = None,
- warehouse_type: str = None,
- enable_serverless: bool = None,
-) -> Dict[str, Any]:
- """Create, modify, or delete a SQL warehouse.
-
- Actions:
- - create: Requires name. Defaults: serverless PRO, Small, 120min auto-stop.
- - modify: Requires warehouse_id. Only specified params change.
- - delete: PERMANENT. CONFIRM WITH USER. Requires warehouse_id.
-
- size: "2X-Small" to "4X-Large". Use list_warehouses to list existing.
- Returns: {warehouse_id, name, state, message}."""
- action = action.lower().strip()
-
- warehouse_id = _none_if_empty(warehouse_id)
- name = _none_if_empty(name)
- size = _none_if_empty(size)
- warehouse_type = _none_if_empty(warehouse_type)
-
- if action == "create":
- if not name:
- return {"success": False, "error": "name is required for create action."}
-
- return _create_sql_warehouse(
- name=name,
- size=size or "Small",
- min_num_clusters=min_num_clusters if min_num_clusters is not None else 1,
- max_num_clusters=max_num_clusters if max_num_clusters is not None else 1,
- auto_stop_mins=auto_stop_mins if auto_stop_mins is not None else 120,
- warehouse_type=warehouse_type or "PRO",
- enable_serverless=enable_serverless if enable_serverless is not None else True,
- )
-
- elif action == "modify":
- if not warehouse_id:
- return {"success": False, "error": "warehouse_id is required for modify action."}
-
- kwargs = {}
- if name:
- kwargs["name"] = name
- if size:
- kwargs["size"] = size
- if min_num_clusters is not None:
- kwargs["min_num_clusters"] = min_num_clusters
- if max_num_clusters is not None:
- kwargs["max_num_clusters"] = max_num_clusters
- if auto_stop_mins is not None:
- kwargs["auto_stop_mins"] = auto_stop_mins
-
- return _modify_sql_warehouse(warehouse_id=warehouse_id, **kwargs)
-
- elif action == "delete":
- if not warehouse_id:
- return {"success": False, "error": "warehouse_id is required for delete action."}
- return _delete_sql_warehouse(warehouse_id)
-
- else:
- return {
- "success": False,
- "error": f"Unknown action: {action!r}. Must be one of: create, modify, delete.",
- }
-
-
-# ---------------------------------------------------------------------------
-# Tool 4: list_compute
-# ---------------------------------------------------------------------------
-
-
-@mcp.tool
-def list_compute(
- resource: str = "clusters",
- cluster_id: str = None,
- auto_select: bool = False,
-) -> Dict[str, Any]:
- """List compute resources: clusters, node types, or spark versions.
-
- resource: "clusters" (default), "node_types", or "spark_versions".
- cluster_id: Get specific cluster status (use to poll after starting).
- auto_select: Return best running cluster (prefers "shared" > "demo" in name)."""
- resource = resource.lower().strip()
- cluster_id = _none_if_empty(cluster_id)
-
- if resource == "clusters":
- if cluster_id:
- return _get_cluster_status(cluster_id)
- if auto_select:
- best = _get_best_cluster()
- return {"cluster_id": best}
- return {"clusters": _list_clusters()}
-
- elif resource == "node_types":
- return {"node_types": _list_node_types()}
-
- elif resource == "spark_versions":
- return {"spark_versions": _list_spark_versions()}
-
- else:
- return {
- "success": False,
- "error": f"Unknown resource: {resource!r}. Must be one of: clusters, node_types, spark_versions.",
- }
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/file.py b/databricks-mcp-server/databricks_mcp_server/tools/file.py
deleted file mode 100644
index 6c7f6130..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/file.py
+++ /dev/null
@@ -1,77 +0,0 @@
-"""File tools - Upload and delete files and folders in Databricks workspace.
-
-Consolidated into 1 tool:
-- manage_workspace_files: upload, delete
-"""
-
-from typing import Any, Dict, Optional
-
-from databricks_tools_core.file import (
- delete_from_workspace as _delete_from_workspace,
- upload_to_workspace as _upload_to_workspace,
-)
-
-from ..server import mcp
-
-
-@mcp.tool(timeout=120)
-def manage_workspace_files(
- action: str,
- workspace_path: str,
- # For upload:
- local_path: Optional[str] = None,
- max_workers: int = 10,
- overwrite: bool = True,
- # For delete:
- recursive: bool = False,
-) -> Dict[str, Any]:
- """Manage workspace files: upload, delete.
-
- Actions:
- - upload: Upload files/folders to workspace. Requires local_path, workspace_path.
- Supports files, folders, globs, tilde expansion.
- max_workers: Parallel upload threads (default 10). overwrite: Replace existing (default True).
- Returns: {local_folder, remote_folder, total_files, successful, failed, success, failed_uploads}.
- - delete: Delete file/folder from workspace. Requires workspace_path.
- recursive=True for non-empty folders. Has safety checks for protected paths.
- Returns: {workspace_path, success, error}.
-
- workspace_path format: /Workspace/Users/user@example.com/path/to/files"""
- act = action.lower()
-
- if act == "upload":
- if not local_path:
- return {"error": "upload requires: local_path"}
- result = _upload_to_workspace(
- local_path=local_path,
- workspace_path=workspace_path,
- max_workers=max_workers,
- overwrite=overwrite,
- )
- return {
- "local_folder": result.local_folder,
- "remote_folder": result.remote_folder,
- "total_files": result.total_files,
- "successful": result.successful,
- "failed": result.failed,
- "success": result.success,
- "failed_uploads": [
- {"local_path": r.local_path, "error": r.error} for r in result.get_failed_uploads()
- ]
- if result.failed > 0
- else [],
- }
-
- elif act == "delete":
- result = _delete_from_workspace(
- workspace_path=workspace_path,
- recursive=recursive,
- )
- return {
- "workspace_path": result.workspace_path,
- "success": result.success,
- "error": result.error,
- }
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: upload, delete"}
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/genie.py b/databricks-mcp-server/databricks_mcp_server/tools/genie.py
deleted file mode 100644
index 40852173..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/genie.py
+++ /dev/null
@@ -1,546 +0,0 @@
-"""Genie tools - Create, manage, and query Databricks Genie Spaces.
-
-Consolidated into 2 tools:
-- manage_genie: create_or_update, get, list, delete, export, import
-- ask_genie: query (hot path - kept separate)
-"""
-
-from datetime import timedelta
-from typing import Any, Dict, List, Optional
-
-from databricks_tools_core.agent_bricks import AgentBricksManager
-from databricks_tools_core.auth import get_workspace_client
-from databricks_tools_core.identity import with_description_footer
-
-from ..manifest import register_deleter
-from ..server import mcp
-
-# Singleton manager instance for space management operations
-_manager: Optional[AgentBricksManager] = None
-
-
-def _get_manager() -> AgentBricksManager:
- """Get or create the singleton AgentBricksManager instance."""
- global _manager
- if _manager is None:
- _manager = AgentBricksManager()
- return _manager
-
-
-def _delete_genie_resource(resource_id: str) -> None:
- """Delete a genie space using SDK."""
- w = get_workspace_client()
- w.genie.trash_space(space_id=resource_id)
-
-
-register_deleter("genie_space", _delete_genie_resource)
-
-
-def _find_space_by_name(name: str) -> Optional[Any]:
- """Find a Genie Space by name using SDK's list_spaces.
-
- Returns the GenieSpaceInfo if found, None otherwise.
- """
- w = get_workspace_client()
- page_token = None
- while True:
- response = w.genie.list_spaces(page_size=200, page_token=page_token)
- if response.spaces:
- for space in response.spaces:
- if space.title == name:
- return space
- if response.next_page_token:
- page_token = response.next_page_token
- else:
- break
- return None
-
-
-# ============================================================================
-# Tool 1: manage_genie
-# ============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_genie(
- action: str,
- # For create_or_update:
- display_name: Optional[str] = None,
- table_identifiers: Optional[List[str]] = None,
- warehouse_id: Optional[str] = None,
- description: Optional[str] = None,
- sample_questions: Optional[List[str]] = None,
- serialized_space: Optional[str] = None,
- # For get/delete/export:
- space_id: Optional[str] = None,
- # For get:
- include_serialized_space: bool = False,
- # For import:
- title: Optional[str] = None,
- parent_path: Optional[str] = None,
-) -> Dict[str, Any]:
- """Manage Genie Spaces: create, update, get, list, delete, export, import.
-
- Actions:
- - create_or_update: Idempotent by name. Requires display_name, table_identifiers.
- warehouse_id auto-detected if omitted. description: Explains space purpose.
- sample_questions: Example questions shown to users.
- serialized_space: Full config from export (preserves instructions/SQL examples).
- Returns: {space_id, display_name, operation: created|updated, warehouse_id, table_count}.
- - get: Get space details. Requires space_id.
- include_serialized_space=True for full config export.
- Returns: {space_id, display_name, description, warehouse_id, table_identifiers, sample_questions}.
- - list: List all spaces.
- Returns: {spaces: [{space_id, title, description}, ...]}.
- - delete: Delete a space. Requires space_id.
- Returns: {success, space_id}.
- - export: Export space config for migration/backup. Requires space_id.
- Returns: {space_id, title, description, warehouse_id, serialized_space}.
- - import: Import space from serialized_space. Requires warehouse_id, serialized_space.
- Optional title, description, parent_path overrides.
- Returns: {space_id, title, description, operation: imported}.
-
- See databricks-genie skill for configuration details."""
- act = action.lower()
-
- if act == "create_or_update":
- # For updates with space_id, display_name is optional
- if not space_id and not display_name:
- return {"error": "create_or_update requires: display_name (or space_id for updates)"}
- if not space_id and not table_identifiers and not serialized_space:
- return {"error": "create_or_update requires: table_identifiers (or serialized_space)"}
-
- return _create_or_update_genie_space(
- display_name=display_name,
- table_identifiers=table_identifiers or [],
- warehouse_id=warehouse_id,
- description=description,
- sample_questions=sample_questions,
- space_id=space_id,
- serialized_space=serialized_space,
- )
-
- elif act == "get":
- if not space_id:
- return {"error": "get requires: space_id"}
- return _get_genie_space(space_id=space_id, include_serialized_space=include_serialized_space)
-
- elif act == "list":
- return _list_genie_spaces()
-
- elif act == "delete":
- if not space_id:
- return {"error": "delete requires: space_id"}
- return _delete_genie_space(space_id=space_id)
-
- elif act == "export":
- if not space_id:
- return {"error": "export requires: space_id"}
- return _export_genie_space(space_id=space_id)
-
- elif act == "import":
- if not warehouse_id or not serialized_space:
- return {"error": "import requires: warehouse_id, serialized_space"}
- return _import_genie_space(
- warehouse_id=warehouse_id,
- serialized_space=serialized_space,
- title=title,
- description=description,
- parent_path=parent_path,
- )
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: create_or_update, get, list, delete, export, import"}
-
-
-# ============================================================================
-# Tool 2: ask_genie (HOT PATH - kept separate for performance)
-# ============================================================================
-
-
-@mcp.tool(timeout=120)
-def ask_genie(
- space_id: str,
- question: str,
- conversation_id: Optional[str] = None,
- timeout_seconds: int = 120,
-) -> Dict[str, Any]:
- """Ask natural language question to Genie Space. Pass conversation_id for follow-ups.
-
- Returns: {question, conversation_id, message_id, status, sql, description, columns, data, row_count, text_response, error}."""
- try:
- w = get_workspace_client()
-
- if conversation_id:
- result = w.genie.create_message_and_wait(
- space_id=space_id,
- conversation_id=conversation_id,
- content=question,
- timeout=timedelta(seconds=timeout_seconds),
- )
- else:
- result = w.genie.start_conversation_and_wait(
- space_id=space_id,
- content=question,
- timeout=timedelta(seconds=timeout_seconds),
- )
-
- return _format_genie_response(question, result, space_id, w)
- except TimeoutError:
- return {
- "question": question,
- "conversation_id": conversation_id,
- "status": "TIMEOUT",
- "error": f"Genie response timed out after {timeout_seconds}s",
- }
- except Exception as e:
- return {
- "question": question,
- "conversation_id": conversation_id,
- "status": "ERROR",
- "error": str(e),
- }
-
-
-# ============================================================================
-# Helper Functions
-# ============================================================================
-
-
-def _create_or_update_genie_space(
- display_name: str,
- table_identifiers: List[str],
- warehouse_id: Optional[str],
- description: Optional[str],
- sample_questions: Optional[List[str]],
- space_id: Optional[str],
- serialized_space: Optional[str],
-) -> Dict[str, Any]:
- """Create or update a Genie Space."""
- try:
- description = with_description_footer(description)
- manager = _get_manager()
-
- # Auto-detect warehouse if not provided
- if warehouse_id is None:
- warehouse_id = manager.get_best_warehouse_id()
- if warehouse_id is None:
- return {"error": "No SQL warehouses available. Please provide a warehouse_id or create a warehouse."}
-
- operation = "created"
-
- # When serialized_space is provided
- if serialized_space:
- w = get_workspace_client()
- if space_id:
- # Update existing space with serialized config using SDK
- w.genie.update_space(
- space_id=space_id,
- serialized_space=serialized_space,
- title=display_name,
- description=description,
- warehouse_id=warehouse_id,
- )
- operation = "updated"
- else:
- # Check if exists by name, then create or update
- existing = _find_space_by_name(display_name)
- if existing:
- operation = "updated"
- space_id = existing.space_id
- # Update existing space with serialized config using SDK
- w.genie.update_space(
- space_id=space_id,
- serialized_space=serialized_space,
- title=display_name,
- description=description,
- warehouse_id=warehouse_id,
- )
- else:
- # Create new space with serialized config using SDK
- w = get_workspace_client()
- space = w.genie.create_space(
- warehouse_id=warehouse_id,
- serialized_space=serialized_space,
- title=display_name,
- description=description,
- )
- space_id = space.space_id or ""
-
- # When serialized_space is not provided
- else:
- if space_id:
- # Update existing space by ID using SDK for proper partial updates
- w = get_workspace_client()
- try:
- # Use SDK's update_space which supports partial updates
- w.genie.update_space(
- space_id=space_id,
- description=description,
- title=display_name,
- warehouse_id=warehouse_id,
- )
- operation = "updated"
- # Handle sample questions separately if provided
- if sample_questions is not None:
- manager.genie_update_sample_questions(space_id, sample_questions)
- # Handle table_identifiers if provided (requires full update via manager)
- if table_identifiers:
- manager.genie_update(
- space_id=space_id,
- display_name=display_name,
- description=description,
- warehouse_id=warehouse_id,
- table_identifiers=table_identifiers,
- )
- except Exception as e:
- return {"error": f"Genie space {space_id} not found or update failed: {e}"}
- else:
- # Check if exists by name first using SDK
- existing = _find_space_by_name(display_name)
- if existing:
- operation = "updated"
- manager.genie_update(
- space_id=existing.space_id,
- display_name=display_name,
- description=description,
- warehouse_id=warehouse_id,
- table_identifiers=table_identifiers,
- sample_questions=sample_questions,
- )
- space_id = existing.space_id
- else:
- # Create new
- result = manager.genie_create(
- display_name=display_name,
- warehouse_id=warehouse_id,
- table_identifiers=table_identifiers,
- description=description,
- )
- space_id = result.get("space_id", "")
-
- # Add sample questions if provided
- if sample_questions and space_id:
- manager.genie_add_sample_questions_batch(space_id, sample_questions)
-
- response = {
- "space_id": space_id,
- "display_name": display_name,
- "operation": operation,
- "warehouse_id": warehouse_id,
- "table_count": len(table_identifiers),
- }
-
- try:
- if space_id:
- from ..manifest import track_resource
-
- track_resource(
- resource_type="genie_space",
- name=display_name,
- resource_id=space_id,
- )
- except Exception:
- pass
-
- return response
-
- except Exception as e:
- return {"error": f"Failed to create/update Genie space '{display_name}': {e}"}
-
-
-def _get_genie_space(space_id: str, include_serialized_space: bool) -> Dict[str, Any]:
- """Get a Genie Space by ID using SDK."""
- try:
- w = get_workspace_client()
- # Use SDK's include_serialized_space parameter if needed
- space = w.genie.get_space(space_id=space_id, include_serialized_space=include_serialized_space)
-
- if not space:
- return {"error": f"Genie space {space_id} not found"}
-
- # Get sample questions using manager (SDK doesn't have this method)
- manager = _get_manager()
- questions_response = manager.genie_list_questions(space_id, question_type="SAMPLE_QUESTION")
- sample_questions = [q.get("question_text", "") for q in questions_response.get("curated_questions", [])]
-
- # Extract table identifiers from serialized_space if available
- # The SDK's GenieSpace doesn't expose tables as a direct attribute
- table_identifiers = []
- if space.serialized_space:
- try:
- import json
- serialized = json.loads(space.serialized_space)
- for table in serialized.get("tables", []):
- if table.get("table_identifier"):
- table_identifiers.append(table["table_identifier"])
- except (json.JSONDecodeError, KeyError):
- pass # Tables will remain empty
-
- response = {
- "space_id": space.space_id or space_id,
- "display_name": space.title or "",
- "description": space.description or "",
- "warehouse_id": space.warehouse_id or "",
- "table_identifiers": table_identifiers,
- "sample_questions": sample_questions,
- }
-
- if include_serialized_space:
- response["serialized_space"] = space.serialized_space or ""
-
- return response
-
- except Exception as e:
- return {"error": f"Failed to get Genie space {space_id}: {e}"}
-
-
-def _list_genie_spaces() -> Dict[str, Any]:
- """List all Genie Spaces with pagination."""
- try:
- w = get_workspace_client()
- spaces = []
- page_token = None
-
- while True:
- response = w.genie.list_spaces(page_size=200, page_token=page_token)
- if response.spaces:
- for space in response.spaces:
- spaces.append(
- {
- "space_id": space.space_id,
- "title": space.title or "",
- "description": space.description or "",
- }
- )
- # Check for next page
- if response.next_page_token:
- page_token = response.next_page_token
- else:
- break
-
- return {"spaces": spaces}
- except Exception as e:
- return {"error": str(e)}
-
-
-def _delete_genie_space(space_id: str) -> Dict[str, Any]:
- """Delete a Genie Space using SDK."""
- try:
- w = get_workspace_client()
- w.genie.trash_space(space_id=space_id)
- try:
- from ..manifest import remove_resource
-
- remove_resource(resource_type="genie_space", resource_id=space_id)
- except Exception:
- pass
- return {"success": True, "space_id": space_id}
- except Exception as e:
- return {"success": False, "space_id": space_id, "error": str(e)}
-
-
-def _export_genie_space(space_id: str) -> Dict[str, Any]:
- """Export a Genie Space for migration/backup using SDK."""
- try:
- w = get_workspace_client()
- space = w.genie.get_space(space_id=space_id, include_serialized_space=True)
- return {
- "space_id": space.space_id or space_id,
- "title": space.title or "",
- "description": space.description or "",
- "warehouse_id": space.warehouse_id or "",
- "serialized_space": space.serialized_space or "",
- }
- except Exception as e:
- return {"error": str(e), "space_id": space_id}
-
-
-def _import_genie_space(
- warehouse_id: str,
- serialized_space: str,
- title: Optional[str],
- description: Optional[str],
- parent_path: Optional[str],
-) -> Dict[str, Any]:
- """Import a Genie Space from serialized config using SDK."""
- try:
- w = get_workspace_client()
- space = w.genie.create_space(
- warehouse_id=warehouse_id,
- serialized_space=serialized_space,
- title=title,
- description=description,
- parent_path=parent_path,
- )
- imported_space_id = space.space_id or ""
-
- if imported_space_id:
- try:
- from ..manifest import track_resource
-
- track_resource(
- resource_type="genie_space",
- name=title or space.title or imported_space_id,
- resource_id=imported_space_id,
- )
- except Exception:
- pass
-
- return {
- "space_id": imported_space_id,
- "title": space.title or title or "",
- "description": space.description or description or "",
- "operation": "imported",
- }
- except Exception as e:
- return {"error": str(e)}
-
-
-def _format_genie_response(question: str, genie_message: Any, space_id: str, w: Any) -> Dict[str, Any]:
- """Format a Genie SDK response into a clean dictionary."""
- result = {
- "question": question,
- "conversation_id": genie_message.conversation_id,
- "message_id": genie_message.id,
- "status": str(genie_message.status.value) if genie_message.status else "UNKNOWN",
- }
-
- # Extract data from attachments
- if genie_message.attachments:
- for attachment in genie_message.attachments:
- # Query attachment (SQL and results)
- if attachment.query:
- result["sql"] = attachment.query.query or ""
- result["description"] = attachment.query.description or ""
-
- # Get row count from metadata
- if attachment.query.query_result_metadata:
- result["row_count"] = attachment.query.query_result_metadata.row_count
-
- # Fetch actual data (columns and rows)
- if attachment.attachment_id:
- try:
- data_result = w.genie.get_message_query_result_by_attachment(
- space_id=space_id,
- conversation_id=genie_message.conversation_id,
- message_id=genie_message.id,
- attachment_id=attachment.attachment_id,
- )
- if data_result.statement_response:
- sr = data_result.statement_response
- # Get columns
- if sr.manifest and sr.manifest.schema and sr.manifest.schema.columns:
- result["columns"] = [c.name for c in sr.manifest.schema.columns]
- # Get data
- if sr.result and sr.result.data_array:
- result["data"] = sr.result.data_array
- except Exception:
- # If data fetch fails, continue without it
- pass
-
- # Text attachment (explanation)
- if attachment.text:
- result["text_response"] = attachment.text.content or ""
-
- return result
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/jobs.py b/databricks-mcp-server/databricks_mcp_server/tools/jobs.py
deleted file mode 100644
index f1bce352..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/jobs.py
+++ /dev/null
@@ -1,255 +0,0 @@
-"""
-Jobs MCP Tools
-
-Consolidated MCP tools for Databricks Jobs operations.
-2 tools covering: job CRUD and job run management.
-"""
-
-from typing import Any, Dict, List
-
-from databricks_tools_core.identity import get_default_tags
-from databricks_tools_core.jobs import (
- list_jobs as _list_jobs,
- get_job as _get_job,
- find_job_by_name as _find_job_by_name,
- create_job as _create_job,
- update_job as _update_job,
- delete_job as _delete_job,
- run_job_now as _run_job_now,
- get_run as _get_run,
- get_run_output as _get_run_output,
- cancel_run as _cancel_run,
- list_runs as _list_runs,
- wait_for_run as _wait_for_run,
-)
-
-from ..manifest import register_deleter
-from ..server import mcp
-
-
-def _delete_job_resource(resource_id: str) -> None:
- _delete_job(job_id=int(resource_id))
-
-
-register_deleter("job", _delete_job_resource)
-
-
-# =============================================================================
-# Tool 1: manage_jobs
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_jobs(
- action: str,
- job_id: int = None,
- name: str = None,
- tasks: List[Dict[str, Any]] = None,
- job_clusters: List[Dict[str, Any]] = None,
- environments: List[Dict[str, Any]] = None,
- tags: Dict[str, str] = None,
- timeout_seconds: int = None,
- max_concurrent_runs: int = None,
- email_notifications: Dict[str, Any] = None,
- webhook_notifications: Dict[str, Any] = None,
- notification_settings: Dict[str, Any] = None,
- schedule: Dict[str, Any] = None,
- queue: Dict[str, Any] = None,
- run_as: Dict[str, Any] = None,
- git_source: Dict[str, Any] = None,
- parameters: List[Dict[str, Any]] = None,
- health: Dict[str, Any] = None,
- deployment: Dict[str, Any] = None,
- limit: int = 25,
- expand_tasks: bool = False,
-) -> Dict[str, Any]:
- """Manage Databricks jobs: create, get, list, find_by_name, update, delete.
-
- create: requires name+tasks, serverless default, idempotent (returns existing if same name).
- get/update/delete: require job_id. find_by_name: returns job_id.
- tasks: [{task_key, notebook_task|spark_python_task|..., job_cluster_key or environment_key}].
- job_clusters: Shared cluster definitions tasks can reference. environments: Serverless env configs.
- schedule: {quartz_cron_expression, timezone_id}. git_source: {git_url, git_provider, git_branch}.
- See databricks-jobs skill for task configuration details.
- Returns: create={job_id}, get=full config, list={items}, find_by_name={job_id}, update/delete={status, job_id}."""
- act = action.lower()
-
- if act == "create":
- # Idempotency guard: check if a job with this name already exists.
- # Prevents duplicate creation when agents retry after MCP timeouts.
- existing_job_id = _find_job_by_name(name=name)
- if existing_job_id is not None:
- return {
- "job_id": existing_job_id,
- "already_exists": True,
- "message": (
- f"Job '{name}' already exists with job_id={existing_job_id}. "
- "Returning existing job instead of creating a duplicate. "
- "Use manage_jobs(action='update') to modify it, or "
- "manage_jobs(action='delete') first to recreate."
- ),
- }
-
- # Auto-inject default tags; user-provided tags take precedence
- merged_tags = {**get_default_tags(), **(tags or {})}
- result = _create_job(
- name=name,
- tasks=tasks,
- job_clusters=job_clusters,
- environments=environments,
- tags=merged_tags,
- timeout_seconds=timeout_seconds,
- max_concurrent_runs=max_concurrent_runs or 1,
- email_notifications=email_notifications,
- webhook_notifications=webhook_notifications,
- notification_settings=notification_settings,
- schedule=schedule,
- queue=queue,
- run_as=run_as,
- git_source=git_source,
- parameters=parameters,
- health=health,
- deployment=deployment,
- )
-
- # Track resource on successful create
- try:
- job_id_val = result.get("job_id") if isinstance(result, dict) else None
- if job_id_val:
- from ..manifest import track_resource
-
- track_resource(
- resource_type="job",
- name=name,
- resource_id=str(job_id_val),
- )
- except Exception:
- pass # best-effort tracking
-
- return result
-
- elif act == "get":
- return _get_job(job_id=job_id)
-
- elif act == "list":
- return {"items": _list_jobs(name=name, limit=limit, expand_tasks=expand_tasks)}
-
- elif act == "find_by_name":
- return {"job_id": _find_job_by_name(name=name)}
-
- elif act == "update":
- _update_job(
- job_id=job_id,
- name=name,
- tasks=tasks,
- job_clusters=job_clusters,
- environments=environments,
- tags=tags,
- timeout_seconds=timeout_seconds,
- max_concurrent_runs=max_concurrent_runs,
- email_notifications=email_notifications,
- webhook_notifications=webhook_notifications,
- notification_settings=notification_settings,
- schedule=schedule,
- queue=queue,
- run_as=run_as,
- git_source=git_source,
- parameters=parameters,
- health=health,
- deployment=deployment,
- )
- return {"status": "updated", "job_id": job_id}
-
- elif act == "delete":
- _delete_job(job_id=job_id)
- try:
- from ..manifest import remove_resource
-
- remove_resource(resource_type="job", resource_id=str(job_id))
- except Exception:
- pass
- return {"status": "deleted", "job_id": job_id}
-
- raise ValueError(f"Invalid action: '{action}'. Valid: create, get, list, find_by_name, update, delete")
-
-
-# =============================================================================
-# Tool 2: manage_job_runs
-# =============================================================================
-
-
-@mcp.tool(timeout=300)
-def manage_job_runs(
- action: str,
- job_id: int = None,
- run_id: int = None,
- idempotency_token: str = None,
- jar_params: List[str] = None,
- notebook_params: Dict[str, str] = None,
- python_params: List[str] = None,
- spark_submit_params: List[str] = None,
- python_named_params: Dict[str, str] = None,
- pipeline_params: Dict[str, Any] = None,
- sql_params: Dict[str, str] = None,
- dbt_commands: List[str] = None,
- queue: Dict[str, Any] = None,
- active_only: bool = False,
- completed_only: bool = False,
- limit: int = 25,
- offset: int = 0,
- start_time_from: int = None,
- start_time_to: int = None,
- timeout: int = 3600,
- poll_interval: int = 10,
-) -> Dict[str, Any]:
- """Manage job runs: run_now, get, get_output, cancel, list, wait.
-
- run_now: requires job_id, returns {run_id}. get/get_output/cancel/wait: require run_id.
- list: filter by job_id/active_only/completed_only. wait: blocks until complete (timeout default 3600s).
- Returns: run_now={run_id}, get=run details, get_output=logs+results, cancel={status}, list={items}, wait=full result."""
- act = action.lower()
-
- if act == "run_now":
- run_id_result = _run_job_now(
- job_id=job_id,
- idempotency_token=idempotency_token,
- jar_params=jar_params,
- notebook_params=notebook_params,
- python_params=python_params,
- spark_submit_params=spark_submit_params,
- python_named_params=python_named_params,
- pipeline_params=pipeline_params,
- sql_params=sql_params,
- dbt_commands=dbt_commands,
- queue=queue,
- )
- return {"run_id": run_id_result}
-
- elif act == "get":
- return _get_run(run_id=run_id)
-
- elif act == "get_output":
- return _get_run_output(run_id=run_id)
-
- elif act == "cancel":
- _cancel_run(run_id=run_id)
- return {"status": "cancelled", "run_id": run_id}
-
- elif act == "list":
- return {
- "items": _list_runs(
- job_id=job_id,
- active_only=active_only,
- completed_only=completed_only,
- limit=limit,
- offset=offset,
- start_time_from=start_time_from,
- start_time_to=start_time_to,
- )
- }
-
- elif act == "wait":
- result = _wait_for_run(run_id=run_id, timeout=timeout, poll_interval=poll_interval)
- return result.to_dict()
-
- raise ValueError(f"Invalid action: '{action}'. Valid: run_now, get, get_output, cancel, list, wait")
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py b/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py
deleted file mode 100644
index c82667bd..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py
+++ /dev/null
@@ -1,532 +0,0 @@
-"""Lakebase tools - Manage Lakebase databases (Provisioned and Autoscaling).
-
-Consolidated into 4 tools:
-- manage_lakebase_database: create_or_update, get, list, delete
-- manage_lakebase_branch: create_or_update, delete
-- manage_lakebase_sync: create_or_update, delete
-- generate_lakebase_credential: Generate OAuth tokens
-"""
-
-import logging
-from typing import Any, Dict, List, Optional
-
-# Provisioned core functions
-from databricks_tools_core.lakebase import (
- create_lakebase_instance as _create_instance,
- get_lakebase_instance as _get_instance,
- list_lakebase_instances as _list_instances,
- update_lakebase_instance as _update_instance,
- delete_lakebase_instance as _delete_instance,
- generate_lakebase_credential as _generate_provisioned_credential,
- create_lakebase_catalog as _create_catalog,
- get_lakebase_catalog as _get_catalog,
- delete_lakebase_catalog as _delete_catalog,
- create_synced_table as _create_synced_table,
- get_synced_table as _get_synced_table,
- delete_synced_table as _delete_synced_table,
-)
-
-# Autoscale core functions
-from databricks_tools_core.lakebase_autoscale import (
- create_project as _create_project,
- get_project as _get_project,
- list_projects as _list_projects,
- update_project as _update_project,
- delete_project as _delete_project,
- create_branch as _create_branch,
- list_branches as _list_branches,
- update_branch as _update_branch,
- delete_branch as _delete_branch,
- create_endpoint as _create_endpoint,
- list_endpoints as _list_endpoints,
- update_endpoint as _update_endpoint,
- generate_credential as _generate_autoscale_credential,
-)
-
-from ..server import mcp
-
-logger = logging.getLogger(__name__)
-
-
-# ============================================================================
-# Helpers
-# ============================================================================
-
-
-def _find_instance_by_name(name: str) -> Optional[Dict[str, Any]]:
- """Find a provisioned instance by name, returns None if not found."""
- try:
- return _get_instance(name=name)
- except Exception:
- return None
-
-
-def _find_project_by_name(name: str) -> Optional[Dict[str, Any]]:
- """Find an autoscale project by name, returns None if not found."""
- try:
- return _get_project(name=name)
- except Exception:
- return None
-
-
-def _find_branch(project_name: str, branch_id: str) -> Optional[Dict[str, Any]]:
- """Find a branch in a project, returns None if not found."""
- try:
- branches = _list_branches(project_name=project_name)
- for branch in branches:
- branch_name = branch.get("name", "")
- if branch_name.endswith(f"/branches/{branch_id}"):
- return branch
- except Exception:
- pass
- return None
-
-
-# ============================================================================
-# Tool 1: manage_lakebase_database
-# ============================================================================
-
-
-@mcp.tool(timeout=120)
-def manage_lakebase_database(
- action: str,
- name: Optional[str] = None,
- type: str = "provisioned",
- # For create_or_update:
- capacity: str = "CU_1",
- stopped: bool = False,
- display_name: Optional[str] = None,
- pg_version: str = "17",
- # For delete:
- force: bool = False,
-) -> Dict[str, Any]:
- """Manage Lakebase PostgreSQL databases: create, update, get, list, delete.
-
- Actions:
- - create_or_update: Idempotent create/update. Requires name.
- type: "provisioned" (fixed capacity CU_1/2/4/8) or "autoscale" (auto-scaling with branches).
- capacity: For provisioned only. pg_version: For autoscale only.
- Returns: {created: bool, type, ...connection info}.
- - get: Get database details. Requires name.
- For autoscale, includes branches and endpoints.
- Returns: {name, type, state, ...}.
- - list: List all databases. Optional type filter.
- Returns: {databases: [{name, type, ...}]}.
- - delete: Delete database. Requires name.
- force=True cascades to children (provisioned). Autoscale deletes all branches/computes/data.
- Returns: {status, ...}.
-
- See databricks-lakebase-provisioned or databricks-lakebase-autoscale skill for details."""
- act = action.lower()
-
- if act == "create_or_update":
- if not name:
- return {"error": "create_or_update requires: name"}
- return _create_or_update_database(
- name=name, type=type, capacity=capacity, stopped=stopped,
- display_name=display_name, pg_version=pg_version,
- )
-
- elif act == "get":
- if not name:
- return {"error": "get requires: name"}
- return _get_database(name=name, type=type)
-
- elif act == "list":
- return _list_databases(type=type if type != "provisioned" else None)
-
- elif act == "delete":
- if not name:
- return {"error": "delete requires: name"}
- return _delete_database(name=name, type=type, force=force)
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: create_or_update, get, list, delete"}
-
-
-# ============================================================================
-# Tool 2: manage_lakebase_branch
-# ============================================================================
-
-
-@mcp.tool(timeout=120)
-def manage_lakebase_branch(
- action: str,
- # For create_or_update:
- project_name: Optional[str] = None,
- branch_id: Optional[str] = None,
- source_branch: Optional[str] = None,
- ttl_seconds: Optional[int] = None,
- no_expiry: bool = False,
- is_protected: Optional[bool] = None,
- endpoint_type: str = "ENDPOINT_TYPE_READ_WRITE",
- autoscaling_limit_min_cu: Optional[float] = None,
- autoscaling_limit_max_cu: Optional[float] = None,
- scale_to_zero_seconds: Optional[int] = None,
- # For delete:
- name: Optional[str] = None,
-) -> Dict[str, Any]:
- """Manage Autoscale branches: create, update, delete.
-
- Branches are isolated copy-on-write environments with their own compute endpoints.
-
- Actions:
- - create_or_update: Idempotent create/update. Requires project_name, branch_id.
- source_branch: Branch to fork from (default: production).
- ttl_seconds: Auto-delete after N seconds. is_protected: Prevent accidental deletion.
- autoscaling_limit_min/max_cu: Compute unit limits. scale_to_zero_seconds: Idle time before scaling to zero.
- Returns: {branch details, endpoint connection info, created: bool}.
- - delete: Delete branch and endpoints. Requires name (full branch name).
- Permanently deletes data/databases/roles. Cannot delete protected branches.
- Returns: {status, ...}.
-
- See databricks-lakebase-autoscale skill for branch workflows."""
- act = action.lower()
-
- if act == "create_or_update":
- if not project_name or not branch_id:
- return {"error": "create_or_update requires: project_name, branch_id"}
- return _create_or_update_branch(
- project_name=project_name, branch_id=branch_id, source_branch=source_branch,
- ttl_seconds=ttl_seconds, no_expiry=no_expiry, is_protected=is_protected,
- endpoint_type=endpoint_type, autoscaling_limit_min_cu=autoscaling_limit_min_cu,
- autoscaling_limit_max_cu=autoscaling_limit_max_cu, scale_to_zero_seconds=scale_to_zero_seconds,
- )
-
- elif act == "delete":
- if not name:
- return {"error": "delete requires: name (full branch name)"}
- return _delete_branch(name=name)
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: create_or_update, delete"}
-
-
-# ============================================================================
-# Tool 3: manage_lakebase_sync
-# ============================================================================
-
-
-@mcp.tool(timeout=120)
-def manage_lakebase_sync(
- action: str,
- # For create_or_update:
- instance_name: Optional[str] = None,
- source_table_name: Optional[str] = None,
- target_table_name: Optional[str] = None,
- catalog_name: Optional[str] = None,
- database_name: str = "databricks_postgres",
- primary_key_columns: Optional[List[str]] = None,
- scheduling_policy: str = "TRIGGERED",
- # For delete:
- table_name: Optional[str] = None,
-) -> Dict[str, Any]:
- """Manage Lakebase sync (reverse ETL): create, delete.
-
- Actions:
- - create_or_update: Set up reverse ETL from Delta table to Lakebase.
- Requires instance_name, source_table_name, target_table_name.
- Creates catalog if needed, then synced table.
- source_table_name: Delta table (catalog.schema.table). target_table_name: Postgres destination.
- primary_key_columns: Required for incremental sync.
- scheduling_policy: TRIGGERED/SNAPSHOT/CONTINUOUS.
- Returns: {catalog, synced_table, created}.
- - delete: Remove synced table, optionally UC catalog. Source Delta table unaffected.
- Requires table_name. Optional catalog_name to also delete catalog.
- Returns: {synced_table, catalog (if deleted)}.
-
- See databricks-lakebase-provisioned skill for sync workflows."""
- act = action.lower()
-
- if act == "create_or_update":
- if not all([instance_name, source_table_name, target_table_name]):
- return {"error": "create_or_update requires: instance_name, source_table_name, target_table_name"}
- return _create_or_update_sync(
- instance_name=instance_name, source_table_name=source_table_name,
- target_table_name=target_table_name, catalog_name=catalog_name,
- database_name=database_name, primary_key_columns=primary_key_columns,
- scheduling_policy=scheduling_policy,
- )
-
- elif act == "delete":
- if not table_name:
- return {"error": "delete requires: table_name"}
- return _delete_sync(table_name=table_name, catalog_name=catalog_name)
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: create_or_update, delete"}
-
-
-# ============================================================================
-# Tool 4: generate_lakebase_credential
-# ============================================================================
-
-
-@mcp.tool(timeout=30)
-def generate_lakebase_credential(
- instance_names: Optional[List[str]] = None,
- endpoint: Optional[str] = None,
-) -> Dict[str, Any]:
- """Generate OAuth token (~1hr) for Lakebase connection. Use as password with sslmode=require.
-
- Provide instance_names (provisioned) or endpoint (autoscale)."""
- if instance_names:
- return _generate_provisioned_credential(instance_names=instance_names)
- elif endpoint:
- return _generate_autoscale_credential(endpoint=endpoint)
- else:
- return {"error": "Provide either instance_names (provisioned) or endpoint (autoscale)."}
-
-
-# ============================================================================
-# Helper Functions
-# ============================================================================
-
-
-def _create_or_update_database(
- name: str, type: str, capacity: str, stopped: bool,
- display_name: Optional[str], pg_version: str,
-) -> Dict[str, Any]:
- """Create or update a Lakebase database."""
- db_type = type.lower()
-
- if db_type == "provisioned":
- existing = _find_instance_by_name(name)
- if existing:
- result = _update_instance(name=name, capacity=capacity, stopped=stopped)
- return {**result, "created": False, "type": "provisioned"}
- else:
- result = _create_instance(name=name, capacity=capacity, stopped=stopped)
- try:
- from ..manifest import track_resource
- track_resource(resource_type="lakebase_instance", name=name, resource_id=name)
- except Exception:
- pass
- return {**result, "created": True, "type": "provisioned"}
-
- elif db_type == "autoscale":
- existing = _find_project_by_name(name)
- # Check if project actually exists (not just a NOT_FOUND response)
- project_exists = existing and "error" not in existing and existing.get("state") != "NOT_FOUND"
- if project_exists:
- result = _update_project(name=name, display_name=display_name)
- return {**result, "created": False, "type": "autoscale"}
- else:
- result = _create_project(
- project_id=name,
- display_name=display_name,
- pg_version=pg_version,
- )
- try:
- from ..manifest import track_resource
- track_resource(resource_type="lakebase_project", name=name, resource_id=name)
- except Exception:
- pass
- return {**result, "created": True, "type": "autoscale"}
-
- else:
- return {"error": f"Invalid type '{type}'. Use 'provisioned' or 'autoscale'."}
-
-
-def _get_database(name: str, type: Optional[str]) -> Dict[str, Any]:
- """Get a database by name."""
- result = None
- if type is None or type.lower() == "provisioned":
- result = _find_instance_by_name(name)
- if result:
- result["type"] = "provisioned"
-
- if result is None and (type is None or type.lower() == "autoscale"):
- result = _find_project_by_name(name)
- if result:
- result["type"] = "autoscale"
- try:
- result["branches"] = _list_branches(project_name=name)
- except Exception:
- pass
- try:
- for branch in result.get("branches", []):
- branch_name = branch.get("name", "")
- branch["endpoints"] = _list_endpoints(branch_name=branch_name)
- except Exception:
- pass
-
- if result is None:
- return {"error": f"Database '{name}' not found."}
- return result
-
-
-def _list_databases(type: Optional[str]) -> Dict[str, Any]:
- """List all databases."""
- databases = []
-
- if type is None or type.lower() == "provisioned":
- try:
- for inst in _list_instances():
- inst["type"] = "provisioned"
- databases.append(inst)
- except Exception as e:
- logger.warning("Failed to list provisioned instances: %s", e)
-
- if type is None or type.lower() == "autoscale":
- try:
- for proj in _list_projects():
- proj["type"] = "autoscale"
- databases.append(proj)
- except Exception as e:
- logger.warning("Failed to list autoscale projects: %s", e)
-
- return {"databases": databases}
-
-
-def _delete_database(name: str, type: str, force: bool) -> Dict[str, Any]:
- """Delete a database."""
- db_type = type.lower()
-
- if db_type == "provisioned":
- return _delete_instance(name=name, force=force, purge=True)
- elif db_type == "autoscale":
- return _delete_project(name=name)
- else:
- return {"error": f"Invalid type '{type}'. Use 'provisioned' or 'autoscale'."}
-
-
-def _create_or_update_branch(
- project_name: str, branch_id: str, source_branch: Optional[str],
- ttl_seconds: Optional[int], no_expiry: bool, is_protected: Optional[bool],
- endpoint_type: str, autoscaling_limit_min_cu: Optional[float],
- autoscaling_limit_max_cu: Optional[float], scale_to_zero_seconds: Optional[int],
-) -> Dict[str, Any]:
- """Create or update a branch with compute endpoint."""
- existing = _find_branch(project_name, branch_id)
-
- if existing:
- branch_name = existing.get("name", f"{project_name}/branches/{branch_id}")
- branch_result = _update_branch(
- name=branch_name,
- is_protected=is_protected,
- ttl_seconds=ttl_seconds,
- no_expiry=no_expiry if no_expiry else None,
- )
-
- # Update endpoint if scaling params provided
- endpoint_result = None
- if any(v is not None for v in [autoscaling_limit_min_cu, autoscaling_limit_max_cu, scale_to_zero_seconds]):
- try:
- endpoints = _list_endpoints(branch_name=branch_name)
- if endpoints:
- ep_name = endpoints[0].get("name", "")
- endpoint_result = _update_endpoint(
- name=ep_name,
- autoscaling_limit_min_cu=autoscaling_limit_min_cu,
- autoscaling_limit_max_cu=autoscaling_limit_max_cu,
- scale_to_zero_seconds=scale_to_zero_seconds,
- )
- except Exception as e:
- logger.warning("Failed to update endpoint: %s", e)
-
- result = {**branch_result, "created": False}
- if endpoint_result:
- result["endpoint"] = endpoint_result
- return result
-
- else:
- branch_result = _create_branch(
- project_name=project_name,
- branch_id=branch_id,
- source_branch=source_branch,
- ttl_seconds=ttl_seconds,
- no_expiry=no_expiry,
- )
-
- # Create compute endpoint on the new branch
- branch_name = branch_result.get("name", f"{project_name}/branches/{branch_id}")
- endpoint_result = None
- try:
- endpoint_result = _create_endpoint(
- branch_name=branch_name,
- endpoint_id=f"{branch_id}-ep",
- endpoint_type=endpoint_type,
- autoscaling_limit_min_cu=autoscaling_limit_min_cu,
- autoscaling_limit_max_cu=autoscaling_limit_max_cu,
- scale_to_zero_seconds=scale_to_zero_seconds,
- )
- except Exception as e:
- logger.warning("Failed to create endpoint on branch: %s", e)
-
- result = {**branch_result, "created": True}
- if endpoint_result:
- result["endpoint"] = endpoint_result
- return result
-
-
-def _create_or_update_sync(
- instance_name: str, source_table_name: str, target_table_name: str,
- catalog_name: Optional[str], database_name: str,
- primary_key_columns: Optional[List[str]], scheduling_policy: str,
-) -> Dict[str, Any]:
- """Create or update a sync configuration."""
- # Derive catalog name from target table if not provided
- if not catalog_name:
- parts = target_table_name.split(".")
- if len(parts) >= 1:
- catalog_name = parts[0]
- else:
- return {"error": "Cannot derive catalog_name from target_table_name. Provide catalog_name explicitly."}
-
- # Ensure catalog registration exists
- catalog_result = None
- try:
- catalog_result = _get_catalog(name=catalog_name)
- except Exception:
- try:
- catalog_result = _create_catalog(
- name=catalog_name,
- instance_name=instance_name,
- database_name=database_name,
- )
- except Exception as e:
- return {"error": f"Failed to create catalog '{catalog_name}': {e}"}
-
- # Check if synced table already exists
- try:
- existing = _get_synced_table(table_name=target_table_name)
- return {
- "catalog": catalog_result,
- "synced_table": existing,
- "created": False,
- }
- except Exception:
- pass
-
- # Create synced table
- sync_result = _create_synced_table(
- instance_name=instance_name,
- source_table_name=source_table_name,
- target_table_name=target_table_name,
- primary_key_columns=primary_key_columns,
- scheduling_policy=scheduling_policy,
- )
-
- return {
- "catalog": catalog_result,
- "synced_table": sync_result,
- "created": True,
- }
-
-
-def _delete_sync(table_name: str, catalog_name: Optional[str]) -> Dict[str, Any]:
- """Delete a sync configuration."""
- result = {}
-
- sync_result = _delete_synced_table(table_name=table_name)
- result["synced_table"] = sync_result
-
- if catalog_name:
- try:
- catalog_result = _delete_catalog(name=catalog_name)
- result["catalog"] = catalog_result
- except Exception as e:
- result["catalog"] = {"error": str(e)}
-
- return result
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/manifest.py b/databricks-mcp-server/databricks_mcp_server/tools/manifest.py
deleted file mode 100644
index 740d18c8..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/manifest.py
+++ /dev/null
@@ -1,78 +0,0 @@
-"""Resource tracking manifest tools.
-
-Exposes the resource manifest as MCP tools so agents can list and clean up
-resources created across sessions.
-"""
-
-import logging
-from typing import Any, Dict, Optional
-
-from ..manifest import _RESOURCE_DELETERS, list_resources, remove_resource
-from ..server import mcp
-
-logger = logging.getLogger(__name__)
-
-
-def _delete_from_databricks(resource_type: str, resource_id: str) -> Optional[str]:
- """Delete a resource from Databricks using the registered deleter.
-
- Returns error string or None on success.
- """
- deleter = _RESOURCE_DELETERS.get(resource_type)
- if not deleter:
- return f"Unsupported resource type for deletion: {resource_type}"
- try:
- deleter(resource_id)
- return None
- except Exception as exc:
- return str(exc)
-
-
-@mcp.tool(timeout=30)
-def list_tracked_resources(type: Optional[str] = None) -> Dict[str, Any]:
- """List resources tracked in project manifest (dashboards, jobs, pipelines, genie_space, etc.).
-
- type: Filter by resource type (optional). Returns: {resources: [...], count}."""
- resources = list_resources(resource_type=type)
- return {
- "resources": resources,
- "count": len(resources),
- }
-
-
-@mcp.tool(timeout=60)
-def delete_tracked_resource(
- type: str,
- resource_id: str,
- delete_from_databricks: bool = False,
-) -> Dict[str, Any]:
- """Delete resource from manifest, optionally from Databricks too.
-
- delete_from_databricks: If True, deletes from Databricks first (default: False, manifest-only).
- Returns: {success, removed_from_manifest, deleted_from_databricks, error}."""
- result: Dict[str, Any] = {
- "success": True,
- "removed_from_manifest": False,
- "deleted_from_databricks": False,
- "error": None,
- }
-
- # Optionally delete from Databricks first
- if delete_from_databricks:
- error = _delete_from_databricks(type, resource_id)
- if error:
- result["error"] = f"Databricks deletion failed: {error}"
- result["success"] = False
- # Still remove from manifest even if Databricks deletion failed
- else:
- result["deleted_from_databricks"] = True
-
- # Remove from manifest
- removed = remove_resource(resource_type=type, resource_id=resource_id)
- result["removed_from_manifest"] = removed
-
- if not removed and not result.get("error"):
- result["error"] = f"Resource {type}/{resource_id} not found in manifest"
- result["success"] = result.get("deleted_from_databricks", False)
-
- return result
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/pdf.py b/databricks-mcp-server/databricks_mcp_server/tools/pdf.py
deleted file mode 100644
index 863d0b32..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/pdf.py
+++ /dev/null
@@ -1,35 +0,0 @@
-"""PDF tools - Convert HTML to PDF and upload to Unity Catalog volumes."""
-
-from typing import Any, Dict, Optional
-
-from databricks_tools_core.pdf import generate_and_upload_pdf as _generate_and_upload_pdf
-
-from ..server import mcp
-
-
-@mcp.tool
-def generate_and_upload_pdf(
- html_content: str,
- filename: str,
- catalog: str,
- schema: str,
- volume: str = "raw_data",
- folder: Optional[str] = None,
-) -> Dict[str, Any]:
- """Convert complete HTML (with styles) to PDF and upload to Unity Catalog volume.
-
- Returns: {success, volume_path, error}."""
- result = _generate_and_upload_pdf(
- html_content=html_content,
- filename=filename,
- catalog=catalog,
- schema=schema,
- volume=volume,
- folder=folder,
- )
-
- return {
- "success": result.success,
- "volume_path": result.volume_path,
- "error": result.error,
- }
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py b/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py
deleted file mode 100644
index 91d61621..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/pipelines.py
+++ /dev/null
@@ -1,271 +0,0 @@
-"""Pipeline tools - Manage Spark Declarative Pipelines (SDP).
-
-Consolidated into 2 tools:
-- manage_pipeline: create, create_or_update, get, update, delete, find_by_name
-- manage_pipeline_run: start, get, stop, get_events
-"""
-
-from typing import List, Dict, Any, Optional
-
-from databricks_tools_core.identity import get_default_tags
-from databricks_tools_core.spark_declarative_pipelines.pipelines import (
- create_pipeline as _create_pipeline,
- get_pipeline as _get_pipeline,
- update_pipeline as _update_pipeline,
- delete_pipeline as _delete_pipeline,
- start_update as _start_update,
- get_update as _get_update,
- stop_pipeline as _stop_pipeline,
- get_pipeline_events as _get_pipeline_events,
- create_or_update_pipeline as _create_or_update_pipeline,
- find_pipeline_by_name as _find_pipeline_by_name,
-)
-
-from ..manifest import register_deleter
-from ..server import mcp
-
-
-def _delete_pipeline_resource(resource_id: str) -> None:
- _delete_pipeline(pipeline_id=resource_id)
-
-
-register_deleter("pipeline", _delete_pipeline_resource)
-
-
-# ============================================================================
-# Tool 1: manage_pipeline
-# ============================================================================
-
-
-@mcp.tool(timeout=300)
-def manage_pipeline(
- action: str,
- # For create/create_or_update/find_by_name:
- name: Optional[str] = None,
- # For create/create_or_update:
- root_path: Optional[str] = None,
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- workspace_file_paths: Optional[List[str]] = None,
- extra_settings: Optional[Dict[str, Any]] = None,
- # For create_or_update only:
- start_run: bool = False,
- wait_for_completion: bool = False,
- full_refresh: bool = True,
- timeout: int = 1800,
- # For get/update/delete:
- pipeline_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """Manage Spark Declarative Pipelines: create, update, get, delete, find.
-
- Actions:
- - create: New pipeline. Requires name, root_path, catalog, schema, workspace_file_paths.
- Returns: {pipeline_id}.
- - create_or_update: Idempotent by name. Same params as create.
- start_run=True triggers run after create/update. wait_for_completion=True blocks until done.
- full_refresh=True reprocesses all data. Returns: {pipeline_id, created, success, state}.
- - get: Get pipeline details. Requires pipeline_id. Returns: full pipeline config.
- - update: Modify config. Requires pipeline_id + fields to change. Returns: {status}.
- - delete: Remove pipeline. Requires pipeline_id. Returns: {status}.
- - find_by_name: Find by name. Requires name. Returns: {found, pipeline_id}.
-
- root_path: Workspace folder for pipeline files (e.g., /Workspace/Users/me/pipelines).
- workspace_file_paths: List of notebook/file paths to include in pipeline.
- extra_settings: Additional config dict (clusters, photon, channel, continuous, etc).
- See databricks-spark-declarative-pipelines skill for configuration details."""
- act = action.lower()
-
- if act == "create":
- if not all([name, root_path, catalog, schema, workspace_file_paths]):
- return {"error": "create requires: name, root_path, catalog, schema, workspace_file_paths"}
-
- # Auto-inject default tags
- settings = extra_settings or {}
- settings.setdefault("tags", {})
- settings["tags"] = {**get_default_tags(), **settings["tags"]}
-
- result = _create_pipeline(
- name=name,
- root_path=root_path,
- catalog=catalog,
- schema=schema,
- workspace_file_paths=workspace_file_paths,
- extra_settings=settings,
- )
-
- # Track resource
- try:
- if result.pipeline_id:
- from ..manifest import track_resource
- track_resource(resource_type="pipeline", name=name, resource_id=result.pipeline_id)
- except Exception:
- pass
-
- return {"pipeline_id": result.pipeline_id}
-
- elif act == "create_or_update":
- if not all([name, root_path, catalog, schema, workspace_file_paths]):
- return {"error": "create_or_update requires: name, root_path, catalog, schema, workspace_file_paths"}
-
- # Auto-inject default tags
- settings = extra_settings or {}
- settings.setdefault("tags", {})
- settings["tags"] = {**get_default_tags(), **settings["tags"]}
-
- result = _create_or_update_pipeline(
- name=name,
- root_path=root_path,
- catalog=catalog,
- schema=schema,
- workspace_file_paths=workspace_file_paths,
- start_run=start_run,
- wait_for_completion=wait_for_completion,
- full_refresh=full_refresh,
- timeout=timeout,
- extra_settings=settings,
- )
-
- # Track resource
- try:
- result_dict = result.to_dict()
- pid = result_dict.get("pipeline_id")
- if pid:
- from ..manifest import track_resource
- track_resource(resource_type="pipeline", name=name, resource_id=pid)
- except Exception:
- pass
-
- return result.to_dict()
-
- elif act == "get":
- if not pipeline_id:
- return {"error": "get requires: pipeline_id"}
- result = _get_pipeline(pipeline_id=pipeline_id)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
- elif act == "update":
- if not pipeline_id:
- return {"error": "update requires: pipeline_id"}
- _update_pipeline(
- pipeline_id=pipeline_id,
- name=name,
- root_path=root_path,
- catalog=catalog,
- schema=schema,
- workspace_file_paths=workspace_file_paths,
- extra_settings=extra_settings,
- )
- return {"status": "updated", "pipeline_id": pipeline_id}
-
- elif act == "delete":
- if not pipeline_id:
- return {"error": "delete requires: pipeline_id"}
- _delete_pipeline(pipeline_id=pipeline_id)
- try:
- from ..manifest import remove_resource
- remove_resource(resource_type="pipeline", resource_id=pipeline_id)
- except Exception:
- pass
- return {"status": "deleted", "pipeline_id": pipeline_id}
-
- elif act == "find_by_name":
- if not name:
- return {"error": "find_by_name requires: name"}
- pid = _find_pipeline_by_name(name=name)
- return {"found": pid is not None, "pipeline_id": pid, "name": name}
-
- else:
- return {
- "error": f"Invalid action '{action}'. Valid actions: create, create_or_update, get, update, delete, find_by_name"
- }
-
-
-# ============================================================================
-# Tool 2: manage_pipeline_run
-# ============================================================================
-
-
-@mcp.tool(timeout=300)
-def manage_pipeline_run(
- action: str,
- pipeline_id: str,
- # For start:
- refresh_selection: Optional[List[str]] = None,
- full_refresh: bool = False,
- full_refresh_selection: Optional[List[str]] = None,
- validate_only: bool = False,
- wait: bool = True,
- timeout: int = 300,
- # For get:
- update_id: Optional[str] = None,
- include_config: bool = False,
- full_error_details: bool = False,
- # For get_events:
- max_results: int = 5,
- event_log_level: str = "WARN",
-) -> Dict[str, Any]:
- """Manage pipeline runs: start, monitor, stop, get events.
-
- Actions:
- - start: Trigger pipeline update. Requires pipeline_id.
- wait=True (default) blocks until complete. validate_only=True checks without running.
- full_refresh=True reprocesses all data. refresh_selection: specific tables to refresh.
- Returns: {update_id, state, success, error_summary}.
- - get: Get run status. Requires pipeline_id, update_id.
- include_config=True includes pipeline config. full_error_details=True for verbose errors.
- Returns: {update_id, state, success, error_summary}.
- - stop: Stop running pipeline. Requires pipeline_id.
- Returns: {status}.
- - get_events: Get events/logs for debugging. Requires pipeline_id.
- event_log_level: ERROR, WARN (default), INFO. max_results: number of events (default 5).
- update_id: filter to specific run.
- Returns: list of event dicts.
-
- See databricks-spark-declarative-pipelines skill for run management details."""
- act = action.lower()
-
- if act == "start":
- return _start_update(
- pipeline_id=pipeline_id,
- refresh_selection=refresh_selection,
- full_refresh=full_refresh,
- full_refresh_selection=full_refresh_selection,
- validate_only=validate_only,
- wait=wait,
- timeout=timeout,
- full_error_details=full_error_details,
- )
-
- elif act == "get":
- if not update_id:
- return {"error": "get requires: update_id"}
- return _get_update(
- pipeline_id=pipeline_id,
- update_id=update_id,
- include_config=include_config,
- full_error_details=full_error_details,
- )
-
- elif act == "stop":
- _stop_pipeline(pipeline_id=pipeline_id)
- return {"status": "stopped", "pipeline_id": pipeline_id}
-
- elif act == "get_events":
- # Convert log level to filter expression
- level_filters = {
- "ERROR": "level='ERROR'",
- "WARN": "level in ('ERROR', 'WARN')",
- "INFO": "", # No filter = all events
- }
- filter_expr = level_filters.get(event_log_level.upper(), level_filters["WARN"])
-
- events = _get_pipeline_events(
- pipeline_id=pipeline_id,
- max_results=max_results,
- filter=filter_expr,
- update_id=update_id,
- )
- return {"events": [e.as_dict() if hasattr(e, "as_dict") else vars(e) for e in events]}
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: start, get, stop, get_events"}
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/serving.py b/databricks-mcp-server/databricks_mcp_server/tools/serving.py
deleted file mode 100644
index 9add7360..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/serving.py
+++ /dev/null
@@ -1,75 +0,0 @@
-"""Model Serving tools - Query and manage serving endpoints.
-
-Consolidated into 1 tool:
-- manage_serving_endpoint: get, list, query
-"""
-
-from typing import Any, Dict, List, Optional
-
-from databricks_tools_core.serving import (
- get_serving_endpoint_status as _get_serving_endpoint_status,
- query_serving_endpoint as _query_serving_endpoint,
- list_serving_endpoints as _list_serving_endpoints,
-)
-
-from ..server import mcp
-
-
-@mcp.tool(timeout=120)
-def manage_serving_endpoint(
- action: str,
- # For get/query:
- name: Optional[str] = None,
- # For query (use one input format):
- messages: Optional[List[Dict[str, str]]] = None,
- inputs: Optional[Dict[str, Any]] = None,
- dataframe_records: Optional[List[Dict[str, Any]]] = None,
- # For query options:
- max_tokens: Optional[int] = None,
- temperature: Optional[float] = None,
- # For list:
- limit: int = 50,
-) -> Dict[str, Any]:
- """Manage Model Serving endpoints: get status, list, query.
-
- Actions:
- - get: Get endpoint status. Requires name.
- Returns: {name, state (READY/NOT_READY/NOT_FOUND), config_update, served_entities, error}.
- - list: List all endpoints. Optional limit (default 50).
- Returns: {endpoints: [{name, state, creation_timestamp, creator, served_entities_count}, ...]}.
- - query: Query an endpoint. Requires name + one input format.
- Input formats (use one):
- - messages: Chat/agent endpoints. Format: [{"role": "user", "content": "..."}]
- - inputs: Custom pyfunc models (dict matching model signature)
- - dataframe_records: ML models. Format: [{"feature1": 1.0, ...}]
- max_tokens, temperature: Optional for chat endpoints.
- Returns: {choices: [...]} for chat or {predictions: [...]} for ML.
-
- See databricks-model-serving skill for endpoint configuration."""
- act = action.lower()
-
- if act == "get":
- if not name:
- return {"error": "get requires: name"}
- return _get_serving_endpoint_status(name=name)
-
- elif act == "list":
- endpoints = _list_serving_endpoints(limit=limit)
- return {"endpoints": endpoints}
-
- elif act == "query":
- if not name:
- return {"error": "query requires: name"}
- if not any([messages, inputs, dataframe_records]):
- return {"error": "query requires one of: messages, inputs, dataframe_records"}
- return _query_serving_endpoint(
- name=name,
- messages=messages,
- inputs=inputs,
- dataframe_records=dataframe_records,
- max_tokens=max_tokens,
- temperature=temperature,
- )
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: get, list, query"}
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/sql.py b/databricks-mcp-server/databricks_mcp_server/tools/sql.py
deleted file mode 100644
index efafd84d..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/sql.py
+++ /dev/null
@@ -1,189 +0,0 @@
-"""SQL tools - Execute SQL queries and get table information.
-
-Tools:
-- execute_sql: Single SQL query
-- execute_sql_multi: Multiple SQL statements with parallel execution
-- manage_warehouse: list, get_best
-- get_table_stats_and_schema: Schema and stats for tables
-- get_volume_folder_details: Schema for volume files
-"""
-
-from typing import Any, Dict, List, Optional, Union
-
-from databricks_tools_core.sql import (
- execute_sql as _execute_sql,
- execute_sql_multi as _execute_sql_multi,
- list_warehouses as _list_warehouses,
- get_best_warehouse as _get_best_warehouse,
- get_table_stats_and_schema as _get_table_stats_and_schema,
- get_volume_folder_details as _get_volume_folder_details,
- TableStatLevel,
-)
-
-from ..server import mcp
-
-
-def _format_results_markdown(rows: List[Dict[str, Any]]) -> str:
- """Format SQL results as a markdown table.
-
- Markdown tables state column names once in the header instead of repeating
- them on every row (as JSON does), reducing token usage by ~50%.
-
- Args:
- rows: List of row dicts from the SQL executor.
-
- Returns:
- Markdown table string, or "(no results)" if empty.
- """
- if not rows:
- return "(no results)"
-
- columns = list(rows[0].keys())
-
- # Build header
- header = "| " + " | ".join(columns) + " |"
- separator = "| " + " | ".join("---" for _ in columns) + " |"
-
- # Build rows — convert None to empty string, stringify everything
- data_lines = []
- for row in rows:
- cells = []
- for col in columns:
- val = row.get(col)
- cell = "" if val is None else str(val)
- # Escape pipe characters inside cell values
- cell = cell.replace("|", "\\|")
- cells.append(cell)
- data_lines.append("| " + " | ".join(cells) + " |")
-
- parts = [header, separator] + data_lines
- # Append row count for awareness
- parts.append(f"\n({len(rows)} row{'s' if len(rows) != 1 else ''})")
- return "\n".join(parts)
-
-
-@mcp.tool(timeout=60)
-def execute_sql(
- sql_query: str,
- warehouse_id: str = None,
- catalog: str = None,
- schema: str = None,
- timeout: int = 180,
- query_tags: str = None,
- output_format: str = "markdown",
-) -> Union[str, List[Dict[str, Any]]]:
- """Execute SQL query on Databricks warehouse. Auto-selects warehouse if not provided.
-
- Use for SELECT/INSERT/UPDATE/table DDL. For catalog/schema/volume DDL, use manage_uc_objects.
- output_format: "markdown" (default, 50% smaller) or "json"."""
- rows = _execute_sql(
- sql_query=sql_query,
- warehouse_id=warehouse_id,
- catalog=catalog,
- schema=schema,
- timeout=timeout,
- query_tags=query_tags,
- )
- if output_format == "json":
- return rows
- return _format_results_markdown(rows)
-
-
-@mcp.tool(timeout=120)
-def execute_sql_multi(
- sql_content: str,
- warehouse_id: str = None,
- catalog: str = None,
- schema: str = None,
- timeout: int = 180,
- max_workers: int = 4,
- query_tags: str = None,
- output_format: str = "markdown",
-) -> Dict[str, Any]:
- """Execute multiple SQL statements with dependency-aware parallelism. Independent queries run in parallel.
-
- For catalog/schema/volume DDL, use manage_uc_objects instead."""
- result = _execute_sql_multi(
- sql_content=sql_content,
- warehouse_id=warehouse_id,
- catalog=catalog,
- schema=schema,
- timeout=timeout,
- max_workers=max_workers,
- query_tags=query_tags,
- )
- # Format sample_results in each query result if markdown requested
- if output_format != "json" and "results" in result:
- for query_result in result["results"].values():
- sample = query_result.get("sample_results")
- if sample and isinstance(sample, list) and len(sample) > 0:
- query_result["sample_results"] = _format_results_markdown(sample)
- return result
-
-
-@mcp.tool(timeout=30)
-def manage_warehouse(
- action: str = "get_best",
-) -> Union[str, List[Dict[str, Any]], Dict[str, Any]]:
- """Manage SQL warehouses: list, get_best.
-
- Actions:
- - list: List all SQL warehouses.
- Returns: {warehouses: [{id, name, state, size, ...}]}.
- - get_best: Get best available warehouse ID. Prefers running, then starting, smaller sizes.
- Returns: {warehouse_id} or {warehouse_id: null, error}."""
- act = action.lower()
-
- if act == "list":
- return {"warehouses": _list_warehouses()}
-
- elif act == "get_best":
- warehouse_id = _get_best_warehouse()
- if warehouse_id:
- return {"warehouse_id": warehouse_id}
- return {"warehouse_id": None, "error": "No available warehouses found"}
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: list, get_best"}
-
-
-@mcp.tool(timeout=60)
-def get_table_stats_and_schema(
- catalog: str,
- schema: str,
- table_names: List[str] = None,
- table_stat_level: str = "SIMPLE",
- warehouse_id: str = None,
-) -> Dict[str, Any]:
- """Get schema and stats for tables. table_stat_level: NONE (schema only), SIMPLE (default, +row count), DETAILED (+cardinality/min/max/histograms).
-
- table_names: list or glob patterns, None=all tables."""
- # Convert string to enum
- level = TableStatLevel[table_stat_level.upper()]
- result = _get_table_stats_and_schema(
- catalog=catalog,
- schema=schema,
- table_names=table_names,
- table_stat_level=level,
- warehouse_id=warehouse_id,
- )
- # Convert to dict for JSON serialization
- return result.model_dump(exclude_none=True) if hasattr(result, "model_dump") else result
-
-
-@mcp.tool(timeout=60)
-def get_volume_folder_details(
- volume_path: str,
- format: str = "parquet",
- table_stat_level: str = "SIMPLE",
- warehouse_id: str = None,
-) -> Dict[str, Any]:
- """Get schema/stats for data files in Volume folder. format: parquet/csv/json/delta/file."""
- level = TableStatLevel[table_stat_level.upper()]
- result = _get_volume_folder_details(
- volume_path=volume_path,
- format=format,
- table_stat_level=level,
- warehouse_id=warehouse_id,
- )
- return result.model_dump(exclude_none=True) if hasattr(result, "model_dump") else result
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/unity_catalog.py b/databricks-mcp-server/databricks_mcp_server/tools/unity_catalog.py
deleted file mode 100644
index 8c6c704e..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/unity_catalog.py
+++ /dev/null
@@ -1,1004 +0,0 @@
-"""
-Unity Catalog MCP Tools
-
-Consolidated MCP tools for Unity Catalog operations.
-8 tools covering: objects, grants, storage, connections,
-tags, security policies, monitors, and sharing.
-"""
-
-import logging
-from typing import Any, Dict, List
-
-from databricks_tools_core.identity import get_default_tags
-from databricks_tools_core.unity_catalog import (
- # Metric Views
- create_metric_view as _create_metric_view,
- alter_metric_view as _alter_metric_view,
- drop_metric_view as _drop_metric_view,
- describe_metric_view as _describe_metric_view,
- query_metric_view as _query_metric_view,
- grant_metric_view as _grant_metric_view,
-)
-from databricks_tools_core.unity_catalog import (
- # Catalogs
- list_catalogs as _list_catalogs,
- get_catalog as _get_catalog,
- create_catalog as _create_catalog,
- update_catalog as _update_catalog,
- delete_catalog as _delete_catalog,
- # Schemas
- list_schemas as _list_schemas,
- get_schema as _get_schema,
- create_schema as _create_schema,
- update_schema as _update_schema,
- delete_schema as _delete_schema,
- # Volumes
- list_volumes as _list_volumes,
- get_volume as _get_volume,
- create_volume as _create_volume,
- update_volume as _update_volume,
- delete_volume as _delete_volume,
- # Functions
- list_functions as _list_functions,
- get_function as _get_function,
- delete_function as _delete_function,
- # Grants
- grant_privileges as _grant_privileges,
- revoke_privileges as _revoke_privileges,
- get_grants as _get_grants,
- get_effective_grants as _get_effective_grants,
- # Storage
- list_storage_credentials as _list_storage_credentials,
- get_storage_credential as _get_storage_credential,
- create_storage_credential as _create_storage_credential,
- update_storage_credential as _update_storage_credential,
- delete_storage_credential as _delete_storage_credential,
- validate_storage_credential as _validate_storage_credential,
- list_external_locations as _list_external_locations,
- get_external_location as _get_external_location,
- create_external_location as _create_external_location,
- update_external_location as _update_external_location,
- delete_external_location as _delete_external_location,
- # Connections
- list_connections as _list_connections,
- get_connection as _get_connection,
- create_connection as _create_connection,
- update_connection as _update_connection,
- delete_connection as _delete_connection,
- create_foreign_catalog as _create_foreign_catalog,
- # Tags
- set_tags as _set_tags,
- unset_tags as _unset_tags,
- set_comment as _set_comment,
- query_table_tags as _query_table_tags,
- query_column_tags as _query_column_tags,
- # Security policies
- create_security_function as _create_security_function,
- set_row_filter as _set_row_filter,
- drop_row_filter as _drop_row_filter,
- set_column_mask as _set_column_mask,
- drop_column_mask as _drop_column_mask,
- # Monitors
- create_monitor as _create_monitor,
- get_monitor as _get_monitor,
- run_monitor_refresh as _run_monitor_refresh,
- list_monitor_refreshes as _list_monitor_refreshes,
- delete_monitor as _delete_monitor,
- # Sharing
- list_shares as _list_shares,
- get_share as _get_share,
- create_share as _create_share,
- add_table_to_share as _add_table_to_share,
- remove_table_from_share as _remove_table_from_share,
- delete_share as _delete_share,
- grant_share_to_recipient as _grant_share_to_recipient,
- revoke_share_from_recipient as _revoke_share_from_recipient,
- list_recipients as _list_recipients,
- get_recipient as _get_recipient,
- create_recipient as _create_recipient,
- rotate_recipient_token as _rotate_recipient_token,
- delete_recipient as _delete_recipient,
- list_providers as _list_providers,
- get_provider as _get_provider,
- list_provider_shares as _list_provider_shares,
-)
-
-from ..manifest import register_deleter
-from ..server import mcp
-
-logger = logging.getLogger(__name__)
-
-
-def _delete_catalog_resource(resource_id: str) -> None:
- _delete_catalog(catalog_name=resource_id, force=True)
-
-
-def _delete_schema_resource(resource_id: str) -> None:
- _delete_schema(full_schema_name=resource_id)
-
-
-def _delete_volume_resource(resource_id: str) -> None:
- _delete_volume(full_volume_name=resource_id)
-
-
-register_deleter("catalog", _delete_catalog_resource)
-register_deleter("schema", _delete_schema_resource)
-register_deleter("volume", _delete_volume_resource)
-
-
-def _auto_tag(object_type: str, full_name: str) -> None:
- """Best-effort: apply default tags to a newly created UC object.
-
- Tags are set individually so that a tag-policy violation on one key
- does not prevent the remaining tags from being applied.
- """
- for key, value in get_default_tags().items():
- try:
- _set_tags(object_type=object_type, full_name=full_name, tags={key: value})
- except Exception:
- logger.warning("Failed to set tag %s=%s on %s '%s'", key, value, object_type, full_name, exc_info=True)
-
-
-def _to_dict(obj: Any) -> Dict[str, Any]:
- """Convert SDK objects to serializable dicts."""
- if isinstance(obj, dict):
- return obj
- if hasattr(obj, "as_dict"):
- return obj.as_dict()
- if hasattr(obj, "model_dump"):
- return obj.model_dump(exclude_none=True)
- return vars(obj)
-
-
-def _to_dict_list(items: list) -> List[Dict[str, Any]]:
- """Convert a list of SDK objects to serializable dicts."""
- return [_to_dict(item) for item in items]
-
-
-# =============================================================================
-# Tool 1: manage_uc_objects
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_uc_objects(
- object_type: str,
- action: str,
- name: str = None,
- full_name: str = None,
- catalog_name: str = None,
- schema_name: str = None,
- comment: str = None,
- owner: str = None,
- storage_root: str = None,
- volume_type: str = None,
- storage_location: str = None,
- new_name: str = None,
- properties: Dict[str, str] = None,
- isolation_mode: str = None,
- force: bool = False,
-) -> Dict[str, Any]:
- """Manage UC namespace objects: catalog/schema/volume/function.
-
- object_type: "catalog", "schema", "volume", or "function".
- action: "create", "get", "list", "update", "delete" (function: no create, use SQL).
-
- Parameters by object_type:
- - catalog: create(name, comment?, storage_root?, properties?), get/update/delete(full_name or name).
- update supports: new_name, comment, owner, isolation_mode (OPEN/ISOLATED).
- - schema: create(catalog_name, name, comment?), get/update/delete(full_name).
- list(catalog_name). update supports: new_name, comment, owner.
- - volume: create(catalog_name, schema_name, name, volume_type?, comment?, storage_location?).
- volume_type: MANAGED (default) or EXTERNAL. storage_location required for EXTERNAL.
- list(catalog_name, schema_name). get/update/delete(full_name).
- - function: get/delete(full_name), list(catalog_name, schema_name). force=True for delete.
-
- full_name format: "catalog" or "catalog.schema" or "catalog.schema.object".
- Returns: list={items}, get/create/update=object details, delete={status}."""
- otype = object_type.lower()
-
- if otype == "catalog":
- if action == "create":
- result = _to_dict(
- _create_catalog(
- name=name,
- comment=comment,
- storage_root=storage_root,
- properties=properties,
- )
- )
- _auto_tag("catalog", name)
- try:
- from ..manifest import track_resource
-
- track_resource(
- resource_type="catalog",
- name=name,
- resource_id=result.get("name", name),
- )
- except Exception:
- pass
- return result
- elif action == "get":
- return _to_dict(_get_catalog(catalog_name=full_name or name))
- elif action == "list":
- return {"items": _to_dict_list(_list_catalogs())}
- elif action == "update":
- return _to_dict(
- _update_catalog(
- catalog_name=full_name or name,
- new_name=new_name,
- comment=comment,
- owner=owner,
- isolation_mode=isolation_mode,
- )
- )
- elif action == "delete":
- _delete_catalog(catalog_name=full_name or name, force=force)
- try:
- from ..manifest import remove_resource
-
- remove_resource(resource_type="catalog", resource_id=full_name or name)
- except Exception:
- pass
- return {"status": "deleted", "catalog": full_name or name}
-
- elif otype == "schema":
- if action == "create":
- result = _to_dict(_create_schema(catalog_name=catalog_name, schema_name=name, comment=comment))
- _auto_tag("schema", f"{catalog_name}.{name}")
- try:
- from ..manifest import track_resource
-
- full_schema = result.get("full_name") or f"{catalog_name}.{name}"
- track_resource(resource_type="schema", name=full_schema, resource_id=full_schema)
- except Exception:
- logger.warning("Failed to track schema in manifest", exc_info=True)
- return result
- elif action == "get":
- return _to_dict(_get_schema(full_schema_name=full_name))
- elif action == "list":
- return {"items": _to_dict_list(_list_schemas(catalog_name=catalog_name))}
- elif action == "update":
- return _to_dict(
- _update_schema(
- full_schema_name=full_name,
- new_name=new_name,
- comment=comment,
- owner=owner,
- )
- )
- elif action == "delete":
- _delete_schema(full_schema_name=full_name)
- try:
- from ..manifest import remove_resource
-
- remove_resource(resource_type="schema", resource_id=full_name)
- except Exception:
- pass
- return {"status": "deleted", "schema": full_name}
-
- elif otype == "volume":
- if action == "create":
- result = _to_dict(
- _create_volume(
- catalog_name=catalog_name,
- schema_name=schema_name,
- name=name,
- volume_type=volume_type or "MANAGED",
- comment=comment,
- storage_location=storage_location,
- )
- )
- _auto_tag("volume", f"{catalog_name}.{schema_name}.{name}")
- try:
- from ..manifest import track_resource
-
- full_vol = result.get("full_name") or f"{catalog_name}.{schema_name}.{name}"
- track_resource(resource_type="volume", name=full_vol, resource_id=full_vol)
- except Exception:
- pass
- return result
- elif action == "get":
- return _to_dict(_get_volume(full_volume_name=full_name))
- elif action == "list":
- return {"items": _to_dict_list(_list_volumes(catalog_name=catalog_name, schema_name=schema_name))}
- elif action == "update":
- return _to_dict(
- _update_volume(
- full_volume_name=full_name,
- new_name=new_name,
- comment=comment,
- owner=owner,
- )
- )
- elif action == "delete":
- _delete_volume(full_volume_name=full_name)
- try:
- from ..manifest import remove_resource
-
- remove_resource(resource_type="volume", resource_id=full_name)
- except Exception:
- pass
- return {"status": "deleted", "volume": full_name}
-
- elif otype == "function":
- if action == "create":
- return {
- "error": """Functions cannot be created via SDK. Use manage_uc_security_policies tool with
- action='create_security_function' or execute_sql with a CREATE FUNCTION statement."""
- }
- elif action == "get":
- return _to_dict(_get_function(full_function_name=full_name))
- elif action == "list":
- return {"items": _to_dict_list(_list_functions(catalog_name=catalog_name, schema_name=schema_name))}
- elif action == "delete":
- _delete_function(full_function_name=full_name, force=force)
- return {"status": "deleted", "function": full_name}
-
- raise ValueError(f"Invalid object_type='{object_type}' or action='{action}'")
-
-
-# =============================================================================
-# Tool 2: manage_uc_grants
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_uc_grants(
- action: str,
- securable_type: str,
- full_name: str,
- principal: str = None,
- privileges: List[str] = None,
-) -> Dict[str, Any]:
- """Manage UC permissions: grant/revoke/get/get_effective.
-
- action: "grant", "revoke", "get", "get_effective".
- securable_type: catalog/schema/table/volume/function/storage_credential/external_location/connection/share.
- full_name: Full UC name (e.g., "catalog.schema.table").
- principal: User, group, or service principal (e.g., "user@example.com", "group_name").
- privileges: List of privileges to grant/revoke. Common values:
- - catalog: USE_CATALOG, CREATE_SCHEMA, ALL_PRIVILEGES
- - schema: USE_SCHEMA, CREATE_TABLE, CREATE_FUNCTION, ALL_PRIVILEGES
- - table: SELECT, MODIFY, ALL_PRIVILEGES
- - volume: READ_VOLUME, WRITE_VOLUME, ALL_PRIVILEGES
- - function: EXECUTE, ALL_PRIVILEGES
- Returns: get/get_effective={privilege_assignments: [...]}, grant/revoke={status}."""
- act = action.lower()
-
- if act == "grant":
- return _grant_privileges(
- securable_type=securable_type,
- full_name=full_name,
- principal=principal,
- privileges=privileges,
- )
- elif act == "revoke":
- return _revoke_privileges(
- securable_type=securable_type,
- full_name=full_name,
- principal=principal,
- privileges=privileges,
- )
- elif act == "get":
- return _get_grants(securable_type=securable_type, full_name=full_name, principal=principal)
- elif act == "get_effective":
- return _get_effective_grants(securable_type=securable_type, full_name=full_name, principal=principal)
-
- raise ValueError(f"Invalid action: '{action}'. Valid: grant, revoke, get, get_effective")
-
-
-# =============================================================================
-# Tool 3: manage_uc_storage
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_uc_storage(
- resource_type: str,
- action: str,
- name: str = None,
- aws_iam_role_arn: str = None,
- azure_access_connector_id: str = None,
- url: str = None,
- credential_name: str = None,
- read_only: bool = False,
- comment: str = None,
- owner: str = None,
- new_name: str = None,
- force: bool = False,
-) -> Dict[str, Any]:
- """Manage storage credentials and external locations.
-
- resource_type: "credential" or "external_location".
-
- credential actions:
- - create: name + (aws_iam_role_arn OR azure_access_connector_id), comment?, read_only?.
- - get/delete: name. delete supports force=True.
- - update: name, new_name?, comment?, owner?, aws_iam_role_arn?, azure_access_connector_id?.
- - validate: name, url (cloud path to validate access).
- - list: no params.
-
- external_location actions:
- - create: name, url (cloud path), credential_name, comment?, read_only?.
- - get/delete: name. delete supports force=True.
- - update: name, new_name?, url?, credential_name?, comment?, owner?, read_only?.
- - list: no params.
-
- Returns: get/create/update=resource details, list={items}, delete={status}, validate={results}."""
- rtype = resource_type.lower().replace(" ", "_").replace("-", "_")
-
- if rtype == "credential":
- if action == "create":
- return _to_dict(
- _create_storage_credential(
- name=name,
- comment=comment,
- aws_iam_role_arn=aws_iam_role_arn,
- azure_access_connector_id=azure_access_connector_id,
- read_only=read_only,
- )
- )
- elif action == "get":
- return _to_dict(_get_storage_credential(name=name))
- elif action == "list":
- return {"items": _to_dict_list(_list_storage_credentials())}
- elif action == "update":
- return _to_dict(
- _update_storage_credential(
- name=name,
- new_name=new_name,
- comment=comment,
- owner=owner,
- aws_iam_role_arn=aws_iam_role_arn,
- azure_access_connector_id=azure_access_connector_id,
- )
- )
- elif action == "delete":
- _delete_storage_credential(name=name, force=force)
- return {"status": "deleted", "credential": name}
- elif action == "validate":
- return _validate_storage_credential(name=name, url=url)
-
- elif rtype == "external_location":
- if action == "create":
- return _to_dict(
- _create_external_location(
- name=name,
- url=url,
- credential_name=credential_name,
- comment=comment,
- read_only=read_only,
- )
- )
- elif action == "get":
- return _to_dict(_get_external_location(name=name))
- elif action == "list":
- return {"items": _to_dict_list(_list_external_locations())}
- elif action == "update":
- return _to_dict(
- _update_external_location(
- name=name,
- new_name=new_name,
- url=url,
- credential_name=credential_name,
- comment=comment,
- owner=owner,
- read_only=read_only,
- )
- )
- elif action == "delete":
- _delete_external_location(name=name, force=force)
- return {"status": "deleted", "external_location": name}
-
- raise ValueError(f"Invalid resource_type='{resource_type}' or action='{action}'")
-
-
-# =============================================================================
-# Tool 4: manage_uc_connections
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_uc_connections(
- action: str,
- name: str = None,
- connection_type: str = None,
- options: Dict[str, str] = None,
- comment: str = None,
- owner: str = None,
- new_name: str = None,
- connection_name: str = None,
- catalog_name: str = None,
- catalog_options: Dict[str, str] = None,
- warehouse_id: str = None,
-) -> Dict[str, Any]:
- """Manage Lakehouse Federation foreign connections.
-
- action: "create", "get", "list", "update", "delete", "create_foreign_catalog".
- connection_type: SNOWFLAKE, POSTGRESQL, MYSQL, SQLSERVER, BIGQUERY, REDSHIFT, SQLDW (Azure Synapse).
-
- Parameters by action:
- - create: name, connection_type, options (dict with connection details), comment?.
- options format varies by type. Example for POSTGRESQL:
- {"host": "...", "port": "5432", "user": "...", "password": "..."}.
- - get/delete: name.
- - update: name, options?, new_name?, owner?.
- - list: no params.
- - create_foreign_catalog: Creates UC catalog from external connection.
- Requires: catalog_name (new UC catalog name), connection_name (existing connection).
- Optional: catalog_options (dict, e.g., {"database": "mydb"}), comment, warehouse_id.
-
- Returns: get/create/update=connection details, list={items}, delete={status}."""
- act = action.lower()
-
- if act == "create":
- return _to_dict(
- _create_connection(
- name=name,
- connection_type=connection_type,
- options=options,
- comment=comment,
- )
- )
- elif act == "get":
- return _to_dict(_get_connection(name=name))
- elif act == "list":
- return {"items": _to_dict_list(_list_connections())}
- elif act == "update":
- return _to_dict(_update_connection(name=name, options=options, new_name=new_name, owner=owner))
- elif act == "delete":
- _delete_connection(name=name)
- return {"status": "deleted", "connection": name}
- elif act == "create_foreign_catalog":
- return _create_foreign_catalog(
- catalog_name=catalog_name,
- connection_name=connection_name,
- catalog_options=catalog_options,
- comment=comment,
- warehouse_id=warehouse_id,
- )
-
- raise ValueError(f"Invalid action: '{action}'")
-
-
-# =============================================================================
-# Tool 5: manage_uc_tags
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_uc_tags(
- action: str,
- object_type: str = None,
- full_name: str = None,
- column_name: str = None,
- tags: Dict[str, str] = None,
- tag_names: List[str] = None,
- comment_text: str = None,
- catalog_filter: str = None,
- tag_name_filter: str = None,
- tag_value_filter: str = None,
- table_name_filter: str = None,
- limit: int = 100,
- warehouse_id: str = None,
-) -> Dict[str, Any]:
- """Manage UC tags and comments.
-
- action: "set_tags", "unset_tags", "set_comment", "query_table_tags", "query_column_tags".
-
- Parameters by action:
- - set_tags: object_type (catalog/schema/table/column), full_name, tags (dict of key-value pairs).
- For columns: also set column_name. warehouse_id? for SQL-based tagging.
- - unset_tags: object_type, full_name, tag_names (list of keys to remove).
- For columns: also set column_name. warehouse_id?.
- - set_comment: object_type, full_name, comment_text. For columns: column_name. warehouse_id?.
- - query_table_tags: Search tables by tags. catalog_filter?, tag_name_filter?, tag_value_filter?, limit? (default 100).
- - query_column_tags: Search columns by tags. catalog_filter?, table_name_filter?, tag_name_filter?, tag_value_filter?, limit?.
-
- Returns: set/unset={status}, query={data: [...]}."""
- act = action.lower()
-
- if act == "set_tags":
- return _set_tags(
- object_type=object_type,
- full_name=full_name,
- tags=tags,
- column_name=column_name,
- warehouse_id=warehouse_id,
- )
- elif act == "unset_tags":
- return _unset_tags(
- object_type=object_type,
- full_name=full_name,
- tag_names=tag_names,
- column_name=column_name,
- warehouse_id=warehouse_id,
- )
- elif act == "set_comment":
- return _set_comment(
- object_type=object_type,
- full_name=full_name,
- comment_text=comment_text,
- column_name=column_name,
- warehouse_id=warehouse_id,
- )
- elif act == "query_table_tags":
- return {
- "data": _query_table_tags(
- catalog_filter=catalog_filter,
- tag_name=tag_name_filter,
- tag_value=tag_value_filter,
- limit=limit,
- warehouse_id=warehouse_id,
- )
- }
- elif act == "query_column_tags":
- return {
- "data": _query_column_tags(
- catalog_filter=catalog_filter,
- table_name=table_name_filter,
- tag_name=tag_name_filter,
- tag_value=tag_value_filter,
- limit=limit,
- warehouse_id=warehouse_id,
- )
- }
-
- raise ValueError(f"Invalid action: '{action}'")
-
-
-# =============================================================================
-# Tool 6: manage_uc_security_policies
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_uc_security_policies(
- action: str,
- table_name: str = None,
- column_name: str = None,
- filter_function: str = None,
- filter_columns: List[str] = None,
- mask_function: str = None,
- function_name: str = None,
- function_body: str = None,
- parameter_name: str = None,
- parameter_type: str = None,
- return_type: str = None,
- function_comment: str = None,
- warehouse_id: str = None,
-) -> Dict[str, Any]:
- """Manage row-level security and column masking.
-
- action: "set_row_filter", "drop_row_filter", "set_column_mask", "drop_column_mask", "create_security_function".
-
- Parameters by action:
- - set_row_filter: table_name (full name), filter_function (UDF name), filter_columns (list of columns to pass).
- Example: filter_function="main.default.row_filter_fn", filter_columns=["user_id"].
- - drop_row_filter: table_name.
- - set_column_mask: table_name, column_name, mask_function (UDF that returns masked value).
- - drop_column_mask: table_name, column_name.
- - create_security_function: Creates a UDF for row filtering or column masking.
- Requires: function_name (full name), parameter_name, parameter_type, return_type, function_body.
- Example: function_name="main.default.my_filter", parameter_name="user_id", parameter_type="STRING",
- return_type="BOOLEAN", function_body="return user_id = current_user()".
-
- All actions accept optional warehouse_id for SQL execution.
- Returns: {status, message} or function details for create."""
- act = action.lower()
-
- if act == "set_row_filter":
- return _set_row_filter(
- table_name=table_name,
- filter_function=filter_function,
- filter_columns=filter_columns,
- warehouse_id=warehouse_id,
- )
- elif act == "drop_row_filter":
- return _drop_row_filter(table_name=table_name, warehouse_id=warehouse_id)
- elif act == "set_column_mask":
- return _set_column_mask(
- table_name=table_name,
- column_name=column_name,
- mask_function=mask_function,
- warehouse_id=warehouse_id,
- )
- elif act == "drop_column_mask":
- return _drop_column_mask(table_name=table_name, column_name=column_name, warehouse_id=warehouse_id)
- elif act == "create_security_function":
- return _create_security_function(
- function_name=function_name,
- parameter_name=parameter_name,
- parameter_type=parameter_type,
- return_type=return_type,
- function_body=function_body,
- comment=function_comment,
- warehouse_id=warehouse_id,
- )
-
- raise ValueError(f"Invalid action: '{action}'")
-
-
-# =============================================================================
-# Tool 7: manage_uc_monitors
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_uc_monitors(
- action: str,
- table_name: str,
- output_schema_name: str = None,
- schedule_cron: str = None,
- schedule_timezone: str = "UTC",
- assets_dir: str = None,
-) -> Dict[str, Any]:
- """Manage Lakehouse quality monitors for data quality tracking.
-
- action: "create", "get", "run_refresh", "list_refreshes", "delete".
- table_name: Full table name (required for all actions).
-
- Parameters by action:
- - create: table_name, output_schema_name (where metrics tables are stored).
- Optional: assets_dir (for dashboard assets), schedule_cron (e.g., "0 0 * * *"),
- schedule_timezone (default "UTC").
- - get: table_name. Returns monitor config and status.
- - run_refresh: table_name. Triggers a new monitor refresh.
- - list_refreshes: table_name. Returns {refreshes: [...]}.
- - delete: table_name. Removes the monitor.
-
- Returns: create/get=monitor details, run_refresh={status}, list_refreshes={refreshes}, delete={status}."""
- act = action.lower()
-
- if act == "create":
- return _create_monitor(
- table_name=table_name,
- output_schema_name=output_schema_name,
- assets_dir=assets_dir,
- schedule_cron=schedule_cron,
- schedule_timezone=schedule_timezone,
- )
- elif act == "get":
- return _get_monitor(table_name=table_name)
- elif act == "run_refresh":
- return _run_monitor_refresh(table_name=table_name)
- elif act == "list_refreshes":
- return {"refreshes": _list_monitor_refreshes(table_name=table_name)}
- elif act == "delete":
- _delete_monitor(table_name=table_name)
- return {"status": "deleted", "table_name": table_name}
-
- raise ValueError(f"Invalid action: '{action}'")
-
-
-# =============================================================================
-# Tool 8: manage_uc_sharing
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_uc_sharing(
- resource_type: str,
- action: str,
- name: str = None,
- comment: str = None,
- table_name: str = None,
- shared_as: str = None,
- partition_spec: str = None,
- authentication_type: str = None,
- sharing_id: str = None,
- ip_access_list: List[str] = None,
- share_name: str = None,
- recipient_name: str = None,
- include_shared_data: bool = True,
-) -> Dict[str, Any]:
- """Manage Delta Sharing: shares, recipients, and providers.
-
- resource_type: "share", "recipient", or "provider".
-
- SHARE actions (for data providers to share tables):
- - create: name, comment?. Creates an empty share.
- - get: name, include_shared_data? (default True).
- - list: no params. Returns {items: [...]}.
- - delete: name.
- - add_table: name (or share_name), table_name (full UC name), shared_as? (alias), partition_spec?.
- - remove_table: name (or share_name), table_name.
- - grant_to_recipient: name (or share_name), recipient_name.
- - revoke_from_recipient: name (or share_name), recipient_name.
-
- RECIPIENT actions (for data providers to manage share consumers):
- - create: name, authentication_type? (TOKEN/DATABRICKS), sharing_id?, comment?, ip_access_list?.
- - get: name. list: no params. delete: name.
- - rotate_token: name. Generates new access token for TOKEN-based recipients.
-
- PROVIDER actions (for data consumers to view available shares):
- - get: name. list: no params.
- - list_shares: name (provider name). Lists shares available from this provider.
-
- Returns: create/get=details, list={items}, delete={status}."""
- rtype = resource_type.lower()
- act = action.lower()
-
- if rtype == "share":
- if act == "create":
- return _create_share(name=name, comment=comment)
- elif act == "get":
- return _get_share(name=name, include_shared_data=include_shared_data)
- elif act == "list":
- return {"items": _list_shares()}
- elif act == "delete":
- _delete_share(name=name)
- return {"status": "deleted", "share": name}
- elif act == "add_table":
- return _add_table_to_share(
- share_name=name or share_name,
- table_name=table_name,
- shared_as=shared_as,
- partition_spec=partition_spec,
- )
- elif act == "remove_table":
- return _remove_table_from_share(share_name=name or share_name, table_name=table_name)
- elif act == "grant_to_recipient":
- return _grant_share_to_recipient(share_name=name or share_name, recipient_name=recipient_name)
- elif act == "revoke_from_recipient":
- return _revoke_share_from_recipient(share_name=name or share_name, recipient_name=recipient_name)
-
- elif rtype == "recipient":
- if act == "create":
- return _create_recipient(
- name=name,
- authentication_type=authentication_type or "TOKEN",
- sharing_id=sharing_id,
- comment=comment,
- ip_access_list=ip_access_list,
- )
- elif act == "get":
- return _get_recipient(name=name)
- elif act == "list":
- return {"items": _list_recipients()}
- elif act == "delete":
- _delete_recipient(name=name)
- return {"status": "deleted", "recipient": name}
- elif act == "rotate_token":
- return _rotate_recipient_token(name=name)
-
- elif rtype == "provider":
- if act == "get":
- return _get_provider(name=name)
- elif act == "list":
- return {"items": _list_providers()}
- elif act == "list_shares":
- return {"items": _list_provider_shares(name=name)}
-
- raise ValueError(f"Invalid resource_type='{resource_type}' or action='{action}'")
-
-
-# =============================================================================
-# Tool 9: manage_metric_views
-# =============================================================================
-
-
-@mcp.tool(timeout=60)
-def manage_metric_views(
- action: str,
- full_name: str,
- source: str = None,
- dimensions: List[Dict[str, str]] = None,
- measures: List[Dict[str, str]] = None,
- version: str = "1.1",
- comment: str = None,
- filter_expr: str = None,
- joins: List[Dict[str, Any]] = None,
- materialization: Dict[str, Any] = None,
- or_replace: bool = False,
- query_measures: List[str] = None,
- query_dimensions: List[str] = None,
- where: str = None,
- order_by: str = None,
- limit: int = None,
- principal: str = None,
- privileges: List[str] = None,
- warehouse_id: str = None,
-) -> Dict[str, Any]:
- """Manage UC metric views (reusable business metrics). Requires DBR 17.2+.
-
- action: "create", "alter", "describe", "query", "drop", "grant".
- full_name: Full metric view name (catalog.schema.metric_view).
-
- Parameters by action:
- - create: full_name, source (table/view name), dimensions, measures.
- dimensions: List of dicts [{name: "dim_name", expr: "column_or_expr"}, ...].
- measures: List of dicts [{name: "measure_name", expr: "SUM(amount)"}, ...] (aggregate functions).
- Optional: version (default "1.1"), comment, filter_expr, joins, materialization, or_replace.
- - alter: Same params as create except or_replace. Updates existing metric view.
- - describe: full_name. Returns metric view definition and metadata.
- - query: full_name, query_measures (list of measure names to retrieve).
- Optional: query_dimensions (list of dimension names), where, order_by, limit.
- - drop: full_name. Deletes the metric view.
- - grant: full_name, principal, privileges (list, e.g., ["SELECT"]).
-
- All actions accept optional warehouse_id for SQL execution.
- Returns: create/alter/describe/grant=details, query={data: [...]}, drop={status}."""
- act = action.lower()
-
- if act == "create":
- result = _create_metric_view(
- full_name=full_name,
- source=source,
- dimensions=dimensions,
- measures=measures,
- version=version,
- comment=comment,
- filter_expr=filter_expr,
- joins=joins,
- materialization=materialization,
- or_replace=or_replace,
- warehouse_id=warehouse_id,
- )
- _auto_tag("metric_view", full_name)
- try:
- from ..manifest import track_resource
-
- track_resource(
- resource_type="metric_view",
- name=full_name,
- resource_id=full_name,
- )
- except Exception:
- pass
- return result
- elif act == "alter":
- return _alter_metric_view(
- full_name=full_name,
- source=source,
- dimensions=dimensions,
- measures=measures,
- version=version,
- comment=comment,
- filter_expr=filter_expr,
- joins=joins,
- materialization=materialization,
- warehouse_id=warehouse_id,
- )
- elif act == "describe":
- return _describe_metric_view(
- full_name=full_name,
- warehouse_id=warehouse_id,
- )
- elif act == "query":
- if not query_measures:
- raise ValueError("query_measures is required for query action")
- return {
- "data": _query_metric_view(
- full_name=full_name,
- measures=query_measures,
- dimensions=query_dimensions,
- where=where,
- order_by=order_by,
- limit=limit,
- warehouse_id=warehouse_id,
- )
- }
- elif act == "drop":
- result = _drop_metric_view(
- full_name=full_name,
- warehouse_id=warehouse_id,
- )
- try:
- from ..manifest import remove_resource
-
- remove_resource(resource_type="metric_view", resource_id=full_name)
- except Exception:
- pass
- return result
- elif act == "grant":
- return _grant_metric_view(
- full_name=full_name,
- principal=principal,
- privileges=privileges,
- warehouse_id=warehouse_id,
- )
-
- raise ValueError(f"Invalid action: '{action}'. Valid: create, alter, describe, query, drop, grant")
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/user.py b/databricks-mcp-server/databricks_mcp_server/tools/user.py
deleted file mode 100644
index 112cf566..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/user.py
+++ /dev/null
@@ -1,20 +0,0 @@
-"""User tools - Get information about the current Databricks user."""
-
-from typing import Dict, Any
-
-from databricks_tools_core.auth import get_current_username
-
-from ..server import mcp
-
-
-@mcp.tool(timeout=30)
-def get_current_user() -> Dict[str, Any]:
- """Get current Databricks user identity.
-
- Returns: {username (email), home_path (/Workspace/Users/user@example.com/)}."""
- username = get_current_username()
- home_path = f"/Workspace/Users/{username}/" if username else None
- return {
- "username": username,
- "home_path": home_path,
- }
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/vector_search.py b/databricks-mcp-server/databricks_mcp_server/tools/vector_search.py
deleted file mode 100644
index a9520511..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/vector_search.py
+++ /dev/null
@@ -1,333 +0,0 @@
-"""Vector Search tools - Manage endpoints, indexes, and query vector data.
-
-Consolidated into 4 tools:
-- manage_vs_endpoint: create_or_update, get, list, delete
-- manage_vs_index: create_or_update, get, list, delete
-- query_vs_index: query vectors (hot path - kept separate)
-- manage_vs_data: upsert, delete, scan, sync
-"""
-
-import json
-import logging
-from typing import Any, Dict, List, Optional, Union
-
-from databricks_tools_core.vector_search import (
- create_vs_endpoint as _create_vs_endpoint,
- get_vs_endpoint as _get_vs_endpoint,
- list_vs_endpoints as _list_vs_endpoints,
- delete_vs_endpoint as _delete_vs_endpoint,
- create_vs_index as _create_vs_index,
- get_vs_index as _get_vs_index,
- list_vs_indexes as _list_vs_indexes,
- delete_vs_index as _delete_vs_index,
- sync_vs_index as _sync_vs_index,
- query_vs_index as _query_vs_index,
- upsert_vs_data as _upsert_vs_data,
- delete_vs_data as _delete_vs_data,
- scan_vs_index as _scan_vs_index,
-)
-
-from ..server import mcp
-
-logger = logging.getLogger(__name__)
-
-
-# ============================================================================
-# Helpers
-# ============================================================================
-
-
-def _find_endpoint_by_name(name: str) -> Optional[Dict[str, Any]]:
- """Find a vector search endpoint by name, returns None if not found."""
- try:
- result = _get_vs_endpoint(name=name)
- if result.get("state") == "NOT_FOUND":
- return None
- return result
- except Exception:
- return None
-
-
-def _find_index_by_name(index_name: str) -> Optional[Dict[str, Any]]:
- """Find a vector search index by name, returns None if not found."""
- try:
- result = _get_vs_index(index_name=index_name)
- if result.get("state") == "NOT_FOUND":
- return None
- return result
- except Exception:
- return None
-
-
-# ============================================================================
-# Tool 1: manage_vs_endpoint
-# ============================================================================
-
-
-@mcp.tool(timeout=120)
-def manage_vs_endpoint(
- action: str,
- name: Optional[str] = None,
- endpoint_type: str = "STANDARD",
-) -> Dict[str, Any]:
- """Manage Vector Search endpoints: create, get, list, delete.
-
- Actions:
- - create_or_update: Idempotent create. Returns existing if found. Requires name.
- endpoint_type: "STANDARD" (<100ms latency) or "STORAGE_OPTIMIZED" (~250ms, 1B+ vectors).
- Async creation - poll with action="get" until state=ONLINE.
- Returns: {name, endpoint_type, state, created: bool}.
- - get: Get endpoint details. Requires name.
- Returns: {name, state, num_indexes, ...}.
- - list: List all endpoints.
- Returns: {endpoints: [{name, state, ...}, ...]}.
- - delete: Delete endpoint. All indexes must be deleted first. Requires name.
- Returns: {name, status}.
-
- See databricks-vector-search skill for endpoint configuration."""
- act = action.lower()
-
- if act == "create_or_update":
- if not name:
- return {"error": "create_or_update requires: name"}
-
- existing = _find_endpoint_by_name(name)
- if existing:
- return {**existing, "created": False}
-
- result = _create_vs_endpoint(name=name, endpoint_type=endpoint_type)
-
- try:
- from ..manifest import track_resource
- track_resource(resource_type="vs_endpoint", name=name, resource_id=name)
- except Exception:
- pass
-
- return {**result, "created": True}
-
- elif act == "get":
- if not name:
- return {"error": "get requires: name"}
- return _get_vs_endpoint(name=name)
-
- elif act == "list":
- return {"endpoints": _list_vs_endpoints()}
-
- elif act == "delete":
- if not name:
- return {"error": "delete requires: name"}
- return _delete_vs_endpoint(name=name)
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: create_or_update, get, list, delete"}
-
-
-# ============================================================================
-# Tool 2: manage_vs_index
-# ============================================================================
-
-
-@mcp.tool(timeout=120)
-def manage_vs_index(
- action: str,
- # For create_or_update:
- name: Optional[str] = None,
- endpoint_name: Optional[str] = None,
- primary_key: Optional[str] = None,
- index_type: str = "DELTA_SYNC",
- delta_sync_index_spec: Optional[Dict[str, Any]] = None,
- direct_access_index_spec: Optional[Dict[str, Any]] = None,
-) -> Dict[str, Any]:
- """Manage Vector Search indexes: create, get, list, delete.
-
- Actions:
- - create_or_update: Idempotent create. Returns existing if found. Auto-triggers initial sync for DELTA_SYNC.
- Requires name, endpoint_name, primary_key.
- index_type: "DELTA_SYNC" (auto-sync from Delta table) or "DIRECT_ACCESS" (manual CRUD via manage_vs_data).
- delta_sync_index_spec: {source_table, embedding_source_columns OR embedding_vector_columns, pipeline_type}.
- - embedding_source_columns: List of text columns for managed embeddings (Databricks generates vectors).
- - embedding_vector_columns: List of {name, dimension} for self-managed embeddings (you provide vectors).
- - pipeline_type: "TRIGGERED" (manual sync) or "CONTINUOUS" (auto-sync on changes).
- direct_access_index_spec: {embedding_vector_columns: [{name, dimension}], schema_json}.
- Returns: {name, created: bool, sync_triggered}.
- - get: Get index details. Requires name (format: catalog.schema.index_name).
- Returns: {name, state, index_type, ...}.
- - list: List indexes. Optional endpoint_name to filter. Omit for all indexes across all endpoints.
- Returns: {indexes: [...]}.
- - delete: Delete index. Requires name.
- Returns: {name, status}.
-
- See databricks-vector-search skill for full spec details and examples."""
- act = action.lower()
-
- if act == "create_or_update":
- if not all([name, endpoint_name, primary_key]):
- return {"error": "create_or_update requires: name, endpoint_name, primary_key"}
-
- existing = _find_index_by_name(name)
- if existing:
- return {**existing, "created": False}
-
- result = _create_vs_index(
- name=name,
- endpoint_name=endpoint_name,
- primary_key=primary_key,
- index_type=index_type,
- delta_sync_index_spec=delta_sync_index_spec,
- direct_access_index_spec=direct_access_index_spec,
- )
-
- # Trigger initial sync for DELTA_SYNC indexes
- if index_type == "DELTA_SYNC" and result.get("status") != "ALREADY_EXISTS":
- try:
- _sync_vs_index(index_name=name)
- result["sync_triggered"] = True
- except Exception as e:
- logger.warning("Failed to trigger initial sync for index '%s': %s", name, e)
- result["sync_triggered"] = False
-
- try:
- from ..manifest import track_resource
- track_resource(resource_type="vs_index", name=name, resource_id=name)
- except Exception:
- pass
-
- return {**result, "created": True}
-
- elif act == "get":
- if not name:
- return {"error": "get requires: name"}
- return _get_vs_index(index_name=name)
-
- elif act == "list":
- if endpoint_name:
- return {"indexes": _list_vs_indexes(endpoint_name=endpoint_name)}
-
- # List all indexes across all endpoints
- all_indexes = []
- endpoints = _list_vs_endpoints()
- for ep in endpoints:
- ep_name = ep.get("name")
- if not ep_name:
- continue
- try:
- indexes = _list_vs_indexes(endpoint_name=ep_name)
- for idx in indexes:
- idx["endpoint_name"] = ep_name
- all_indexes.extend(indexes)
- except Exception:
- logger.warning("Failed to list indexes on endpoint '%s'", ep_name)
- return {"indexes": all_indexes}
-
- elif act == "delete":
- if not name:
- return {"error": "delete requires: name"}
- return _delete_vs_index(index_name=name)
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: create_or_update, get, list, delete"}
-
-
-# ============================================================================
-# Tool 3: query_vs_index (HOT PATH - kept separate for performance)
-# ============================================================================
-
-
-@mcp.tool(timeout=60)
-def query_vs_index(
- index_name: str,
- columns: List[str],
- query_text: Optional[str] = None,
- query_vector: Optional[List[float]] = None,
- num_results: int = 5,
- filters_json: Optional[Union[str, dict]] = None,
- filter_string: Optional[str] = None,
- query_type: Optional[str] = None,
-) -> Dict[str, Any]:
- """Query a Vector Search index for similar documents.
-
- Use ONE OF:
- - query_text: For managed embeddings (Databricks generates vector from text).
- - query_vector: For self-managed embeddings (you provide the vector).
-
- columns: List of columns to return in results.
- num_results: Number of results to return (default 5).
- Filters (use one based on endpoint type):
- - filters_json: For STANDARD endpoints. Dict like {"field": "value"} or {"field NOT": "value"}.
- - filter_string: For STORAGE_OPTIMIZED endpoints. SQL WHERE clause like "field = 'value'".
- query_type: "ANN" (default, approximate) or "HYBRID" (combines vector + keyword search).
-
- Returns: {columns, data (with similarity score appended), num_results}."""
- # MCP deserializes JSON params, so filters_json may arrive as a dict
- if isinstance(filters_json, dict):
- filters_json = json.dumps(filters_json)
-
- return _query_vs_index(
- index_name=index_name,
- columns=columns,
- query_text=query_text,
- query_vector=query_vector,
- num_results=num_results,
- filters_json=filters_json,
- filter_string=filter_string,
- query_type=query_type,
- )
-
-
-# ============================================================================
-# Tool 4: manage_vs_data
-# ============================================================================
-
-
-@mcp.tool(timeout=120)
-def manage_vs_data(
- action: str,
- index_name: str,
- # For upsert:
- inputs_json: Optional[Union[str, list]] = None,
- # For delete:
- primary_keys: Optional[List[str]] = None,
- # For scan:
- num_results: int = 100,
-) -> Dict[str, Any]:
- """Manage Vector Search index data: upsert, delete, scan, sync.
-
- Actions:
- - upsert: Insert or update records. Requires inputs_json.
- inputs_json: List of records, each with primary key + embedding vector.
- Example: [{"id": "doc1", "text": "...", "embedding": [0.1, 0.2, ...]}]
- Returns: {status, upserted_count}.
- - delete: Delete records by primary key. Requires primary_keys.
- primary_keys: List of primary key values to delete.
- Returns: {status, deleted_count}.
- - scan: Scan index contents. Optional num_results (default 100).
- Returns: {columns, data, num_results}.
- - sync: Trigger re-sync for TRIGGERED DELTA_SYNC indexes.
- Returns: {index_name, status: "sync_triggered"}.
-
- For DIRECT_ACCESS indexes, use upsert/delete to manage data.
- For DELTA_SYNC indexes, use sync to trigger refresh from source table."""
- act = action.lower()
-
- if act == "upsert":
- if inputs_json is None:
- return {"error": "upsert requires: inputs_json"}
- # MCP deserializes JSON params, so inputs_json may arrive as a list
- if isinstance(inputs_json, (dict, list)):
- inputs_json = json.dumps(inputs_json)
- return _upsert_vs_data(index_name=index_name, inputs_json=inputs_json)
-
- elif act == "delete":
- if primary_keys is None:
- return {"error": "delete requires: primary_keys"}
- return _delete_vs_data(index_name=index_name, primary_keys=primary_keys)
-
- elif act == "scan":
- return _scan_vs_index(index_name=index_name, num_results=num_results)
-
- elif act == "sync":
- _sync_vs_index(index_name=index_name)
- return {"index_name": index_name, "status": "sync_triggered"}
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: upsert, delete, scan, sync"}
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/volume_files.py b/databricks-mcp-server/databricks_mcp_server/tools/volume_files.py
deleted file mode 100644
index 73485c91..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/volume_files.py
+++ /dev/null
@@ -1,162 +0,0 @@
-"""Volume file tools - Manage files in Unity Catalog Volumes.
-
-Consolidated into 1 tool:
-- manage_volume_files: list, upload, download, delete, mkdir, get_info
-"""
-
-from typing import Dict, Any, Optional
-
-from databricks_tools_core.unity_catalog import (
- list_volume_files as _list_volume_files,
- upload_to_volume as _upload_to_volume,
- download_from_volume as _download_from_volume,
- delete_from_volume as _delete_from_volume,
- create_volume_directory as _create_volume_directory,
- get_volume_file_metadata as _get_volume_file_metadata,
-)
-
-from ..server import mcp
-
-
-@mcp.tool(timeout=300)
-def manage_volume_files(
- action: str,
- volume_path: str,
- # For upload:
- local_path: Optional[str] = None,
- # For download:
- local_destination: Optional[str] = None,
- # For list:
- max_results: int = 500,
- # For delete:
- recursive: bool = False,
- # Common:
- max_workers: int = 4,
- overwrite: bool = True,
-) -> Dict[str, Any]:
- """Manage Unity Catalog Volume files: list, upload, download, delete, mkdir, get_info.
-
- Actions:
- - list: List files in volume path. Returns: {files: [{name, path, is_directory, file_size}], truncated}.
- max_results: Limit results (default 500, max 1000).
- - upload: Upload local file/folder/glob to volume. Auto-creates directories.
- Requires volume_path, local_path. Returns: {total_files, successful, failed}.
- - download: Download file from volume to local path.
- Requires volume_path, local_destination. Returns: {success, error}.
- - delete: Delete file/directory from volume.
- recursive=True for non-empty directories. Returns: {files_deleted, directories_deleted}.
- - mkdir: Create directory in volume (like mkdir -p). Idempotent.
- Returns: {success}.
- - get_info: Get file/directory metadata.
- Returns: {name, path, is_directory, file_size, last_modified}.
-
- volume_path format: /Volumes/catalog/schema/volume/path/to/file_or_dir
- Supports tilde expansion (~) and glob patterns for local_path."""
- act = action.lower()
-
- if act == "list":
- # Cap max_results to prevent buffer overflow (1MB JSON limit)
- capped_max = min(max_results, 1000)
-
- # Fetch one extra to detect if there are more results
- results = _list_volume_files(volume_path, max_results=capped_max + 1)
- truncated = len(results) > capped_max
-
- # Only return up to max_results
- results = results[:capped_max]
-
- files = [
- {
- "name": r.name,
- "path": r.path,
- "is_directory": r.is_directory,
- "file_size": r.file_size,
- "last_modified": r.last_modified,
- }
- for r in results
- ]
-
- return {
- "files": files,
- "returned_count": len(files),
- "truncated": truncated,
- "message": f"Results limited to {len(files)} items. Use a more specific path to see more."
- if truncated
- else None,
- }
-
- elif act == "upload":
- if not local_path:
- return {"error": "upload requires: local_path"}
-
- result = _upload_to_volume(
- local_path=local_path,
- volume_path=volume_path,
- max_workers=max_workers,
- overwrite=overwrite,
- )
- return {
- "local_folder": result.local_folder,
- "remote_folder": result.remote_folder,
- "total_files": result.total_files,
- "successful": result.successful,
- "failed": result.failed,
- "success": result.success,
- "failed_uploads": [{"local_path": r.local_path, "error": r.error} for r in result.get_failed_uploads()]
- if result.failed > 0
- else [],
- }
-
- elif act == "download":
- if not local_destination:
- return {"error": "download requires: local_destination"}
-
- result = _download_from_volume(
- volume_path=volume_path,
- local_path=local_destination,
- overwrite=overwrite,
- )
- return {
- "volume_path": result.volume_path,
- "local_path": result.local_path,
- "success": result.success,
- "error": result.error,
- }
-
- elif act == "delete":
- result = _delete_from_volume(
- volume_path=volume_path,
- recursive=recursive,
- max_workers=max_workers,
- )
- return {
- "volume_path": result.volume_path,
- "success": result.success,
- "files_deleted": result.files_deleted,
- "directories_deleted": result.directories_deleted,
- "error": result.error,
- }
-
- elif act == "mkdir":
- try:
- _create_volume_directory(volume_path)
- return {"volume_path": volume_path, "success": True}
- except Exception as e:
- return {"volume_path": volume_path, "success": False, "error": str(e)}
-
- elif act == "get_info":
- try:
- info = _get_volume_file_metadata(volume_path)
- return {
- "name": info.name,
- "path": info.path,
- "is_directory": info.is_directory,
- "file_size": info.file_size,
- "last_modified": info.last_modified,
- "success": True,
- }
- except Exception as e:
- return {"volume_path": volume_path, "success": False, "error": str(e)}
-
- else:
- return {"error": f"Invalid action '{action}'. Valid actions: list, upload, download, delete, mkdir, get_info"}
diff --git a/databricks-mcp-server/databricks_mcp_server/tools/workspace.py b/databricks-mcp-server/databricks_mcp_server/tools/workspace.py
deleted file mode 100644
index 2973c559..00000000
--- a/databricks-mcp-server/databricks_mcp_server/tools/workspace.py
+++ /dev/null
@@ -1,240 +0,0 @@
-"""Workspace management tool - switch between Databricks workspaces at runtime."""
-
-import configparser
-import os
-import subprocess
-from typing import Any, Dict, List, Optional
-
-from databricks_tools_core.auth import (
- get_active_workspace,
- get_workspace_client,
- set_active_workspace,
-)
-
-from ..server import mcp
-
-_DATABRICKS_CFG_PATH = os.path.expanduser("~/.databrickscfg")
-_VALID_ACTIONS = ("status", "list", "switch", "login")
-
-_TOKEN_EXPIRED_PATTERNS = (
- "refresh token is invalid",
- "token is expired",
- "access token could not be retrieved",
- "invalid_grant",
- "token has expired",
- "unauthenticated",
- "invalid access token",
-)
-
-
-def _read_profiles() -> List[Dict[str, str]]:
- """Parse ~/.databrickscfg and return a list of profile dicts.
-
- configparser treats [DEFAULT] as a special section that does not appear
- in cfg.sections(), so we handle it explicitly via cfg.defaults().
- """
- cfg = configparser.ConfigParser()
- try:
- cfg.read(_DATABRICKS_CFG_PATH)
- except Exception:
- return []
- profiles = []
- # Include DEFAULT section if it has any keys
- if cfg.defaults():
- host = cfg.defaults().get("host", None)
- profiles.append({"profile": "DEFAULT", "host": host or "(no host configured)"})
- for section in cfg.sections():
- host = cfg.get(section, "host", fallback=None)
- profiles.append({"profile": section, "host": host or "(no host configured)"})
- return profiles
-
-
-def _derive_profile_name(host: str) -> str:
- """Derive a profile name from a workspace URL.
-
- E.g. https://adb-1234567890.7.azuredatabricks.net -> adb-1234567890
- """
- # Strip scheme and trailing slash
- name = host.rstrip("/")
- if "://" in name:
- name = name.split("://", 1)[1]
- # Take the first hostname segment (before the first dot)
- name = name.split(".")[0]
- return name or "workspace"
-
-
-def _validate_and_switch(profile: Optional[str] = None, host: Optional[str] = None) -> Dict[str, Any]:
- """Set active workspace state and validate by calling current_user.me().
-
- Rolls back if validation fails.
-
- Returns a success dict on success, raises on failure.
- """
- previous = get_active_workspace()
- set_active_workspace(profile=profile, host=host)
- try:
- client = get_workspace_client()
- me = client.current_user.me()
- return {
- "host": client.config.host,
- "profile": profile or host,
- "username": me.user_name,
- }
- except Exception as exc:
- # Roll back to previous state
- set_active_workspace(
- profile=previous["profile"],
- host=previous["host"],
- )
- raise exc
-
-
-def _manage_workspace_impl(
- action: str,
- profile: Optional[str] = None,
- host: Optional[str] = None,
-) -> Dict[str, Any]:
- """Business logic for manage_workspace. Separated from the MCP decorator
- so it can be imported and tested directly without FastMCP wrapping."""
-
- if action not in _VALID_ACTIONS:
- return {"error": f"Invalid action '{action}'. Valid actions: {', '.join(_VALID_ACTIONS)}"}
-
- # -------------------------------------------------------------------------
- # status: return info about the currently connected workspace
- # -------------------------------------------------------------------------
- if action == "status":
- try:
- client = get_workspace_client()
- me = client.current_user.me()
- active = get_active_workspace()
- env_profile = os.environ.get("DATABRICKS_CONFIG_PROFILE")
- return {
- "host": client.config.host,
- "profile": active["profile"] or env_profile or "(default)",
- "username": me.user_name,
- }
- except Exception as exc:
- return {"error": f"Failed to get workspace status: {exc}"}
-
- # -------------------------------------------------------------------------
- # list: show all profiles from ~/.databrickscfg
- # -------------------------------------------------------------------------
- if action == "list":
- profiles = _read_profiles()
- if not profiles:
- return {
- "profiles": [],
- "message": f"No profiles found in {_DATABRICKS_CFG_PATH}. "
- "Run manage_workspace(action='login', host='...') to add one.",
- }
- active = get_active_workspace()
- env_profile = os.environ.get("DATABRICKS_CONFIG_PROFILE")
- current_profile = active["profile"] or env_profile
-
- for p in profiles:
- p["active"] = p["profile"] == current_profile
-
- return {"profiles": profiles}
-
- # -------------------------------------------------------------------------
- # switch: change to an existing profile or host
- # -------------------------------------------------------------------------
- if action == "switch":
- if not profile and not host:
- return {"error": "Provide either 'profile' (name from ~/.databrickscfg) or 'host' (workspace URL)."}
-
- if profile:
- # Verify profile exists in config
- known = {p["profile"] for p in _read_profiles()}
- if profile not in known:
- suggestions = ", ".join(sorted(known)) if known else "none configured"
- return {
- "error": f"Profile '{profile}' not found in {_DATABRICKS_CFG_PATH}. "
- f"Available profiles: {suggestions}. "
- "Use action='login' to authenticate a new workspace."
- }
-
- try:
- result = _validate_and_switch(profile=profile, host=host)
- result["message"] = f"Switched to workspace: {result['host']}"
- return result
- except Exception as exc:
- err_str = str(exc).lower()
- is_expired = any(p in err_str for p in _TOKEN_EXPIRED_PATTERNS)
- if is_expired:
- # Look up the host for this profile so the LLM can call login directly
- profile_host = host
- if not profile_host and profile:
- for p in _read_profiles():
- if p["profile"] == profile:
- profile_host = p["host"]
- break
- return {
- "error": "Token expired or invalid for this workspace.",
- "token_expired": True,
- "profile": profile,
- "host": profile_host,
- "action_required": f"Run manage_workspace(action='login', host='{profile_host}') "
- "to re-authenticate via browser OAuth.",
- }
- return {
- "error": f"Failed to connect to workspace: {exc}",
- "hint": "Check your credentials or use action='login' to re-authenticate.",
- }
-
- # -------------------------------------------------------------------------
- # login: run OAuth via the Databricks CLI then switch
- # -------------------------------------------------------------------------
- if action == "login":
- if not host:
- return {"error": "Provide 'host' (workspace URL) for the login action."}
-
- derived_profile = _derive_profile_name(host)
-
- try:
- proc = subprocess.run(
- ["databricks", "auth", "login", "--host", host, "--profile", derived_profile],
- capture_output=True,
- text=True,
- stdin=subprocess.DEVNULL,
- timeout=120,
- )
- except subprocess.TimeoutExpired:
- return {
- "error": "OAuth login timed out after 120 seconds. "
- "Please complete the browser authorization flow promptly, "
- "or run 'databricks auth login --host ' manually in a terminal."
- }
- except FileNotFoundError:
- return {
- "error": "Databricks CLI not found. Install it with: pip install databricks-cli "
- "or brew install databricks/tap/databricks"
- }
-
- if proc.returncode != 0:
- stderr = proc.stderr.strip() or proc.stdout.strip()
- return {"error": f"databricks auth login failed (exit {proc.returncode}): {stderr}"}
-
- try:
- conn = _validate_and_switch(profile=derived_profile, host=host)
- conn["message"] = f"Logged in and switched to workspace: {conn['host']}"
- return conn
- except Exception as exc:
- return {
- "error": f"Login succeeded but validation failed: {exc}",
- "hint": f"Try manage_workspace(action='switch', profile='{derived_profile}') manually.",
- }
-
-
-@mcp.tool(timeout=60)
-def manage_workspace(
- action: str,
- profile: Optional[str] = None,
- host: Optional[str] = None,
-) -> Dict[str, Any]:
- """Manage active Databricks workspace connection (session-scoped).
-
- Actions: status (current workspace), list (profiles from ~/.databrickscfg), switch (profile or host), login (OAuth via CLI).
- Returns: {host, profile, username} or {profiles: [...]}."""
- return _manage_workspace_impl(action=action, profile=profile, host=host)
diff --git a/databricks-mcp-server/pyproject.toml b/databricks-mcp-server/pyproject.toml
deleted file mode 100644
index 24ec1489..00000000
--- a/databricks-mcp-server/pyproject.toml
+++ /dev/null
@@ -1,37 +0,0 @@
-[build-system]
-requires = ["setuptools>=61.0", "wheel"]
-build-backend = "setuptools.build_meta"
-
-[project]
-name = "databricks-mcp-server"
-version = "0.1.0"
-description = "MCP server exposing Databricks operations via FastMCP"
-readme = "README.md"
-requires-python = ">=3.9"
-license = {file = "LICENSE.md"}
-authors = [
- {name = "Databricks"},
-]
-classifiers = [
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- "Programming Language :: Python :: 3.12",
-]
-dependencies = [
- "databricks-tools-core",
- "fastmcp==3.1.1",
-]
-
-[project.optional-dependencies]
-dev = [
- "pytest>=7.0.0",
- "pytest-asyncio>=1.0.0",
- "black>=23.0.0",
- "ruff>=0.1.0",
-]
-
-[tool.setuptools.packages.find]
-where = ["."]
-include = ["databricks_mcp_server*"]
diff --git a/databricks-mcp-server/run_server.py b/databricks-mcp-server/run_server.py
deleted file mode 100755
index 78520d62..00000000
--- a/databricks-mcp-server/run_server.py
+++ /dev/null
@@ -1,18 +0,0 @@
-#!/usr/bin/env python
-"""Run the Databricks MCP Server."""
-
-import logging
-import os
-import sys
-
-if os.environ.get("DATABRICKS_MCP_DEBUG"):
- logging.basicConfig(
- level=logging.DEBUG,
- stream=sys.stderr,
- format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
- )
-
-from databricks_mcp_server.server import mcp
-
-if __name__ == "__main__":
- mcp.run(transport="stdio")
diff --git a/databricks-mcp-server/setup.sh b/databricks-mcp-server/setup.sh
deleted file mode 100755
index f4839ee3..00000000
--- a/databricks-mcp-server/setup.sh
+++ /dev/null
@@ -1,87 +0,0 @@
-#!/bin/bash
-#
-# Setup script for databricks-mcp-server
-# Creates virtual environment and installs dependencies
-#
-
-set -e
-
-SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-PARENT_DIR="$(dirname "${SCRIPT_DIR}")"
-TOOLS_CORE_DIR="${PARENT_DIR}/databricks-tools-core"
-echo AI Dev Kit directory: $PARENT_DIR
-echo MCP Server directory: $SCRIPT_DIR
-echo Tools Core directory: $TOOLS_CORE_DIR
-
-
-echo "======================================"
-echo "Setting up Databricks MCP Server"
-echo "======================================"
-echo ""
-
-# Check for uv
-if ! command -v uv &> /dev/null; then
- echo "Error: 'uv' is not installed."
- echo "Install it with: curl -LsSf https://astral.sh/uv/install.sh | sh"
- exit 1
-fi
-echo "✓ uv is installed"
-
-# Check if tools-core directory exists
-if [ ! -d "$TOOLS_CORE_DIR" ]; then
- echo "Error: databricks-tools-core not found at $TOOLS_CORE_DIR"
- exit 1
-fi
-echo "✓ databricks-tools-core found"
-
-
-# Create virtual environment
-echo ""
-echo "Creating virtual environment..."
-uv venv --python 3.11
-echo "✓ Virtual environment created"
-
-
-# Install packages
-echo ""
-echo "Installing databricks-tools-core (editable)..."
-uv pip install --python .venv/bin/python -e "$TOOLS_CORE_DIR" --quiet
-echo "✓ databricks-tools-core installed"
-
-echo ""
-echo "Installing databricks-mcp-server (editable)..."
-
-uv pip install --python .venv/bin/python -e "$SCRIPT_DIR" --quiet
-echo "✓ databricks-mcp-server installed"
-
-# Verify
-echo ""
-echo "Verifying installation..."
-if .venv/bin/python -c "import databricks_mcp_server; print('✓ MCP server can be imported')"; then
- echo ""
- echo "======================================"
- echo "Setup complete!"
- echo "======================================"
- echo ""
- echo "To run the MCP server:"
- echo " .venv/bin/python run_server.py"
- echo ""
- echo "To setup in the project, paste this into .mcp.json (Claude) or .cursor/mcp.json (Cursor):"
- cat < .test-results/$TIMESTAMP/sql.txt 2>&1 &
- python -m pytest tests/integration/genie -m integration -v > .test-results/$TIMESTAMP/genie.txt 2>&1 &
- python -m pytest tests/integration/apps -m integration -v > .test-results/$TIMESTAMP/apps.txt 2>&1 &
- python -m pytest tests/integration/agent_bricks -m integration -v > .test-results/$TIMESTAMP/agent_bricks.txt 2>&1 &
- python -m pytest tests/integration/dashboards -m integration -v > .test-results/$TIMESTAMP/dashboards.txt 2>&1 &
- python -m pytest tests/integration/lakebase -m integration -v > .test-results/$TIMESTAMP/lakebase.txt 2>&1 &
- python -m pytest tests/integration/compute -m integration -v > .test-results/$TIMESTAMP/compute.txt 2>&1 &
- python -m pytest tests/integration/pipelines -m integration -v > .test-results/$TIMESTAMP/pipelines.txt 2>&1 &
- python -m pytest tests/integration/jobs -m integration -v > .test-results/$TIMESTAMP/jobs.txt 2>&1 &
- python -m pytest tests/integration/vector_search -m integration -v > .test-results/$TIMESTAMP/vector_search.txt 2>&1 &
- python -m pytest tests/integration/volume_files -m integration -v > .test-results/$TIMESTAMP/volume_files.txt 2>&1 &
- python -m pytest tests/integration/serving -m integration -v > .test-results/$TIMESTAMP/serving.txt 2>&1 &
- python -m pytest tests/integration/workspace_files -m integration -v > .test-results/$TIMESTAMP/workspace_files.txt 2>&1 &
- python -m pytest tests/integration/pdf -m integration -v > .test-results/$TIMESTAMP/pdf.txt 2>&1 &
- wait
-) && echo "Results in: .test-results/$TIMESTAMP/"
-```
-
-### Analyze Results
-
-After running tests in parallel, analyze results:
-
-```bash
-# Show summary of all test results
-for f in .test-results/$(ls -t .test-results | head -1)/*.txt; do
- name=$(basename "$f" .txt)
- result=$(grep -E "passed|failed|error" "$f" | tail -1)
- echo "$name: $result"
-done
-
-# Show failures only
-grep -l FAILED .test-results/$(ls -t .test-results | head -1)/*.txt | \
- xargs -I{} sh -c 'echo "=== {} ===" && grep -A5 "FAILED\|ERROR" {}'
-```
-
-## Test Structure
-
-### Test Markers
-
-- `@pytest.mark.integration` - All integration tests
-- `@pytest.mark.slow` - Tests that take >10s (list operations, lifecycle tests)
-
-### Test Categories
-
-| Module | Fast Tests | Lifecycle Tests | Notes |
-|--------|------------|-----------------|-------|
-| sql | Yes | No | SQL query execution |
-| genie | Yes | Yes | Genie space CRUD + queries |
-| apps | Yes | Yes | App deployment (slow) |
-| agent_bricks | Yes | Yes | KA/MAS creation (very slow) |
-| dashboards | Yes | No | Dashboard CRUD |
-| lakebase | Yes | Yes | Autoscale project lifecycle |
-| compute | Yes | Yes | Cluster lifecycle |
-| pipelines | Yes | Yes | DLT pipeline lifecycle |
-| jobs | Yes | Yes | Job lifecycle |
-| vector_search | Yes | Yes | VS endpoint/index lifecycle |
-| volume_files | Yes | No | Volume file operations |
-| workspace_files | Yes | No | Workspace file operations |
-| serving | Yes | No | Model serving endpoints |
-| pdf | Yes | No | PDF processing |
-
-### Naming Conventions
-
-Test resources use the prefix `ai_dev_kit_test_` to enable safe cleanup:
-- Apps: `ai-dev-kit-test-app-{uuid}` (apps require lowercase/dashes only)
-- Other resources: `ai_dev_kit_test_{type}_{uuid}`
-
-## Environment Variables
-
-| Variable | Description | Default |
-|----------|-------------|---------|
-| `TEST_CATALOG` | Unity Catalog for test resources | `ai_dev_kit_test` |
-| `DATABRICKS_HOST` | Workspace URL | From CLI profile |
-| `DATABRICKS_TOKEN` | Personal access token | From CLI profile |
-
-## Test Output
-
-Test results are stored in `.test-results/` (gitignored):
-- Each run creates a timestamped folder: `.test-results/20250331_123456/`
-- Each module gets its own file: `sql.txt`, `genie.txt`, etc.
-- Summary in `summary.txt`
-
-## Troubleshooting
-
-### Tests Timeout
-
-Some lifecycle tests (apps, agent_bricks, compute) may take 5+ minutes:
-```bash
-# Increase pytest timeout
-python -m pytest tests/integration/apps -m integration -v --timeout=600
-```
-
-### Resource Cleanup
-
-Test resources are automatically cleaned up. Manual cleanup:
-```bash
-# List test resources
-databricks apps list | grep ai-dev-kit-test
-databricks clusters list | grep ai_dev_kit_test
-
-# Delete orphaned resources
-databricks apps delete ai-dev-kit-test-app-abc123
-```
-
-### SDK Version Issues
-
-If you see API errors like `unexpected keyword argument`:
-```bash
-# Update SDK
-pip install --upgrade databricks-sdk
-```
diff --git a/databricks-mcp-server/tests/__init__.py b/databricks-mcp-server/tests/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/databricks-mcp-server/tests/conftest.py b/databricks-mcp-server/tests/conftest.py
deleted file mode 100644
index e33907f3..00000000
--- a/databricks-mcp-server/tests/conftest.py
+++ /dev/null
@@ -1,449 +0,0 @@
-"""
-Pytest fixtures for databricks-mcp-server integration tests.
-
-Uses centralized configuration from test_config.py.
-Each test module gets its own schema to enable parallel execution.
-"""
-
-import logging
-import os
-from pathlib import Path
-from typing import Generator, Callable
-
-import pytest
-from databricks.sdk import WorkspaceClient
-
-from .test_config import TEST_CATALOG, SCHEMAS, TEST_RESOURCE_PREFIX, get_full_schema_name
-
-# Load .env.test file if it exists
-_env_file = Path(__file__).parent.parent / ".env.test"
-if _env_file.exists():
- from dotenv import load_dotenv
- load_dotenv(_env_file)
- logging.getLogger(__name__).info(f"Loaded environment from {_env_file}")
-
-logger = logging.getLogger(__name__)
-
-
-def pytest_configure(config):
- """Configure pytest with custom markers."""
- config.addinivalue_line("markers", "integration: mark test as integration test requiring Databricks")
- config.addinivalue_line("markers", "slow: mark test as slow (may take a while to run)")
-
-
-# =============================================================================
-# Core Fixtures (Session-scoped)
-# =============================================================================
-
-@pytest.fixture(scope="session")
-def workspace_client() -> WorkspaceClient:
- """
- Create a WorkspaceClient for the test session.
-
- Uses standard Databricks authentication:
- 1. DATABRICKS_HOST + DATABRICKS_TOKEN env vars
- 2. ~/.databrickscfg profile
- """
- try:
- client = WorkspaceClient()
- # Verify connection works
- client.current_user.me()
- logger.info(f"Connected to Databricks: {client.config.host}")
- return client
- except Exception as e:
- pytest.skip(f"Could not connect to Databricks: {e}")
-
-
-@pytest.fixture(scope="session")
-def current_user(workspace_client: WorkspaceClient) -> str:
- """Get current user's email/username."""
- return workspace_client.current_user.me().user_name
-
-
-@pytest.fixture(scope="session")
-def test_catalog(workspace_client: WorkspaceClient, warehouse_id: str) -> str:
- """
- Ensure test catalog exists and current user has permissions.
-
- Returns the catalog name.
- """
- try:
- workspace_client.catalogs.get(TEST_CATALOG)
- logger.info(f"Using existing catalog: {TEST_CATALOG}")
- except Exception:
- logger.info(f"Creating catalog: {TEST_CATALOG}")
- workspace_client.catalogs.create(name=TEST_CATALOG)
-
- # Grant ALL_PRIVILEGES on the catalog to the current user using SQL
- current_user = workspace_client.current_user.me().user_name
- try:
- # Use backticks to escape the email address (contains @)
- grant_sql = f"GRANT ALL PRIVILEGES ON CATALOG `{TEST_CATALOG}` TO `{current_user}`"
- workspace_client.statement_execution.execute_statement(
- warehouse_id=warehouse_id,
- statement=grant_sql,
- wait_timeout="30s",
- )
- logger.info(f"Granted ALL_PRIVILEGES on {TEST_CATALOG} to {current_user}")
- except Exception as e:
- logger.warning(f"Could not grant permissions on catalog (may already have them): {e}")
-
- return TEST_CATALOG
-
-
-@pytest.fixture(scope="session")
-def warehouse_id(workspace_client: WorkspaceClient) -> str:
- """
- Get a running SQL warehouse for tests.
-
- Prefers shared endpoints, falls back to any running warehouse.
- """
- from databricks.sdk.service.sql import State
-
- warehouses = list(workspace_client.warehouses.list())
-
- # Priority: running shared endpoint
- for w in warehouses:
- if w.state == State.RUNNING and "shared" in (w.name or "").lower():
- logger.info(f"Using warehouse: {w.name} ({w.id})")
- return w.id
-
- # Fallback: any running warehouse
- for w in warehouses:
- if w.state == State.RUNNING:
- logger.info(f"Using warehouse: {w.name} ({w.id})")
- return w.id
-
- # No running warehouse found
- pytest.skip("No running SQL warehouse available for tests")
-
-
-# =============================================================================
-# Schema Fixtures (Module-scoped, per test module)
-# =============================================================================
-
-def _create_test_schema(
- workspace_client: WorkspaceClient,
- test_catalog: str,
- schema_name: str,
-) -> Generator[str, None, None]:
- """Helper to create and cleanup a test schema."""
- full_schema_name = f"{test_catalog}.{schema_name}"
-
- # Drop schema if exists (cascade to remove all objects)
- try:
- logger.info(f"Dropping existing schema: {full_schema_name}")
- workspace_client.schemas.delete(full_schema_name, force=True)
- except Exception as e:
- logger.debug(f"Schema delete failed (may not exist): {e}")
-
- # Create fresh schema
- logger.info(f"Creating schema: {full_schema_name}")
- try:
- workspace_client.schemas.create(
- name=schema_name,
- catalog_name=test_catalog,
- )
- except Exception as e:
- if "already exists" in str(e).lower():
- logger.info(f"Schema already exists, reusing: {full_schema_name}")
- else:
- raise
-
- yield schema_name
-
- # Cleanup after tests
- try:
- logger.info(f"Cleaning up schema: {full_schema_name}")
- workspace_client.schemas.delete(full_schema_name, force=True)
- except Exception as e:
- logger.warning(f"Failed to cleanup schema: {e}")
-
-
-@pytest.fixture(scope="module")
-def sql_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for SQL tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["sql"])
-
-
-@pytest.fixture(scope="module")
-def pipelines_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for pipeline tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["pipelines"])
-
-
-@pytest.fixture(scope="module")
-def vector_search_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for vector search tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["vector_search"])
-
-
-@pytest.fixture(scope="module")
-def genie_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for Genie tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["genie"])
-
-
-@pytest.fixture(scope="module")
-def dashboards_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for dashboard tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["dashboards"])
-
-
-@pytest.fixture(scope="module")
-def jobs_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for jobs tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["jobs"])
-
-
-@pytest.fixture(scope="module")
-def volume_files_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for volume files tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["volume_files"])
-
-
-@pytest.fixture(scope="module")
-def compute_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for compute tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["compute"])
-
-
-@pytest.fixture(scope="module")
-def agent_bricks_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for agent bricks tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["agent_bricks"])
-
-
-@pytest.fixture(scope="module")
-def pdf_schema(workspace_client: WorkspaceClient, test_catalog: str) -> Generator[str, None, None]:
- """Schema for PDF tests."""
- yield from _create_test_schema(workspace_client, test_catalog, SCHEMAS["pdf"])
-
-
-# =============================================================================
-# Cleanup Fixtures
-# =============================================================================
-
-@pytest.fixture(scope="function")
-def cleanup_pipelines() -> Generator[Callable[[str], None], None, None]:
- """Register pipelines for cleanup after test."""
- from databricks_mcp_server.tools.pipelines import manage_pipeline
-
- pipelines_to_cleanup = []
-
- def register(pipeline_id: str):
- pipelines_to_cleanup.append(pipeline_id)
-
- yield register
-
- for pid in pipelines_to_cleanup:
- try:
- manage_pipeline(action="delete", pipeline_id=pid)
- logger.info(f"Cleaned up pipeline: {pid}")
- except Exception as e:
- logger.warning(f"Failed to cleanup pipeline {pid}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_vs_endpoints() -> Generator[Callable[[str], None], None, None]:
- """Register vector search endpoints for cleanup after test."""
- from databricks_mcp_server.tools.vector_search import manage_vs_endpoint
-
- endpoints_to_cleanup = []
-
- def register(endpoint_name: str):
- endpoints_to_cleanup.append(endpoint_name)
-
- yield register
-
- for name in endpoints_to_cleanup:
- try:
- manage_vs_endpoint(action="delete", name=name)
- logger.info(f"Cleaned up VS endpoint: {name}")
- except Exception as e:
- logger.warning(f"Failed to cleanup VS endpoint {name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_vs_indexes() -> Generator[Callable[[str], None], None, None]:
- """Register vector search indexes for cleanup after test."""
- from databricks_mcp_server.tools.vector_search import manage_vs_index
-
- indexes_to_cleanup = []
-
- def register(index_name: str):
- indexes_to_cleanup.append(index_name)
-
- yield register
-
- for name in indexes_to_cleanup:
- try:
- manage_vs_index(action="delete", index_name=name)
- logger.info(f"Cleaned up VS index: {name}")
- except Exception as e:
- logger.warning(f"Failed to cleanup VS index {name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_jobs() -> Generator[Callable[[str], None], None, None]:
- """Register jobs for cleanup after test."""
- from databricks_mcp_server.tools.jobs import manage_jobs
-
- jobs_to_cleanup = []
-
- def register(job_id: str):
- jobs_to_cleanup.append(job_id)
-
- yield register
-
- for jid in jobs_to_cleanup:
- try:
- manage_jobs(action="delete", job_id=jid)
- logger.info(f"Cleaned up job: {jid}")
- except Exception as e:
- logger.warning(f"Failed to cleanup job {jid}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_dashboards() -> Generator[Callable[[str], None], None, None]:
- """Register dashboards for cleanup after test."""
- from databricks_mcp_server.tools.aibi_dashboards import manage_dashboard
-
- dashboards_to_cleanup = []
-
- def register(dashboard_id: str):
- dashboards_to_cleanup.append(dashboard_id)
-
- yield register
-
- for did in dashboards_to_cleanup:
- try:
- manage_dashboard(action="delete", dashboard_id=did)
- logger.info(f"Cleaned up dashboard: {did}")
- except Exception as e:
- logger.warning(f"Failed to cleanup dashboard {did}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_genie_spaces() -> Generator[Callable[[str], None], None, None]:
- """Register Genie spaces for cleanup after test."""
- from databricks_mcp_server.tools.genie import manage_genie
-
- spaces_to_cleanup = []
-
- def register(space_id: str):
- spaces_to_cleanup.append(space_id)
-
- yield register
-
- for sid in spaces_to_cleanup:
- try:
- manage_genie(action="delete", space_id=sid)
- logger.info(f"Cleaned up Genie space: {sid}")
- except Exception as e:
- logger.warning(f"Failed to cleanup Genie space {sid}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_apps() -> Generator[Callable[[str], None], None, None]:
- """Register apps for cleanup after test."""
- from databricks_mcp_server.tools.apps import manage_app
-
- apps_to_cleanup = []
-
- def register(app_name: str):
- apps_to_cleanup.append(app_name)
-
- yield register
-
- for name in apps_to_cleanup:
- try:
- manage_app(action="delete", name=name)
- logger.info(f"Cleaned up app: {name}")
- except Exception as e:
- logger.warning(f"Failed to cleanup app {name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_lakebase_instances() -> Generator[Callable[[str], None], None, None]:
- """Register Lakebase instances for cleanup after test."""
- from databricks_mcp_server.tools.lakebase import manage_lakebase_database
-
- instances_to_cleanup = []
-
- def register(name: str, db_type: str = "provisioned"):
- instances_to_cleanup.append((name, db_type))
-
- yield register
-
- for name, db_type in instances_to_cleanup:
- try:
- manage_lakebase_database(action="delete", name=name, type=db_type, force=True)
- logger.info(f"Cleaned up Lakebase {db_type} instance: {name}")
- except Exception as e:
- logger.warning(f"Failed to cleanup Lakebase instance {name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_ka() -> Generator[Callable[[str], None], None, None]:
- """Register Knowledge Assistants for cleanup after test."""
- from databricks_mcp_server.tools.agent_bricks import manage_ka
-
- kas_to_cleanup = []
-
- def register(tile_id: str):
- kas_to_cleanup.append(tile_id)
-
- yield register
-
- for tile_id in kas_to_cleanup:
- try:
- manage_ka(action="delete", tile_id=tile_id)
- logger.info(f"Cleaned up KA: {tile_id}")
- except Exception as e:
- logger.warning(f"Failed to cleanup KA {tile_id}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_mas() -> Generator[Callable[[str], None], None, None]:
- """Register Multi-Agent Supervisors for cleanup after test."""
- from databricks_mcp_server.tools.agent_bricks import manage_mas
-
- mas_to_cleanup = []
-
- def register(tile_id: str):
- mas_to_cleanup.append(tile_id)
-
- yield register
-
- for tile_id in mas_to_cleanup:
- try:
- manage_mas(action="delete", tile_id=tile_id)
- logger.info(f"Cleaned up MAS: {tile_id}")
- except Exception as e:
- logger.warning(f"Failed to cleanup MAS {tile_id}: {e}")
-
-
-# =============================================================================
-# Workspace Path Fixtures
-# =============================================================================
-
-@pytest.fixture(scope="module")
-def workspace_test_path(workspace_client: WorkspaceClient, current_user: str) -> Generator[str, None, None]:
- """Get a workspace path for test files and clean up after tests."""
- path = f"/Workspace/Users/{current_user}/ai_dev_kit_test/workspace_files/resources"
-
- # Delete if exists
- try:
- workspace_client.workspace.delete(path, recursive=True)
- except Exception:
- pass
-
- yield path
-
- # Cleanup after tests
- try:
- workspace_client.workspace.delete(path, recursive=True)
- logger.info(f"Cleaned up workspace path: {path}")
- except Exception as e:
- logger.warning(f"Failed to cleanup workspace path: {e}")
diff --git a/databricks-mcp-server/tests/integration/README.md b/databricks-mcp-server/tests/integration/README.md
deleted file mode 100644
index 9146727d..00000000
--- a/databricks-mcp-server/tests/integration/README.md
+++ /dev/null
@@ -1,115 +0,0 @@
-# Integration Tests
-
-This directory contains integration tests for the Databricks MCP Server tools. These tests run against a real Databricks workspace.
-
-## Prerequisites
-
-1. **Databricks Authentication**: Configure your Databricks credentials via environment variables or `~/.databrickscfg`
-2. **Test Catalog**: Set `TEST_CATALOG` in `tests/test_config.py` or use the default
-3. **Python Dependencies**: Install test dependencies with `pip install -e ".[dev]"`
-
-## Running Tests
-
-### Quick Start: Run All Tests
-
-```bash
-# Run all tests (excluding slow tests)
-python tests/integration/run_tests.py
-
-# Run all tests including slow tests (cluster lifecycle, etc.)
-python tests/integration/run_tests.py --all
-```
-
-### View Test Reports
-
-```bash
-# Show report from the latest test run
-python tests/integration/run_tests.py --report
-
-# Show report from a specific run (by timestamp)
-python tests/integration/run_tests.py --report 20260331_112315
-```
-
-### Check Status of Running Tests
-
-```bash
-# Show status of ongoing and recently completed runs
-python tests/integration/run_tests.py --status
-```
-
-### Advanced Options
-
-```bash
-# Run with fewer parallel workers (default: 8)
-python tests/integration/run_tests.py -j 4
-
-# Combine options
-python tests/integration/run_tests.py --all -j 4
-
-# Clean up old test results (keeps last 5 runs)
-python tests/integration/run_tests.py --cleanup-results
-```
-
-### Run Individual Test Folders
-
-```bash
-# Run a specific test folder
-python -m pytest tests/integration/sql -m integration -v
-
-# Run a specific test
-python -m pytest tests/integration/sql/test_sql.py::TestExecuteSql::test_simple_query -v
-```
-
-## Test Output
-
-Test results are saved to `.test-results//`:
-
-```
-.test-results/
-└── 20260331_112315/
- ├── results.json # Machine-readable results
- ├── sql.txt # Logs for sql tests
- ├── workspace_files.txt
- ├── dashboards.txt
- └── ...
-```
-
-## Test Markers
-
-- `@pytest.mark.integration` - Standard integration tests
-- `@pytest.mark.slow` - Tests that take a long time (cluster creation, etc.)
-
-## Test Folders
-
-| Folder | Description |
-|--------|-------------|
-| `sql/` | SQL execution and query tests |
-| `workspace_files/` | Workspace file upload/download tests |
-| `volume_files/` | Unity Catalog volume file operations |
-| `dashboards/` | AI/BI dashboard management |
-| `genie/` | Genie (AI assistant) spaces |
-| `agent_bricks/` | Agent Bricks tool tests |
-| `compute/` | Cluster and serverless compute |
-| `jobs/` | Job creation and execution |
-| `pipelines/` | DLT pipeline management |
-| `vector_search/` | Vector search endpoints and indexes |
-| `serving/` | Model serving endpoints |
-| `apps/` | Databricks Apps |
-| `lakebase/` | Lakebase database operations |
-| `pdf/` | PDF processing tests |
-
-## Re-running Failed Tests
-
-After a test run, you can re-run specific failed tests:
-
-```bash
-# View the failure details
-cat .test-results//jobs.txt
-
-# Re-run with more verbose output
-python -m pytest tests/integration/jobs -v --tb=long
-```
-
-## Cleanup
-
-Test resources are automatically cleaned up after tests. If cleanup fails, resources are prefixed with `ai_dev_kit_test_` for easy identification.
diff --git a/databricks-mcp-server/tests/integration/__init__.py b/databricks-mcp-server/tests/integration/__init__.py
deleted file mode 100644
index a968bf1a..00000000
--- a/databricks-mcp-server/tests/integration/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Integration tests for databricks-mcp-server
diff --git a/databricks-mcp-server/tests/integration/agent_bricks/__init__.py b/databricks-mcp-server/tests/integration/agent_bricks/__init__.py
deleted file mode 100644
index 7efbac60..00000000
--- a/databricks-mcp-server/tests/integration/agent_bricks/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Agent Bricks integration tests
diff --git a/databricks-mcp-server/tests/integration/agent_bricks/resources/api_reference.pdf b/databricks-mcp-server/tests/integration/agent_bricks/resources/api_reference.pdf
deleted file mode 100644
index bb88ef8a..00000000
Binary files a/databricks-mcp-server/tests/integration/agent_bricks/resources/api_reference.pdf and /dev/null differ
diff --git a/databricks-mcp-server/tests/integration/agent_bricks/resources/product_guide.pdf b/databricks-mcp-server/tests/integration/agent_bricks/resources/product_guide.pdf
deleted file mode 100644
index 113e9722..00000000
Binary files a/databricks-mcp-server/tests/integration/agent_bricks/resources/product_guide.pdf and /dev/null differ
diff --git a/databricks-mcp-server/tests/integration/agent_bricks/test_agent_bricks.py b/databricks-mcp-server/tests/integration/agent_bricks/test_agent_bricks.py
deleted file mode 100644
index 7d2275a9..00000000
--- a/databricks-mcp-server/tests/integration/agent_bricks/test_agent_bricks.py
+++ /dev/null
@@ -1,455 +0,0 @@
-"""
-Integration tests for Agent Bricks MCP tools.
-
-Tests:
-- manage_ka: create_or_update, get, find_by_name, delete
-- manage_mas: create_or_update, get, find_by_name, delete
-"""
-
-import logging
-import time
-import uuid
-from pathlib import Path
-
-import pytest
-from databricks.sdk.service.catalog import VolumeType
-
-from databricks_mcp_server.tools.agent_bricks import manage_ka, manage_mas
-from databricks_mcp_server.tools.volume_files import manage_volume_files
-from tests.test_config import TEST_RESOURCE_PREFIX, TEST_CATALOG
-
-logger = logging.getLogger(__name__)
-
-# Path to test resources (static PDFs)
-RESOURCES_DIR = Path(__file__).parent / "resources"
-
-
-@pytest.mark.integration
-class TestManageKA:
- """Tests for manage_ka tool - fast validation tests."""
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_ka(action="invalid_action")
-
- assert "error" in result
-
- def test_find_by_name_nonexistent(self):
- """Should return not found for nonexistent KA."""
- result = manage_ka(
- action="find_by_name",
- name="nonexistent_ka_xyz_12345"
- )
-
- logger.info(f"Find nonexistent KA result: {result}")
-
- # Should not crash, and return found=False
- assert result is not None
- assert result.get("found") is False
-
- def test_find_by_name_missing_param(self):
- """Should return error when name not provided."""
- result = manage_ka(action="find_by_name")
-
- assert "error" in result
- assert "name" in result["error"]
-
- def test_get_missing_tile_id(self):
- """Should return error when tile_id not provided for get."""
- result = manage_ka(action="get")
-
- assert "error" in result
- assert "tile_id" in result["error"]
-
- def test_get_nonexistent_ka(self):
- """Should handle nonexistent KA gracefully."""
- result = manage_ka(
- action="get",
- tile_id="nonexistent_tile_id_xyz_12345"
- )
-
- logger.info(f"Get nonexistent KA result: {result}")
-
- assert "error" in result
-
- def test_delete_missing_tile_id(self):
- """Should return error when tile_id not provided for delete."""
- result = manage_ka(action="delete")
-
- assert "error" in result
- assert "tile_id" in result["error"]
-
- def test_create_or_update_requires_params(self):
- """Should require name and volume_path for create_or_update."""
- result = manage_ka(action="create_or_update")
-
- assert "error" in result
- assert "name" in result["error"] or "volume_path" in result["error"]
-
-
-@pytest.mark.integration
-class TestManageMAS:
- """Tests for manage_mas tool - fast validation tests."""
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_mas(action="invalid_action")
-
- assert "error" in result
-
- def test_find_by_name_nonexistent(self):
- """Should return not found for nonexistent MAS."""
- result = manage_mas(
- action="find_by_name",
- name="nonexistent_mas_xyz_12345"
- )
-
- logger.info(f"Find nonexistent MAS result: {result}")
-
- # Should not crash, and return found=False
- assert result is not None
- assert result.get("found") is False
-
- def test_find_by_name_missing_param(self):
- """Should return error when name not provided."""
- result = manage_mas(action="find_by_name")
-
- assert "error" in result
- assert "name" in result["error"]
-
- def test_get_missing_tile_id(self):
- """Should return error when tile_id not provided for get."""
- result = manage_mas(action="get")
-
- assert "error" in result
- assert "tile_id" in result["error"]
-
- def test_get_nonexistent_mas(self):
- """Should handle nonexistent MAS gracefully."""
- result = manage_mas(
- action="get",
- tile_id="nonexistent_tile_id_xyz_12345"
- )
-
- logger.info(f"Get nonexistent MAS result: {result}")
-
- assert "error" in result
-
- def test_delete_missing_tile_id(self):
- """Should return error when tile_id not provided for delete."""
- result = manage_mas(action="delete")
-
- assert "error" in result
- assert "tile_id" in result["error"]
-
- def test_create_or_update_requires_params(self):
- """Should require name and agents for create_or_update."""
- result = manage_mas(action="create_or_update")
-
- assert "error" in result
- assert "name" in result["error"] or "agents" in result["error"]
-
- def test_create_or_update_agents_validation(self):
- """Should validate agent configuration."""
- result = manage_mas(
- action="create_or_update",
- name="test_mas",
- agents=[
- {"name": "agent1"} # Missing description and agent type
- ]
- )
-
- assert "error" in result
- assert "description" in result["error"]
-
- def test_create_or_update_agent_type_validation(self):
- """Should require exactly one agent type."""
- result = manage_mas(
- action="create_or_update",
- name="test_mas",
- agents=[
- {
- "name": "agent1",
- "description": "Test agent",
- # Missing agent type (endpoint_name, genie_space_id, etc.)
- }
- ]
- )
-
- assert "error" in result
- assert "endpoint_name" in result["error"] or "one of" in result["error"].lower()
-
- def test_create_or_update_multiple_agent_types_validation(self):
- """Should reject multiple agent types on same agent."""
- result = manage_mas(
- action="create_or_update",
- name="test_mas",
- agents=[
- {
- "name": "agent1",
- "description": "Test agent",
- "endpoint_name": "some-endpoint",
- "genie_space_id": "some-space-id", # Multiple types!
- }
- ]
- )
-
- assert "error" in result
- assert "multiple" in result["error"].lower()
-
-
-@pytest.mark.integration
-class TestAgentBricksLifecycle:
- """End-to-end test for KA + MAS lifecycle: upload PDFs -> create KA -> create MAS -> verify -> delete."""
-
- def test_full_ka_mas_lifecycle(
- self,
- workspace_client,
- test_catalog: str,
- agent_bricks_schema: str,
- cleanup_ka,
- cleanup_mas,
- ):
- """Test complete lifecycle: upload PDFs, create KA, create MAS using KA, test all actions, delete both."""
- test_start = time.time()
- unique_id = uuid.uuid4().hex[:6]
-
- # Names with underscores (API normalizes spaces and / to underscores)
- ka_name = f"{TEST_RESOURCE_PREFIX}KA_Test_{unique_id}"
- mas_name = f"{TEST_RESOURCE_PREFIX}MAS_Test_{unique_id}"
-
- volume_name = f"{TEST_RESOURCE_PREFIX}ka_docs_{unique_id}"
- full_volume_path = f"/Volumes/{test_catalog}/{agent_bricks_schema}/{volume_name}"
-
- ka_tile_id = None
- mas_tile_id = None
-
- def log_time(msg):
- elapsed = time.time() - test_start
- logger.info(f"[{elapsed:.1f}s] {msg}")
-
- try:
- # ==================== SETUP: Create volume and upload PDFs ====================
- log_time("Step 1: Creating volume and uploading PDFs...")
-
- # Create volume
- try:
- workspace_client.volumes.create(
- catalog_name=test_catalog,
- schema_name=agent_bricks_schema,
- name=volume_name,
- volume_type=VolumeType.MANAGED,
- )
- log_time(f"Created volume: {full_volume_path}")
- except Exception as e:
- if "already exists" in str(e).lower():
- log_time(f"Volume already exists: {full_volume_path}")
- else:
- raise
-
- # Upload PDFs from resources folder
- for pdf_file in RESOURCES_DIR.glob("*.pdf"):
- upload_result = manage_volume_files(
- action="upload",
- volume_path=f"{full_volume_path}/docs/{pdf_file.name}",
- local_path=str(pdf_file),
- )
- log_time(f"Uploaded {pdf_file.name}: {upload_result.get('success', upload_result)}")
-
- # ==================== KA LIFECYCLE ====================
- log_time(f"Step 2: Creating KA '{ka_name}'...")
-
- # KA create_or_update
- create_ka_result = manage_ka(
- action="create_or_update",
- name=ka_name,
- volume_path=full_volume_path,
- description="Test KA for integration tests with special chars in name",
- instructions="Answer questions about the test documents.",
- add_examples_from_volume=False,
- )
- log_time(f"Create KA result: {create_ka_result}")
- assert "error" not in create_ka_result, f"Create KA failed: {create_ka_result}"
- assert create_ka_result.get("tile_id"), "Should return tile_id"
-
- ka_tile_id = create_ka_result["tile_id"]
- cleanup_ka(ka_tile_id)
- log_time(f"KA created with tile_id: {ka_tile_id}")
-
- # Wait for endpoint to start provisioning
- time.sleep(10)
-
- # KA get (tile_id)
- log_time("Step 3: Testing KA get...")
- get_ka_result = manage_ka(action="get", tile_id=ka_tile_id)
- log_time(f"Get KA result: {get_ka_result}")
- assert "error" not in get_ka_result, f"Get KA failed: {get_ka_result}"
- assert get_ka_result.get("tile_id") == ka_tile_id
- assert get_ka_result.get("name") == ka_name
-
- # KA find_by_name (name)
- log_time("Step 4: Testing KA find_by_name...")
- find_ka_result = manage_ka(action="find_by_name", name=ka_name)
- log_time(f"Find KA result: {find_ka_result}")
- assert find_ka_result.get("found") is True, f"KA should be found: {find_ka_result}"
- assert find_ka_result.get("tile_id") == ka_tile_id
-
- # Get KA endpoint name for MAS
- ka_endpoint_name = get_ka_result.get("endpoint_name")
- log_time(f"KA endpoint name: {ka_endpoint_name}")
-
- # Wait for KA endpoint to be online (needed for MAS)
- log_time("Step 5: Waiting for KA endpoint to be online...")
- max_wait = 600 # 10 minutes - KA provisioning can take a while
- wait_interval = 15
- waited = 0
- ka_ready = False
-
- while waited < max_wait:
- check_result = manage_ka(action="get", tile_id=ka_tile_id)
- endpoint_status = check_result.get("endpoint_status", "UNKNOWN")
- log_time(f"KA endpoint status after {waited}s: {endpoint_status}")
-
- if endpoint_status == "ONLINE":
- ka_ready = True
- ka_endpoint_name = check_result.get("endpoint_name")
- break
- elif endpoint_status in ("FAILED", "ERROR"):
- log_time(f"KA endpoint failed: {check_result}")
- break
-
- time.sleep(wait_interval)
- waited += wait_interval
-
- if not ka_ready:
- log_time(f"KA endpoint not online after {max_wait}s, skipping MAS creation")
- pytest.skip("KA endpoint not ready, cannot test MAS")
-
- # KA create_or_update (UPDATE existing - tests name lookup and API 2.1 update)
- # Must wait for ONLINE status before update is allowed
- log_time("Step 5b: Testing KA create_or_update on EXISTING KA...")
- update_ka_result = manage_ka(
- action="create_or_update",
- name=ka_name, # Same name - should find existing and update
- volume_path=full_volume_path,
- description="UPDATED description for integration test",
- instructions="UPDATED instructions for the test.",
- add_examples_from_volume=False,
- )
- log_time(f"Update KA result: {update_ka_result}")
- assert "error" not in update_ka_result, f"Update KA failed: {update_ka_result}"
- assert update_ka_result.get("tile_id") == ka_tile_id, "Should return same tile_id"
- assert update_ka_result.get("operation") == "updated", "Should report 'updated' operation"
-
- # Verify the update was applied
- verify_result = manage_ka(action="get", tile_id=ka_tile_id)
- assert "UPDATED description" in verify_result.get("description", ""), "Description should be updated"
- assert "UPDATED instructions" in verify_result.get("instructions", ""), "Instructions should be updated"
- log_time("KA update verified successfully")
-
- # ==================== MAS LIFECYCLE ====================
- log_time(f"Step 6: Creating MAS '{mas_name}' using KA endpoint...")
-
- # MAS create_or_update (name + agents)
- create_mas_result = manage_mas(
- action="create_or_update",
- name=mas_name,
- description="Test MAS for integration tests with KA agent",
- instructions="Route questions to the Knowledge Assistant.",
- agents=[
- {
- "name": "knowledge_agent",
- "description": "Answers questions using the Knowledge Assistant",
- "endpoint_name": ka_endpoint_name,
- }
- ],
- )
- log_time(f"Create MAS result: {create_mas_result}")
- assert "error" not in create_mas_result, f"Create MAS failed: {create_mas_result}"
- assert create_mas_result.get("tile_id"), "Should return tile_id"
- assert create_mas_result.get("agents_count") == 1
-
- mas_tile_id = create_mas_result["tile_id"]
- cleanup_mas(mas_tile_id)
- log_time(f"MAS created with tile_id: {mas_tile_id}")
-
- # Wait for MAS to be created
- time.sleep(10)
-
- # MAS get (tile_id)
- log_time("Step 7: Testing MAS get...")
- get_mas_result = manage_mas(action="get", tile_id=mas_tile_id)
- log_time(f"Get MAS result: {get_mas_result}")
- assert "error" not in get_mas_result, f"Get MAS failed: {get_mas_result}"
- assert get_mas_result.get("tile_id") == mas_tile_id
- assert get_mas_result.get("name") == mas_name
- assert len(get_mas_result.get("agents", [])) == 1
-
- # MAS find_by_name (name)
- log_time("Step 8: Testing MAS find_by_name...")
- find_mas_result = manage_mas(action="find_by_name", name=mas_name)
- log_time(f"Find MAS result: {find_mas_result}")
- assert find_mas_result.get("found") is True, f"MAS should be found: {find_mas_result}"
- assert find_mas_result.get("tile_id") == mas_tile_id
- assert find_mas_result.get("agents_count") == 1
-
- # ==================== CLEANUP: Delete MAS then KA ====================
- # MAS delete (tile_id)
- log_time("Step 9: Deleting MAS...")
- delete_mas_result = manage_mas(action="delete", tile_id=mas_tile_id)
- log_time(f"Delete MAS result: {delete_mas_result}")
- assert delete_mas_result.get("success") is True, f"Delete MAS failed: {delete_mas_result}"
-
- # Wait for deletion
- time.sleep(10)
-
- # Verify MAS is gone
- log_time("Step 10: Verifying MAS deleted...")
- find_mas_after = manage_mas(action="find_by_name", name=mas_name)
- log_time(f"Find MAS after delete: {find_mas_after}")
- assert find_mas_after.get("found") is False, f"MAS should be deleted: {find_mas_after}"
- mas_tile_id = None # Mark as deleted
-
- # KA delete (tile_id)
- log_time("Step 11: Deleting KA...")
- delete_ka_result = manage_ka(action="delete", tile_id=ka_tile_id)
- log_time(f"Delete KA result: {delete_ka_result}")
- assert delete_ka_result.get("success") is True, f"Delete KA failed: {delete_ka_result}"
-
- # Wait for deletion
- time.sleep(10)
-
- # Verify KA is gone
- log_time("Step 12: Verifying KA deleted...")
- find_ka_after = manage_ka(action="find_by_name", name=ka_name)
- log_time(f"Find KA after delete: {find_ka_after}")
- assert find_ka_after.get("found") is False, f"KA should be deleted: {find_ka_after}"
- ka_tile_id = None # Mark as deleted
-
- log_time("Full KA + MAS lifecycle test PASSED!")
-
- except Exception as e:
- log_time(f"Test failed: {e}")
- raise
- finally:
- # Cleanup on failure
- if mas_tile_id:
- log_time(f"Cleanup: deleting MAS {mas_tile_id}")
- try:
- manage_mas(action="delete", tile_id=mas_tile_id)
- except Exception:
- pass
-
- if ka_tile_id:
- log_time(f"Cleanup: deleting KA {ka_tile_id}")
- try:
- manage_ka(action="delete", tile_id=ka_tile_id)
- except Exception:
- pass
-
- # Cleanup volume
- try:
- workspace_client.volumes.delete(f"{test_catalog}.{agent_bricks_schema}.{volume_name}")
- log_time(f"Cleaned up volume: {full_volume_path}")
- except Exception as e:
- log_time(f"Failed to cleanup volume: {e}")
diff --git a/databricks-mcp-server/tests/integration/apps/__init__.py b/databricks-mcp-server/tests/integration/apps/__init__.py
deleted file mode 100644
index 7ec0c1f4..00000000
--- a/databricks-mcp-server/tests/integration/apps/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Apps integration tests
diff --git a/databricks-mcp-server/tests/integration/apps/resources/app.py b/databricks-mcp-server/tests/integration/apps/resources/app.py
deleted file mode 100644
index fb3bac46..00000000
--- a/databricks-mcp-server/tests/integration/apps/resources/app.py
+++ /dev/null
@@ -1,25 +0,0 @@
-"""
-Simple Databricks App for MCP Integration Tests.
-
-This is a minimal Gradio app that just displays a greeting.
-"""
-
-import gradio as gr
-
-
-def greet(name: str) -> str:
- """Return a greeting message."""
- return f"Hello, {name}! This is a test app for MCP integration tests."
-
-
-# Create simple Gradio interface
-demo = gr.Interface(
- fn=greet,
- inputs=gr.Textbox(label="Your Name", placeholder="Enter your name"),
- outputs=gr.Textbox(label="Greeting"),
- title="MCP Test App",
- description="A simple test app for integration tests.",
-)
-
-# For Databricks Apps
-app = demo.app
diff --git a/databricks-mcp-server/tests/integration/apps/resources/app.yaml b/databricks-mcp-server/tests/integration/apps/resources/app.yaml
deleted file mode 100644
index c674f6b8..00000000
--- a/databricks-mcp-server/tests/integration/apps/resources/app.yaml
+++ /dev/null
@@ -1,8 +0,0 @@
-# Databricks App configuration for MCP Integration Tests
-command:
- - python
- - app.py
-
-env:
- - name: GRADIO_SERVER_NAME
- value: "0.0.0.0"
diff --git a/databricks-mcp-server/tests/integration/apps/test_apps.py b/databricks-mcp-server/tests/integration/apps/test_apps.py
deleted file mode 100644
index fdf5e067..00000000
--- a/databricks-mcp-server/tests/integration/apps/test_apps.py
+++ /dev/null
@@ -1,216 +0,0 @@
-"""
-Integration tests for Databricks Apps MCP tool.
-
-Tests:
-- manage_app: create_or_update, get, list, delete
-"""
-
-import logging
-import time
-import uuid
-from pathlib import Path
-
-import pytest
-
-from databricks_mcp_server.tools.apps import manage_app
-from databricks_mcp_server.tools.file import manage_workspace_files
-from tests.test_config import TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-# Path to test app resources
-RESOURCES_DIR = Path(__file__).parent / "resources"
-
-
-@pytest.mark.integration
-class TestManageApp:
- """Tests for manage_app tool - fast validation tests."""
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_app(action="invalid_action")
-
- assert "error" in result
-
- def test_get_nonexistent_app(self):
- """Should handle nonexistent app gracefully."""
- try:
- result = manage_app(action="get", name="nonexistent_app_xyz_12345")
-
- logger.info(f"Get nonexistent result: {result}")
-
- # Should return error or not found
- assert result.get("error") or result.get("status") == "NOT_FOUND"
- except Exception as e:
- # SDK raises exception for nonexistent app - this is acceptable
- error_msg = str(e).lower()
- assert "not exist" in error_msg or "not found" in error_msg or "deleted" in error_msg
-
- def test_create_or_update_requires_name(self):
- """Should require name for create_or_update."""
- result = manage_app(action="create_or_update")
-
- assert "error" in result
- assert "name" in result["error"]
-
- def test_delete_requires_name(self):
- """Should require name for delete."""
- result = manage_app(action="delete")
-
- assert "error" in result
- assert "name" in result["error"]
-
- @pytest.mark.slow
- def test_list_apps(self):
- """Should list all apps (slow due to pagination)."""
- result = manage_app(action="list")
-
- logger.info(f"List result: {result}")
-
- assert "error" not in result, f"List failed: {result}"
-
-
-@pytest.mark.integration
-class TestAppLifecycle:
- """End-to-end test for app lifecycle: upload source -> create -> deploy -> get -> delete."""
-
- def test_full_app_lifecycle(
- self,
- workspace_client,
- current_user: str,
- cleanup_apps,
- ):
- """Test complete app lifecycle: upload source, create, wait for deployment, get, delete, verify."""
- test_start = time.time()
- unique_id = uuid.uuid4().hex[:6]
-
- # App names can only contain lowercase letters, numbers, and dashes
- # Convert TEST_RESOURCE_PREFIX underscores to dashes for valid app name
- app_name = f"ai-dev-kit-test-app-{unique_id}"
-
- workspace_path = f"/Workspace/Users/{current_user}/ai_dev_kit_test/apps/{app_name}"
-
- def log_time(msg):
- elapsed = time.time() - test_start
- logger.info(f"[{elapsed:.1f}s] {msg}")
-
- try:
- # ==================== SETUP: Upload app source code ====================
- log_time(f"Step 1: Uploading app source to {workspace_path}...")
-
- upload_result = manage_workspace_files(
- action="upload",
- local_path=str(RESOURCES_DIR) + "/", # Trailing slash = upload contents only
- workspace_path=workspace_path,
- overwrite=True,
- )
-
- assert upload_result.get("success", False) or upload_result.get("status") == "success", \
- f"Failed to upload app resources: {upload_result}"
- log_time(f"Upload result: {upload_result}")
-
- # ==================== APP LIFECYCLE ====================
- # Step 2: Create and deploy app
- log_time(f"Step 2: Creating app '{app_name}'...")
-
- create_result = manage_app(
- action="create_or_update",
- name=app_name,
- source_code_path=workspace_path,
- description="Test app for MCP integration tests with special chars",
- )
-
- log_time(f"Create result: {create_result}")
-
- # Skip test if quota exceeded (workspace limitation, not a test failure)
- if "error" in create_result:
- error_msg = str(create_result.get("error", "")).lower()
- if "quota" in error_msg or "limit" in error_msg or "exceeded" in error_msg:
- pytest.skip(f"Skipping test due to quota/limit: {create_result['error']}")
-
- assert "error" not in create_result, f"Create failed: {create_result}"
- assert create_result.get("name") == app_name
-
- cleanup_apps(app_name)
-
- # Step 3: Wait for deployment
- log_time("Step 3: Waiting for app deployment...")
- max_wait = 600 # 10 minutes
- wait_interval = 15
- waited = 0
- deployed = False
-
- while waited < max_wait:
- get_result = manage_app(action="get", name=app_name)
- compute_status = get_result.get("status") or get_result.get("state")
-
- # Check deployment state (this is where deployment failures are reported)
- active_deployment = get_result.get("active_deployment", {})
- deployment_state = active_deployment.get("state") or active_deployment.get("status")
-
- log_time(f"App after {waited}s: compute={compute_status}, deployment={deployment_state}")
-
- # Check for deployment failure first (most important)
- if deployment_state and "FAILED" in str(deployment_state).upper():
- deployment_msg = active_deployment.get("status", {}).get("message", "")
- log_time(f"App deployment FAILED: {deployment_msg}")
- pytest.fail(f"App deployment failed: {deployment_state} - {deployment_msg}")
-
- # Check for successful deployment
- if deployment_state and "SUCCEEDED" in str(deployment_state).upper():
- deployed = True
- log_time("App deployment succeeded!")
- break
- elif compute_status in ("RUNNING", "READY", "DEPLOYED"):
- deployed = True
- break
-
- time.sleep(wait_interval)
- waited += wait_interval
-
- if not deployed:
- log_time(f"App not fully deployed after {max_wait}s, continuing with tests")
-
- # Step 4: Verify app via get
- log_time("Step 4: Verifying app via get...")
- get_result = manage_app(action="get", name=app_name)
- log_time(f"Get result: {get_result}")
- assert "error" not in get_result, f"Get app failed: {get_result}"
- assert get_result.get("name") == app_name
-
- # App should have a URL (may be None if not fully deployed)
- url = get_result.get("url")
- log_time(f"App URL: {url}")
-
- # Step 5: Delete app
- log_time("Step 5: Deleting app...")
- delete_result = manage_app(action="delete", name=app_name)
- log_time(f"Delete result: {delete_result}")
- assert "error" not in delete_result, f"Delete failed: {delete_result}"
-
- # Step 6: Verify app is deleted or deleting
- log_time("Step 6: Verifying app deleted...")
- time.sleep(10)
- get_after = manage_app(action="get", name=app_name)
- log_time(f"Get after delete: {get_after}")
-
- # Should return error, indicate not found, or be in DELETING state
- status_str = str(get_after).lower()
- assert (
- "error" in get_after
- or "not found" in status_str
- or "deleting" in status_str
- ), f"App should be deleted or deleting: {get_after}"
-
- log_time("Full app lifecycle test PASSED!")
-
- except Exception as e:
- log_time(f"Test failed: {e}")
- raise
- finally:
- # Cleanup workspace files
- try:
- workspace_client.workspace.delete(workspace_path, recursive=True)
- log_time(f"Cleaned up workspace path: {workspace_path}")
- except Exception as e:
- log_time(f"Failed to cleanup workspace path: {e}")
diff --git a/databricks-mcp-server/tests/integration/compute/__init__.py b/databricks-mcp-server/tests/integration/compute/__init__.py
deleted file mode 100644
index b4382f56..00000000
--- a/databricks-mcp-server/tests/integration/compute/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Compute integration tests
diff --git a/databricks-mcp-server/tests/integration/compute/test_compute.py b/databricks-mcp-server/tests/integration/compute/test_compute.py
deleted file mode 100644
index 23e8e00a..00000000
--- a/databricks-mcp-server/tests/integration/compute/test_compute.py
+++ /dev/null
@@ -1,262 +0,0 @@
-"""
-Integration tests for compute MCP tools.
-
-Tests:
-- execute_code: serverless and cluster execution
-- list_compute: clusters, node_types, spark_versions
-- manage_cluster: create, modify, start, terminate, delete
-"""
-
-import logging
-import time
-
-import pytest
-
-from databricks_mcp_server.tools.compute import execute_code, list_compute, manage_cluster
-from tests.test_config import TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-# Deterministic name for tests (enables safe cleanup/restart)
-CLUSTER_NAME = f"{TEST_RESOURCE_PREFIX}cluster_lifecycle"
-
-
-@pytest.mark.integration
-class TestListCompute:
- """Tests for list_compute tool."""
-
- @pytest.mark.slow
- def test_list_clusters(self):
- """Should list all clusters (slow - iterates all clusters)."""
- result = list_compute(resource="clusters")
-
- logger.info(f"List clusters result: {result}")
-
- assert "error" not in result, f"List failed: {result}"
- assert "clusters" in result
- assert isinstance(result["clusters"], list)
-
- def test_list_node_types(self):
- """Should list available node types."""
- result = list_compute(resource="node_types")
-
- logger.info(f"List node types result: {result}")
-
- assert "error" not in result, f"List failed: {result}"
- assert "node_types" in result
- assert isinstance(result["node_types"], list)
- assert len(result["node_types"]) > 0
-
- def test_list_spark_versions(self):
- """Should list available Spark versions."""
- result = list_compute(resource="spark_versions")
-
- logger.info(f"List spark versions result: {result}")
-
- assert "error" not in result, f"List failed: {result}"
- assert "spark_versions" in result
- assert isinstance(result["spark_versions"], list)
- assert len(result["spark_versions"]) > 0
-
- def test_invalid_resource(self):
- """Should return error for invalid resource type."""
- result = list_compute(resource="invalid_resource")
-
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestExecuteCode:
- """Tests for execute_code tool.
-
- Consolidated into fewer tests to minimize serverless cold start overhead (~30s each).
- """
-
- def test_execute_serverless_comprehensive(self):
- """Test Python, SQL via spark.sql, and error handling in one serverless run.
-
- This consolidates multiple scenarios into one test to avoid repeated cold starts.
- Tests: print output, spark.sql execution, try/except error handling.
- """
- # Comprehensive notebook that tests multiple scenarios
- code = '''
-# Test 1: Basic Python output
-print("TEST1: Hello from MCP test")
-result_value = 42
-print(f"TEST1: Result value is {result_value}")
-
-# Test 2: SQL via spark.sql
-df = spark.sql("SELECT 42 as answer, 'hello' as greeting")
-print(f"TEST2: SQL row count = {df.count()}")
-row = df.first()
-print(f"TEST2: answer={row.answer}, greeting={row.greeting}")
-
-# Test 3: Error handling (try invalid code in try/except)
-error_caught = False
-try:
- exec("this is not valid python!!!")
-except SyntaxError as e:
- error_caught = True
- print(f"TEST3: Caught expected SyntaxError: {type(e).__name__}")
-
-if not error_caught:
- raise Exception("TEST3 FAILED: SyntaxError was not caught")
-
-print("ALL TESTS PASSED")
-dbutils.notebook.exit("success")
-'''
- result = execute_code(
- code=code,
- compute_type="serverless",
- timeout=180,
- )
-
- logger.info(f"Comprehensive execute result: {result}")
-
- assert not result.get("error"), f"Execute failed: {result}"
- assert result.get("success", False), f"Execution should succeed: {result}"
-
- # Serverless notebooks return dbutils.notebook.exit() value as result
- # Print statements go to logs which may not be captured in result
- output = str(result.get("output", "")) + str(result.get("result", ""))
-
- # If we got "success" it means all internal tests passed (including SQL and error handling)
- # The notebook only exits with "success" if all tests pass
- assert "success" in output.lower(), \
- f"Notebook should exit with success (all internal tests passed): {output}"
-
-
-@pytest.mark.integration
-class TestManageCluster:
- """Tests for manage_cluster tool (read-only operations)."""
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_cluster(action="invalid_action")
-
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestClusterLifecycle:
- """End-to-end test for complete cluster lifecycle.
-
- Consolidated into a single test: create -> list -> start -> terminate -> delete.
- This is much faster than separate tests that each create their own cluster.
- """
-
- def test_full_cluster_lifecycle(self, workspace_client):
- """Test complete cluster lifecycle: create, list, start, terminate, delete."""
- cluster_id = None
- test_start = time.time()
-
- def log_time(msg):
- elapsed = time.time() - test_start
- print(f"[{elapsed:.1f}s] {msg}", flush=True)
-
- try:
- # Step 1: Create cluster
- log_time("Step 1: Creating cluster...")
- t0 = time.time()
- create_result = manage_cluster(
- action="create",
- name=CLUSTER_NAME,
- num_workers=0, # Single-node for cost savings
- autotermination_minutes=10,
- )
- log_time(f"Create took {time.time()-t0:.1f}s - Result: {create_result}")
- assert create_result.get("cluster_id"), f"Create failed: {create_result}"
- cluster_id = create_result["cluster_id"]
-
- # Step 2: Get cluster to verify it exists (test both APIs)
- log_time("Step 2a: Verifying cluster via manage_cluster(get)...")
- t0 = time.time()
- status = manage_cluster(action="get", cluster_id=cluster_id)
- log_time(f"Get took {time.time()-t0:.1f}s - Status: {status.get('state')}")
- assert "error" not in status, f"Get cluster failed: {status}"
- assert status.get("cluster_name") == CLUSTER_NAME, f"Name mismatch: {status}"
-
- log_time("Step 2b: Verifying cluster via list_compute(cluster_id)...")
- t0 = time.time()
- status2 = list_compute(resource="clusters", cluster_id=cluster_id)
- log_time(f"list_compute took {time.time()-t0:.1f}s - Status: {status2.get('state')}")
- assert "error" not in status2, f"list_compute failed: {status2}"
-
- # Step 3: Start the cluster
- log_time("Step 3: Starting cluster...")
- t0 = time.time()
- start_result = manage_cluster(action="start", cluster_id=cluster_id)
- log_time(f"Start took {time.time()-t0:.1f}s - Result: {start_result}")
- # Start may fail if cluster is already starting/running, that's ok
- if "error" in str(start_result).lower() and "already" not in str(start_result).lower():
- assert False, f"Start failed unexpectedly: {start_result}"
-
- # Wait for cluster to start processing the request
- log_time("Waiting 10s after start request...")
- time.sleep(10)
-
- # Check state (should be PENDING, RUNNING, or STARTING)
- t0 = time.time()
- status = manage_cluster(action="get", cluster_id=cluster_id)
- log_time(f"State check took {time.time()-t0:.1f}s - State: {status.get('state')}")
-
- # Step 4: Terminate the cluster
- log_time("Step 4: Terminating cluster...")
- t0 = time.time()
- terminate_result = manage_cluster(action="terminate", cluster_id=cluster_id)
- log_time(f"Terminate took {time.time()-t0:.1f}s - Result: {terminate_result}")
- assert "error" not in str(terminate_result).lower() or terminate_result.get("success"), \
- f"Terminate failed: {terminate_result}"
-
- # Wait for termination to process
- log_time("Waiting 10s after terminate request...")
- time.sleep(10)
-
- # Poll for TERMINATED state (max 60s)
- max_wait = 60
- waited = 0
- while waited < max_wait:
- t0 = time.time()
- status = manage_cluster(action="get", cluster_id=cluster_id)
- state = status.get("state")
- log_time(f"Poll took {time.time()-t0:.1f}s - State after {waited}s: {state}")
- if state == "TERMINATED":
- break
- time.sleep(10)
- waited += 10
-
- # Step 5: Delete the cluster (permanent)
- log_time("Step 5: Deleting cluster...")
- t0 = time.time()
- delete_result = manage_cluster(action="delete", cluster_id=cluster_id)
- log_time(f"Delete took {time.time()-t0:.1f}s - Result: {delete_result}")
- assert "error" not in str(delete_result).lower() or delete_result.get("success"), \
- f"Delete failed: {delete_result}"
-
- # Wait for deletion
- time.sleep(5)
-
- # Verify cluster is gone or marked as deleted
- t0 = time.time()
- final_status = manage_cluster(action="get", cluster_id=cluster_id)
- log_time(f"Final check took {time.time()-t0:.1f}s - Status: {final_status}")
-
- is_deleted = (
- final_status.get("exists") is False or
- final_status.get("state") in ("DELETED", "TERMINATED")
- )
- assert is_deleted, f"Cluster should be deleted: {final_status}"
-
- log_time("Full cluster lifecycle test PASSED!")
- cluster_id = None # Clear so finally block doesn't try to clean up
-
- finally:
- # Cleanup if test failed partway through
- if cluster_id:
- log_time(f"Cleanup: attempting to delete cluster {cluster_id}")
- try:
- manage_cluster(action="terminate", cluster_id=cluster_id)
- time.sleep(5)
- manage_cluster(action="delete", cluster_id=cluster_id)
- except Exception as e:
- logger.warning(f"Cleanup failed: {e}")
diff --git a/databricks-mcp-server/tests/integration/dashboards/__init__.py b/databricks-mcp-server/tests/integration/dashboards/__init__.py
deleted file mode 100644
index 3002af54..00000000
--- a/databricks-mcp-server/tests/integration/dashboards/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Dashboard integration tests
diff --git a/databricks-mcp-server/tests/integration/dashboards/test_dashboards.py b/databricks-mcp-server/tests/integration/dashboards/test_dashboards.py
deleted file mode 100644
index 2ef80c70..00000000
--- a/databricks-mcp-server/tests/integration/dashboards/test_dashboards.py
+++ /dev/null
@@ -1,466 +0,0 @@
-"""
-Integration tests for AI/BI dashboards MCP tool.
-
-Tests:
-- manage_dashboard: create_or_update, get, list, delete, publish, unpublish
-"""
-
-import json
-import logging
-import uuid
-
-import pytest
-
-from databricks_mcp_server.tools.aibi_dashboards import manage_dashboard
-from databricks_mcp_server.tools.sql import manage_warehouse
-from tests.test_config import TEST_CATALOG, SCHEMAS, TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-# Deterministic dashboard names for tests (enables safe cleanup/restart)
-DASHBOARD_NAME = f"{TEST_RESOURCE_PREFIX}dashboard"
-DASHBOARD_UPDATE = f"{TEST_RESOURCE_PREFIX}dashboard_update"
-DASHBOARD_PUBLISH = f"{TEST_RESOURCE_PREFIX}dashboard_publish"
-
-
-@pytest.fixture(scope="module")
-def clean_dashboards(current_user: str):
- """Pre-test cleanup: delete any existing test dashboards.
-
- Uses direct path lookup instead of listing all dashboards (much faster).
- """
- from databricks_tools_core.aibi_dashboards import find_dashboard_by_path
-
- dashboards_to_clean = [DASHBOARD_NAME, DASHBOARD_UPDATE, DASHBOARD_PUBLISH]
- parent_path = f"/Workspace/Users/{current_user}/ai_dev_kit_test/dashboards"
-
- for dash_name in dashboards_to_clean:
- try:
- # Direct path lookup - much faster than listing all dashboards
- dashboard_path = f"{parent_path}/{dash_name}.lvdash.json"
- dash_id = find_dashboard_by_path(dashboard_path)
- if dash_id:
- manage_dashboard(action="delete", dashboard_id=dash_id)
- logger.info(f"Pre-cleanup: deleted dashboard {dash_name}")
- except Exception as e:
- logger.warning(f"Pre-cleanup failed for {dash_name}: {e}")
-
- yield
-
- # Post-test cleanup
- for dash_name in dashboards_to_clean:
- try:
- dashboard_path = f"{parent_path}/{dash_name}.lvdash.json"
- dash_id = find_dashboard_by_path(dashboard_path)
- if dash_id:
- manage_dashboard(action="delete", dashboard_id=dash_id)
- logger.info(f"Post-cleanup: deleted dashboard {dash_name}")
- except Exception:
- pass
-
-
-@pytest.fixture(scope="module")
-def simple_dashboard_json() -> str:
- """Create a simple dashboard JSON for testing."""
- dashboard = {
- "datasets": [
- {
- "name": "simple_data",
- "displayName": "Simple Data",
- "queryLines": ["SELECT 1 as id, 'test' as value"]
- }
- ],
- "pages": [
- {
- "name": "page1",
- "displayName": "Test Page",
- "pageType": "PAGE_TYPE_CANVAS",
- "layout": [
- {
- "widget": {
- "name": "counter1",
- "queries": [
- {
- "name": "main_query",
- "query": {
- "datasetName": "simple_data",
- "fields": [{"name": "id", "expression": "`id`"}],
- "disaggregated": True
- }
- }
- ],
- "spec": {
- "version": 2,
- "widgetType": "counter",
- "encodings": {
- "value": {"fieldName": "id", "displayName": "Count"}
- },
- "frame": {"title": "Test Counter", "showTitle": True}
- }
- },
- "position": {"x": 0, "y": 0, "width": 2, "height": 3}
- }
- ]
- }
- ]
- }
- return json.dumps(dashboard)
-
-
-@pytest.fixture(scope="module")
-def existing_dashboard(workspace_client, current_user: str) -> str:
- """Find an existing dashboard in our test folder for read-only tests."""
- from databricks_tools_core.aibi_dashboards import find_dashboard_by_path
-
- test_folder = f"/Workspace/Users/{current_user}/ai_dev_kit_test/dashboards"
-
- # Try to find one of our test dashboards
- for dash_name in [DASHBOARD_NAME, DASHBOARD_UPDATE, DASHBOARD_PUBLISH]:
- try:
- dashboard_path = f"{test_folder}/{dash_name}.lvdash.json"
- dashboard_id = find_dashboard_by_path(dashboard_path)
- if dashboard_id:
- logger.info(f"Using existing dashboard: {dash_name}")
- return dashboard_id
- except Exception:
- pass
-
- # If no test dashboard exists, list the test folder
- try:
- items = workspace_client.workspace.list(test_folder)
- for item in items:
- if item.path and item.path.endswith(".lvdash.json"):
- dashboard_id = item.resource_id
- if dashboard_id:
- logger.info(f"Using dashboard from test folder: {item.path}")
- return dashboard_id
- except Exception as e:
- logger.warning(f"Could not list test folder: {e}")
-
- pytest.skip("No existing dashboard in test folder")
-
-
-@pytest.mark.integration
-class TestManageDashboard:
- """Tests for manage_dashboard tool."""
-
- @pytest.mark.slow
- def test_list_dashboards(self):
- """Should list all dashboards (slow - iterates all dashboards)."""
- result = manage_dashboard(action="list")
-
- logger.info(f"List result: {result}")
-
- assert not result.get("error"), f"List failed: {result}"
-
- def test_get_dashboard(self, existing_dashboard: str):
- """Should get dashboard details."""
- result = manage_dashboard(action="get", dashboard_id=existing_dashboard)
-
- logger.info(f"Get result: {result}")
-
- assert not result.get("error"), f"Get failed: {result}"
-
- def test_get_nonexistent_dashboard(self):
- """Should handle nonexistent dashboard gracefully."""
- try:
- result = manage_dashboard(action="get", dashboard_id="nonexistent_dashboard_12345")
- logger.info(f"Get nonexistent result: {result}")
- # Should return error
- assert result.get("error")
- except Exception as e:
- # SDK raises exception for nonexistent dashboard - this is acceptable
- error_msg = str(e).lower()
- assert "invalid" in error_msg or "not found" in error_msg or "not exist" in error_msg
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_dashboard(action="invalid_action")
-
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestDashboardLifecycle:
- """End-to-end tests for dashboard lifecycle."""
-
- def test_create_dashboard(
- self,
- current_user: str,
- simple_dashboard_json: str,
- warehouse_id: str,
- cleanup_dashboards,
- ):
- """Should create a dashboard and verify its structure."""
- dashboard_name = f"{TEST_RESOURCE_PREFIX}dash_create_{uuid.uuid4().hex[:6]}"
- parent_path = f"/Workspace/Users/{current_user}/ai_dev_kit_test/dashboards"
-
- result = manage_dashboard(
- action="create_or_update",
- display_name=dashboard_name,
- parent_path=parent_path,
- serialized_dashboard=simple_dashboard_json,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Create result: {result}")
-
- assert not result.get("error"), f"Create failed: {result}"
-
- dashboard_id = result.get("dashboard_id") or result.get("id")
- assert dashboard_id, f"Dashboard ID should be returned: {result}"
-
- cleanup_dashboards(dashboard_id)
-
- # Verify dashboard can be retrieved with correct structure
- get_result = manage_dashboard(action="get", dashboard_id=dashboard_id)
-
- logger.info(f"Get result after create: {get_result}")
-
- assert not get_result.get("error"), f"Get failed: {get_result}"
-
- # Verify dashboard name
- retrieved_name = get_result.get("display_name") or get_result.get("name")
- assert dashboard_name in str(retrieved_name), \
- f"Dashboard name mismatch: expected {dashboard_name}, got {retrieved_name}"
-
- # Verify dashboard has serialized content (proving it was created with our definition)
- serialized = get_result.get("serialized_dashboard")
- if serialized:
- # Should contain our dataset and widget
- assert "simple_data" in serialized, "Dashboard should contain our dataset"
- assert "counter" in serialized.lower(), "Dashboard should contain our counter widget"
-
- def test_create_and_delete_dashboard(
- self,
- current_user: str,
- simple_dashboard_json: str,
- warehouse_id: str,
- ):
- """Should create and delete a dashboard, verifying each step."""
- dashboard_name = f"{TEST_RESOURCE_PREFIX}dash_del_{uuid.uuid4().hex[:6]}"
- parent_path = f"/Workspace/Users/{current_user}/ai_dev_kit_test/dashboards"
-
- # Create
- create_result = manage_dashboard(
- action="create_or_update",
- display_name=dashboard_name,
- parent_path=parent_path,
- serialized_dashboard=simple_dashboard_json,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Create result: {create_result}")
-
- dashboard_id = create_result.get("dashboard_id") or create_result.get("id")
- assert dashboard_id, f"Dashboard not created: {create_result}"
-
- # Verify dashboard exists before delete
- get_result = manage_dashboard(action="get", dashboard_id=dashboard_id)
- assert not get_result.get("error"), f"Dashboard should exist after create: {get_result}"
-
- # Delete
- delete_result = manage_dashboard(action="delete", dashboard_id=dashboard_id)
-
- logger.info(f"Delete result: {delete_result}")
-
- assert not delete_result.get("error") or delete_result.get("status") == "deleted"
-
- # Verify dashboard is gone
- get_after_delete = manage_dashboard(action="get", dashboard_id=dashboard_id)
- # Should return error or indicate not found
- assert "error" in get_after_delete or get_after_delete.get("lifecycle_state") == "TRASHED", \
- f"Dashboard should be deleted/trashed: {get_after_delete}"
-
-
-@pytest.mark.integration
-class TestDashboardUpdate:
- """Tests for dashboard update functionality."""
-
- def test_update_dashboard(
- self,
- current_user: str,
- simple_dashboard_json: str,
- warehouse_id: str,
- clean_dashboards,
- cleanup_dashboards,
- ):
- """Should create a dashboard, update it, and verify changes."""
- parent_path = f"/Workspace/Users/{current_user}/ai_dev_kit_test/dashboards"
-
- # Create initial dashboard
- create_result = manage_dashboard(
- action="create_or_update",
- display_name=DASHBOARD_UPDATE,
- parent_path=parent_path,
- serialized_dashboard=simple_dashboard_json,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Create for update test: {create_result}")
-
- assert not create_result.get("error"), f"Create failed: {create_result}"
-
- dashboard_id = create_result.get("dashboard_id") or create_result.get("id")
- assert dashboard_id, f"Dashboard ID should be returned: {create_result}"
-
- cleanup_dashboards(dashboard_id)
-
- # Create updated dashboard JSON with different dataset
- updated_dashboard = {
- "datasets": [
- {
- "name": "updated_data",
- "displayName": "Updated Data",
- "queryLines": ["SELECT 42 as answer, 'updated' as status"]
- }
- ],
- "pages": [
- {
- "name": "page1",
- "displayName": "Updated Page",
- "pageType": "PAGE_TYPE_CANVAS",
- "layout": [
- {
- "widget": {
- "name": "counter1",
- "queries": [
- {
- "name": "main_query",
- "query": {
- "datasetName": "updated_data",
- "fields": [{"name": "answer", "expression": "`answer`"}],
- "disaggregated": True
- }
- }
- ],
- "spec": {
- "version": 2,
- "widgetType": "counter",
- "encodings": {
- "value": {"fieldName": "answer", "displayName": "Answer"}
- },
- "frame": {"title": "Updated Counter", "showTitle": True}
- }
- },
- "position": {"x": 0, "y": 0, "width": 2, "height": 3}
- }
- ]
- }
- ]
- }
-
- # Update the dashboard
- update_result = manage_dashboard(
- action="create_or_update",
- display_name=DASHBOARD_UPDATE,
- parent_path=parent_path,
- serialized_dashboard=json.dumps(updated_dashboard),
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Update result: {update_result}")
-
- assert not update_result.get("error"), f"Update failed: {update_result}"
-
- # Verify the dashboard was updated
- get_result = manage_dashboard(action="get", dashboard_id=dashboard_id)
-
- logger.info(f"Get after update: {get_result}")
-
- assert not get_result.get("error"), f"Get after update failed: {get_result}"
-
- # Verify updated content
- serialized = get_result.get("serialized_dashboard")
- if serialized:
- assert "updated_data" in serialized, \
- f"Dashboard should contain updated dataset: {serialized[:200]}..."
- assert "42" in serialized or "answer" in serialized, \
- f"Dashboard should contain updated query: {serialized[:200]}..."
-
-
-@pytest.mark.integration
-class TestDashboardPublish:
- """Tests for dashboard publish/unpublish functionality."""
-
- def test_publish_and_unpublish_dashboard(
- self,
- current_user: str,
- simple_dashboard_json: str,
- warehouse_id: str,
- clean_dashboards,
- cleanup_dashboards,
- ):
- """Should create a dashboard, publish it, verify published state, then unpublish."""
- parent_path = f"/Workspace/Users/{current_user}/ai_dev_kit_test/dashboards"
-
- # Create dashboard
- create_result = manage_dashboard(
- action="create_or_update",
- display_name=DASHBOARD_PUBLISH,
- parent_path=parent_path,
- serialized_dashboard=simple_dashboard_json,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Create for publish test: {create_result}")
-
- assert not create_result.get("error"), f"Create failed: {create_result}"
-
- dashboard_id = create_result.get("dashboard_id") or create_result.get("id")
- assert dashboard_id, f"Dashboard ID should be returned: {create_result}"
-
- cleanup_dashboards(dashboard_id)
-
- # Verify initial state is DRAFT
- get_before = manage_dashboard(action="get", dashboard_id=dashboard_id)
- initial_state = get_before.get("lifecycle_state") or get_before.get("state")
- logger.info(f"Initial lifecycle state: {initial_state}")
-
- # Publish the dashboard
- publish_result = manage_dashboard(
- action="publish",
- dashboard_id=dashboard_id,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Publish result: {publish_result}")
-
- assert not publish_result.get("error"), f"Publish failed: {publish_result}"
-
- # Verify published state
- get_after_publish = manage_dashboard(action="get", dashboard_id=dashboard_id)
-
- logger.info(f"Get after publish: {get_after_publish}")
-
- assert not get_after_publish.get("error"), f"Get after publish failed: {get_after_publish}"
-
- # Check lifecycle_state is PUBLISHED or similar
- published_state = get_after_publish.get("lifecycle_state") or get_after_publish.get("state")
- logger.info(f"State after publish: {published_state}")
-
- # The dashboard should indicate it's published
- assert published_state in ("PUBLISHED", "ACTIVE") or get_after_publish.get("is_published"), \
- f"Dashboard should be published, got state: {published_state}"
-
- # Unpublish the dashboard
- unpublish_result = manage_dashboard(
- action="unpublish",
- dashboard_id=dashboard_id,
- )
-
- logger.info(f"Unpublish result: {unpublish_result}")
-
- assert not unpublish_result.get("error"), f"Unpublish failed: {unpublish_result}"
-
- # Verify unpublished state
- get_after_unpublish = manage_dashboard(action="get", dashboard_id=dashboard_id)
-
- logger.info(f"Get after unpublish: {get_after_unpublish}")
-
- unpublished_state = get_after_unpublish.get("lifecycle_state") or get_after_unpublish.get("state")
- logger.info(f"State after unpublish: {unpublished_state}")
-
- # Should be back to DRAFT or similar
- assert unpublished_state in ("DRAFT", "ACTIVE") or not get_after_unpublish.get("is_published"), \
- f"Dashboard should be unpublished, got state: {unpublished_state}"
diff --git a/databricks-mcp-server/tests/integration/genie/__init__.py b/databricks-mcp-server/tests/integration/genie/__init__.py
deleted file mode 100644
index 99c30014..00000000
--- a/databricks-mcp-server/tests/integration/genie/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Genie integration tests
diff --git a/databricks-mcp-server/tests/integration/genie/test_genie.py b/databricks-mcp-server/tests/integration/genie/test_genie.py
deleted file mode 100644
index c8c1d8db..00000000
--- a/databricks-mcp-server/tests/integration/genie/test_genie.py
+++ /dev/null
@@ -1,248 +0,0 @@
-"""
-Integration tests for Genie MCP tools.
-
-Tests:
-- manage_genie: create_or_update, get, list, delete
-- ask_genie: basic queries
-"""
-
-import logging
-import time
-import uuid
-
-import pytest
-
-from databricks_mcp_server.tools.genie import manage_genie, ask_genie
-from databricks_mcp_server.tools.sql import execute_sql
-from tests.test_config import TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture(scope="module")
-def genie_source_table(
- workspace_client,
- test_catalog: str,
- genie_schema: str,
- warehouse_id: str,
-) -> str:
- """Create a source table for Genie space."""
- table_name = f"{test_catalog}.{genie_schema}.sales_data"
-
- # Create table with test data
- execute_sql(
- sql_query=f"""
- CREATE OR REPLACE TABLE {table_name} (
- order_id INT,
- customer STRING,
- amount DECIMAL(10, 2),
- order_date DATE
- )
- """,
- warehouse_id=warehouse_id,
- )
-
- execute_sql(
- sql_query=f"""
- INSERT INTO {table_name} VALUES
- (1, 'Alice', 100.00, '2024-01-15'),
- (2, 'Bob', 150.00, '2024-01-16'),
- (3, 'Alice', 200.00, '2024-01-17')
- """,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Created Genie source table: {table_name}")
- return table_name
-
-
-@pytest.mark.integration
-class TestManageGenie:
- """Tests for manage_genie tool."""
-
- # TODO: Re-enable once pagination performance is improved - takes too long in large workspaces
- # @pytest.mark.slow
- # def test_list_spaces(self):
- # """Should list all Genie spaces (slow due to pagination)."""
- # result = manage_genie(action="list")
- #
- # logger.info(f"List result: found {len(result.get('spaces', []))} spaces")
- #
- # assert "error" not in result, f"List failed: {result}"
- # # Should have spaces in a real workspace
- # assert "spaces" in result
-
- def test_get_nonexistent_space(self):
- """Should handle nonexistent space gracefully."""
- result = manage_genie(action="get", space_id="nonexistent_space_12345")
-
- # Should return error or not found
- logger.info(f"Get nonexistent result: {result}")
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_genie(action="invalid_action")
-
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestGenieLifecycle:
- """End-to-end test for Genie space lifecycle: create -> update -> query -> export -> import -> delete."""
-
- def test_full_genie_lifecycle(
- self,
- genie_source_table: str,
- warehouse_id: str,
- current_user: str,
- ):
- """Test complete Genie lifecycle with space name containing spaces and /."""
- test_start = time.time()
- # Name with spaces to test edge cases (no / allowed in display name)
- space_name = f"{TEST_RESOURCE_PREFIX}Genie Test Space {uuid.uuid4().hex[:6]}"
- space_id = None
- imported_space_id = None
-
- def log_time(msg):
- elapsed = time.time() - test_start
- logger.info(f"[{elapsed:.1f}s] {msg}")
-
- try:
- # Step 1: Create space with special characters in name
- log_time(f"Step 1: Creating space '{space_name}'...")
- create_result = manage_genie(
- action="create_or_update",
- display_name=space_name,
- description="Initial description",
- warehouse_id=warehouse_id,
- table_identifiers=[genie_source_table],
- )
- log_time(f"Create result: {create_result}")
- assert "error" not in create_result, f"Create failed: {create_result}"
- space_id = create_result.get("space_id")
- assert space_id, f"Space ID should be returned: {create_result}"
-
- # Step 2: Update the space
- log_time("Step 2: Updating space...")
- update_result = manage_genie(
- action="create_or_update",
- space_id=space_id,
- display_name=space_name,
- description="Updated description",
- warehouse_id=warehouse_id,
- table_identifiers=[genie_source_table],
- )
- log_time(f"Update result: {update_result}")
- assert "error" not in update_result, f"Update failed: {update_result}"
-
- # Step 3: Get and verify update
- log_time("Step 3: Verifying update via get...")
- get_result = manage_genie(action="get", space_id=space_id)
- assert "error" not in get_result, f"Get failed: {get_result}"
- description = get_result.get("description") or get_result.get("spec", {}).get("description")
- log_time(f"Description after update: {description}")
- # Note: The Genie API may not immediately reflect description updates, so we just log it
-
- # Step 4: Query the space (wait for ready)
- log_time("Step 4: Querying space...")
- max_wait = 120
- wait_interval = 10
- waited = 0
- query_success = False
-
- while waited < max_wait:
- test_query = ask_genie(
- space_id=space_id,
- question="How many orders are there?",
- timeout_seconds=30,
- )
- log_time(f"Query attempt after {waited}s: status={test_query.get('status')}")
-
- if test_query.get("status") not in ("FAILED", "ERROR") and "error" not in test_query:
- query_success = True
- response_text = str(test_query.get("response", "")) + str(test_query.get("result", ""))
- if "3" in response_text or test_query.get("status") == "COMPLETED":
- log_time("Query succeeded with expected result")
- break
- time.sleep(wait_interval)
- waited += wait_interval
-
- if not query_success:
- log_time(f"Space not ready after {max_wait}s, skipping query verification")
-
- # Step 5: Export the space
- log_time("Step 5: Exporting space...")
- export_result = manage_genie(action="export", space_id=space_id)
- log_time(f"Export result: {list(export_result.keys()) if isinstance(export_result, dict) else 'error'}")
- assert "error" not in export_result, f"Export failed: {export_result}"
- serialized_space = export_result.get("serialized_space")
- assert serialized_space, f"serialized_space should be returned: {export_result}"
-
- # Step 6: Import as new space
- log_time("Step 6: Importing as new space...")
- parent_path = f"/Workspace/Users/{current_user}"
- import_name = f"{TEST_RESOURCE_PREFIX}Genie Imported Space"
- import_result = manage_genie(
- action="import",
- serialized_space=serialized_space,
- title=import_name,
- parent_path=parent_path,
- warehouse_id=warehouse_id,
- )
- log_time(f"Import result: {import_result}")
- assert "error" not in import_result, f"Import failed: {import_result}"
- imported_space_id = import_result.get("space_id")
- assert imported_space_id, f"Imported space ID should be returned: {import_result}"
- assert imported_space_id != space_id, "Imported space should have different ID"
-
- # Step 7: Delete original space
- log_time("Step 7: Deleting original space...")
- delete_result = manage_genie(action="delete", space_id=space_id)
- log_time(f"Delete result: {delete_result}")
- assert "error" not in delete_result, f"Delete failed: {delete_result}"
- assert delete_result.get("success") is True, f"Delete should return success=True: {delete_result}"
- space_id = None # Mark as deleted
-
- # Step 8: Delete imported space
- log_time("Step 8: Deleting imported space...")
- delete_imported = manage_genie(action="delete", space_id=imported_space_id)
- log_time(f"Delete imported result: {delete_imported}")
- assert "error" not in delete_imported, f"Delete imported failed: {delete_imported}"
- imported_space_id = None # Mark as deleted
-
- log_time("Full Genie lifecycle test PASSED!")
-
- except Exception as e:
- log_time(f"Test failed: {e}")
- raise
- finally:
- # Cleanup on failure
- if space_id:
- log_time(f"Cleanup: deleting space {space_id}")
- try:
- manage_genie(action="delete", space_id=space_id)
- except Exception:
- pass
- if imported_space_id:
- log_time(f"Cleanup: deleting imported space {imported_space_id}")
- try:
- manage_genie(action="delete", space_id=imported_space_id)
- except Exception:
- pass
-
-
-@pytest.mark.integration
-class TestAskGenie:
- """Tests for ask_genie tool."""
-
- def test_ask_nonexistent_space(self):
- """Should handle nonexistent space gracefully."""
- result = ask_genie(
- space_id="nonexistent_space_12345",
- question="test question",
- timeout_seconds=10,
- )
-
- # Should return error
- logger.info(f"Ask nonexistent result: {result}")
- assert result.get("status") == "FAILED" or "error" in result
diff --git a/databricks-mcp-server/tests/integration/jobs/__init__.py b/databricks-mcp-server/tests/integration/jobs/__init__.py
deleted file mode 100644
index 0b9e38c8..00000000
--- a/databricks-mcp-server/tests/integration/jobs/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Jobs integration tests
diff --git a/databricks-mcp-server/tests/integration/jobs/test_jobs.py b/databricks-mcp-server/tests/integration/jobs/test_jobs.py
deleted file mode 100644
index 6c1f2b55..00000000
--- a/databricks-mcp-server/tests/integration/jobs/test_jobs.py
+++ /dev/null
@@ -1,697 +0,0 @@
-"""
-Integration tests for jobs MCP tools.
-
-Tests:
-- manage_jobs: create, get, list, update, delete
-- manage_job_runs: run_now, get, list, cancel
-"""
-
-import logging
-import time
-import uuid
-
-import pytest
-
-from databricks_mcp_server.tools.jobs import manage_jobs, manage_job_runs
-from databricks_mcp_server.tools.file import manage_workspace_files
-from tests.test_config import TEST_CATALOG, SCHEMAS, TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-# Deterministic job names for tests (enables safe cleanup/restart)
-JOB_NAME = f"{TEST_RESOURCE_PREFIX}job"
-JOB_UPDATE = f"{TEST_RESOURCE_PREFIX}job_update"
-JOB_CANCEL = f"{TEST_RESOURCE_PREFIX}job_cancel"
-
-# Simple notebook content that exits successfully
-TEST_NOTEBOOK_CONTENT = """# Databricks notebook source
-# MAGIC %md
-# MAGIC # Test Notebook for MCP Integration Tests
-
-# COMMAND ----------
-
-print("Hello from MCP integration test!")
-result = 1 + 1
-print(f"1 + 1 = {result}")
-
-# COMMAND ----------
-
-dbutils.notebook.exit("SUCCESS")
-"""
-
-
-@pytest.fixture(scope="module")
-def test_notebook_path(workspace_client, current_user: str):
- """Create a test notebook for job execution."""
- import tempfile
- import shutil
- import os
-
- # The notebook path without extension (Databricks strips .py from notebooks)
- notebook_path = f"/Workspace/Users/{current_user}/ai_dev_kit_test/jobs/resources/test_notebook"
-
- # Create temp directory with properly named notebook file
- temp_dir = tempfile.mkdtemp()
- temp_notebook_file = os.path.join(temp_dir, "test_notebook.py")
- with open(temp_notebook_file, "w") as f:
- f.write(TEST_NOTEBOOK_CONTENT)
-
- try:
- # Upload notebook directly to the full path (upload_to_workspace places single files
- # at the workspace_path directly, so we specify the full notebook path)
- result = manage_workspace_files(
- action="upload",
- local_path=temp_notebook_file,
- workspace_path=notebook_path,
- overwrite=True,
- )
- logger.info(f"Uploaded test notebook: {result}")
-
- yield notebook_path
-
- finally:
- # Cleanup temp directory
- shutil.rmtree(temp_dir, ignore_errors=True)
-
- # Cleanup workspace notebook
- try:
- workspace_client.workspace.delete(notebook_path)
- except Exception as e:
- logger.warning(f"Failed to cleanup notebook: {e}")
-
-
-@pytest.fixture(scope="module")
-def clean_jobs():
- """Pre-test cleanup: delete any existing test jobs using find_by_name (fast)."""
- jobs_to_clean = [JOB_NAME, JOB_UPDATE, JOB_CANCEL]
-
- for job_name in jobs_to_clean:
- try:
- # Use find_by_name for O(1) lookup instead of listing all jobs
- result = manage_jobs(action="find_by_name", name=job_name)
- job_id = result.get("job_id")
- if job_id:
- manage_jobs(action="delete", job_id=str(job_id))
- logger.info(f"Pre-cleanup: deleted job {job_name}")
- except Exception as e:
- logger.warning(f"Pre-cleanup failed for {job_name}: {e}")
-
- yield
-
- # Post-test cleanup
- for job_name in jobs_to_clean:
- try:
- result = manage_jobs(action="find_by_name", name=job_name)
- job_id = result.get("job_id")
- if job_id:
- manage_jobs(action="delete", job_id=str(job_id))
- logger.info(f"Post-cleanup: deleted job {job_name}")
- except Exception:
- pass
-
-
-@pytest.fixture(scope="module")
-def clean_job():
- """Ensure job doesn't exist before tests and cleanup after using find_by_name (fast)."""
- # Try to find and delete existing job with this name
- try:
- result = manage_jobs(action="find_by_name", name=JOB_NAME)
- job_id = result.get("job_id")
- if job_id:
- manage_jobs(action="delete", job_id=str(job_id))
- logger.info(f"Cleaned up existing job: {JOB_NAME}")
- except Exception as e:
- logger.warning(f"Error during pre-cleanup: {e}")
-
- yield JOB_NAME
-
- # Cleanup after tests
- try:
- result = manage_jobs(action="find_by_name", name=JOB_NAME)
- job_id = result.get("job_id")
- if job_id:
- manage_jobs(action="delete", job_id=str(job_id))
- logger.info(f"Final cleanup of job: {JOB_NAME}")
- except Exception as e:
- logger.warning(f"Error during post-cleanup: {e}")
-
-
-@pytest.mark.integration
-class TestManageJobs:
- """Tests for manage_jobs tool."""
-
- def test_list_jobs(self):
- """Should list all jobs."""
- result = manage_jobs(action="list")
-
- logger.info(f"List result keys: {result.keys() if isinstance(result, dict) else result}")
-
- assert not result.get("error"), f"List failed: {result}"
- # API may return "jobs" or "items" depending on SDK version
- jobs = result.get("jobs") or result.get("items", [])
- assert isinstance(jobs, list)
-
- def test_create_job(self, clean_job: str, test_notebook_path: str, cleanup_jobs):
- """Should create a new job and verify its configuration."""
- # Create a simple notebook task job
- result = manage_jobs(
- action="create",
- name=clean_job,
- tasks=[
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- },
- "new_cluster": {
- "spark_version": "14.3.x-scala2.12",
- "num_workers": 0,
- "node_type_id": "i3.xlarge",
- },
- }
- ],
- )
-
- logger.info(f"Create result: {result}")
-
- assert "error" not in result, f"Create failed: {result}"
- assert result.get("job_id") is not None
-
- job_id = result["job_id"]
- cleanup_jobs(str(job_id))
-
- # Verify job configuration
- get_result = manage_jobs(action="get", job_id=str(job_id))
- assert "error" not in get_result, f"Get job failed: {get_result}"
-
- # Verify job name matches
- settings = get_result.get("settings", {})
- assert settings.get("name") == clean_job, f"Job name mismatch: {settings.get('name')}"
-
- # Verify task is configured
- tasks = settings.get("tasks", [])
- assert len(tasks) >= 1, f"Job should have at least 1 task: {tasks}"
- assert tasks[0].get("task_key") == "test_task"
-
- def test_create_job_with_optional_params(self, test_notebook_path: str, cleanup_jobs):
- """Should create a job with optional params (email_notifications, schedule, queue).
-
- This tests the fix for passing raw dicts to SDK - they must be converted to SDK objects.
- """
- job_name = f"{TEST_RESOURCE_PREFIX}job_optional_{uuid.uuid4().hex[:6]}"
-
- # Create job with optional parameters that require SDK type conversion
- result = manage_jobs(
- action="create",
- name=job_name,
- tasks=[
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- },
- "new_cluster": {
- "spark_version": "14.3.x-scala2.12",
- "num_workers": 0,
- "node_type_id": "i3.xlarge",
- },
- }
- ],
- # These optional params require SDK type conversion (JobEmailNotifications, CronSchedule, QueueSettings)
- email_notifications={
- "on_start": [], # Empty list is valid
- "on_success": [],
- "on_failure": [],
- "no_alert_for_skipped_runs": True,
- },
- schedule={
- "quartz_cron_expression": "0 0 9 * * ?", # Daily at 9 AM
- "timezone_id": "UTC",
- "pause_status": "PAUSED", # Start paused so it doesn't actually run
- },
- queue={
- "enabled": True,
- },
- )
-
- logger.info(f"Create with optional params result: {result}")
-
- assert "error" not in result, f"Create with optional params failed: {result}"
- assert result.get("job_id") is not None
-
- job_id = result["job_id"]
- cleanup_jobs(str(job_id))
-
- # Verify job configuration includes the optional params
- get_result = manage_jobs(action="get", job_id=str(job_id))
- assert "error" not in get_result, f"Get job failed: {get_result}"
-
- settings = get_result.get("settings", {})
-
- # Verify email_notifications was persisted
- email_notif = settings.get("email_notifications", {})
- assert email_notif.get("no_alert_for_skipped_runs") is True, \
- f"email_notifications should be persisted: {email_notif}"
-
- # Verify schedule was persisted
- schedule = settings.get("schedule", {})
- assert schedule.get("quartz_cron_expression") == "0 0 9 * * ?", \
- f"schedule should be persisted: {schedule}"
- assert schedule.get("timezone_id") == "UTC", \
- f"schedule timezone should be UTC: {schedule}"
- assert schedule.get("pause_status") == "PAUSED", \
- f"schedule should be paused: {schedule}"
-
- # Verify queue was persisted
- queue = settings.get("queue", {})
- assert queue.get("enabled") is True, \
- f"queue should be enabled: {queue}"
-
- logger.info("Successfully created job with optional params and verified they were persisted")
-
- def test_get_job(self, clean_job: str, test_notebook_path: str, cleanup_jobs):
- """Should get job details and verify structure."""
- # First create a job
- job_name = f"{TEST_RESOURCE_PREFIX}job_get_{uuid.uuid4().hex[:6]}"
- create_result = manage_jobs(
- action="create",
- name=job_name,
- tasks=[
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- },
- "new_cluster": {
- "spark_version": "14.3.x-scala2.12",
- "num_workers": 0,
- "node_type_id": "i3.xlarge",
- },
- }
- ],
- )
-
- job_id = create_result.get("job_id")
- assert job_id, f"Job not created: {create_result}"
- cleanup_jobs(str(job_id))
-
- # Get job details
- result = manage_jobs(action="get", job_id=str(job_id))
-
- logger.info(f"Get result keys: {result.keys() if isinstance(result, dict) else result}")
-
- assert "error" not in result, f"Get failed: {result}"
-
- # Verify expected fields are present
- assert "job_id" in result or "settings" in result, f"Missing expected fields: {result}"
-
- def test_delete_job(self, test_notebook_path: str, cleanup_jobs):
- """Should delete a job and verify it's gone."""
- # Create a job to delete
- job_name = f"{TEST_RESOURCE_PREFIX}job_delete_{uuid.uuid4().hex[:6]}"
- create_result = manage_jobs(
- action="create",
- name=job_name,
- tasks=[
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- },
- "new_cluster": {
- "spark_version": "14.3.x-scala2.12",
- "num_workers": 0,
- "node_type_id": "i3.xlarge",
- },
- }
- ],
- )
-
- job_id = create_result.get("job_id")
- assert job_id, f"Job not created: {create_result}"
-
- # Verify job exists before delete
- get_before = manage_jobs(action="get", job_id=str(job_id))
- assert "error" not in get_before, f"Job should exist before delete: {get_before}"
-
- # Delete the job
- result = manage_jobs(action="delete", job_id=str(job_id))
-
- logger.info(f"Delete result: {result}")
-
- assert result.get("status") == "deleted" or "error" not in result
-
- # Verify job is gone - the get action raises an exception for deleted jobs
- try:
- get_after = manage_jobs(action="get", job_id=str(job_id))
- # If we get here without exception, check for error in response
- assert "error" in get_after or "not found" in str(get_after).lower(), \
- f"Job should be deleted: {get_after}"
- except Exception as e:
- # Exception is expected - job doesn't exist
- assert "does not exist" in str(e).lower() or "not found" in str(e).lower(), \
- f"Expected 'does not exist' error, got: {e}"
- logger.info(f"Confirmed job was deleted - get raised expected error: {e}")
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- try:
- result = manage_jobs(action="invalid_action")
- assert "error" in result
- except ValueError as e:
- # Function raises ValueError for invalid action - this is acceptable
- assert "invalid" in str(e).lower()
-
-
-@pytest.mark.integration
-class TestManageJobRuns:
- """Tests for manage_job_runs tool."""
-
- def test_list_runs(self):
- """Should list job runs."""
- result = manage_job_runs(action="list")
-
- logger.info(f"List runs result keys: {result.keys() if isinstance(result, dict) else result}")
-
- assert not result.get("error"), f"List runs failed: {result}"
-
- def test_get_run_nonexistent(self):
- """Should handle nonexistent run gracefully."""
- try:
- result = manage_job_runs(action="get", run_id="999999999999")
- # Should return error or not found
- logger.info(f"Get nonexistent run result: {result}")
- except Exception as e:
- # SDK raises exception for nonexistent run - this is acceptable
- logger.info(f"Expected error for nonexistent run: {e}")
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- try:
- result = manage_job_runs(action="invalid_action")
- assert "error" in result
- except ValueError as e:
- # Function raises ValueError for invalid action - this is acceptable
- assert "invalid" in str(e).lower()
-
-
-@pytest.mark.integration
-class TestJobExecution:
- """Tests for actual job execution (slow)."""
-
- def test_run_job_and_verify_completion(self, test_notebook_path: str, cleanup_jobs):
- """Should run a job and verify it completes successfully."""
- # Create a job
- job_name = f"{TEST_RESOURCE_PREFIX}job_run_{uuid.uuid4().hex[:6]}"
- create_result = manage_jobs(
- action="create",
- name=job_name,
- tasks=[
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- },
- "new_cluster": {
- "spark_version": "14.3.x-scala2.12",
- "num_workers": 0,
- "node_type_id": "i3.xlarge",
- },
- }
- ],
- )
-
- job_id = create_result.get("job_id")
- assert job_id, f"Job not created: {create_result}"
- cleanup_jobs(str(job_id))
-
- # Run the job
- run_result = manage_job_runs(
- action="run_now",
- job_id=str(job_id),
- )
-
- logger.info(f"Run result: {run_result}")
-
- assert "error" not in run_result, f"Run failed: {run_result}"
- run_id = run_result.get("run_id")
- assert run_id, f"Run ID should be returned: {run_result}"
-
- # Wait for job to complete (with timeout)
- max_wait = 600 # 10 minutes
- wait_interval = 15
- waited = 0
- final_state = None
-
- while waited < max_wait:
- status_result = manage_job_runs(action="get", run_id=str(run_id))
-
- logger.info(f"Run status after {waited}s: {status_result}")
-
- state = status_result.get("state", {})
- life_cycle_state = state.get("life_cycle_state")
- result_state = state.get("result_state")
-
- if life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR"):
- final_state = result_state
- break
-
- time.sleep(wait_interval)
- waited += wait_interval
-
- assert final_state is not None, f"Job did not complete within {max_wait}s"
- assert final_state == "SUCCESS", f"Job should succeed, got: {final_state}"
-
-
-@pytest.mark.integration
-class TestJobUpdate:
- """Tests for job update functionality."""
-
- def test_update_job(
- self,
- test_notebook_path: str,
- clean_jobs,
- cleanup_jobs,
- ):
- """Should create a job, update its configuration, and verify changes."""
- # Create initial job
- create_result = manage_jobs(
- action="create",
- name=JOB_UPDATE,
- tasks=[
- {
- "task_key": "initial_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- },
- "new_cluster": {
- "spark_version": "14.3.x-scala2.12",
- "num_workers": 0,
- "node_type_id": "i3.xlarge",
- },
- }
- ],
- )
-
- logger.info(f"Create for update test: {create_result}")
-
- assert "error" not in create_result, f"Create failed: {create_result}"
-
- job_id = create_result.get("job_id")
- assert job_id, f"Job ID should be returned: {create_result}"
-
- cleanup_jobs(str(job_id))
-
- # Verify initial configuration
- get_before = manage_jobs(action="get", job_id=str(job_id))
- initial_tasks = get_before.get("settings", {}).get("tasks", [])
- assert len(initial_tasks) == 1, f"Should have 1 task initially: {initial_tasks}"
- assert initial_tasks[0].get("task_key") == "initial_task"
-
- # Update the job with a new task key
- update_result = manage_jobs(
- action="update",
- job_id=str(job_id),
- name=JOB_UPDATE,
- tasks=[
- {
- "task_key": "updated_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- },
- "new_cluster": {
- "spark_version": "14.3.x-scala2.12",
- "num_workers": 0,
- "node_type_id": "i3.xlarge",
- },
- }
- ],
- )
-
- logger.info(f"Update result: {update_result}")
-
- assert "error" not in update_result, f"Update failed: {update_result}"
-
- # Verify the update
- get_after = manage_jobs(action="get", job_id=str(job_id))
-
- logger.info(f"Get after update: {get_after}")
-
- assert "error" not in get_after, f"Get after update failed: {get_after}"
-
- # Verify task was added (Jobs API update does partial updates, adding tasks)
- updated_tasks = get_after.get("settings", {}).get("tasks", [])
- task_keys = [t.get("task_key") for t in updated_tasks]
- # Partial update adds the new task to existing tasks
- assert "updated_task" in task_keys, \
- f"updated_task should be in task list: {task_keys}"
- assert len(updated_tasks) >= 1, f"Should have at least 1 task after update: {updated_tasks}"
-
-
-@pytest.mark.integration
-class TestJobCancel:
- """Tests for job cancellation."""
-
- def test_cancel_run(
- self,
- test_notebook_path: str,
- clean_jobs,
- cleanup_jobs,
- ):
- """Should start a job run and cancel it, verifying the cancellation."""
- # Create a job that takes a while (sleep in notebook)
- long_running_notebook = """# Databricks notebook source
-# MAGIC %md
-# MAGIC # Long Running Notebook for Cancel Test
-
-# COMMAND ----------
-
-import time
-print("Starting long running task...")
-time.sleep(300) # Sleep for 5 minutes - should be cancelled before this completes
-print("This should not print if cancelled")
-
-# COMMAND ----------
-
-dbutils.notebook.exit("COMPLETED")
-"""
- import tempfile
- import os
- from pathlib import Path
-
- # Create temp directory with properly named notebook file
- temp_dir = tempfile.mkdtemp()
- temp_notebook_file = os.path.join(temp_dir, "long_notebook.py")
- with open(temp_notebook_file, "w") as f:
- f.write(long_running_notebook)
-
- try:
- # Upload long-running notebook
- from databricks_mcp_server.tools.file import manage_workspace_files
-
- # Get user from test_notebook_path
- user = test_notebook_path.split('/Users/')[1].split('/')[0]
- # Use the same resources folder that the fixture already created
- # This ensures the parent directory exists
- notebook_path = f"/Workspace/Users/{user}/ai_dev_kit_test/jobs/resources/long_notebook"
-
- # Upload the notebook file directly to the full path
- upload_result = manage_workspace_files(
- action="upload",
- local_path=temp_notebook_file,
- workspace_path=notebook_path,
- overwrite=True,
- )
- logger.info(f"Upload result for cancel test notebook: {upload_result}")
- assert upload_result.get("success", False) or upload_result.get("status") == "success", \
- f"Failed to upload cancel test notebook: {upload_result}"
-
- # Create job
- create_result = manage_jobs(
- action="create",
- name=JOB_CANCEL,
- tasks=[
- {
- "task_key": "long_task",
- "notebook_task": {
- "notebook_path": notebook_path,
- },
- "new_cluster": {
- "spark_version": "14.3.x-scala2.12",
- "num_workers": 0,
- "node_type_id": "i3.xlarge",
- },
- }
- ],
- )
-
- logger.info(f"Create for cancel test: {create_result}")
-
- assert "error" not in create_result, f"Create failed: {create_result}"
-
- job_id = create_result.get("job_id")
- assert job_id, f"Job ID should be returned: {create_result}"
-
- cleanup_jobs(str(job_id))
-
- # Start the job
- run_result = manage_job_runs(
- action="run_now",
- job_id=str(job_id),
- )
-
- logger.info(f"Run result: {run_result}")
-
- assert "error" not in run_result, f"Run failed: {run_result}"
-
- run_id = run_result.get("run_id")
- assert run_id, f"Run ID should be returned: {run_result}"
-
- # Wait a bit for the job to start
- time.sleep(10)
-
- # Verify the job is running
- status_before = manage_job_runs(action="get", run_id=str(run_id))
- life_cycle_state = status_before.get("state", {}).get("life_cycle_state")
- logger.info(f"State before cancel: {life_cycle_state}")
-
- # Cancel the run
- cancel_result = manage_job_runs(
- action="cancel",
- run_id=str(run_id),
- )
-
- logger.info(f"Cancel result: {cancel_result}")
-
- assert "error" not in cancel_result, f"Cancel failed: {cancel_result}"
-
- # Wait for cancellation to take effect
- max_wait = 60
- waited = 0
- cancelled = False
-
- while waited < max_wait:
- status_after = manage_job_runs(action="get", run_id=str(run_id))
- state = status_after.get("state", {})
- life_cycle_state = state.get("life_cycle_state")
- result_state = state.get("result_state")
-
- logger.info(f"State after cancel ({waited}s): lifecycle={life_cycle_state}, result={result_state}")
-
- if life_cycle_state == "TERMINATED":
- # Check if it was cancelled
- if result_state == "CANCELED":
- cancelled = True
- break
- # If terminated but not cancelled, check other states
- break
-
- time.sleep(5)
- waited += 5
-
- assert cancelled, f"Job run should be cancelled, got state: {status_after.get('state')}"
-
- finally:
- # Cleanup temp directory
- import shutil
- shutil.rmtree(temp_dir, ignore_errors=True)
diff --git a/databricks-mcp-server/tests/integration/lakebase/__init__.py b/databricks-mcp-server/tests/integration/lakebase/__init__.py
deleted file mode 100644
index a6e0a476..00000000
--- a/databricks-mcp-server/tests/integration/lakebase/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Lakebase integration tests
diff --git a/databricks-mcp-server/tests/integration/lakebase/test_lakebase.py b/databricks-mcp-server/tests/integration/lakebase/test_lakebase.py
deleted file mode 100644
index 9bfb1165..00000000
--- a/databricks-mcp-server/tests/integration/lakebase/test_lakebase.py
+++ /dev/null
@@ -1,287 +0,0 @@
-"""
-Integration tests for Lakebase MCP tools.
-
-Tests:
-- manage_lakebase_database: create_or_update, get, list, delete
-- manage_lakebase_branch: create_or_update, delete
-- manage_lakebase_sync: (requires existing provisioned instance)
-- generate_lakebase_credential: (read-only)
-"""
-
-import logging
-import time
-import uuid
-
-import pytest
-
-from databricks_mcp_server.tools.lakebase import (
- manage_lakebase_database,
- manage_lakebase_branch,
- manage_lakebase_sync,
- generate_lakebase_credential,
-)
-from tests.test_config import TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestManageLakebaseDatabase:
- """Tests for manage_lakebase_database tool."""
-
- @pytest.mark.slow
- def test_list_all_databases(self):
- """Should list all Lakebase databases (slow due to pagination)."""
- result = manage_lakebase_database(action="list")
-
- logger.info(f"List all result: found {len(result.get('databases', []))} databases")
-
- assert "error" not in result, f"List failed: {result}"
- assert "databases" in result
- assert isinstance(result["databases"], list)
-
- def test_get_nonexistent_database(self):
- """Should handle nonexistent database gracefully."""
- result = manage_lakebase_database(
- action="get",
- name="nonexistent_db_xyz_12345",
- )
-
- logger.info(f"Get nonexistent result: {result}")
-
- assert "error" in result
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_lakebase_database(action="invalid_action")
-
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestManageLakebaseBranch:
- """Tests for manage_lakebase_branch tool (autoscale only)."""
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_lakebase_branch(action="invalid_action")
-
- assert "error" in result
-
- def test_create_requires_params(self):
- """Should require project_name and branch_id for create."""
- result = manage_lakebase_branch(action="create_or_update")
-
- assert "error" in result
- assert "project_name" in result["error"] or "branch_id" in result["error"]
-
- def test_delete_requires_name(self):
- """Should require name for delete."""
- result = manage_lakebase_branch(action="delete")
-
- assert "error" in result
- assert "name" in result["error"]
-
-
-@pytest.mark.integration
-class TestManageLakebaseSync:
- """Tests for manage_lakebase_sync tool."""
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_lakebase_sync(action="invalid_action")
-
- assert "error" in result
-
- def test_create_requires_params(self):
- """Should require instance_name, source_table_name, target_table_name for create."""
- result = manage_lakebase_sync(action="create_or_update")
-
- assert "error" in result
- assert "instance_name" in result["error"] or "source_table_name" in result["error"]
-
- def test_delete_requires_table_name(self):
- """Should require table_name for delete."""
- result = manage_lakebase_sync(action="delete")
-
- assert "error" in result
- assert "table_name" in result["error"]
-
-
-@pytest.mark.integration
-class TestGenerateLakebaseCredential:
- """Tests for generate_lakebase_credential tool."""
-
- def test_requires_instance_or_endpoint(self):
- """Should require either instance_names or endpoint."""
- result = generate_lakebase_credential()
-
- logger.info(f"Generate credential without params result: {result}")
-
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestAutoscaleLifecycle:
- """End-to-end test for autoscale project lifecycle: create -> branch -> delete."""
-
- def test_full_autoscale_lifecycle(self, cleanup_lakebase_instances):
- """Test complete autoscale lifecycle: create project, add branch, delete branch, delete project."""
- test_start = time.time()
- # Unique name for this test run
- # Note: Lakebase project_id must match pattern ^[a-z]([a-z0-9-]{0,61}[a-z0-9])?$
- # So we use hyphens instead of underscores and lowercase only
- project_name = f"ai-dev-kit-test-lakebase-{uuid.uuid4().hex[:6]}"
- branch_id = "ai-dev-kit-test-branch"
- branch_name = None
-
- def log_time(msg):
- elapsed = time.time() - test_start
- logger.info(f"[{elapsed:.1f}s] {msg}")
-
- try:
- # Step 0: Pre-cleanup - delete any existing project with this name (from crashed previous runs)
- log_time(f"Step 0: Pre-cleanup - checking for existing project '{project_name}'...")
- existing = manage_lakebase_database(
- action="get",
- name=project_name,
- type="autoscale",
- )
- if "error" not in existing:
- log_time(f"Found existing project, deleting...")
- manage_lakebase_database(
- action="delete",
- name=project_name,
- type="autoscale",
- )
- # Wait for deletion to propagate
- time.sleep(15)
- log_time("Existing project deleted")
-
- # Step 1: Create autoscale project
- log_time(f"Step 1: Creating autoscale project '{project_name}'...")
- create_result = manage_lakebase_database(
- action="create_or_update",
- name=project_name,
- type="autoscale",
- display_name=f"Test Autoscale Project",
- )
- log_time(f"Create result: {create_result}")
- assert "error" not in create_result, f"Create project failed: {create_result}"
-
- cleanup_lakebase_instances(project_name, "autoscale")
-
- # Step 2: Wait for project to be ready
- log_time("Step 2: Waiting for project to be ready...")
- max_wait = 120
- wait_interval = 10
- waited = 0
- project_ready = False
-
- while waited < max_wait:
- get_result = manage_lakebase_database(
- action="get",
- name=project_name,
- type="autoscale",
- )
- state = get_result.get("state") or get_result.get("status")
- log_time(f"Project state after {waited}s: {state}")
-
- if state in ("ACTIVE", "READY", "RUNNING"):
- project_ready = True
- break
- elif state in ("FAILED", "ERROR", "DELETED"):
- pytest.fail(f"Project creation failed: {get_result}")
-
- time.sleep(wait_interval)
- waited += wait_interval
-
- if not project_ready:
- log_time(f"Project not fully ready after {max_wait}s, continuing anyway")
-
- # Step 3: Create a branch
- log_time("Step 3: Creating branch...")
- branch_result = manage_lakebase_branch(
- action="create_or_update",
- project_name=project_name,
- branch_id=branch_id,
- ttl_seconds=3600,
- )
- log_time(f"Create branch result: {branch_result}")
- assert "error" not in branch_result, f"Create branch failed: {branch_result}"
- branch_name = branch_result.get("name", f"{project_name}/branches/{branch_id}")
-
- # Step 4: Wait for branch and verify it exists
- log_time("Step 4: Verifying branch exists...")
- time.sleep(10)
- get_project = manage_lakebase_database(
- action="get",
- name=project_name,
- type="autoscale",
- )
- branches = get_project.get("branches", [])
- branch_names = [b.get("name", "") for b in branches]
- log_time(f"Branches found: {branch_names}")
- assert any(branch_id in name for name in branch_names), \
- f"Branch should exist: {branch_names}"
-
- # Step 5: Delete the branch
- log_time("Step 5: Deleting branch...")
- delete_branch_result = manage_lakebase_branch(
- action="delete",
- name=branch_name,
- )
- log_time(f"Delete branch result: {delete_branch_result}")
- assert "error" not in delete_branch_result, f"Delete branch failed: {delete_branch_result}"
-
- # Step 6: Verify branch is gone
- log_time("Step 6: Verifying branch deleted...")
- time.sleep(10)
- get_after_branch_delete = manage_lakebase_database(
- action="get",
- name=project_name,
- type="autoscale",
- )
- branches_after = get_after_branch_delete.get("branches", [])
- branch_names_after = [b.get("name", "") for b in branches_after]
- log_time(f"Branches after delete: {branch_names_after}")
- assert not any(branch_id in name for name in branch_names_after), \
- f"Branch should be deleted: {branch_names_after}"
- branch_name = None # Mark as deleted
-
- # Step 7: Delete the project
- log_time("Step 7: Deleting project...")
- delete_result = manage_lakebase_database(
- action="delete",
- name=project_name,
- type="autoscale",
- )
- log_time(f"Delete project result: {delete_result}")
- assert "error" not in delete_result, f"Delete project failed: {delete_result}"
-
- # Step 8: Verify project is gone
- log_time("Step 8: Verifying project deleted...")
- time.sleep(10)
- get_after = manage_lakebase_database(
- action="get",
- name=project_name,
- type="autoscale",
- )
- log_time(f"Get after delete: {get_after}")
- assert "error" in get_after or "not found" in str(get_after).lower(), \
- f"Project should be deleted: {get_after}"
-
- log_time("Full autoscale lifecycle test PASSED!")
-
- except Exception as e:
- log_time(f"Test failed: {e}")
- raise
- finally:
- # Cleanup on failure
- if branch_name:
- log_time(f"Cleanup: deleting branch {branch_name}")
- try:
- manage_lakebase_branch(action="delete", name=branch_name)
- except Exception:
- pass
- # Project cleanup is handled by cleanup_lakebase_instances fixture
diff --git a/databricks-mcp-server/tests/integration/pdf/__init__.py b/databricks-mcp-server/tests/integration/pdf/__init__.py
deleted file mode 100644
index cf10fd9c..00000000
--- a/databricks-mcp-server/tests/integration/pdf/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# PDF integration tests
diff --git a/databricks-mcp-server/tests/integration/pdf/test_pdf.py b/databricks-mcp-server/tests/integration/pdf/test_pdf.py
deleted file mode 100644
index 46778a26..00000000
--- a/databricks-mcp-server/tests/integration/pdf/test_pdf.py
+++ /dev/null
@@ -1,241 +0,0 @@
-"""
-Integration tests for PDF MCP tool.
-
-Tests:
-- generate_and_upload_pdf: create PDF from HTML and upload to volume
-"""
-
-import logging
-
-import pytest
-from databricks.sdk.service.catalog import VolumeType
-
-from databricks_mcp_server.tools.pdf import generate_and_upload_pdf
-from databricks_mcp_server.tools.volume_files import manage_volume_files
-from tests.test_config import TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-# Deterministic names for tests
-PDF_FILENAME = f"{TEST_RESOURCE_PREFIX}test.pdf"
-PDF_VOLUME_NAME = f"{TEST_RESOURCE_PREFIX}pdf_vol"
-
-
-@pytest.fixture(scope="module")
-def pdf_volume(workspace_client, test_catalog: str, pdf_schema: str):
- """Create a volume for PDF tests."""
- volume_name = PDF_VOLUME_NAME
-
- # Create volume if not exists
- try:
- workspace_client.volumes.create(
- catalog_name=test_catalog,
- schema_name=pdf_schema,
- name=volume_name,
- volume_type=VolumeType.MANAGED,
- )
- logger.info(f"Created volume: {volume_name}")
- except Exception as e:
- if "already exists" in str(e).lower():
- logger.info(f"Volume already exists: {volume_name}")
- else:
- raise
-
- yield volume_name
-
- # Cleanup
- try:
- workspace_client.volumes.delete(f"{test_catalog}.{pdf_schema}.{volume_name}")
- logger.info(f"Cleaned up volume: {volume_name}")
- except Exception as e:
- logger.warning(f"Failed to cleanup volume: {e}")
-
-
-@pytest.mark.integration
-class TestGenerateAndUploadPdf:
- """Tests for generate_and_upload_pdf tool."""
-
- def test_generate_simple_pdf(
- self,
- test_catalog: str,
- pdf_schema: str,
- pdf_volume: str,
- ):
- """Should generate a PDF from HTML and upload to volume."""
- html_content = """
-
-
-
-
-
- Test PDF Document
- This is a test paragraph with bold and italic text.
- Generated by MCP integration tests.
- Features Tested
-
- HTML to PDF conversion
- CSS styling support
- Upload to Unity Catalog volume
-
-
-
- """
-
- result = generate_and_upload_pdf(
- html_content=html_content,
- filename=PDF_FILENAME,
- catalog=test_catalog,
- schema=pdf_schema,
- volume=pdf_volume,
- folder="test_pdfs",
- )
-
- logger.info(f"Generate PDF result: {result}")
-
- assert result.get("success"), f"PDF generation failed: {result}"
- assert result.get("volume_path"), "Should return volume path"
- assert PDF_FILENAME in result["volume_path"]
-
- # Verify file exists by listing the folder
- volume_path = f"/Volumes/{test_catalog}/{pdf_schema}/{pdf_volume}/test_pdfs"
- list_result = manage_volume_files(
- action="list",
- volume_path=volume_path,
- )
-
- logger.info(f"List volume result: {list_result}")
-
- # Check that our PDF is in the list
- files = list_result.get("files", [])
- file_names = [f.get("name", "") for f in files]
- assert PDF_FILENAME in file_names, f"PDF should exist in volume: {file_names}"
-
- def test_generate_pdf_with_complex_html(
- self,
- test_catalog: str,
- pdf_schema: str,
- pdf_volume: str,
- ):
- """Should handle complex HTML with tables and images."""
- html_content = """
-
-
-
-
-
- Data Report
-
-
- ID
- Name
- Value
- Status
-
-
- 1
- Item A
- $100.00
- Active
-
-
- 2
- Item B
- $250.50
- Pending
-
-
- 3
- Item C
- $75.25
- Complete
-
-
- Report generated automatically by MCP tests.
-
-
- """
-
- result = generate_and_upload_pdf(
- html_content=html_content,
- filename="complex_report.pdf",
- catalog=test_catalog,
- schema=pdf_schema,
- volume=pdf_volume,
- folder="reports",
- )
-
- logger.info(f"Generate complex PDF result: {result}")
-
- assert result.get("success"), f"PDF generation failed: {result}"
- assert result.get("volume_path"), "Should return volume path"
- assert "reports" in result["volume_path"]
-
- def test_generate_pdf_minimal_html(
- self,
- test_catalog: str,
- pdf_schema: str,
- pdf_volume: str,
- ):
- """Should handle minimal HTML content."""
- result = generate_and_upload_pdf(
- html_content="Minimal content
",
- filename="minimal.pdf",
- catalog=test_catalog,
- schema=pdf_schema,
- volume=pdf_volume,
- )
-
- logger.info(f"Generate minimal PDF result: {result}")
-
- assert result.get("success"), f"PDF generation failed: {result}"
-
- def test_generate_pdf_nested_folder(
- self,
- test_catalog: str,
- pdf_schema: str,
- pdf_volume: str,
- ):
- """Should create PDF in nested folder path."""
- result = generate_and_upload_pdf(
- html_content="Nested folder test",
- filename="nested.pdf",
- catalog=test_catalog,
- schema=pdf_schema,
- volume=pdf_volume,
- folder="level1/level2/level3",
- )
-
- logger.info(f"Generate nested PDF result: {result}")
-
- assert result.get("success"), f"PDF generation failed: {result}"
- assert "level1/level2/level3" in result.get("volume_path", "")
-
- def test_generate_pdf_special_characters_filename(
- self,
- test_catalog: str,
- pdf_schema: str,
- pdf_volume: str,
- ):
- """Should handle special characters in filename (within limits)."""
- # Use underscores and hyphens which are safe
- result = generate_and_upload_pdf(
- html_content="Special chars test",
- filename="report_2024-01-15_final.pdf",
- catalog=test_catalog,
- schema=pdf_schema,
- volume=pdf_volume,
- )
-
- logger.info(f"Generate special chars PDF result: {result}")
-
- assert result.get("success"), f"PDF generation failed: {result}"
diff --git a/databricks-mcp-server/tests/integration/pipelines/__init__.py b/databricks-mcp-server/tests/integration/pipelines/__init__.py
deleted file mode 100644
index 73d0373c..00000000
--- a/databricks-mcp-server/tests/integration/pipelines/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Pipeline integration tests
diff --git a/databricks-mcp-server/tests/integration/pipelines/resources/simple_bronze.sql b/databricks-mcp-server/tests/integration/pipelines/resources/simple_bronze.sql
deleted file mode 100644
index 2af4f58c..00000000
--- a/databricks-mcp-server/tests/integration/pipelines/resources/simple_bronze.sql
+++ /dev/null
@@ -1,9 +0,0 @@
--- Simple bronze table for testing
--- Reads from samples.nyctaxi.trips (available in all workspaces)
-CREATE OR REFRESH STREAMING TABLE bronze_test
-AS SELECT
- tpep_pickup_datetime,
- tpep_dropoff_datetime,
- trip_distance,
- fare_amount
-FROM STREAM samples.nyctaxi.trips
diff --git a/databricks-mcp-server/tests/integration/pipelines/resources/simple_silver.sql b/databricks-mcp-server/tests/integration/pipelines/resources/simple_silver.sql
deleted file mode 100644
index 24c8ec47..00000000
--- a/databricks-mcp-server/tests/integration/pipelines/resources/simple_silver.sql
+++ /dev/null
@@ -1,10 +0,0 @@
--- Simple silver table for testing
--- Reads from bronze and adds transformation
-CREATE OR REFRESH MATERIALIZED VIEW silver_test
-AS SELECT
- tpep_pickup_datetime,
- tpep_dropoff_datetime,
- trip_distance,
- fare_amount,
- ROUND(fare_amount / NULLIF(trip_distance, 0), 2) AS fare_per_mile
-FROM bronze_test
diff --git a/databricks-mcp-server/tests/integration/pipelines/test_pipelines.py b/databricks-mcp-server/tests/integration/pipelines/test_pipelines.py
deleted file mode 100644
index 982b9c98..00000000
--- a/databricks-mcp-server/tests/integration/pipelines/test_pipelines.py
+++ /dev/null
@@ -1,200 +0,0 @@
-"""
-Integration tests for manage_pipeline and manage_pipeline_run MCP tools.
-
-Tests:
-- manage_pipeline: create_or_update, get, delete
-- manage_pipeline_run: start, get_events, stop
-"""
-
-import logging
-import time
-import uuid
-from pathlib import Path
-
-import pytest
-
-from databricks_mcp_server.tools.pipelines import manage_pipeline, manage_pipeline_run
-from databricks_mcp_server.tools.file import manage_workspace_files
-from tests.test_config import TEST_CATALOG, TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-# Path to test pipeline SQL files
-RESOURCES_DIR = Path(__file__).parent / "resources"
-
-
-@pytest.mark.integration
-class TestPipelineLifecycle:
- """End-to-end test for pipeline lifecycle: upload SQL -> create -> start -> stop -> delete."""
-
- def test_full_pipeline_lifecycle(
- self,
- workspace_client,
- current_user: str,
- test_catalog: str,
- pipelines_schema: str,
- cleanup_pipelines,
- ):
- """Test complete pipeline lifecycle in a single test."""
- test_start = time.time()
- unique_id = uuid.uuid4().hex[:6]
- pipeline_name = f"{TEST_RESOURCE_PREFIX}pipeline_{unique_id}"
- workspace_path = f"/Workspace/Users/{current_user}/ai_dev_kit_test/pipelines/{pipeline_name}"
- pipeline_id = None
-
- def log_time(msg):
- elapsed = time.time() - test_start
- logger.info(f"[{elapsed:.1f}s] {msg}")
-
- try:
- # Step 1: Upload SQL files to workspace
- # Use trailing slash to upload folder contents directly (not the folder itself)
- log_time(f"Step 1: Uploading SQL files to {workspace_path}...")
- upload_result = manage_workspace_files(
- action="upload",
- local_path=str(RESOURCES_DIR) + "/", # Trailing slash = upload contents only
- workspace_path=workspace_path,
- overwrite=True,
- )
- log_time(f"Upload result: {upload_result}")
- assert upload_result.get("success", False) or upload_result.get("status") == "success", \
- f"Failed to upload pipeline files: {upload_result}"
-
- # Step 2: Create pipeline with create_or_update
- log_time(f"Step 2: Creating pipeline '{pipeline_name}'...")
- bronze_path = f"{workspace_path}/simple_bronze.sql"
-
- create_result = manage_pipeline(
- action="create_or_update",
- name=pipeline_name,
- root_path=workspace_path,
- catalog=test_catalog,
- schema=pipelines_schema,
- workspace_file_paths=[bronze_path],
- start_run=False,
- )
- log_time(f"Create result: {create_result}")
- assert "error" not in create_result, f"Create failed: {create_result}"
- assert create_result.get("pipeline_id"), "Should return pipeline_id"
- assert create_result.get("created") is True, "Should be a new pipeline"
-
- pipeline_id = create_result["pipeline_id"]
- cleanup_pipelines(pipeline_id)
- log_time(f"Pipeline created with ID: {pipeline_id}")
-
- # Step 3: Get pipeline details
- log_time("Step 3: Getting pipeline details...")
- get_result = manage_pipeline(action="get", pipeline_id=pipeline_id)
- log_time(f"Get result: {get_result}")
- assert "error" not in get_result, f"Get failed: {get_result}"
- assert get_result.get("name") == pipeline_name
-
- # Verify catalog and schema
- spec = get_result.get("spec", {})
- assert spec.get("catalog") == test_catalog, f"Catalog mismatch: {spec.get('catalog')}"
-
- # Step 4: Update pipeline (add another SQL file)
- log_time("Step 4: Updating pipeline with additional file...")
- silver_path = f"{workspace_path}/simple_silver.sql"
- update_result = manage_pipeline(
- action="create_or_update",
- name=pipeline_name,
- root_path=workspace_path,
- catalog=test_catalog,
- schema=pipelines_schema,
- workspace_file_paths=[bronze_path, silver_path],
- start_run=False,
- )
- log_time(f"Update result: {update_result}")
- assert "error" not in update_result, f"Update failed: {update_result}"
- assert update_result.get("created") is False, "Should update existing pipeline"
-
- # Step 5: Start a run and wait for completion
- log_time("Step 5: Starting pipeline run (wait for completion)...")
- start_result = manage_pipeline_run(
- action="start",
- pipeline_id=pipeline_id,
- full_refresh=True,
- wait=True, # Wait for completion to verify it works
- timeout=600, # 10 minutes timeout for serverless pipelines
- )
- log_time(f"Start result: {start_result}")
- assert "error" not in start_result, f"Start failed: {start_result}"
- update_id = start_result.get("update_id")
- log_time(f"Pipeline run completed with update_id: {update_id}")
-
- # Verify the run completed successfully
- state = start_result.get("state") or start_result.get("status")
- assert state in ("COMPLETED", "IDLE", "SUCCESS", None) or "error" not in str(start_result).lower(), \
- f"Pipeline run should complete successfully, got state: {state}, result: {start_result}"
-
- # Step 6: Get events
- log_time("Step 6: Getting pipeline events...")
- events_result = manage_pipeline_run(
- action="get_events",
- pipeline_id=pipeline_id,
- max_results=10,
- )
- log_time(f"Events result: {len(events_result.get('events', []))} events")
- assert "error" not in events_result, f"Get events failed: {events_result}"
- assert "events" in events_result
-
- # Step 7: Stop the pipeline (cleanup)
- log_time("Step 7: Stopping pipeline...")
- stop_result = manage_pipeline_run(action="stop", pipeline_id=pipeline_id)
- log_time(f"Stop result: {stop_result}")
- assert stop_result.get("status") == "stopped", f"Stop failed: {stop_result}"
-
- # Step 8: Delete the pipeline
- log_time("Step 8: Deleting pipeline...")
- delete_result = manage_pipeline(action="delete", pipeline_id=pipeline_id)
- log_time(f"Delete result: {delete_result}")
- assert delete_result.get("status") == "deleted", f"Delete failed: {delete_result}"
- pipeline_id = None # Mark as deleted
-
- log_time("Full pipeline lifecycle test PASSED!")
-
- except Exception as e:
- log_time(f"Test failed: {e}")
- raise
- finally:
- # Cleanup on failure
- if pipeline_id:
- log_time(f"Cleanup: deleting pipeline {pipeline_id}")
- try:
- manage_pipeline_run(action="stop", pipeline_id=pipeline_id)
- except Exception:
- pass
- try:
- manage_pipeline(action="delete", pipeline_id=pipeline_id)
- except Exception:
- pass
-
- # Cleanup workspace files
- try:
- workspace_client.workspace.delete(workspace_path, recursive=True)
- log_time(f"Cleaned up workspace path: {workspace_path}")
- except Exception as e:
- log_time(f"Failed to cleanup workspace: {e}")
-
-
-@pytest.mark.integration
-class TestPipelineErrors:
- """Fast validation tests for error handling."""
-
- def test_create_missing_params(self):
- """Should return error for missing required params."""
- result = manage_pipeline(action="create", name="test")
- assert "error" in result
- assert "requires" in result["error"].lower()
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_pipeline(action="invalid_action")
- assert "error" in result
- assert "invalid action" in result["error"].lower()
-
- def test_run_invalid_action(self):
- """Should return error for invalid run action."""
- result = manage_pipeline_run(action="invalid_action", pipeline_id="fake-id")
- assert "error" in result
diff --git a/databricks-mcp-server/tests/integration/run_tests.py b/databricks-mcp-server/tests/integration/run_tests.py
deleted file mode 100644
index 51d18d92..00000000
--- a/databricks-mcp-server/tests/integration/run_tests.py
+++ /dev/null
@@ -1,712 +0,0 @@
-#!/usr/bin/env python3
-"""
-Integration Test Runner
-
-Run all integration tests in parallel with detailed reporting.
-
-Usage:
- python tests/integration/run_tests.py # Run all tests (excluding slow)
- python tests/integration/run_tests.py --all # Run all tests including slow
- python tests/integration/run_tests.py --report # Show report from latest run
- python tests/integration/run_tests.py --status # Check status of ongoing/recent runs
-"""
-
-import argparse
-import json
-import os
-import re
-import shutil
-import subprocess
-import sys
-import time
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from dataclasses import dataclass, field
-from datetime import datetime
-from pathlib import Path
-from typing import Optional
-
-
-# ANSI color codes
-class Colors:
- HEADER = "\033[95m"
- BLUE = "\033[94m"
- CYAN = "\033[96m"
- GREEN = "\033[92m"
- YELLOW = "\033[93m"
- RED = "\033[91m"
- BOLD = "\033[1m"
- DIM = "\033[2m"
- RESET = "\033[0m"
-
-
-@dataclass
-class TestResult:
- """Result from a single test folder."""
- folder: str
- passed: int = 0
- failed: int = 0
- skipped: int = 0
- errors: int = 0
- duration: float = 0.0
- log_file: str = ""
- error_details: list = field(default_factory=list)
- status: str = "unknown" # unknown, running, completed
-
- @property
- def total(self) -> int:
- return self.passed + self.failed + self.skipped + self.errors
-
- @property
- def success(self) -> bool:
- return self.failed == 0 and self.errors == 0
-
-
-def format_timestamp(ts_str: str) -> str:
- """Format a timestamp string (YYYYMMDD_HHMMSS or ISO) into human-readable format."""
- try:
- # Try ISO format first
- if "T" in ts_str:
- dt = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
- else:
- # Try YYYYMMDD_HHMMSS format
- dt = datetime.strptime(ts_str, "%Y%m%d_%H%M%S")
- return dt.strftime("%Y-%m-%d %H:%M:%S")
- except ValueError:
- return ts_str
-
-
-def format_duration(seconds: float) -> str:
- """Format duration in human-readable format."""
- if seconds < 60:
- return f"{seconds:.1f}s"
- elif seconds < 3600:
- mins = int(seconds // 60)
- secs = int(seconds % 60)
- return f"{mins}m {secs}s"
- else:
- hours = int(seconds // 3600)
- mins = int((seconds % 3600) // 60)
- return f"{hours}h {mins}m"
-
-
-def get_test_folders() -> list[str]:
- """Get all test folders in the integration directory."""
- integration_dir = Path(__file__).parent
- folders = []
- for item in sorted(integration_dir.iterdir()):
- if item.is_dir() and not item.name.startswith(("__", ".")):
- # Check if it contains test files
- if list(item.glob("test_*.py")):
- folders.append(item.name)
- return folders
-
-
-def get_results_dir() -> Path:
- """Get the results directory path."""
- return Path(__file__).parent / ".test-results"
-
-
-def run_test_folder(
- folder: str,
- output_dir: Path,
- include_slow: bool = False,
-) -> TestResult:
- """Run tests for a single folder and return results."""
- result = TestResult(folder=folder, status="running")
- log_file = output_dir / f"{folder}.txt"
- result.log_file = str(log_file)
-
- # Write initial status
- log_file.write_text(f"[RUNNING] Started at {datetime.now().isoformat()}\n")
-
- # Build pytest command
- test_path = Path(__file__).parent / folder
- cmd = [
- sys.executable, "-m", "pytest",
- str(test_path),
- "-v",
- "-s", # Stream output to see real-time logs
- "--tb=short",
- "-m", "integration" if not include_slow else "integration or slow",
- ]
-
- start_time = time.time()
-
- try:
- proc = subprocess.run(
- cmd,
- capture_output=True,
- text=True,
- timeout=1200, # 20 minute timeout per folder
- )
- output = proc.stdout + proc.stderr
- except subprocess.TimeoutExpired:
- output = f"TIMEOUT: Tests in {folder} exceeded 10 minute limit"
- result.errors = 1
- except Exception as e:
- output = f"ERROR: {e}"
- result.errors = 1
-
- result.duration = time.time() - start_time
- result.status = "completed"
-
- # Save log file
- log_file.write_text(output)
-
- # Parse results from output
- result = parse_pytest_output(output, result)
-
- return result
-
-
-def parse_pytest_output(output: str, result: TestResult) -> TestResult:
- """Parse pytest output to extract test counts and errors."""
- # Look for summary line like "5 passed, 2 failed, 1 skipped in 10.5s"
- summary_pattern = r"(\d+)\s+passed"
- failed_pattern = r"(\d+)\s+failed"
- skipped_pattern = r"(\d+)\s+skipped"
- error_pattern = r"(\d+)\s+error"
-
- if match := re.search(summary_pattern, output):
- result.passed = int(match.group(1))
- if match := re.search(failed_pattern, output):
- result.failed = int(match.group(1))
- if match := re.search(skipped_pattern, output):
- result.skipped = int(match.group(1))
- if match := re.search(error_pattern, output):
- result.errors = int(match.group(1))
-
- # Extract failure details
- if result.failed > 0 or result.errors > 0:
- # Find FAILURES section
- failures_start = output.find("=== FAILURES ===")
- if failures_start == -1:
- failures_start = output.find("FAILED")
-
- if failures_start != -1:
- # Extract test names and short error messages
- failed_tests = re.findall(r"FAILED\s+([\w/:.]+)", output)
- for test in failed_tests[:5]: # Limit to 5 failures
- result.error_details.append(test)
-
- # Also capture assertion errors
- assertions = re.findall(r"AssertionError:\s*(.+?)(?:\n|$)", output)
- for assertion in assertions[:3]:
- result.error_details.append(f" -> {assertion[:100]}")
-
- return result
-
-
-def parse_log_file_status(log_file: Path) -> tuple[str, Optional[TestResult]]:
- """Parse a log file to determine if test is running or completed."""
- if not log_file.exists():
- return "pending", None
-
- content = log_file.read_text()
-
- # Check if still running
- if content.startswith("[RUNNING]"):
- return "running", None
-
- # Check for timeout (test was killed due to exceeding time limit)
- if "TIMEOUT:" in content:
- result = TestResult(folder=log_file.stem, log_file=str(log_file))
- result.errors = 1
- result.status = "timeout"
- result.error_details = ["Test timed out"]
- return "timeout", result
-
- # Parse completed results
- result = TestResult(folder=log_file.stem, log_file=str(log_file))
- result = parse_pytest_output(content, result)
-
- if result.total > 0 or "passed" in content.lower() or "failed" in content.lower():
- result.status = "completed"
- return "completed", result
-
- return "running", None
-
-
-def print_progress(folder: str, status: str, duration: float = 0):
- """Print progress update."""
- if status == "running":
- print(f" {Colors.CYAN}[RUNNING]{Colors.RESET} {folder}...")
- elif status == "done":
- print(f" {Colors.GREEN}[DONE]{Colors.RESET} {folder} ({format_duration(duration)})")
- elif status == "failed":
- print(f" {Colors.RED}[FAILED]{Colors.RESET} {folder} ({format_duration(duration)})")
-
-
-def print_header(text: str):
- """Print a section header."""
- width = 70
- print()
- print(f"{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}")
- print(f"{Colors.BOLD}{Colors.BLUE}{text.center(width)}{Colors.RESET}")
- print(f"{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}")
- print()
-
-
-def print_summary(results: list[TestResult], total_duration: float, output_dir: Path, run_timestamp: str = None):
- """Print a detailed summary of test results."""
- print_header("TEST RESULTS SUMMARY")
-
- # Show run timestamp
- if run_timestamp:
- print(f" {Colors.BOLD}Run Date:{Colors.RESET} {format_timestamp(run_timestamp)}")
- print()
-
- # Calculate totals
- total_passed = sum(r.passed for r in results)
- total_failed = sum(r.failed for r in results)
- total_skipped = sum(r.skipped for r in results)
- total_errors = sum(r.errors for r in results)
- total_tests = total_passed + total_failed + total_skipped + total_errors
-
- # Overall status
- all_passed = total_failed == 0 and total_errors == 0
- status_color = Colors.GREEN if all_passed else Colors.RED
- status_text = "ALL TESTS PASSED" if all_passed else "SOME TESTS FAILED"
-
- print(f" {status_color}{Colors.BOLD}{status_text}{Colors.RESET}")
- print()
-
- # Summary stats
- print(f" {Colors.BOLD}Overall Statistics:{Colors.RESET}")
- print(f" Total Tests: {total_tests}")
- print(f" {Colors.GREEN}Passed:{Colors.RESET} {total_passed}")
- if total_failed > 0:
- print(f" {Colors.RED}Failed:{Colors.RESET} {total_failed}")
- if total_errors > 0:
- print(f" {Colors.RED}Errors:{Colors.RESET} {total_errors}")
- if total_skipped > 0:
- print(f" {Colors.YELLOW}Skipped:{Colors.RESET} {total_skipped}")
- print(f" Duration: {format_duration(total_duration)}")
- print()
-
- # Per-folder breakdown
- print(f" {Colors.BOLD}Results by Folder:{Colors.RESET}")
- print()
-
- # Header
- print(f" {'Folder':<20} {'Status':<10} {'Passed':<8} {'Failed':<8} {'Skip':<8} {'Time':<10}")
- print(f" {'-' * 20} {'-' * 10} {'-' * 8} {'-' * 8} {'-' * 8} {'-' * 10}")
-
- for r in sorted(results, key=lambda x: (x.success, -x.failed)):
- status = f"{Colors.GREEN}PASS{Colors.RESET}" if r.success else f"{Colors.RED}FAIL{Colors.RESET}"
- failed_str = f"{Colors.RED}{r.failed}{Colors.RESET}" if r.failed > 0 else str(r.failed)
- print(f" {r.folder:<20} {status:<19} {r.passed:<8} {failed_str:<17} {r.skipped:<8} {format_duration(r.duration)}")
-
- print()
-
- # Show failures
- failed_results = [r for r in results if not r.success]
- if failed_results:
- print(f" {Colors.BOLD}{Colors.RED}Failed Tests:{Colors.RESET}")
- print()
- for r in failed_results:
- print(f" {Colors.RED}{r.folder}:{Colors.RESET}")
- for detail in r.error_details[:5]:
- print(f" {Colors.DIM}{detail}{Colors.RESET}")
- print(f" {Colors.DIM}Log: {r.log_file}{Colors.RESET}")
- print()
-
- # Output location
- print(f" {Colors.BOLD}Output Location:{Colors.RESET}")
- print(f" {output_dir}")
- print()
-
- # Quick commands
- print(f" {Colors.BOLD}Useful Commands:{Colors.RESET}")
- print(f" View report: python tests/integration/run_tests.py --report")
- print(f" Check status: python tests/integration/run_tests.py --status")
- print(f" Re-run failed: python -m pytest -v --tb=long")
- print()
-
-
-def save_results_json(results: list[TestResult], total_duration: float, output_dir: Path, status: str = "completed"):
- """Save results as JSON for later reporting."""
- data = {
- "timestamp": datetime.now().isoformat(),
- "status": status,
- "total_duration": total_duration,
- "results": [
- {
- "folder": r.folder,
- "passed": r.passed,
- "failed": r.failed,
- "skipped": r.skipped,
- "errors": r.errors,
- "duration": r.duration,
- "log_file": r.log_file,
- "error_details": r.error_details,
- "status": r.status,
- }
- for r in results
- ],
- }
-
- json_file = output_dir / "results.json"
- json_file.write_text(json.dumps(data, indent=2))
-
-
-def list_all_runs() -> list[dict]:
- """List all test runs with their status."""
- results_dir = get_results_dir()
- if not results_dir.exists():
- return []
-
- runs = []
- for run_dir in sorted(results_dir.iterdir(), reverse=True):
- if not run_dir.is_dir():
- continue
-
- json_file = run_dir / "results.json"
- if json_file.exists():
- try:
- data = json.loads(json_file.read_text())
- runs.append({
- "dir": run_dir,
- "timestamp": run_dir.name,
- "status": data.get("status", "completed"),
- "data": data,
- })
- except json.JSONDecodeError:
- runs.append({
- "dir": run_dir,
- "timestamp": run_dir.name,
- "status": "error",
- "data": None,
- })
- else:
- # Check if any log files indicate running tests
- log_files = list(run_dir.glob("*.txt"))
- has_running = any(
- f.read_text().startswith("[RUNNING]")
- for f in log_files if f.exists()
- )
- runs.append({
- "dir": run_dir,
- "timestamp": run_dir.name,
- "status": "running" if has_running else "incomplete",
- "data": None,
- })
-
- return runs
-
-
-def show_status():
- """Show status of the most recent test run."""
- print_header("TEST RUN STATUS")
-
- runs = list_all_runs()
-
- if not runs:
- print(f" {Colors.YELLOW}No test runs found.{Colors.RESET}")
- print(f" Run tests with: python tests/integration/run_tests.py")
- return
-
- # Get the most recent run (running or completed)
- latest = runs[0]
- run_dir = latest["dir"]
- is_running = latest["status"] == "running"
-
- # Header
- status_label = f"{Colors.CYAN}RUNNING{Colors.RESET}" if is_running else "completed"
- print(f" {Colors.BOLD}Last run:{Colors.RESET} {format_timestamp(latest['timestamp'])} ({status_label})")
- print()
-
- # Collect status for all folders
- all_folders = get_test_folders()
- folder_status = {}
-
- for folder in all_folders:
- log_file = run_dir / f"{folder}.txt"
- if log_file.exists():
- status, result = parse_log_file_status(log_file)
- folder_status[folder] = (status, result)
- else:
- folder_status[folder] = ("pending", None)
-
- # Count totals
- total_passed = 0
- total_failed = 0
- running_count = 0
- completed_count = 0
-
- for folder, (status, result) in folder_status.items():
- if status == "running":
- running_count += 1
- elif result:
- completed_count += 1
- total_passed += result.passed
- total_failed += result.failed + result.errors
-
- # Show progress if running
- if is_running:
- print(f" Progress: {completed_count}/{len(all_folders)} folders completed, {running_count} running")
- print()
-
- # Show per-folder status
- print(f" {'Folder':<20} {'Status':<12} {'Result':<30}")
- print(f" {'-' * 20} {'-' * 12} {'-' * 30}")
-
- for folder in sorted(all_folders):
- status, result = folder_status.get(folder, ("pending", None))
-
- if status == "running":
- status_str = f"{Colors.CYAN}RUNNING{Colors.RESET}"
- result_str = ""
- elif status == "timeout":
- status_str = f"{Colors.RED}TIMEOUT{Colors.RESET}"
- result_str = f"{Colors.RED}Test timed out{Colors.RESET}"
- elif status == "pending":
- status_str = f"{Colors.DIM}pending{Colors.RESET}"
- result_str = ""
- elif result:
- if result.success:
- status_str = f"{Colors.GREEN}PASS{Colors.RESET}"
- result_str = f"{result.passed} passed"
- else:
- status_str = f"{Colors.RED}FAIL{Colors.RESET}"
- result_str = f"{Colors.RED}{result.passed} passed, {result.failed} failed{Colors.RESET}"
- else:
- status_str = f"{Colors.YELLOW}unknown{Colors.RESET}"
- result_str = ""
-
- print(f" {folder:<20} {status_str:<21} {result_str}")
-
- print()
-
- # Summary line
- if not is_running:
- all_pass = total_failed == 0
- status_color = Colors.GREEN if all_pass else Colors.RED
- status_text = "ALL PASSED" if all_pass else f"{total_failed} FAILED"
- print(f" {Colors.BOLD}Total:{Colors.RESET} {total_passed} passed, {status_color}{status_text}{Colors.RESET}")
- print()
-
- print(f" {Colors.BOLD}Commands:{Colors.RESET}")
- print(f" View full report: python tests/integration/run_tests.py --report")
- print()
-
-
-def load_and_show_report(timestamp: Optional[str] = None):
- """Load and display a report from a previous run."""
- results_dir = get_results_dir()
-
- if not results_dir.exists():
- print(f"{Colors.RED}No test results found. Run tests first.{Colors.RESET}")
- return
-
- # Find the results directory
- if timestamp and timestamp != "latest":
- output_dir = results_dir / timestamp
- if not output_dir.exists():
- print(f"{Colors.RED}No results found for timestamp: {timestamp}{Colors.RESET}")
- available = sorted([d.name for d in results_dir.iterdir() if d.is_dir()])[-5:]
- print(f"Available runs: {', '.join(available)}")
- return
- else:
- # Use latest
- dirs = sorted([d for d in results_dir.iterdir() if d.is_dir()])
- if not dirs:
- print(f"{Colors.RED}No test results found.{Colors.RESET}")
- return
- output_dir = dirs[-1]
-
- # Load JSON results
- json_file = output_dir / "results.json"
- if not json_file.exists():
- # Try to build results from log files
- print(f" {Colors.YELLOW}No results.json found, parsing log files...{Colors.RESET}")
- results = []
- for log_file in output_dir.glob("*.txt"):
- status, result = parse_log_file_status(log_file)
- if result:
- results.append(result)
- elif status == "running":
- results.append(TestResult(folder=log_file.stem, status="running"))
-
- if results:
- total_duration = sum(r.duration for r in results)
- print_summary(results, total_duration, output_dir, output_dir.name)
- else:
- print(f"{Colors.RED}No results found in {output_dir}{Colors.RESET}")
- return
-
- data = json.loads(json_file.read_text())
-
- # Convert to TestResult objects
- results = [
- TestResult(
- folder=r["folder"],
- passed=r["passed"],
- failed=r["failed"],
- skipped=r["skipped"],
- errors=r["errors"],
- duration=r["duration"],
- log_file=r["log_file"],
- error_details=r["error_details"],
- status=r.get("status", "completed"),
- )
- for r in data["results"]
- ]
-
- print_summary(results, data["total_duration"], output_dir, data.get("timestamp", output_dir.name))
-
-
-def cleanup_results(keep_last: int = 5):
- """Delete old test result directories."""
- print_header("CLEANUP TEST RESULTS")
-
- results_dir = get_results_dir()
- if not results_dir.exists():
- print(f" No test results to clean up.")
- return
-
- dirs = sorted([d for d in results_dir.iterdir() if d.is_dir()])
-
- if len(dirs) <= keep_last:
- print(f" Only {len(dirs)} runs found, keeping all.")
- return
-
- to_delete = dirs[:-keep_last]
- print(f" Keeping last {keep_last} runs, deleting {len(to_delete)} old runs...")
- print()
-
- for d in to_delete:
- print(f" Deleting: {d.name}")
- shutil.rmtree(d)
-
- print()
- print(f" {Colors.GREEN}Cleaned up {len(to_delete)} old test runs.{Colors.RESET}")
-
-
-def run_all_tests(include_slow: bool = False, max_workers: int = 8):
- """Run all integration tests in parallel."""
- print_header("INTEGRATION TEST RUNNER")
-
- # Get test folders
- folders = get_test_folders()
- print(f" Found {len(folders)} test folders: {', '.join(folders)}")
- print(f" Include slow tests: {include_slow}")
- print(f" Max parallel workers: {max_workers}")
- print()
-
- # Create output directory
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- results_dir = get_results_dir()
- output_dir = results_dir / timestamp
- output_dir.mkdir(parents=True, exist_ok=True)
-
- print(f" Started at: {format_timestamp(timestamp)}")
- print(f" Output directory: {output_dir}")
- print()
-
- # Run tests in parallel
- print(f" {Colors.BOLD}Running tests...{Colors.RESET}")
- print()
-
- results = []
- start_time = time.time()
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- # Submit all tasks
- future_to_folder = {
- executor.submit(run_test_folder, folder, output_dir, include_slow): folder
- for folder in folders
- }
-
- # Track running
- for folder in folders:
- print_progress(folder, "running")
-
- # Collect results as they complete
- for future in as_completed(future_to_folder):
- folder = future_to_folder[future]
- try:
- result = future.result()
- results.append(result)
- status = "done" if result.success else "failed"
- print_progress(folder, status, result.duration)
- except Exception as e:
- print(f" {Colors.RED}[ERROR]{Colors.RESET} {folder}: {e}")
- results.append(TestResult(folder=folder, errors=1, status="error"))
-
- total_duration = time.time() - start_time
-
- # Save results
- save_results_json(results, total_duration, output_dir, status="completed")
-
- # Print summary
- print_summary(results, total_duration, output_dir, timestamp)
-
- # Return exit code
- all_passed = all(r.success for r in results)
- return 0 if all_passed else 1
-
-
-def main():
- parser = argparse.ArgumentParser(
- description="Run integration tests in parallel with detailed reporting",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
-Examples:
- python tests/integration/run_tests.py # Run tests (excluding slow)
- python tests/integration/run_tests.py --all # Run all tests including slow
- python tests/integration/run_tests.py --report # Show latest report
- python tests/integration/run_tests.py --status # Check ongoing/recent runs
- python tests/integration/run_tests.py -j 4 # Run with 4 parallel workers
- """,
- )
-
- parser.add_argument(
- "--all", "-a",
- action="store_true",
- help="Include slow tests",
- )
- parser.add_argument(
- "--report", "-r",
- nargs="?",
- const="latest",
- metavar="TIMESTAMP",
- help="Show report from a previous run (default: latest)",
- )
- parser.add_argument(
- "--status", "-s",
- action="store_true",
- help="Show status of ongoing and recent test runs",
- )
- parser.add_argument(
- "--cleanup-results",
- action="store_true",
- help="Delete old test result directories (keeps last 5)",
- )
- parser.add_argument(
- "-j", "--jobs",
- type=int,
- default=8,
- help="Number of parallel test workers (default: 8)",
- )
-
- args = parser.parse_args()
-
- if args.status:
- show_status()
- return 0
-
- if args.cleanup_results:
- cleanup_results()
- return 0
-
- if args.report:
- timestamp = None if args.report == "latest" else args.report
- load_and_show_report(timestamp)
- return 0
-
- return run_all_tests(include_slow=args.all, max_workers=args.jobs)
-
-
-if __name__ == "__main__":
- sys.exit(main())
diff --git a/databricks-mcp-server/tests/integration/serving/__init__.py b/databricks-mcp-server/tests/integration/serving/__init__.py
deleted file mode 100644
index fcfaba9c..00000000
--- a/databricks-mcp-server/tests/integration/serving/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Serving integration tests
diff --git a/databricks-mcp-server/tests/integration/serving/test_serving.py b/databricks-mcp-server/tests/integration/serving/test_serving.py
deleted file mode 100644
index 8aaf57dc..00000000
--- a/databricks-mcp-server/tests/integration/serving/test_serving.py
+++ /dev/null
@@ -1,116 +0,0 @@
-"""
-Integration tests for model serving MCP tool.
-
-Tests:
-- manage_serving_endpoint: get, list, query
-
-Note: The manage_serving_endpoint tool only supports read-only operations.
-Creating/updating/deleting serving endpoints requires a separate skill
-(databricks-model-serving) or direct SDK usage.
-"""
-
-import logging
-
-import pytest
-
-from databricks_mcp_server.tools.serving import manage_serving_endpoint
-from tests.test_config import TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture(scope="module")
-def existing_serving_endpoint(workspace_client) -> str:
- """Find an existing serving endpoint for tests."""
- try:
- result = manage_serving_endpoint(action="list")
- endpoints = result.get("endpoints", [])
- for ep in endpoints:
- state = ep.get("state") or ep.get("status")
- if state in ("READY", "ONLINE", "NOT_UPDATING"):
- name = ep.get("name")
- logger.info(f"Using existing serving endpoint: {name}")
- return name
- except Exception as e:
- logger.warning(f"Could not list serving endpoints: {e}")
-
- pytest.skip("No existing serving endpoint available")
-
-
-@pytest.mark.integration
-class TestManageServingEndpoint:
- """Tests for manage_serving_endpoint tool."""
-
- def test_list_endpoints(self):
- """Should list all serving endpoints."""
- result = manage_serving_endpoint(action="list")
-
- logger.info(f"List result: {result}")
-
- assert "error" not in result, f"List failed: {result}"
- assert "endpoints" in result
- assert isinstance(result["endpoints"], list)
-
- def test_get_endpoint(self, existing_serving_endpoint: str):
- """Should get endpoint details."""
- result = manage_serving_endpoint(action="get", name=existing_serving_endpoint)
-
- logger.info(f"Get result: {result}")
-
- # API returns error: None as a field, check truthiness not presence
- assert not result.get("error"), f"Get failed: {result}"
- assert result.get("name") == existing_serving_endpoint
-
- def test_get_nonexistent_endpoint(self):
- """Should handle nonexistent endpoint gracefully."""
- try:
- result = manage_serving_endpoint(action="get", name="nonexistent_endpoint_xyz_12345")
-
- logger.info(f"Get nonexistent result: {result}")
-
- assert result.get("state") == "NOT_FOUND" or result.get("error")
- except Exception as e:
- # Function raises exception for nonexistent endpoint - this is acceptable
- error_msg = str(e).lower()
- assert "not exist" in error_msg or "not found" in error_msg
-
- def test_query_foundation_model(self):
- """Should query a foundation model endpoint."""
- result = manage_serving_endpoint(
- action="query",
- name="databricks-meta-llama-3-3-70b-instruct",
- messages=[
- {"role": "user", "content": "Say hello in one word."}
- ],
- max_tokens=10,
- )
-
- logger.info(f"Query result: {result}")
-
- assert "error" not in result, f"Query failed: {result}"
- # Should have some response
- assert result.get("choices") or result.get("predictions") or result.get("output")
-
- def test_query_with_invalid_endpoint(self):
- """Should handle invalid endpoint query gracefully."""
- try:
- result = manage_serving_endpoint(
- action="query",
- name="nonexistent_endpoint_xyz_12345",
- messages=[{"role": "user", "content": "test"}],
- )
-
- logger.info(f"Query invalid endpoint result: {result}")
-
- # Should return error
- assert result.get("error") or result.get("status") == "error"
- except Exception as e:
- # Function raises exception for invalid endpoint - this is acceptable
- error_msg = str(e).lower()
- assert "not exist" in error_msg or "not found" in error_msg
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_serving_endpoint(action="invalid_action")
-
- assert "error" in result
diff --git a/databricks-mcp-server/tests/integration/sql/__init__.py b/databricks-mcp-server/tests/integration/sql/__init__.py
deleted file mode 100644
index c3707b3c..00000000
--- a/databricks-mcp-server/tests/integration/sql/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# SQL integration tests
diff --git a/databricks-mcp-server/tests/integration/sql/test_sql.py b/databricks-mcp-server/tests/integration/sql/test_sql.py
deleted file mode 100644
index cb46c4cd..00000000
--- a/databricks-mcp-server/tests/integration/sql/test_sql.py
+++ /dev/null
@@ -1,182 +0,0 @@
-"""
-Integration tests for SQL MCP tools.
-
-Tests:
-- execute_sql: basic queries, catalog/schema context
-- manage_warehouse: list, get_best
-"""
-
-import logging
-
-import pytest
-
-from databricks_mcp_server.tools.sql import execute_sql, manage_warehouse
-from tests.test_config import TEST_CATALOG, SCHEMAS
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestExecuteSql:
- """Tests for execute_sql tool."""
-
- def test_simple_select(self, warehouse_id: str):
- """Should execute a simple SELECT statement."""
- result = execute_sql(
- sql_query="SELECT 1 as num, 'hello' as greeting",
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Result: {result}")
-
- # Result is now a markdown-formatted string
- assert isinstance(result, str)
- assert "(1 row)" in result
- assert "num" in result and "greeting" in result
- assert "| 1 |" in result and "hello" in result
-
- def test_select_with_multiple_rows(self, warehouse_id: str):
- """Should return multiple rows correctly."""
- result = execute_sql(
- sql_query="""
- SELECT * FROM (
- VALUES (1, 'a'), (2, 'b'), (3, 'c')
- ) AS t(id, letter)
- """,
- warehouse_id=warehouse_id,
- )
-
- # Result is now a markdown-formatted string
- assert isinstance(result, str)
- assert "(3 rows)" in result
- assert "id" in result and "letter" in result
- # Check all three rows are present
- assert "| 1 |" in result and "| a |" in result
- assert "| 2 |" in result and "| b |" in result
- assert "| 3 |" in result and "| c |" in result
-
- def test_create_and_query_table(
- self,
- warehouse_id: str,
- test_catalog: str,
- sql_schema: str,
- ):
- """Should create a table and query it."""
- table_name = f"{test_catalog}.{sql_schema}.test_table"
-
- # Create table
- execute_sql(
- sql_query=f"""
- CREATE OR REPLACE TABLE {table_name} (
- id INT,
- name STRING
- )
- """,
- warehouse_id=warehouse_id,
- )
-
- # Insert data
- execute_sql(
- sql_query=f"""
- INSERT INTO {table_name} VALUES
- (1, 'Alice'),
- (2, 'Bob')
- """,
- warehouse_id=warehouse_id,
- )
-
- # Query
- result = execute_sql(
- sql_query=f"SELECT * FROM {table_name} ORDER BY id",
- warehouse_id=warehouse_id,
- )
-
- # Result is now a markdown-formatted string
- assert isinstance(result, str)
- assert "(2 rows)" in result
- assert "Alice" in result
- assert "Bob" in result
-
- def test_catalog_schema_context(
- self,
- warehouse_id: str,
- test_catalog: str,
- sql_schema: str,
- ):
- """Should use catalog/schema context for unqualified names."""
- table_name = f"{test_catalog}.{sql_schema}.context_test"
-
- # Create table with qualified name
- execute_sql(
- sql_query=f"CREATE OR REPLACE TABLE {table_name} AS SELECT 1 as val",
- warehouse_id=warehouse_id,
- )
-
- # Query with unqualified name using context
- result = execute_sql(
- sql_query="SELECT * FROM context_test",
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=sql_schema,
- )
-
- # Result is now a markdown-formatted string
- assert isinstance(result, str)
- assert "(1 row)" in result
-
- def test_auto_select_warehouse(self, test_catalog: str, sql_schema: str):
- """Should auto-select warehouse when not provided."""
- result = execute_sql(
- sql_query="SELECT 1 as num",
- # warehouse_id not provided
- )
-
- # Result is now a markdown-formatted string
- assert isinstance(result, str)
- assert "(1 row)" in result
-
- def test_invalid_sql_returns_error(self, warehouse_id: str):
- """Should handle invalid SQL gracefully."""
- # This should raise or return error, not crash
- try:
- result = execute_sql(
- sql_query="SELECT * FROM nonexistent_table_xyz_12345",
- warehouse_id=warehouse_id,
- )
- # If it returns instead of raising, check for error indicators
- logger.info(f"Result for invalid SQL: {result}")
- except Exception as e:
- logger.info(f"Expected error for invalid SQL: {e}")
- error_msg = str(e).lower()
- assert "not found" in error_msg or "does not exist" in error_msg or "cannot be found" in error_msg
-
-
-@pytest.mark.integration
-class TestManageWarehouse:
- """Tests for manage_warehouse tool."""
-
- def test_list_warehouses(self):
- """Should list all warehouses."""
- result = manage_warehouse(action="list")
-
- logger.info(f"List result: {result}")
-
- assert "error" not in result, f"List failed: {result}"
- assert "warehouses" in result
- assert isinstance(result["warehouses"], list)
-
- def test_get_best_warehouse(self):
- """Should return the best available warehouse."""
- result = manage_warehouse(action="get_best")
-
- logger.info(f"Get best result: {result}")
-
- assert "error" not in result, f"Get best failed: {result}"
- # Should have warehouse info
- assert result.get("warehouse_id") or result.get("id")
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_warehouse(action="invalid_action")
-
- assert "error" in result
diff --git a/databricks-mcp-server/tests/integration/vector_search/__init__.py b/databricks-mcp-server/tests/integration/vector_search/__init__.py
deleted file mode 100644
index 53745641..00000000
--- a/databricks-mcp-server/tests/integration/vector_search/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Vector search integration tests
diff --git a/databricks-mcp-server/tests/integration/vector_search/test_vector_search.py b/databricks-mcp-server/tests/integration/vector_search/test_vector_search.py
deleted file mode 100644
index 49960bce..00000000
--- a/databricks-mcp-server/tests/integration/vector_search/test_vector_search.py
+++ /dev/null
@@ -1,362 +0,0 @@
-"""
-Integration tests for vector search MCP tools.
-
-Tests:
-- manage_vs_endpoint: create, get, list, delete
-- manage_vs_index: create, get, sync, delete
-- query_vs_index: basic queries
-"""
-
-import logging
-import uuid
-import time
-
-import pytest
-
-from databricks_mcp_server.tools.vector_search import (
- manage_vs_endpoint,
- manage_vs_index,
- query_vs_index,
-)
-from databricks_mcp_server.tools.sql import execute_sql
-from tests.test_config import TEST_CATALOG, SCHEMAS, TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-# Deterministic names for tests (enables safe cleanup/restart)
-VS_ENDPOINT_NAME = f"{TEST_RESOURCE_PREFIX}vs_endpoint"
-VS_ENDPOINT_DELETE = f"{TEST_RESOURCE_PREFIX}vs_ep_delete"
-VS_INDEX_NAME_SUFFIX = f"{TEST_RESOURCE_PREFIX}vs_index"
-VS_INDEX_DELETE_SUFFIX = f"{TEST_RESOURCE_PREFIX}vs_idx_delete"
-
-
-@pytest.fixture(scope="module")
-def clean_vs_resources():
- """Pre-test cleanup: delete any existing test VS endpoints and indexes."""
- endpoints_to_clean = [VS_ENDPOINT_NAME, VS_ENDPOINT_DELETE]
-
- for ep_name in endpoints_to_clean:
- try:
- result = manage_vs_endpoint(action="get", name=ep_name)
- if result.get("state") != "NOT_FOUND" and not result.get("error"):
- manage_vs_endpoint(action="delete", name=ep_name)
- logger.info(f"Pre-cleanup: deleted VS endpoint {ep_name}")
- except Exception as e:
- logger.warning(f"Pre-cleanup failed for endpoint {ep_name}: {e}")
-
- yield
-
- # Post-test cleanup
- for ep_name in endpoints_to_clean:
- try:
- result = manage_vs_endpoint(action="get", name=ep_name)
- if result.get("state") != "NOT_FOUND" and not result.get("error"):
- manage_vs_endpoint(action="delete", name=ep_name)
- logger.info(f"Post-cleanup: deleted VS endpoint {ep_name}")
- except Exception:
- pass
-
-
-@pytest.fixture(scope="module")
-def vs_source_table(
- workspace_client,
- test_catalog: str,
- vector_search_schema: str,
- warehouse_id: str,
-) -> str:
- """Create a source table for vector search index."""
- table_name = f"{test_catalog}.{vector_search_schema}.vs_source_table"
-
- # Create table with text content for embedding
- execute_sql(
- sql_query=f"""
- CREATE OR REPLACE TABLE {table_name} (
- id STRING,
- content STRING,
- category STRING
- )
- TBLPROPERTIES (delta.enableChangeDataFeed = true)
- """,
- warehouse_id=warehouse_id,
- )
-
- # Insert test data
- execute_sql(
- sql_query=f"""
- INSERT INTO {table_name} VALUES
- ('doc1', 'Databricks is a unified analytics platform', 'tech'),
- ('doc2', 'Machine learning helps analyze data', 'tech'),
- ('doc3', 'Python is a programming language', 'tech')
- """,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Created source table: {table_name}")
- return table_name
-
-
-@pytest.fixture(scope="module")
-def existing_vs_endpoint(workspace_client) -> str:
- """Find an existing VS endpoint for read-only tests."""
- try:
- result = manage_vs_endpoint(action="list")
- endpoints = result.get("endpoints", [])
- for ep in endpoints:
- if ep.get("state") == "ONLINE":
- logger.info(f"Using existing endpoint: {ep['name']}")
- return ep["name"]
- except Exception as e:
- logger.warning(f"Could not list endpoints: {e}")
-
- pytest.skip("No existing VS endpoint available")
-
-
-@pytest.mark.integration
-class TestManageVsEndpoint:
- """Tests for manage_vs_endpoint tool."""
-
- def test_list_endpoints(self):
- """Should list all VS endpoints."""
- result = manage_vs_endpoint(action="list")
-
- logger.info(f"List result: {result}")
-
- assert not result.get("error"), f"List failed: {result}"
- assert "endpoints" in result
- assert isinstance(result["endpoints"], list)
-
- def test_get_endpoint(self, existing_vs_endpoint: str):
- """Should get endpoint details."""
- result = manage_vs_endpoint(action="get", name=existing_vs_endpoint)
-
- logger.info(f"Get result: {result}")
-
- # API returns error: None as a field, check truthiness not presence
- assert not result.get("error"), f"Get failed: {result}"
- assert result.get("name") == existing_vs_endpoint
- assert result.get("state") is not None
-
- def test_get_nonexistent_endpoint(self):
- """Should handle nonexistent endpoint gracefully."""
- result = manage_vs_endpoint(action="get", name="nonexistent_endpoint_xyz_12345")
-
- assert result.get("state") == "NOT_FOUND" or "error" in result
-
- def test_create_endpoint(self, cleanup_vs_endpoints):
- """Should create a new endpoint."""
- name = f"{TEST_RESOURCE_PREFIX}vs_create_{uuid.uuid4().hex[:6]}"
- cleanup_vs_endpoints(name)
-
- result = manage_vs_endpoint(
- action="create_or_update", # API only supports create_or_update
- name=name,
- endpoint_type="STANDARD",
- )
-
- logger.info(f"Create result: {result}")
-
- assert not result.get("error"), f"Create failed: {result}"
- assert result.get("name") == name
- assert result.get("status") in ("CREATING", "ALREADY_EXISTS", "ONLINE", None)
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_vs_endpoint(action="invalid_action")
-
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestManageVsIndex:
- """Tests for manage_vs_index tool."""
-
- def test_list_indexes(self, existing_vs_endpoint: str):
- """Should list indexes on an endpoint."""
- result = manage_vs_index(action="list", endpoint_name=existing_vs_endpoint)
-
- logger.info(f"List indexes result: {result}")
-
- # May have no indexes, but should not error
- assert not result.get("error") or "indexes" in result
-
- def test_get_nonexistent_index(self):
- """Should handle nonexistent index gracefully."""
- result = manage_vs_index(
- action="get",
- name="nonexistent.schema.index_xyz_12345", # Param is 'name' not 'index_name'
- )
-
- assert result.get("state") == "NOT_FOUND" or result.get("error")
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_vs_index(action="invalid_action")
-
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestVsIndexLifecycle:
- """End-to-end test for VS index lifecycle: create -> get -> query -> delete."""
-
- def test_full_index_lifecycle(
- self,
- existing_vs_endpoint: str,
- vs_source_table: str,
- test_catalog: str,
- vector_search_schema: str,
- ):
- """Test complete index lifecycle: create, wait, get, query, delete, verify."""
- index_name = f"{test_catalog}.{vector_search_schema}.{VS_INDEX_NAME_SUFFIX}"
- test_start = time.time()
-
- def log_time(msg):
- elapsed = time.time() - test_start
- logger.info(f"[{elapsed:.1f}s] {msg}")
-
- try:
- # Step 1: Create index
- log_time("Step 1: Creating index...")
- create_result = manage_vs_index(
- action="create_or_update",
- name=index_name,
- endpoint_name=existing_vs_endpoint,
- primary_key="id",
- index_type="DELTA_SYNC",
- delta_sync_index_spec={
- "source_table": vs_source_table,
- "embedding_source_columns": [
- {
- "name": "content",
- "embedding_model_endpoint_name": "databricks-gte-large-en",
- }
- ],
- "pipeline_type": "TRIGGERED",
- "columns_to_sync": ["id", "content", "category"],
- },
- )
- log_time(f"Create result: {create_result}")
- assert not create_result.get("error"), f"Create index failed: {create_result}"
- # create_or_update returns "created" boolean and may have "status" field
- assert create_result.get("created") is True or create_result.get("status") in ("CREATING", "ALREADY_EXISTS", None)
-
- # Step 2: Wait for index to be ready and verify via get
- log_time("Step 2: Waiting for index to be ready...")
- max_wait = 300 # 5 minutes
- wait_interval = 10
- waited = 0
-
- while waited < max_wait:
- get_result = manage_vs_index(action="get", name=index_name)
- state = get_result.get("state", get_result.get("status"))
- ready = get_result.get("ready", False)
- log_time(f"Index state: {state}, ready: {ready}")
- if ready:
- log_time(f"Index ready after {waited}s")
- break
- time.sleep(wait_interval)
- waited += wait_interval
- else:
- pytest.skip(f"Index not ready after {max_wait}s, skipping remaining tests")
-
- # Step 3: Query the index
- log_time("Step 3: Querying index...")
- query_result = query_vs_index(
- index_name=index_name,
- columns=["id", "content", "category"],
- query_text="analytics platform",
- num_results=3,
- )
- log_time(f"Query result: {query_result}")
- assert not query_result.get("error"), f"Query failed: {query_result}"
-
- results = query_result.get("result", {}).get("data_array", []) or query_result.get("results", [])
- assert len(results) > 0, f"Query should return results: {query_result}"
-
- # Verify the expected document is found
- result_contents = str(results)
- assert "doc1" in result_contents or "Databricks" in result_contents or "analytics" in result_contents, \
- f"Query should return doc about Databricks analytics: {results}"
-
- # Step 4: Delete the index
- log_time("Step 4: Deleting index...")
- delete_result = manage_vs_index(action="delete", name=index_name)
- log_time(f"Delete result: {delete_result}")
- assert not delete_result.get("error"), f"Delete index failed: {delete_result}"
-
- # Step 5: Verify index is gone
- log_time("Step 5: Verifying deletion...")
- time.sleep(10)
- get_after = manage_vs_index(action="get", name=index_name)
- log_time(f"Get after delete: {get_after}")
- assert get_after.get("state") == "NOT_FOUND" or "error" in get_after, \
- f"Index should be deleted: {get_after}"
-
- log_time("Full index lifecycle test PASSED!")
-
- except Exception as e:
- # Cleanup on failure
- log_time(f"Test failed, attempting cleanup: {e}")
- try:
- manage_vs_index(action="delete", name=index_name)
- except Exception:
- pass
- raise
-
-
-@pytest.mark.integration
-class TestVsEndpointLifecycle:
- """End-to-end test for VS endpoint lifecycle: create -> get -> delete."""
-
- def test_full_endpoint_lifecycle(self, clean_vs_resources):
- """Test complete endpoint lifecycle: create, get, delete, verify."""
- test_start = time.time()
-
- def log_time(msg):
- elapsed = time.time() - test_start
- logger.info(f"[{elapsed:.1f}s] {msg}")
-
- try:
- # Step 1: Create endpoint
- log_time("Step 1: Creating endpoint...")
- create_result = manage_vs_endpoint(
- action="create_or_update",
- name=VS_ENDPOINT_DELETE,
- endpoint_type="STANDARD",
- )
- log_time(f"Create result: {create_result}")
- assert not create_result.get("error"), f"Create failed: {create_result}"
- assert create_result.get("name") == VS_ENDPOINT_DELETE
-
- # Step 2: Verify endpoint exists via get
- log_time("Step 2: Verifying endpoint exists...")
- time.sleep(5)
- get_before = manage_vs_endpoint(action="get", name=VS_ENDPOINT_DELETE)
- log_time(f"Get result: {get_before}")
- assert get_before.get("state") != "NOT_FOUND", \
- f"Endpoint should exist after create: {get_before}"
-
- # Step 3: Delete endpoint
- log_time("Step 3: Deleting endpoint...")
- delete_result = manage_vs_endpoint(action="delete", name=VS_ENDPOINT_DELETE)
- log_time(f"Delete result: {delete_result}")
- assert not delete_result.get("error"), f"Delete failed: {delete_result}"
-
- # Step 4: Verify endpoint is gone
- log_time("Step 4: Verifying deletion...")
- time.sleep(5)
- get_after = manage_vs_endpoint(action="get", name=VS_ENDPOINT_DELETE)
- log_time(f"Get after delete: {get_after}")
- assert get_after.get("state") == "NOT_FOUND" or "error" in get_after, \
- f"Endpoint should be deleted: {get_after}"
-
- log_time("Full endpoint lifecycle test PASSED!")
-
- except Exception as e:
- # Cleanup on failure
- log_time(f"Test failed, attempting cleanup: {e}")
- try:
- manage_vs_endpoint(action="delete", name=VS_ENDPOINT_DELETE)
- except Exception:
- pass
- raise
diff --git a/databricks-mcp-server/tests/integration/volume_files/__init__.py b/databricks-mcp-server/tests/integration/volume_files/__init__.py
deleted file mode 100644
index 047b6220..00000000
--- a/databricks-mcp-server/tests/integration/volume_files/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Volume files integration tests
diff --git a/databricks-mcp-server/tests/integration/volume_files/test_volume_files.py b/databricks-mcp-server/tests/integration/volume_files/test_volume_files.py
deleted file mode 100644
index 318160f5..00000000
--- a/databricks-mcp-server/tests/integration/volume_files/test_volume_files.py
+++ /dev/null
@@ -1,265 +0,0 @@
-"""
-Integration tests for volume files MCP tool.
-
-Tests:
-- manage_volume_files: upload, download, list, delete
-"""
-
-import logging
-import tempfile
-from pathlib import Path
-
-import pytest
-
-from databricks_mcp_server.tools.volume_files import manage_volume_files
-from tests.test_config import TEST_CATALOG, SCHEMAS, TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture(scope="module")
-def test_volume(
- workspace_client,
- test_catalog: str,
- volume_files_schema: str,
-) -> str:
- """Create a test volume for file operations."""
- from databricks.sdk.service.catalog import VolumeType
-
- volume_name = f"{TEST_RESOURCE_PREFIX}volume"
- full_volume_name = f"{test_catalog}.{volume_files_schema}.{volume_name}"
-
- # Delete if exists
- try:
- workspace_client.volumes.delete(full_volume_name)
- except Exception:
- pass
-
- # Create volume
- workspace_client.volumes.create(
- catalog_name=test_catalog,
- schema_name=volume_files_schema,
- name=volume_name,
- volume_type=VolumeType.MANAGED,
- )
-
- logger.info(f"Created test volume: {full_volume_name}")
-
- yield full_volume_name
-
- # Cleanup
- try:
- workspace_client.volumes.delete(full_volume_name)
- logger.info(f"Cleaned up volume: {full_volume_name}")
- except Exception as e:
- logger.warning(f"Failed to cleanup volume: {e}")
-
-
-@pytest.fixture
-def test_local_file():
- """Create a temporary local file for upload tests."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
- f.write("Hello from MCP integration test!\n")
- f.write("Line 2 of test file.\n")
- temp_path = f.name
-
- yield temp_path
-
- # Cleanup
- try:
- Path(temp_path).unlink()
- except Exception:
- pass
-
-
-@pytest.fixture
-def test_local_dir():
- """Create a temporary local directory with files for upload tests."""
- with tempfile.TemporaryDirectory() as temp_dir:
- # Create some test files
- (Path(temp_dir) / "file1.txt").write_text("File 1 content")
- (Path(temp_dir) / "file2.txt").write_text("File 2 content")
- (Path(temp_dir) / "subdir").mkdir()
- (Path(temp_dir) / "subdir" / "file3.txt").write_text("File 3 in subdir")
-
- yield temp_dir
-
-
-@pytest.mark.integration
-class TestManageVolumeFiles:
- """Tests for manage_volume_files tool."""
-
- def test_upload_single_file(
- self,
- test_volume: str,
- test_local_file: str,
- test_catalog: str,
- volume_files_schema: str,
- ):
- """Should upload a single file to volume and verify it exists."""
- file_name = Path(test_local_file).name
- # For single file upload, volume_path must include the destination filename
- volume_path = f"/Volumes/{test_catalog}/{volume_files_schema}/{TEST_RESOURCE_PREFIX}volume/{file_name}"
-
- result = manage_volume_files(
- action="upload",
- volume_path=volume_path,
- local_path=test_local_file,
- overwrite=True,
- )
-
- logger.info(f"Upload result: {result}")
-
- assert not result.get("error"), f"Upload failed: {result}"
- assert result.get("success", False) or result.get("status") == "success"
-
- # Verify file exists by listing the parent directory
- volume_dir = f"/Volumes/{test_catalog}/{volume_files_schema}/{TEST_RESOURCE_PREFIX}volume"
- list_result = manage_volume_files(action="list", volume_path=volume_dir)
- assert not list_result.get("error"), f"List failed: {list_result}"
-
- # Check file appears in listing
- files = list_result.get("files", []) or list_result.get("contents", [])
- file_names = [f.get("name") or f.get("path", "").split("/")[-1] for f in files]
- assert file_name in file_names, f"Uploaded file {file_name} not found in {file_names}"
-
- def test_upload_directory(
- self,
- test_volume: str,
- test_local_dir: str,
- test_catalog: str,
- volume_files_schema: str,
- ):
- """Should upload a directory to volume."""
- volume_path = f"/Volumes/{test_catalog}/{volume_files_schema}/{TEST_RESOURCE_PREFIX}volume/test_dir"
-
- result = manage_volume_files(
- action="upload",
- local_path=test_local_dir,
- volume_path=volume_path,
- overwrite=True,
- )
-
- logger.info(f"Upload directory result: {result}")
-
- assert "error" not in result, f"Upload failed: {result}"
-
- def test_list_files(
- self,
- test_volume: str,
- test_catalog: str,
- volume_files_schema: str,
- ):
- """Should list files in volume."""
- volume_path = f"/Volumes/{test_catalog}/{volume_files_schema}/{TEST_RESOURCE_PREFIX}volume"
-
- result = manage_volume_files(
- action="list",
- volume_path=volume_path,
- )
-
- logger.info(f"List result: {result}")
-
- assert "error" not in result, f"List failed: {result}"
-
- def test_download_file(
- self,
- test_volume: str,
- test_local_file: str,
- test_catalog: str,
- volume_files_schema: str,
- ):
- """Should download a file from volume and verify content matches."""
- volume_dir = f"/Volumes/{test_catalog}/{volume_files_schema}/{TEST_RESOURCE_PREFIX}volume"
- file_name = Path(test_local_file).name
- # For single file upload, volume_path must include the destination filename
- volume_file_path = f"{volume_dir}/{file_name}"
-
- # Read original content
- original_content = Path(test_local_file).read_text()
-
- # First upload a file
- manage_volume_files(
- action="upload",
- volume_path=volume_file_path,
- local_path=test_local_file,
- overwrite=True,
- )
-
- # Download to temp location
- with tempfile.TemporaryDirectory() as temp_dir:
- # local_destination must be the full file path, not just directory
- local_file_path = str(Path(temp_dir) / file_name)
- result = manage_volume_files(
- action="download",
- volume_path=volume_file_path,
- local_destination=local_file_path,
- )
-
- logger.info(f"Download result: {result}")
-
- assert not result.get("error"), f"Download failed: {result}"
-
- # Verify downloaded content matches original
- downloaded_file = Path(local_file_path)
- assert downloaded_file.exists(), f"Downloaded file not found at {downloaded_file}"
-
- downloaded_content = downloaded_file.read_text()
- assert downloaded_content == original_content, \
- f"Content mismatch: expected {original_content!r}, got {downloaded_content!r}"
-
- def test_delete_file(
- self,
- test_volume: str,
- test_local_file: str,
- test_catalog: str,
- volume_files_schema: str,
- ):
- """Should delete a file from volume and verify it's gone."""
- volume_dir = f"/Volumes/{test_catalog}/{volume_files_schema}/{TEST_RESOURCE_PREFIX}volume"
- # Use a unique file name for this test to avoid conflicts
- delete_test_file = Path(test_local_file).parent / f"delete_test_{Path(test_local_file).name}"
- delete_test_file.write_text("File to be deleted")
- file_name = delete_test_file.name
- # For single file upload, volume_path must include the destination filename
- volume_file_path = f"{volume_dir}/{file_name}"
-
- try:
- # First upload a file
- manage_volume_files(
- action="upload",
- volume_path=volume_file_path,
- local_path=str(delete_test_file),
- overwrite=True,
- )
-
- # Verify file exists before delete
- list_before = manage_volume_files(action="list", volume_path=volume_dir)
- files_before = list_before.get("files", []) or list_before.get("contents", [])
- file_names_before = [f.get("name") or f.get("path", "").split("/")[-1] for f in files_before]
- assert file_name in file_names_before, f"File {file_name} should exist before delete"
-
- # Delete it
- result = manage_volume_files(
- action="delete",
- volume_path=volume_file_path,
- )
-
- logger.info(f"Delete result: {result}")
-
- assert not result.get("error"), f"Delete failed: {result}"
-
- # Verify file is gone
- list_after = manage_volume_files(action="list", volume_path=volume_dir)
- files_after = list_after.get("files", []) or list_after.get("contents", [])
- file_names_after = [f.get("name") or f.get("path", "").split("/")[-1] for f in files_after]
- assert file_name not in file_names_after, f"File {file_name} should be deleted but still exists"
- finally:
- # Cleanup local temp file
- delete_test_file.unlink(missing_ok=True)
-
- def test_invalid_action(self):
- """Should return error for invalid action."""
- result = manage_volume_files(action="invalid_action", volume_path="/Volumes/dummy/path")
-
- assert "error" in result
diff --git a/databricks-mcp-server/tests/integration/workspace_files/__init__.py b/databricks-mcp-server/tests/integration/workspace_files/__init__.py
deleted file mode 100644
index 3b06becf..00000000
--- a/databricks-mcp-server/tests/integration/workspace_files/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Workspace files integration tests
diff --git a/databricks-mcp-server/tests/integration/workspace_files/test_workspace_files.py b/databricks-mcp-server/tests/integration/workspace_files/test_workspace_files.py
deleted file mode 100644
index 05039db3..00000000
--- a/databricks-mcp-server/tests/integration/workspace_files/test_workspace_files.py
+++ /dev/null
@@ -1,587 +0,0 @@
-"""
-Integration tests for workspace files MCP tool.
-
-Tests:
-- manage_workspace_files: upload, delete
-- File type preservation (Python files should remain FILE, not NOTEBOOK)
-"""
-
-import logging
-import tempfile
-from pathlib import Path
-
-import pytest
-from databricks.sdk import WorkspaceClient
-
-from databricks_mcp_server.tools.file import manage_workspace_files
-from tests.test_config import TEST_RESOURCE_PREFIX
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture
-def test_local_file():
- """Create a temporary local file for upload tests."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write("# Test Python file\n")
- f.write("print('Hello from MCP test')\n")
- temp_path = f.name
-
- yield temp_path
-
- # Cleanup
- try:
- Path(temp_path).unlink()
- except Exception:
- pass
-
-
-@pytest.fixture
-def test_local_dir():
- """Create a temporary local directory with files for upload tests."""
- with tempfile.TemporaryDirectory() as temp_dir:
- # Create some test files
- (Path(temp_dir) / "script1.py").write_text("# Script 1\nprint('one')")
- (Path(temp_dir) / "script2.py").write_text("# Script 2\nprint('two')")
- (Path(temp_dir) / "subdir").mkdir()
- (Path(temp_dir) / "subdir" / "script3.py").write_text("# Script 3\nprint('three')")
-
- yield temp_dir
-
-
-@pytest.mark.integration
-class TestManageWorkspaceFiles:
- """Tests for manage_workspace_files tool."""
-
- def test_upload_single_file(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- test_local_file: str,
- ):
- """Should upload a single file to workspace and verify it exists as FILE type (not NOTEBOOK)."""
- upload_path = f"{workspace_test_path}/single_file_test"
- file_name = Path(test_local_file).name
-
- result = manage_workspace_files(
- action="upload",
- local_path=test_local_file,
- workspace_path=upload_path,
- overwrite=True,
- )
-
- logger.info(f"Upload result: {result}")
-
- assert "error" not in result or result.get("error") is None, f"Upload failed: {result}"
- assert result.get("success", False), f"Upload not successful: {result}"
-
- # List the parent directory to see what was created
- parent_objects = list(workspace_client.workspace.list(workspace_test_path))
- logger.info(f"Objects in parent {workspace_test_path}: {[(obj.path, obj.object_type) for obj in parent_objects]}")
-
- # Find what was created at our upload_path
- created_obj = next((obj for obj in parent_objects if "single_file_test" in obj.path), None)
- assert created_obj is not None, f"Upload path not found in {[obj.path for obj in parent_objects]}"
-
- logger.info(f"Created object: path={created_obj.path}, type={created_obj.object_type}")
-
- # If it's a directory, list its contents to find the .py file
- if created_obj.object_type and created_obj.object_type.value == "DIRECTORY":
- inner_objects = list(workspace_client.workspace.list(upload_path))
- logger.info(f"Contents of {upload_path}: {[(obj.path, obj.object_type) for obj in inner_objects]}")
-
- # Find the .py file
- uploaded_file = next((obj for obj in inner_objects if obj.path.endswith(".py")), None)
- assert uploaded_file is not None, f"Could not find .py file in {[obj.path for obj in inner_objects]}"
-
- object_type = uploaded_file.object_type.value if uploaded_file.object_type else None
- else:
- # The upload might have created a file directly (rare case)
- object_type = created_obj.object_type.value if created_obj.object_type else None
- uploaded_file = created_obj
-
- logger.info(f"Uploaded file object_type: {object_type}")
-
- # Python files should be stored as FILE, not NOTEBOOK
- assert object_type == "FILE", \
- f"Python file should be uploaded as FILE type, not {object_type}. " \
- f"This indicates a bug where .py files are converted to notebooks during import."
-
- def test_upload_directory(
- self,
- workspace_test_path: str,
- test_local_dir: str,
- ):
- """Should upload a directory to workspace."""
- result = manage_workspace_files(
- action="upload",
- local_path=test_local_dir,
- workspace_path=f"{workspace_test_path}/test_dir",
- overwrite=True,
- )
-
- logger.info(f"Upload directory result: {result}")
-
- assert "error" not in result or result.get("error") is None, f"Upload failed: {result}"
- assert result.get("success", False), f"Upload not successful: {result}"
-
- def test_list_files_via_sdk(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- test_local_dir: str,
- ):
- """Should upload files and verify listing via SDK."""
- # First upload some files
- upload_path = f"{workspace_test_path}/list_test"
- manage_workspace_files(
- action="upload",
- local_path=test_local_dir,
- workspace_path=upload_path,
- overwrite=True,
- )
-
- # List files using SDK
- objects = list(workspace_client.workspace.list(upload_path))
- logger.info(f"Listed objects: {[obj.path for obj in objects]}")
-
- assert len(objects) > 0, "Should have uploaded files"
-
- def test_delete_path(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- test_local_file: str,
- ):
- """Should delete a file/directory from workspace and verify it's gone."""
- # First upload a file
- upload_path = f"{workspace_test_path}/delete_test"
- manage_workspace_files(
- action="upload",
- local_path=test_local_file,
- workspace_path=upload_path,
- overwrite=True,
- )
-
- # Verify it exists before delete using SDK
- objects_before = list(workspace_client.workspace.list(workspace_test_path))
- paths_before = [obj.path for obj in objects_before]
- assert any("delete_test" in p for p in paths_before), f"Path should exist before delete: {paths_before}"
-
- # Delete it
- result = manage_workspace_files(
- action="delete",
- workspace_path=upload_path,
- recursive=True,
- )
-
- logger.info(f"Delete result: {result}")
-
- assert result.get("success", False), f"Delete failed: {result}"
-
- # Verify it's gone using SDK
- objects_after = list(workspace_client.workspace.list(workspace_test_path))
- paths_after = [obj.path for obj in objects_after]
- assert not any("delete_test" in p for p in paths_after), f"Path should be deleted: {paths_after}"
-
- def test_invalid_action(self, workspace_test_path: str):
- """Should return error for invalid action."""
- result = manage_workspace_files(
- action="invalid_action",
- workspace_path=workspace_test_path,
- )
-
- assert "error" in result
-
- def test_file_type_preservation(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- ):
- """Should preserve file types during upload - .py files should remain FILE, not NOTEBOOK.
-
- This test specifically catches bugs where Python files are incorrectly
- converted to Databricks notebooks during workspace import.
- """
- # Create various file types
- with tempfile.TemporaryDirectory() as temp_dir:
- temp_path = Path(temp_dir)
-
- # Create files with different extensions
- test_files = {
- "script.py": ("# Python script\nprint('hello')", "FILE"),
- "data.json": ('{"key": "value"}', "FILE"),
- "config.yaml": ("key: value", "FILE"),
- "readme.txt": ("Plain text file", "FILE"),
- }
-
- for filename, (content, expected_type) in test_files.items():
- (temp_path / filename).write_text(content)
-
- # Upload all files
- upload_path = f"{workspace_test_path}/type_preservation_test"
- result = manage_workspace_files(
- action="upload",
- local_path=str(temp_path),
- workspace_path=upload_path,
- overwrite=True,
- )
-
- assert result.get("success", False), f"Upload failed: {result}"
-
- # List contents of the upload directory
- # When uploading a temp directory, it creates a subdirectory with the temp dir name
- objects = list(workspace_client.workspace.list(upload_path))
- logger.info(f"Listed objects in {upload_path}: {[(obj.path, obj.object_type) for obj in objects]}")
-
- # If there's a subdirectory (from temp dir), look inside it
- if objects and objects[0].object_type and objects[0].object_type.value == "DIRECTORY":
- inner_dir = objects[0].path
- objects = list(workspace_client.workspace.list(inner_dir))
- logger.info(f"Listed objects in nested dir {inner_dir}: {[(obj.path, obj.object_type) for obj in objects]}")
-
- for filename, (_, expected_type) in test_files.items():
- # Find this file in the listing
- file_obj = next(
- (obj for obj in objects if filename in obj.path),
- None
- )
-
- assert file_obj is not None, f"File {filename} not found in workspace listing: {[obj.path for obj in objects]}"
-
- actual_type = file_obj.object_type.value if file_obj.object_type else None
-
- assert actual_type == expected_type, \
- f"File {filename} should be {expected_type}, but got {actual_type}. " \
- f"This indicates a bug in file type handling during workspace import."
-
- logger.info("All file types preserved correctly")
-
-
-@pytest.mark.integration
-class TestNotebookUpload:
- """Tests for notebook vs file type handling during upload.
-
- Databricks notebooks have special markers (e.g., '# Databricks notebook source')
- that distinguish them from regular files. Files with these markers should be
- imported as NOTEBOOK objects, while regular files should remain as FILE objects.
- """
-
- def test_upload_python_notebook(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- ):
- """Python files with notebook marker should be uploaded as NOTEBOOK type."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- # Write Databricks notebook marker + content
- f.write("# Databricks notebook source\n")
- f.write("print('Hello from Python notebook')\n")
- temp_path = f.name
-
- try:
- upload_path = f"{workspace_test_path}/python_notebook_test"
-
- result = manage_workspace_files(
- action="upload",
- local_path=temp_path,
- workspace_path=upload_path,
- overwrite=True,
- )
-
- logger.info(f"Upload Python notebook result: {result}")
- assert result.get("success", False), f"Upload failed: {result}"
-
- # Verify the uploaded object type
- info = workspace_client.workspace.get_status(upload_path)
- logger.info(f"Python notebook status: type={info.object_type}, language={info.language}")
-
- assert info.object_type.value == "NOTEBOOK", \
- f"Python notebook should be NOTEBOOK type, got {info.object_type}"
- assert info.language.value == "PYTHON", \
- f"Python notebook should have PYTHON language, got {info.language}"
-
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_upload_sql_notebook(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- ):
- """SQL files with notebook marker should be uploaded as NOTEBOOK type."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as f:
- # Write Databricks notebook marker + content
- f.write("-- Databricks notebook source\n")
- f.write("SELECT 1 AS test_value\n")
- temp_path = f.name
-
- try:
- upload_path = f"{workspace_test_path}/sql_notebook_test"
-
- result = manage_workspace_files(
- action="upload",
- local_path=temp_path,
- workspace_path=upload_path,
- overwrite=True,
- )
-
- logger.info(f"Upload SQL notebook result: {result}")
- assert result.get("success", False), f"Upload failed: {result}"
-
- # Verify the uploaded object type
- info = workspace_client.workspace.get_status(upload_path)
- logger.info(f"SQL notebook status: type={info.object_type}, language={info.language}")
-
- assert info.object_type.value == "NOTEBOOK", \
- f"SQL notebook should be NOTEBOOK type, got {info.object_type}"
- assert info.language.value == "SQL", \
- f"SQL notebook should have SQL language, got {info.language}"
-
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_upload_scala_notebook(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- ):
- """Scala files with notebook marker should be uploaded as NOTEBOOK type."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".scala", delete=False) as f:
- # Write Databricks notebook marker + content
- f.write("// Databricks notebook source\n")
- f.write("println(\"Hello from Scala notebook\")\n")
- temp_path = f.name
-
- try:
- upload_path = f"{workspace_test_path}/scala_notebook_test"
-
- result = manage_workspace_files(
- action="upload",
- local_path=temp_path,
- workspace_path=upload_path,
- overwrite=True,
- )
-
- logger.info(f"Upload Scala notebook result: {result}")
- assert result.get("success", False), f"Upload failed: {result}"
-
- # Verify the uploaded object type
- info = workspace_client.workspace.get_status(upload_path)
- logger.info(f"Scala notebook status: type={info.object_type}, language={info.language}")
-
- assert info.object_type.value == "NOTEBOOK", \
- f"Scala notebook should be NOTEBOOK type, got {info.object_type}"
- assert info.language.value == "SCALA", \
- f"Scala notebook should have SCALA language, got {info.language}"
-
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_upload_regular_python_file(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- ):
- """Python files WITHOUT notebook marker should remain as FILE type."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- # Regular Python file (no notebook marker)
- f.write("# Regular Python script\n")
- f.write("def hello():\n")
- f.write(" print('Hello from regular Python file')\n")
- temp_path = f.name
-
- try:
- upload_path = f"{workspace_test_path}/regular_python_test.py"
-
- result = manage_workspace_files(
- action="upload",
- local_path=temp_path,
- workspace_path=upload_path,
- overwrite=True,
- )
-
- logger.info(f"Upload regular Python file result: {result}")
- assert result.get("success", False), f"Upload failed: {result}"
-
- # Verify the uploaded object type
- info = workspace_client.workspace.get_status(upload_path)
- logger.info(f"Regular Python file status: type={info.object_type}")
-
- assert info.object_type.value == "FILE", \
- f"Regular Python file should be FILE type, got {info.object_type}. " \
- f"Files without notebook markers should NOT be converted to notebooks."
-
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_upload_regular_sql_file(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- ):
- """SQL files WITHOUT notebook marker should remain as FILE type."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as f:
- # Regular SQL file (no notebook marker)
- f.write("-- Regular SQL script\n")
- f.write("SELECT * FROM some_table WHERE id = 1;\n")
- temp_path = f.name
-
- try:
- upload_path = f"{workspace_test_path}/regular_sql_test.sql"
-
- result = manage_workspace_files(
- action="upload",
- local_path=temp_path,
- workspace_path=upload_path,
- overwrite=True,
- )
-
- logger.info(f"Upload regular SQL file result: {result}")
- assert result.get("success", False), f"Upload failed: {result}"
-
- # Verify the uploaded object type
- info = workspace_client.workspace.get_status(upload_path)
- logger.info(f"Regular SQL file status: type={info.object_type}")
-
- assert info.object_type.value == "FILE", \
- f"Regular SQL file should be FILE type, got {info.object_type}. " \
- f"Files without notebook markers should NOT be converted to notebooks."
-
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_upload_mixed_directory(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- ):
- """Uploading a directory with both notebooks and regular files should preserve types."""
- with tempfile.TemporaryDirectory() as temp_dir:
- temp_path = Path(temp_dir)
-
- # Create various file types
- test_files = {
- # Notebooks (with marker)
- "notebook_python.py": (
- "# Databricks notebook source\nprint('Python notebook')",
- "NOTEBOOK",
- "PYTHON"
- ),
- "notebook_sql.sql": (
- "-- Databricks notebook source\nSELECT 1",
- "NOTEBOOK",
- "SQL"
- ),
- # Regular files (no marker)
- "script.py": (
- "# Regular script\nprint('hello')",
- "FILE",
- None
- ),
- "query.sql": (
- "-- Regular query\nSELECT * FROM table",
- "FILE",
- None
- ),
- "data.json": (
- '{"key": "value"}',
- "FILE",
- None
- ),
- }
-
- for filename, (content, _, _) in test_files.items():
- (temp_path / filename).write_text(content)
-
- # Upload directory contents (trailing slash = copy contents, like cp -r src/ dest/)
- upload_path = f"{workspace_test_path}/mixed_directory_test"
- result = manage_workspace_files(
- action="upload",
- local_path=str(temp_path) + "/", # Trailing slash = copy contents directly
- workspace_path=upload_path,
- overwrite=True,
- )
-
- logger.info(f"Upload mixed directory result: {result}")
- assert result.get("success", False), f"Upload failed: {result}"
-
- # List and verify each file's type
- objects = list(workspace_client.workspace.list(upload_path))
- logger.info(f"Listed objects: {[(obj.path, obj.object_type) for obj in objects]}")
-
- for filename, (_, expected_type, expected_lang) in test_files.items():
- # Find the file - notebooks don't have extensions in path
- if expected_type == "NOTEBOOK":
- # Notebooks are stored without extension
- name_without_ext = filename.rsplit(".", 1)[0]
- file_obj = next(
- (obj for obj in objects if name_without_ext in obj.path and expected_type == obj.object_type.value),
- None
- )
- else:
- # Regular files keep their extension
- file_obj = next(
- (obj for obj in objects if filename in obj.path),
- None
- )
-
- assert file_obj is not None, \
- f"File {filename} not found in workspace: {[obj.path for obj in objects]}"
-
- actual_type = file_obj.object_type.value if file_obj.object_type else None
- assert actual_type == expected_type, \
- f"File {filename} should be {expected_type}, got {actual_type}"
-
- if expected_lang:
- actual_lang = file_obj.language.value if file_obj.language else None
- assert actual_lang == expected_lang, \
- f"Notebook {filename} should have language {expected_lang}, got {actual_lang}"
-
- logger.info("All files in mixed directory have correct types")
-
- def test_upload_notebook_to_directory(
- self,
- workspace_client: WorkspaceClient,
- workspace_test_path: str,
- ):
- """Uploading a notebook to a directory path should work correctly."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write("# Databricks notebook source\n")
- f.write("print('Notebook in directory')\n")
- temp_path = f.name
-
- try:
- # Create target directory first
- dir_path = f"{workspace_test_path}/notebook_in_dir"
- workspace_client.workspace.mkdirs(dir_path)
-
- # Upload to the directory (not to a specific file path)
- result = manage_workspace_files(
- action="upload",
- local_path=temp_path,
- workspace_path=dir_path,
- overwrite=True,
- )
-
- logger.info(f"Upload notebook to directory result: {result}")
- assert result.get("success", False), f"Upload failed: {result}"
-
- # List directory contents
- objects = list(workspace_client.workspace.list(dir_path))
- logger.info(f"Directory contents: {[(obj.path, obj.object_type, obj.language) for obj in objects]}")
-
- assert len(objects) > 0, f"Directory should contain the uploaded notebook"
-
- # Find the notebook
- notebook = next(
- (obj for obj in objects if obj.object_type.value == "NOTEBOOK"),
- None
- )
- assert notebook is not None, \
- f"Should find a NOTEBOOK in directory, got: {[(obj.path, obj.object_type) for obj in objects]}"
- assert notebook.language.value == "PYTHON", \
- f"Notebook should be PYTHON, got {notebook.language}"
-
- finally:
- Path(temp_path).unlink(missing_ok=True)
diff --git a/databricks-mcp-server/tests/test_compute_tools.py b/databricks-mcp-server/tests/test_compute_tools.py
deleted file mode 100644
index 677f350b..00000000
--- a/databricks-mcp-server/tests/test_compute_tools.py
+++ /dev/null
@@ -1,392 +0,0 @@
-"""
-Unit tests for consolidated compute tools.
-
-Tests the MCP tool wrapper routing logic without hitting Databricks APIs.
-"""
-
-import pytest
-from unittest.mock import patch, MagicMock
-from databricks_mcp_server.tools.compute import (
- execute_code,
- manage_cluster,
- manage_sql_warehouse,
- list_compute,
-)
-
-
-# ---------------------------------------------------------------------------
-# execute_code routing tests
-# ---------------------------------------------------------------------------
-
-
-class TestExecuteCodeRouting:
- """Test that execute_code routes to the correct backend."""
-
- def test_requires_code_or_file_path(self):
- result = execute_code()
- assert result["success"] is False
- assert "code" in result["error"].lower() or "file_path" in result["error"].lower()
-
- def test_empty_strings_treated_as_none(self):
- result = execute_code(code="", file_path="")
- assert result["success"] is False
- assert "code" in result["error"].lower() or "file_path" in result["error"].lower()
-
- @patch("databricks_mcp_server.tools.compute._run_code_on_serverless")
- def test_auto_routes_to_serverless_for_python(self, mock_serverless):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True, "output": "hello"}
- mock_serverless.return_value = mock_result
-
- execute_code(code="print('hi')", compute_type="auto")
-
- mock_serverless.assert_called_once()
- call_kwargs = mock_serverless.call_args[1]
- assert call_kwargs["code"] == "print('hi')"
- assert call_kwargs["language"] == "python"
-
- @patch("databricks_mcp_server.tools.compute._execute_databricks_command")
- def test_auto_routes_to_cluster_with_cluster_id(self, mock_cluster):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_cluster.return_value = mock_result
-
- execute_code(code="print('hi')", cluster_id="abc-123")
-
- mock_cluster.assert_called_once()
- assert mock_cluster.call_args[1]["cluster_id"] == "abc-123"
-
- @patch("databricks_mcp_server.tools.compute._execute_databricks_command")
- def test_auto_routes_to_cluster_with_context_id(self, mock_cluster):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_cluster.return_value = mock_result
-
- execute_code(code="print('hi')", context_id="ctx-456")
-
- mock_cluster.assert_called_once()
- assert mock_cluster.call_args[1]["context_id"] == "ctx-456"
-
- @patch("databricks_mcp_server.tools.compute._execute_databricks_command")
- def test_auto_routes_to_cluster_for_scala(self, mock_cluster):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_cluster.return_value = mock_result
-
- execute_code(code="println(42)", language="scala")
-
- mock_cluster.assert_called_once()
- assert mock_cluster.call_args[1]["language"] == "scala"
-
- @patch("databricks_mcp_server.tools.compute._execute_databricks_command")
- def test_auto_routes_to_cluster_for_r(self, mock_cluster):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_cluster.return_value = mock_result
-
- execute_code(code="print(42)", language="r")
-
- mock_cluster.assert_called_once()
-
- @patch("databricks_mcp_server.tools.compute._run_code_on_serverless")
- def test_explicit_serverless(self, mock_serverless):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_serverless.return_value = mock_result
-
- execute_code(code="print('hi')", compute_type="serverless")
-
- mock_serverless.assert_called_once()
-
- @patch("databricks_mcp_server.tools.compute._execute_databricks_command")
- def test_explicit_cluster(self, mock_cluster):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_cluster.return_value = mock_result
-
- execute_code(code="print('hi')", compute_type="cluster")
-
- mock_cluster.assert_called_once()
-
- @patch("databricks_mcp_server.tools.compute._run_code_on_serverless")
- def test_serverless_default_timeout(self, mock_serverless):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_serverless.return_value = mock_result
-
- execute_code(code="x", compute_type="serverless")
-
- assert mock_serverless.call_args[1]["timeout"] == 1800
-
- @patch("databricks_mcp_server.tools.compute._execute_databricks_command")
- def test_cluster_default_timeout(self, mock_cluster):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_cluster.return_value = mock_result
-
- execute_code(code="x", compute_type="cluster")
-
- assert mock_cluster.call_args[1]["timeout"] == 120
-
- @patch("databricks_mcp_server.tools.compute._run_code_on_serverless")
- def test_workspace_path_passed_to_serverless(self, mock_serverless):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_serverless.return_value = mock_result
-
- execute_code(code="x", compute_type="serverless", workspace_path="/Workspace/Users/a/b")
-
- call_kwargs = mock_serverless.call_args[1]
- assert call_kwargs["workspace_path"] == "/Workspace/Users/a/b"
- assert call_kwargs["cleanup"] is False
-
- @patch("databricks_mcp_server.tools.compute._run_file_on_databricks")
- def test_file_path_on_cluster(self, mock_run_file):
- mock_result = MagicMock()
- mock_result.to_dict.return_value = {"success": True}
- mock_run_file.return_value = mock_result
-
- execute_code(file_path="/tmp/test.py", compute_type="cluster")
-
- mock_run_file.assert_called_once()
- assert mock_run_file.call_args[1]["file_path"] == "/tmp/test.py"
-
- def test_file_path_not_found_serverless(self):
- result = execute_code(file_path="/nonexistent/file.py", compute_type="serverless")
- assert result["success"] is False
- assert "not found" in result["error"].lower()
-
- @patch("databricks_mcp_server.tools.compute._execute_databricks_command")
- def test_no_running_cluster_error(self, mock_cluster):
- from databricks_tools_core.compute import NoRunningClusterError
- mock_cluster.side_effect = NoRunningClusterError(
- available_clusters=[],
- skipped_clusters=[],
- startable_clusters=[{"cluster_id": "abc", "cluster_name": "test", "state": "TERMINATED"}],
- )
-
- result = execute_code(code="x", compute_type="cluster")
-
- assert result["success"] is False
- assert "startable_clusters" in result
- assert len(result["startable_clusters"]) == 1
-
-
-# ---------------------------------------------------------------------------
-# manage_cluster routing tests
-# ---------------------------------------------------------------------------
-
-
-class TestManageCluster:
- """Test manage_cluster action routing."""
-
- def test_invalid_action(self):
- result = manage_cluster(action="explode")
- assert result["success"] is False
- assert "unknown action" in result["error"].lower()
-
- def test_create_requires_name(self):
- result = manage_cluster(action="create")
- assert result["success"] is False
- assert "name" in result["error"].lower()
-
- def test_modify_requires_cluster_id(self):
- result = manage_cluster(action="modify")
- assert result["success"] is False
- assert "cluster_id" in result["error"].lower()
-
- def test_start_requires_cluster_id(self):
- result = manage_cluster(action="start")
- assert result["success"] is False
- assert "cluster_id" in result["error"].lower()
-
- def test_terminate_requires_cluster_id(self):
- result = manage_cluster(action="terminate")
- assert result["success"] is False
- assert "cluster_id" in result["error"].lower()
-
- def test_delete_requires_cluster_id(self):
- result = manage_cluster(action="delete")
- assert result["success"] is False
- assert "cluster_id" in result["error"].lower()
-
- @patch("databricks_mcp_server.tools.compute._create_cluster")
- def test_create_routes_correctly(self, mock_create):
- mock_create.return_value = {"cluster_id": "abc", "state": "PENDING"}
-
- result = manage_cluster(action="create", name="test-cluster", num_workers=2)
-
- mock_create.assert_called_once()
- assert mock_create.call_args[1]["name"] == "test-cluster"
- assert mock_create.call_args[1]["num_workers"] == 2
-
- @patch("databricks_mcp_server.tools.compute._modify_cluster")
- def test_modify_routes_correctly(self, mock_modify):
- mock_modify.return_value = {"cluster_id": "abc"}
-
- manage_cluster(action="modify", cluster_id="abc", num_workers=4)
-
- mock_modify.assert_called_once()
- assert mock_modify.call_args[1]["cluster_id"] == "abc"
- assert mock_modify.call_args[1]["num_workers"] == 4
-
- @patch("databricks_mcp_server.tools.compute._start_cluster")
- def test_start_routes_correctly(self, mock_start):
- mock_start.return_value = {"cluster_id": "abc", "state": "PENDING"}
-
- manage_cluster(action="start", cluster_id="abc")
-
- mock_start.assert_called_once_with("abc")
-
- @patch("databricks_mcp_server.tools.compute._terminate_cluster")
- def test_terminate_routes_correctly(self, mock_terminate):
- mock_terminate.return_value = {"cluster_id": "abc", "state": "TERMINATING"}
-
- manage_cluster(action="terminate", cluster_id="abc")
-
- mock_terminate.assert_called_once_with("abc")
-
- @patch("databricks_mcp_server.tools.compute._delete_cluster")
- def test_delete_routes_correctly(self, mock_delete):
- mock_delete.return_value = {"cluster_id": "abc", "state": "DELETED"}
-
- manage_cluster(action="delete", cluster_id="abc")
-
- mock_delete.assert_called_once_with("abc")
-
- @patch("databricks_mcp_server.tools.compute._create_cluster")
- def test_create_defaults(self, mock_create):
- mock_create.return_value = {"cluster_id": "abc"}
-
- manage_cluster(action="create", name="test")
-
- call_kwargs = mock_create.call_args[1]
- assert call_kwargs["num_workers"] == 1
- assert call_kwargs["autotermination_minutes"] == 120
-
- @patch("databricks_mcp_server.tools.compute._create_cluster")
- def test_create_with_spark_conf_json(self, mock_create):
- mock_create.return_value = {"cluster_id": "abc"}
-
- manage_cluster(
- action="create",
- name="test",
- spark_conf='{"spark.sql.shuffle.partitions": "8"}',
- )
-
- call_kwargs = mock_create.call_args[1]
- assert call_kwargs["spark_conf"] == {"spark.sql.shuffle.partitions": "8"}
-
-
-# ---------------------------------------------------------------------------
-# manage_sql_warehouse routing tests
-# ---------------------------------------------------------------------------
-
-
-class TestManageSqlWarehouse:
- """Test manage_sql_warehouse action routing."""
-
- def test_invalid_action(self):
- result = manage_sql_warehouse(action="explode")
- assert result["success"] is False
-
- def test_create_requires_name(self):
- result = manage_sql_warehouse(action="create")
- assert result["success"] is False
- assert "name" in result["error"].lower()
-
- def test_modify_requires_warehouse_id(self):
- result = manage_sql_warehouse(action="modify")
- assert result["success"] is False
- assert "warehouse_id" in result["error"].lower()
-
- def test_delete_requires_warehouse_id(self):
- result = manage_sql_warehouse(action="delete")
- assert result["success"] is False
- assert "warehouse_id" in result["error"].lower()
-
- @patch("databricks_mcp_server.tools.compute._create_sql_warehouse")
- def test_create_routes_correctly(self, mock_create):
- mock_create.return_value = {"warehouse_id": "abc"}
-
- manage_sql_warehouse(action="create", name="test-wh", size="Medium")
-
- mock_create.assert_called_once()
- assert mock_create.call_args[1]["name"] == "test-wh"
- assert mock_create.call_args[1]["size"] == "Medium"
-
- @patch("databricks_mcp_server.tools.compute._modify_sql_warehouse")
- def test_modify_routes_correctly(self, mock_modify):
- mock_modify.return_value = {"warehouse_id": "abc"}
-
- manage_sql_warehouse(action="modify", warehouse_id="abc", size="Large")
-
- mock_modify.assert_called_once()
- assert mock_modify.call_args[1]["size"] == "Large"
-
- @patch("databricks_mcp_server.tools.compute._delete_sql_warehouse")
- def test_delete_routes_correctly(self, mock_delete):
- mock_delete.return_value = {"warehouse_id": "abc"}
-
- manage_sql_warehouse(action="delete", warehouse_id="abc")
-
- mock_delete.assert_called_once_with("abc")
-
-
-# ---------------------------------------------------------------------------
-# list_compute routing tests
-# ---------------------------------------------------------------------------
-
-
-class TestListCompute:
- """Test list_compute resource routing."""
-
- @patch("databricks_mcp_server.tools.compute._list_clusters")
- def test_default_lists_clusters(self, mock_list):
- mock_list.return_value = [{"cluster_id": "abc", "cluster_name": "test"}]
-
- result = list_compute()
-
- mock_list.assert_called_once()
- assert "clusters" in result
-
- @patch("databricks_mcp_server.tools.compute._get_cluster_status")
- def test_cluster_id_gets_status(self, mock_status):
- mock_status.return_value = {"cluster_id": "abc", "state": "RUNNING"}
-
- result = list_compute(cluster_id="abc")
-
- mock_status.assert_called_once_with("abc")
- assert result["state"] == "RUNNING"
-
- @patch("databricks_mcp_server.tools.compute._get_best_cluster")
- def test_auto_select(self, mock_best):
- mock_best.return_value = "best-cluster-id"
-
- result = list_compute(auto_select=True)
-
- mock_best.assert_called_once()
- assert result["cluster_id"] == "best-cluster-id"
-
- @patch("databricks_mcp_server.tools.compute._list_node_types")
- def test_node_types(self, mock_nodes):
- mock_nodes.return_value = [{"node_type_id": "i3.xlarge"}]
-
- result = list_compute(resource="node_types")
-
- mock_nodes.assert_called_once()
- assert "node_types" in result
-
- @patch("databricks_mcp_server.tools.compute._list_spark_versions")
- def test_spark_versions(self, mock_versions):
- mock_versions.return_value = [{"key": "15.4.x-scala2.12"}]
-
- result = list_compute(resource="spark_versions")
-
- mock_versions.assert_called_once()
- assert "spark_versions" in result
-
- def test_invalid_resource(self):
- result = list_compute(resource="invalid")
- assert result["success"] is False
- assert "unknown resource" in result["error"].lower()
diff --git a/databricks-mcp-server/tests/test_config.py b/databricks-mcp-server/tests/test_config.py
deleted file mode 100644
index 4bd85936..00000000
--- a/databricks-mcp-server/tests/test_config.py
+++ /dev/null
@@ -1,140 +0,0 @@
-"""
-Centralized test configuration for MCP server integration tests.
-
-Each test module uses a unique schema to enable parallel test execution without conflicts.
-"""
-
-import os
-
-# =============================================================================
-# Core Test Configuration
-# =============================================================================
-
-# Default catalog for all tests (can be overridden via env var)
-TEST_CATALOG = os.environ.get("TEST_CATALOG", "ai_dev_kit_test")
-
-# =============================================================================
-# Per-Module Schema Configuration
-# Each module gets its own schema to avoid conflicts during parallel execution
-# =============================================================================
-
-SCHEMAS = {
- # SQL and core tests
- "sql": "test_sql",
- "warehouse": "test_warehouse",
-
- # Pipeline tests
- "pipelines": "test_pipelines",
-
- # Vector search tests
- "vector_search": "test_vs",
-
- # Genie tests
- "genie": "test_genie",
-
- # Serving tests
- "serving": "test_serving",
-
- # Dashboard tests
- "dashboards": "test_dashboards",
-
- # Apps tests
- "apps": "test_apps",
-
- # Jobs tests
- "jobs": "test_jobs",
-
- # Volume files tests
- "volume_files": "test_volume_files",
-
- # Workspace file tests
- "workspace_files": "test_workspace_files",
-
- # Lakebase tests
- "lakebase": "test_lakebase",
-
- # Compute tests
- "compute": "test_compute",
-
- # Agent bricks tests
- "agent_bricks": "test_agent_bricks",
-
- # Unity catalog tests
- "unity_catalog": "test_uc",
-
- # PDF tests
- "pdf": "test_pdf",
-}
-
-# =============================================================================
-# Resource Naming Conventions
-# =============================================================================
-
-# Prefix for all test resources (pipelines, endpoints, etc.)
-TEST_RESOURCE_PREFIX = "ai_dev_kit_test_"
-
-# Pipeline names
-PIPELINE_NAMES = {
- "basic": f"{TEST_RESOURCE_PREFIX}pipeline_basic",
- "with_run": f"{TEST_RESOURCE_PREFIX}pipeline_with_run",
-}
-
-# Vector search endpoint names
-VS_ENDPOINT_NAMES = {
- "basic": f"{TEST_RESOURCE_PREFIX}vs_endpoint",
-}
-
-# Vector search index names (use schema-qualified names)
-def get_vs_index_name(index_key: str) -> str:
- """Get fully-qualified VS index name."""
- return f"{TEST_CATALOG}.{SCHEMAS['vector_search']}.{TEST_RESOURCE_PREFIX}vs_index_{index_key}"
-
-# Genie space names
-GENIE_SPACE_NAMES = {
- "basic": f"{TEST_RESOURCE_PREFIX}genie_space",
-}
-
-# Serving endpoint names
-SERVING_ENDPOINT_NAMES = {
- "basic": f"{TEST_RESOURCE_PREFIX}serving_endpoint",
-}
-
-# Dashboard names
-DASHBOARD_NAMES = {
- "basic": f"{TEST_RESOURCE_PREFIX}dashboard",
-}
-
-# App names
-APP_NAMES = {
- "basic": f"{TEST_RESOURCE_PREFIX}app",
-}
-
-# Job names
-JOB_NAMES = {
- "basic": f"{TEST_RESOURCE_PREFIX}job",
-}
-
-# Volume names
-VOLUME_NAMES = {
- "basic": f"{TEST_RESOURCE_PREFIX}volume",
-}
-
-# =============================================================================
-# Helper Functions
-# =============================================================================
-
-def get_full_schema_name(module: str) -> str:
- """Get fully-qualified schema name for a test module."""
- return f"{TEST_CATALOG}.{SCHEMAS[module]}"
-
-def get_table_name(module: str, table: str) -> str:
- """Get fully-qualified table name for a test module."""
- return f"{TEST_CATALOG}.{SCHEMAS[module]}.{table}"
-
-def get_volume_path(module: str, volume: str = "test_volume") -> str:
- """Get volume path for a test module."""
- return f"/Volumes/{TEST_CATALOG}/{SCHEMAS[module]}/{volume}"
-
-def get_workspace_path(username: str, module: str) -> str:
- """Get workspace path for a test module."""
- return f"/Workspace/Users/{username}/ai_dev_kit_test/{module}/resources"
diff --git a/databricks-mcp-server/tests/test_middleware.py b/databricks-mcp-server/tests/test_middleware.py
deleted file mode 100644
index 0dabbdff..00000000
--- a/databricks-mcp-server/tests/test_middleware.py
+++ /dev/null
@@ -1,95 +0,0 @@
-"""Tests for the TimeoutHandlingMiddleware."""
-
-import asyncio
-import json
-from unittest.mock import AsyncMock, MagicMock
-
-import pytest
-
-from fastmcp.exceptions import ToolError
-
-from databricks_mcp_server.middleware import TimeoutHandlingMiddleware
-
-
-@pytest.fixture
-def middleware():
- return TimeoutHandlingMiddleware()
-
-
-def _make_context(tool_name="test_tool", arguments=None):
- """Build a minimal MiddlewareContext mock for on_call_tool."""
- ctx = MagicMock()
- ctx.message.name = tool_name
- ctx.message.arguments = arguments or {}
- return ctx
-
-
-@pytest.mark.asyncio
-async def test_normal_call_passes_through(middleware):
- """Tool results pass through unchanged when no error occurs."""
- expected = MagicMock()
- call_next = AsyncMock(return_value=expected)
- ctx = _make_context()
-
- result = await middleware.on_call_tool(ctx, call_next)
-
- assert result is expected
- call_next.assert_awaited_once_with(ctx)
-
-
-@pytest.mark.asyncio
-async def test_timeout_error_raises_tool_error(middleware):
- """TimeoutError is caught and re-raised as ToolError with structured JSON."""
- call_next = AsyncMock(side_effect=TimeoutError("Run did not complete within 3600 seconds"))
- ctx = _make_context(tool_name="wait_for_run")
-
- with pytest.raises(ToolError) as exc_info:
- await middleware.on_call_tool(ctx, call_next)
-
- payload = json.loads(str(exc_info.value))
- assert payload["error"] is True
- assert payload["error_type"] == "timeout"
- assert payload["tool"] == "wait_for_run"
- assert "3600 seconds" in payload["message"]
- assert "Do NOT retry" in payload["action_required"]
-
-
-@pytest.mark.asyncio
-async def test_asyncio_timeout_error_raises_tool_error(middleware):
- """asyncio.TimeoutError is caught and re-raised as ToolError with structured JSON."""
- call_next = AsyncMock(side_effect=asyncio.TimeoutError())
- ctx = _make_context(tool_name="long_running_tool")
-
- with pytest.raises(ToolError) as exc_info:
- await middleware.on_call_tool(ctx, call_next)
-
- payload = json.loads(str(exc_info.value))
- assert payload["error"] is True
- assert payload["error_type"] == "timeout"
- assert payload["tool"] == "long_running_tool"
-
-
-@pytest.mark.asyncio
-async def test_cancelled_error_is_reraised(middleware):
- """asyncio.CancelledError is re-raised to let MCP SDK handle cleanup."""
- call_next = AsyncMock(side_effect=asyncio.CancelledError())
- ctx = _make_context(tool_name="cancelled_tool")
-
- with pytest.raises(asyncio.CancelledError):
- await middleware.on_call_tool(ctx, call_next)
-
-
-@pytest.mark.asyncio
-async def test_generic_exception_raises_tool_error(middleware):
- """Generic exceptions are caught and re-raised as ToolError with structured JSON."""
- call_next = AsyncMock(side_effect=ValueError("bad input"))
- ctx = _make_context(tool_name="failing_tool")
-
- with pytest.raises(ToolError) as exc_info:
- await middleware.on_call_tool(ctx, call_next)
-
- payload = json.loads(str(exc_info.value))
- assert payload["error"] is True
- assert payload["error_type"] == "ValueError"
- assert payload["tool"] == "failing_tool"
- assert "bad input" in payload["message"]
diff --git a/databricks-mcp-server/tests/test_sql_output_format.py b/databricks-mcp-server/tests/test_sql_output_format.py
deleted file mode 100644
index 9b678abd..00000000
--- a/databricks-mcp-server/tests/test_sql_output_format.py
+++ /dev/null
@@ -1,76 +0,0 @@
-"""Unit tests for SQL output formatting (markdown vs JSON)."""
-
-from databricks_mcp_server.tools.sql import _format_results_markdown
-
-
-class TestFormatResultsMarkdown:
- """Tests for _format_results_markdown helper."""
-
- def test_empty_list_returns_no_results(self):
- assert _format_results_markdown([]) == "(no results)"
-
- def test_single_row(self):
- rows = [{"id": "1", "name": "Alice"}]
- result = _format_results_markdown(rows)
- lines = result.strip().split("\n")
- assert lines[0] == "| id | name |"
- assert lines[1] == "| --- | --- |"
- assert lines[2] == "| 1 | Alice |"
- assert "(1 row)" in result
-
- def test_multiple_rows(self):
- rows = [
- {"id": "1", "name": "Alice", "city": "NYC"},
- {"id": "2", "name": "Bob", "city": "Chicago"},
- {"id": "3", "name": "Carol", "city": "Denver"},
- ]
- result = _format_results_markdown(rows)
- lines = result.strip().split("\n")
- # Header + separator + 3 data rows + blank + count
- assert lines[0] == "| id | name | city |"
- assert lines[1] == "| --- | --- | --- |"
- assert lines[2] == "| 1 | Alice | NYC |"
- assert lines[3] == "| 2 | Bob | Chicago |"
- assert lines[4] == "| 3 | Carol | Denver |"
- assert "(3 rows)" in result
-
- def test_none_values_become_empty(self):
- rows = [{"id": "1", "name": None}]
- result = _format_results_markdown(rows)
- assert "| 1 | |" in result
-
- def test_pipe_chars_escaped(self):
- rows = [{"expr": "a | b"}]
- result = _format_results_markdown(rows)
- assert "a \\| b" in result
-
- def test_column_names_appear_once(self):
- """The whole point: column names should appear exactly once (in the header)."""
- rows = [
- {"event_id": "1", "event_name": "Concert A"},
- {"event_id": "2", "event_name": "Concert B"},
- {"event_id": "3", "event_name": "Concert C"},
- ]
- result = _format_results_markdown(rows)
- # Column name should appear once in header, not repeated per row
- assert result.count("event_id") == 1
- assert result.count("event_name") == 1
-
- def test_markdown_smaller_than_json(self):
- """Markdown output should be significantly smaller than JSON for many rows."""
- import json
-
- rows = [
- {
- "id": str(i),
- "name": f"User {i}",
- "email": f"user{i}@example.com",
- "department": "Engineering",
- "status": "Active",
- }
- for i in range(50)
- ]
- md = _format_results_markdown(rows)
- js = json.dumps(rows)
- # Markdown should be at least 30% smaller
- assert len(md) < len(js) * 0.7, f"Markdown ({len(md)} chars) should be <70% of JSON ({len(js)} chars)"
diff --git a/databricks-mcp-server/tests/test_windows_compat.py b/databricks-mcp-server/tests/test_windows_compat.py
deleted file mode 100644
index e426ca22..00000000
--- a/databricks-mcp-server/tests/test_windows_compat.py
+++ /dev/null
@@ -1,114 +0,0 @@
-"""Tests for the Windows compatibility wrapper (_wrap_sync_in_thread)."""
-
-import asyncio
-import inspect
-import threading
-
-import pydantic
-import pytest
-
-from databricks_mcp_server.server import _wrap_sync_in_thread
-
-
-def sample_tool(query: str, limit: int = 10) -> str:
- """Execute a sample query."""
- return f"result:{query}:{limit}"
-
-
-class TestWrapSyncInThread:
- """Tests for _wrap_sync_in_thread wrapper."""
-
- def test_preserves_function_name(self):
- wrapped = _wrap_sync_in_thread(sample_tool)
- assert wrapped.__name__ == "sample_tool"
-
- def test_preserves_docstring(self):
- wrapped = _wrap_sync_in_thread(sample_tool)
- assert wrapped.__doc__ == "Execute a sample query."
-
- def test_preserves_annotations(self):
- wrapped = _wrap_sync_in_thread(sample_tool)
- assert wrapped.__annotations__ == sample_tool.__annotations__
-
- def test_preserves_signature(self):
- wrapped = _wrap_sync_in_thread(sample_tool)
- original_sig = inspect.signature(sample_tool)
- wrapped_sig = inspect.signature(wrapped)
- assert str(original_sig) == str(wrapped_sig)
-
- def test_is_coroutine_function(self):
- wrapped = _wrap_sync_in_thread(sample_tool)
- assert inspect.iscoroutinefunction(wrapped)
-
- @pytest.mark.asyncio
- async def test_returns_correct_result(self):
- wrapped = _wrap_sync_in_thread(sample_tool)
- result = await wrapped(query="test", limit=5)
- assert result == "result:test:5"
-
- @pytest.mark.asyncio
- async def test_returns_correct_result_with_defaults(self):
- wrapped = _wrap_sync_in_thread(sample_tool)
- result = await wrapped(query="test")
- assert result == "result:test:10"
-
- @pytest.mark.asyncio
- async def test_runs_in_thread_pool(self):
- """Verify the sync function runs in a different thread than the event loop."""
- main_thread = threading.current_thread().ident
-
- def capture_thread(query: str) -> int:
- return threading.current_thread().ident
-
- wrapped = _wrap_sync_in_thread(capture_thread)
- worker_thread = await wrapped(query="test")
- assert worker_thread != main_thread
-
- @pytest.mark.asyncio
- async def test_does_not_block_event_loop(self):
- """Verify concurrent tasks can run while the wrapped function executes."""
- import time
-
- def slow_tool(query: str) -> str:
- time.sleep(0.2)
- return "done"
-
- wrapped = _wrap_sync_in_thread(slow_tool)
-
- concurrent_ran = False
-
- async def concurrent_task():
- nonlocal concurrent_ran
- await asyncio.sleep(0.05)
- concurrent_ran = True
-
- await asyncio.gather(wrapped(query="test"), concurrent_task())
- assert concurrent_ran
-
- @pytest.mark.asyncio
- async def test_propagates_exceptions(self):
- def failing_tool(query: str) -> str:
- raise ValueError("something went wrong")
-
- wrapped = _wrap_sync_in_thread(failing_tool)
- with pytest.raises(ValueError, match="something went wrong"):
- await wrapped(query="test")
-
- def test_pydantic_type_adapter_returns_awaitable(self):
- """Verify pydantic's TypeAdapter can call the wrapped function and get a coroutine.
-
- FastMCP uses TypeAdapter.validate_python() to invoke tool functions.
- The wrapper must produce a result that pydantic recognizes as callable
- with the original signature.
- """
- wrapped = _wrap_sync_in_thread(sample_tool)
- pydantic.TypeAdapter(wrapped.__annotations__.get("return", str))
- # TypeAdapter should be able to read the function's annotations
- sig = inspect.signature(wrapped)
- assert "query" in sig.parameters
- assert "limit" in sig.parameters
- # Calling the wrapper returns a coroutine
- coro = wrapped(query="test", limit=5)
- assert inspect.iscoroutine(coro)
- # Clean up the unawaited coroutine
- coro.close()
diff --git a/databricks-mcp-server/tests/test_workspace.py b/databricks-mcp-server/tests/test_workspace.py
deleted file mode 100644
index 81cf461a..00000000
--- a/databricks-mcp-server/tests/test_workspace.py
+++ /dev/null
@@ -1,293 +0,0 @@
-"""Tests for the manage_workspace MCP tool."""
-
-import subprocess
-from unittest.mock import MagicMock, patch
-
-import pytest
-
-from databricks_mcp_server.tools.workspace import _manage_workspace_impl as manage_workspace
-from databricks_tools_core.auth import clear_active_workspace, get_active_workspace
-
-# Patch targets
-_CFG_PATH = "databricks_mcp_server.tools.workspace._DATABRICKS_CFG_PATH"
-_VALIDATE_AND_SWITCH = "databricks_mcp_server.tools.workspace._validate_and_switch"
-_GET_WORKSPACE_CLIENT = "databricks_mcp_server.tools.workspace.get_workspace_client"
-_GET_ACTIVE_WORKSPACE = "databricks_mcp_server.tools.workspace.get_active_workspace"
-_SUBPROCESS_RUN = "databricks_mcp_server.tools.workspace.subprocess.run"
-
-
-@pytest.fixture(autouse=True)
-def reset_active_workspace():
- """Ensure active workspace is cleared before and after each test."""
- clear_active_workspace()
- yield
- clear_active_workspace()
-
-
-@pytest.fixture
-def tmp_databrickscfg(tmp_path):
- """Write a temporary ~/.databrickscfg with three known profiles."""
- cfg = tmp_path / ".databrickscfg"
- cfg.write_text(
- "[DEFAULT]\nhost = https://adb-111.azuredatabricks.net\n\n"
- "[prod]\nhost = https://adb-222.azuredatabricks.net\n\n"
- "[staging]\nhost = https://adb-333.azuredatabricks.net\n"
- )
- return cfg
-
-
-# ---------------------------------------------------------------------------
-# status
-# ---------------------------------------------------------------------------
-
-
-def test_status_returns_current_info():
- """action='status' returns host, profile, and username."""
- mock_client = MagicMock()
- mock_client.config.host = "https://adb-111.azuredatabricks.net"
- mock_client.current_user.me.return_value = MagicMock(user_name="user@example.com")
-
- with (
- patch(_GET_WORKSPACE_CLIENT, return_value=mock_client),
- patch(_GET_ACTIVE_WORKSPACE, return_value={"profile": "DEFAULT", "host": None}),
- ):
- result = manage_workspace(action="status")
-
- assert result["host"] == "https://adb-111.azuredatabricks.net"
- assert result["username"] == "user@example.com"
- assert result["profile"] == "DEFAULT"
-
-
-def test_status_returns_error_on_failure():
- """action='status' returns an error dict when the SDK raises."""
- with patch(_GET_WORKSPACE_CLIENT, side_effect=Exception("auth failed")):
- result = manage_workspace(action="status")
-
- assert "error" in result
- assert "auth failed" in result["error"]
-
-
-# ---------------------------------------------------------------------------
-# list
-# ---------------------------------------------------------------------------
-
-
-def test_list_returns_all_profiles(tmp_databrickscfg):
- """action='list' returns all profiles with host URLs and marks the active one."""
- with (
- patch(_CFG_PATH, str(tmp_databrickscfg)),
- patch(_GET_ACTIVE_WORKSPACE, return_value={"profile": "prod", "host": None}),
- ):
- result = manage_workspace(action="list")
-
- assert "profiles" in result
- assert len(result["profiles"]) == 3
- profiles_by_name = {p["profile"]: p for p in result["profiles"]}
- assert profiles_by_name["prod"]["active"] is True
- assert profiles_by_name["DEFAULT"]["active"] is False
- assert "adb-222" in profiles_by_name["prod"]["host"]
-
-
-def test_list_empty_config(tmp_path):
- """action='list' with an empty config returns empty list and a hint message."""
- empty_cfg = tmp_path / ".databrickscfg"
- empty_cfg.write_text("")
- with patch(_CFG_PATH, str(empty_cfg)), patch(_GET_ACTIVE_WORKSPACE, return_value={"profile": None, "host": None}):
- result = manage_workspace(action="list")
-
- assert result["profiles"] == []
- assert "message" in result
-
-
-def test_list_missing_config(tmp_path):
- """action='list' when the config file doesn't exist returns empty list."""
- with (
- patch(_CFG_PATH, str(tmp_path / "nonexistent.cfg")),
- patch(_GET_ACTIVE_WORKSPACE, return_value={"profile": None, "host": None}),
- ):
- result = manage_workspace(action="list")
-
- assert result["profiles"] == []
-
-
-def test_list_profile_without_host(tmp_path):
- """action='list' with a profile that has no host key still returns the profile."""
- cfg = tmp_path / ".databrickscfg"
- cfg.write_text("[nohostprofile]\ntoken = abc123\n")
- with patch(_CFG_PATH, str(cfg)), patch(_GET_ACTIVE_WORKSPACE, return_value={"profile": None, "host": None}):
- result = manage_workspace(action="list")
-
- assert len(result["profiles"]) == 1
- assert result["profiles"][0]["profile"] == "nohostprofile"
- assert "no host configured" in result["profiles"][0]["host"]
-
-
-# ---------------------------------------------------------------------------
-# switch
-# ---------------------------------------------------------------------------
-
-
-def test_switch_valid_profile(tmp_databrickscfg):
- """action='switch' with a known profile calls _validate_and_switch and returns success."""
- success = {"host": "https://adb-222.azuredatabricks.net", "profile": "prod", "username": "user@example.com"}
- with patch(_CFG_PATH, str(tmp_databrickscfg)), patch(_VALIDATE_AND_SWITCH, return_value=success) as mock_validate:
- result = manage_workspace(action="switch", profile="prod")
-
- mock_validate.assert_called_once_with(profile="prod", host=None)
- assert result["profile"] == "prod"
- assert "message" in result
-
-
-def test_switch_nonexistent_profile(tmp_databrickscfg):
- """action='switch' with an unknown profile name returns error with available profiles."""
- with patch(_CFG_PATH, str(tmp_databrickscfg)):
- result = manage_workspace(action="switch", profile="unknown-profile")
-
- assert "error" in result
- assert "unknown-profile" in result["error"]
- assert "DEFAULT" in result["error"] or "prod" in result["error"]
-
-
-def test_switch_with_host(tmp_databrickscfg):
- """action='switch' with a host URL calls _validate_and_switch with the host."""
- host = "https://adb-222.azuredatabricks.net"
- success = {"host": host, "profile": host, "username": "user@example.com"}
- with patch(_CFG_PATH, str(tmp_databrickscfg)), patch(_VALIDATE_AND_SWITCH, return_value=success) as mock_validate:
- result = manage_workspace(action="switch", host=host)
-
- mock_validate.assert_called_once_with(profile=None, host=host)
- assert "message" in result
-
-
-def test_switch_rollback_on_auth_failure(tmp_databrickscfg):
- """action='switch' returns error when validation fails; active workspace is NOT updated."""
- with (
- patch(_CFG_PATH, str(tmp_databrickscfg)),
- patch(_VALIDATE_AND_SWITCH, side_effect=Exception("invalid credentials")),
- ):
- result = manage_workspace(action="switch", profile="prod")
-
- assert "error" in result
- assert "invalid credentials" in result["error"]
- assert get_active_workspace()["profile"] is None
-
-
-def test_switch_expired_token_returns_structured_response(tmp_databrickscfg):
- """action='switch' with an expired token returns a structured response with token_expired flag."""
- expired_msg = "default auth: databricks-cli: cannot get access token: refresh token is invalid"
- with patch(_CFG_PATH, str(tmp_databrickscfg)), patch(_VALIDATE_AND_SWITCH, side_effect=Exception(expired_msg)):
- result = manage_workspace(action="switch", profile="prod")
-
- assert result.get("token_expired") is True
- assert result["profile"] == "prod"
- assert "adb-222" in result["host"]
- assert "login" in result["action_required"]
-
-
-def test_switch_no_profile_no_host():
- """action='switch' without profile or host returns a clear error."""
- result = manage_workspace(action="switch")
- assert "error" in result
- assert "profile" in result["error"].lower() or "host" in result["error"].lower()
-
-
-# ---------------------------------------------------------------------------
-# login
-# ---------------------------------------------------------------------------
-
-
-def test_login_calls_cli():
- """action='login' runs 'databricks auth login --host ...'."""
- mock_proc = MagicMock()
- mock_proc.returncode = 0
- success = {"host": "https://adb-999.net", "profile": "adb-999", "username": "u@x.com"}
-
- with patch(_SUBPROCESS_RUN, return_value=mock_proc) as mock_run, patch(_VALIDATE_AND_SWITCH, return_value=success):
- result = manage_workspace(action="login", host="https://adb-999.azuredatabricks.net")
-
- args = mock_run.call_args.args[0]
- assert "databricks" in args and "auth" in args and "login" in args
- assert "--host" in args and "https://adb-999.azuredatabricks.net" in args
- assert result["profile"] == "adb-999"
-
-
-def test_login_passes_stdin_devnull():
- """action='login' sets stdin=DEVNULL to avoid inheriting the MCP stdio pipe."""
- mock_proc = MagicMock()
- mock_proc.returncode = 0
- success = {"host": "https://adb-999.net", "profile": "adb-999", "username": "u@x.com"}
-
- with patch(_SUBPROCESS_RUN, return_value=mock_proc) as mock_run, patch(_VALIDATE_AND_SWITCH, return_value=success):
- manage_workspace(action="login", host="https://adb-999.azuredatabricks.net")
-
- call_kwargs = mock_run.call_args.kwargs
- assert call_kwargs.get("stdin") == subprocess.DEVNULL
-
-
-def test_login_timeout():
- """action='login' returns a clear error when the OAuth flow times out."""
- with patch(_SUBPROCESS_RUN, side_effect=subprocess.TimeoutExpired(cmd="databricks", timeout=120)):
- result = manage_workspace(action="login", host="https://adb-999.net")
-
- assert "error" in result
- assert "timed out" in result["error"].lower()
-
-
-def test_login_cli_failure():
- """action='login' returns an error when the CLI exits non-zero."""
- mock_proc = MagicMock()
- mock_proc.returncode = 1
- mock_proc.stderr = "Error: invalid workspace URL"
- mock_proc.stdout = ""
-
- with patch(_SUBPROCESS_RUN, return_value=mock_proc):
- result = manage_workspace(action="login", host="https://bad-host.net")
-
- assert "error" in result
- assert "invalid workspace URL" in result["error"]
-
-
-def test_login_cli_not_installed():
- """action='login' returns a helpful error when the Databricks CLI is not found."""
- with patch(_SUBPROCESS_RUN, side_effect=FileNotFoundError):
- result = manage_workspace(action="login", host="https://adb-999.net")
-
- assert "error" in result
- assert "CLI" in result["error"] or "databricks" in result["error"].lower()
-
-
-def test_login_switches_after_success():
- """action='login' updates the active workspace after a successful CLI call."""
- mock_proc = MagicMock()
- mock_proc.returncode = 0
- success = {"host": "https://adb-999.net", "profile": "adb-999", "username": "u@x.com"}
-
- with (
- patch(_SUBPROCESS_RUN, return_value=mock_proc),
- patch(_VALIDATE_AND_SWITCH, return_value=success) as mock_validate,
- ):
- result = manage_workspace(action="login", host="https://adb-999.azuredatabricks.net")
-
- mock_validate.assert_called_once()
- assert result["username"] == "u@x.com"
- assert "message" in result
-
-
-def test_login_no_host():
- """action='login' without a host returns a clear error."""
- result = manage_workspace(action="login")
- assert "error" in result
- assert "host" in result["error"].lower()
-
-
-# ---------------------------------------------------------------------------
-# invalid action
-# ---------------------------------------------------------------------------
-
-
-def test_invalid_action():
- """An unrecognised action returns an error listing valid actions."""
- result = manage_workspace(action="badaction")
- assert "error" in result
- for valid in ("status", "list", "switch", "login"):
- assert valid in result["error"]
diff --git a/databricks-skills/.tests/__init__.py b/databricks-skills/.tests/__init__.py
new file mode 100644
index 00000000..22366876
--- /dev/null
+++ b/databricks-skills/.tests/__init__.py
@@ -0,0 +1 @@
+"""databricks-skills integration tests."""
diff --git a/databricks-skills/.tests/conftest.py b/databricks-skills/.tests/conftest.py
new file mode 100644
index 00000000..f5612394
--- /dev/null
+++ b/databricks-skills/.tests/conftest.py
@@ -0,0 +1,78 @@
+"""
+Pytest fixtures for databricks-skills integration tests.
+
+These fixtures set up test resources in Databricks for testing the Python scripts
+in databricks-skills that use databricks-tools-core functionality.
+
+Requires a valid Databricks connection (via env vars or ~/.databrickscfg).
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import pytest
+from databricks.sdk import WorkspaceClient
+
+# Load .env.test file if it exists
+_env_file = Path(__file__).parent.parent.parent / "databricks-tools-core" / ".env.test"
+if _env_file.exists():
+ from dotenv import load_dotenv
+
+ load_dotenv(_env_file)
+ logging.getLogger(__name__).info(f"Loaded environment from {_env_file}")
+
+logger = logging.getLogger(__name__)
+
+
+def pytest_configure(config):
+ """Configure pytest with custom markers."""
+ config.addinivalue_line(
+ "markers", "integration: mark test as integration test requiring Databricks"
+ )
+
+
+@pytest.fixture(scope="session")
+def workspace_client() -> WorkspaceClient:
+ """
+ Create a WorkspaceClient for the test session.
+
+ Uses standard Databricks authentication:
+ 1. DATABRICKS_HOST + DATABRICKS_TOKEN env vars
+ 2. ~/.databrickscfg profile
+ """
+ try:
+ client = WorkspaceClient()
+ # Verify connection works
+ client.current_user.me()
+ logger.info(f"Connected to Databricks: {client.config.host}")
+ return client
+ except Exception as e:
+ pytest.skip(f"Could not connect to Databricks: {e}")
+
+
+@pytest.fixture(scope="session")
+def warehouse_id(workspace_client: WorkspaceClient) -> str:
+ """
+ Get a running SQL warehouse for tests.
+
+ Prefers shared endpoints, falls back to any running warehouse.
+ """
+ from databricks.sdk.service.sql import State
+
+ warehouses = list(workspace_client.warehouses.list())
+
+ # Priority: running shared endpoint
+ for w in warehouses:
+ if w.state == State.RUNNING and "shared" in (w.name or "").lower():
+ logger.info(f"Using warehouse: {w.name} ({w.id})")
+ return w.id
+
+ # Fallback: any running warehouse
+ for w in warehouses:
+ if w.state == State.RUNNING:
+ logger.info(f"Using warehouse: {w.name} ({w.id})")
+ return w.id
+
+ # No running warehouse found
+ pytest.skip("No running SQL warehouse available for tests")
diff --git a/databricks-skills/.tests/run_tests.py b/databricks-skills/.tests/run_tests.py
new file mode 100755
index 00000000..cae0da56
--- /dev/null
+++ b/databricks-skills/.tests/run_tests.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python3
+"""
+Test runner for databricks-skills.
+
+Runs unit tests (mocked, no Databricks connection required) and integration tests
+(require Databricks connection). Generates HTML and terminal reports.
+
+Usage:
+ python run_tests.py # Run all tests
+ python run_tests.py --unit # Run only unit tests
+ python run_tests.py --integration # Run only integration tests
+ python run_tests.py -v # Verbose output
+ python run_tests.py --html # Generate HTML report
+"""
+
+import argparse
+import os
+import subprocess
+import sys
+from datetime import datetime
+from pathlib import Path
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Run databricks-skills tests with reports"
+ )
+ parser.add_argument(
+ "--unit",
+ action="store_true",
+ help="Run only unit tests (mocked, no Databricks connection)",
+ )
+ parser.add_argument(
+ "--integration",
+ action="store_true",
+ help="Run only integration tests (requires Databricks connection)",
+ )
+ parser.add_argument(
+ "-v", "--verbose",
+ action="store_true",
+ help="Verbose output",
+ )
+ parser.add_argument(
+ "--html",
+ action="store_true",
+ help="Generate HTML report",
+ )
+ parser.add_argument(
+ "--xml",
+ action="store_true",
+ help="Generate JUnit XML report for CI",
+ )
+ parser.add_argument(
+ "-k",
+ metavar="EXPRESSION",
+ help="Only run tests matching the given expression",
+ )
+
+ args = parser.parse_args()
+
+ # Determine test directory
+ tests_dir = Path(__file__).parent
+ skills_dir = tests_dir.parent
+ repo_root = skills_dir.parent
+
+ # Results directory for reports
+ results_dir = tests_dir / ".test-results"
+ results_dir.mkdir(exist_ok=True)
+
+ # Build pytest command
+ pytest_args = [
+ sys.executable,
+ "-m", "pytest",
+ str(tests_dir),
+ ]
+
+ # Filter by test type
+ if args.unit and not args.integration:
+ # Unit tests: exclude integration marker
+ pytest_args.extend(["-m", "not integration"])
+ elif args.integration and not args.unit:
+ # Integration tests only
+ pytest_args.extend(["-m", "integration"])
+ # If both or neither specified, run all tests
+
+ # Add verbosity
+ if args.verbose:
+ pytest_args.append("-v")
+ else:
+ pytest_args.append("-q")
+
+ # Add expression filter
+ if args.k:
+ pytest_args.extend(["-k", args.k])
+
+ # Add HTML report
+ if args.html:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ html_path = results_dir / f"report_{timestamp}.html"
+ pytest_args.extend(["--html", str(html_path), "--self-contained-html"])
+ print(f"HTML report will be saved to: {html_path}")
+
+ # Add XML report
+ if args.xml:
+ xml_path = results_dir / "junit.xml"
+ pytest_args.extend(["--junitxml", str(xml_path)])
+ print(f"JUnit XML report will be saved to: {xml_path}")
+
+ # Add color output
+ pytest_args.append("--color=yes")
+
+ # Show captured output on failure
+ pytest_args.append("-rA")
+
+ # Set PYTHONPATH to include skills directory
+ env = os.environ.copy()
+ pythonpath = env.get("PYTHONPATH", "")
+ env["PYTHONPATH"] = f"{skills_dir}:{repo_root / 'databricks-tools-core'}:{pythonpath}"
+
+ # Print test configuration
+ print("=" * 60)
+ print("databricks-skills Test Runner")
+ print("=" * 60)
+ print(f"Tests directory: {tests_dir}")
+ print(f"Results directory: {results_dir}")
+ test_type = "all"
+ if args.unit and not args.integration:
+ test_type = "unit only"
+ elif args.integration and not args.unit:
+ test_type = "integration only"
+ print(f"Test type: {test_type}")
+ print("=" * 60)
+ print()
+
+ # Run pytest
+ result = subprocess.run(pytest_args, env=env)
+
+ # Print summary
+ print()
+ print("=" * 60)
+ if result.returncode == 0:
+ print("All tests PASSED")
+ else:
+ print(f"Tests FAILED (exit code: {result.returncode})")
+
+ if args.html:
+ print(f"HTML report: {html_path}")
+ if args.xml:
+ print(f"JUnit XML: {xml_path}")
+ print("=" * 60)
+
+ return result.returncode
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/databricks-skills/.tests/test_agent_bricks_manager.py b/databricks-skills/.tests/test_agent_bricks_manager.py
new file mode 100644
index 00000000..6708acba
--- /dev/null
+++ b/databricks-skills/.tests/test_agent_bricks_manager.py
@@ -0,0 +1,160 @@
+"""
+Integration tests for databricks-agent-bricks/scripts/mas_manager.py
+
+Tests the Supervisor Agent (MAS) CLI interface functions.
+Self-contained - requires only databricks-sdk and requests.
+"""
+
+import sys
+from pathlib import Path
+
+import pytest
+
+# Add the skills directory to the path
+SKILLS_DIR = Path(__file__).parent.parent
+sys.path.insert(0, str(SKILLS_DIR / "databricks-agent-bricks"))
+
+from mas_manager import (
+ create_mas,
+ get_mas,
+ find_mas,
+ update_mas,
+ delete_mas,
+ list_mas,
+ add_examples,
+ add_examples_wait,
+ list_examples,
+ _build_agents,
+)
+
+
+@pytest.fixture
+def sample_agent_config():
+ """Sample agent configuration for testing."""
+ return {
+ "name": "Test Agent",
+ "description": "A test agent for unit testing",
+ "endpoint_name": "test-endpoint",
+ }
+
+
+@pytest.fixture
+def sample_genie_agent():
+ """Sample Genie agent configuration."""
+ return {
+ "name": "Genie Agent",
+ "description": "A Genie-based agent",
+ "genie_space_id": "test-space-123",
+ }
+
+
+@pytest.fixture
+def sample_uc_function_agent():
+ """Sample UC Function agent configuration."""
+ return {
+ "name": "UC Function Agent",
+ "description": "A UC function agent",
+ "uc_function_name": "catalog.schema.function_name",
+ }
+
+
+class TestBuildAgents:
+ """Tests for _build_agents helper function."""
+
+ def test_build_serving_endpoint_agent(self, sample_agent_config):
+ """Should build serving endpoint agent config."""
+ result = _build_agents([sample_agent_config])
+
+ assert len(result) == 1
+ agent = result[0]
+ assert agent["name"] == "Test Agent"
+ assert agent["description"] == "A test agent for unit testing"
+ assert agent["agent_type"] == "serving_endpoint"
+ assert agent["serving_endpoint"]["name"] == "test-endpoint"
+
+ def test_build_genie_agent(self, sample_genie_agent):
+ """Should build Genie agent config."""
+ result = _build_agents([sample_genie_agent])
+
+ assert len(result) == 1
+ agent = result[0]
+ assert agent["agent_type"] == "genie"
+ assert agent["genie_space"]["id"] == "test-space-123"
+
+ def test_build_uc_function_agent(self, sample_uc_function_agent):
+ """Should build UC function agent config."""
+ result = _build_agents([sample_uc_function_agent])
+
+ assert len(result) == 1
+ agent = result[0]
+ assert agent["agent_type"] == "unity_catalog_function"
+ assert agent["unity_catalog_function"]["uc_path"]["catalog"] == "catalog"
+ assert agent["unity_catalog_function"]["uc_path"]["schema"] == "schema"
+ assert agent["unity_catalog_function"]["uc_path"]["name"] == "function_name"
+
+ def test_build_mcp_connection_agent(self):
+ """Should build external MCP server agent config."""
+ agent_config = {
+ "name": "MCP Agent",
+ "description": "External MCP server",
+ "connection_name": "my-mcp-connection",
+ }
+ result = _build_agents([agent_config])
+
+ assert len(result) == 1
+ agent = result[0]
+ assert agent["agent_type"] == "external_mcp_server"
+ assert agent["external_mcp_server"]["connection_name"] == "my-mcp-connection"
+
+ def test_build_multiple_agents(self, sample_agent_config, sample_genie_agent):
+ """Should build multiple agent configs."""
+ result = _build_agents([sample_agent_config, sample_genie_agent])
+
+ assert len(result) == 2
+ assert result[0]["agent_type"] == "serving_endpoint"
+ assert result[1]["agent_type"] == "genie"
+
+
+@pytest.mark.integration
+class TestMASLifecycle:
+ """Integration tests for MAS CRUD operations.
+
+ Note: These tests require a Databricks workspace with Agent Bricks enabled.
+ They are marked as integration tests and may be skipped if connection fails.
+ """
+
+ @pytest.fixture
+ def test_mas_name(self):
+ """Unique name for test MAS."""
+ import uuid
+ return f"test-mas-{uuid.uuid4().hex[:8]}"
+
+ def test_list_mas(self, workspace_client):
+ """Should list existing MAS tiles."""
+ try:
+ result = list_mas()
+ assert isinstance(result, list)
+ except Exception as e:
+ if "Agent Bricks" in str(e) or "not enabled" in str(e).lower():
+ pytest.skip("Agent Bricks not enabled in workspace")
+ raise
+
+ def test_find_mas_not_found(self, workspace_client):
+ """Should return not found for non-existent MAS."""
+ try:
+ result = find_mas("nonexistent-mas-name-xyz-123")
+ assert result["found"] is False
+ except Exception as e:
+ if "Agent Bricks" in str(e) or "not enabled" in str(e).lower():
+ pytest.skip("Agent Bricks not enabled in workspace")
+ raise
+
+ def test_get_mas_not_found(self, workspace_client):
+ """Should return error for non-existent tile ID."""
+ try:
+ result = get_mas("00000000-0000-0000-0000-000000000000")
+ assert "error" in result or result.get("tile_id") == ""
+ except Exception as e:
+ if "Agent Bricks" in str(e) or "not enabled" in str(e).lower():
+ pytest.skip("Agent Bricks not enabled in workspace")
+ raise
diff --git a/databricks-skills/.tests/test_genie_conversation.py b/databricks-skills/.tests/test_genie_conversation.py
new file mode 100644
index 00000000..0ada389f
--- /dev/null
+++ b/databricks-skills/.tests/test_genie_conversation.py
@@ -0,0 +1,204 @@
+"""
+Integration tests for databricks-genie/scripts/conversation.py
+
+Tests the Genie Conversation API CLI interface.
+Requires databricks.sdk for Genie Space operations.
+"""
+
+import json
+import os
+import sys
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+# Add the skills directory to the path
+SKILLS_DIR = Path(__file__).parent.parent
+sys.path.insert(0, str(SKILLS_DIR / "databricks-genie"))
+
+from conversation import ask_genie, _print_json
+
+
+class TestAskGenieFunction:
+ """Tests for the ask_genie function structure and error handling."""
+
+ def test_ask_genie_returns_dict(self):
+ """Should return a dictionary result."""
+ # Test with a mock to verify return structure
+ with patch("conversation.WorkspaceClient") as mock_client:
+ # Setup mock
+ mock_response = MagicMock()
+ mock_response.conversation_id = "conv-123"
+ mock_response.message_id = "msg-456"
+
+ mock_message = MagicMock()
+ mock_message.status = MagicMock()
+ mock_message.status.value = "COMPLETED"
+ mock_message.attachments = []
+ mock_message.query_result = None
+
+ mock_instance = mock_client.return_value
+ mock_instance.genie.start_conversation_and_wait.return_value = mock_response
+ mock_instance.genie.get_message.return_value = mock_message
+
+ result = ask_genie(
+ space_id="test-space",
+ question="Test question",
+ timeout_seconds=5,
+ )
+
+ assert isinstance(result, dict)
+ assert "question" in result
+ assert "conversation_id" in result
+ assert "message_id" in result
+ assert "status" in result
+
+ def test_ask_genie_with_conversation_id(self):
+ """Should pass conversation_id for follow-up questions."""
+ with patch("conversation.WorkspaceClient") as mock_client:
+ mock_response = MagicMock()
+ mock_response.conversation_id = "conv-123"
+ mock_response.message_id = "msg-456"
+
+ mock_message = MagicMock()
+ mock_message.status = MagicMock()
+ mock_message.status.value = "COMPLETED"
+ mock_message.attachments = []
+ mock_message.query_result = None
+
+ mock_instance = mock_client.return_value
+ mock_instance.genie.start_conversation_and_wait.return_value = mock_response
+ mock_instance.genie.get_message.return_value = mock_message
+
+ result = ask_genie(
+ space_id="test-space",
+ question="Follow-up question",
+ conversation_id="existing-conv-id",
+ timeout_seconds=5,
+ )
+
+ # Verify the conversation_id was passed
+ call_args = mock_instance.genie.start_conversation_and_wait.call_args
+ assert call_args.kwargs.get("conversation_id") == "existing-conv-id"
+
+ def test_ask_genie_handles_timeout(self):
+ """Should return timeout status when query exceeds timeout."""
+ with patch("conversation.WorkspaceClient") as mock_client:
+ mock_response = MagicMock()
+ mock_response.conversation_id = "conv-123"
+ mock_response.message_id = "msg-456"
+
+ mock_message = MagicMock()
+ mock_message.status = MagicMock()
+ mock_message.status.value = "EXECUTING_QUERY" # Never completes
+ mock_message.attachments = []
+
+ mock_instance = mock_client.return_value
+ mock_instance.genie.start_conversation_and_wait.return_value = mock_response
+ mock_instance.genie.get_message.return_value = mock_message
+
+ # Very short timeout to trigger timeout path
+ result = ask_genie(
+ space_id="test-space",
+ question="Test question",
+ timeout_seconds=0.1, # Will timeout immediately
+ )
+
+ assert result["status"] == "TIMEOUT"
+ assert "error" in result
+
+ def test_ask_genie_handles_failure(self):
+ """Should return failure status when query fails."""
+ with patch("conversation.WorkspaceClient") as mock_client:
+ mock_response = MagicMock()
+ mock_response.conversation_id = "conv-123"
+ mock_response.message_id = "msg-456"
+
+ mock_message = MagicMock()
+ mock_message.status = MagicMock()
+ mock_message.status.value = "FAILED"
+ mock_message.attachments = []
+
+ mock_instance = mock_client.return_value
+ mock_instance.genie.start_conversation_and_wait.return_value = mock_response
+ mock_instance.genie.get_message.return_value = mock_message
+
+ result = ask_genie(
+ space_id="test-space",
+ question="Test question",
+ timeout_seconds=5,
+ )
+
+ assert result["status"] == "FAILED"
+
+
+class TestPrintJson:
+ """Tests for the _print_json helper function."""
+
+ def test_print_json_dict(self, capsys):
+ """Should print dict as formatted JSON."""
+ _print_json({"key": "value", "number": 42})
+ captured = capsys.readouterr()
+ assert '"key": "value"' in captured.out
+ assert '"number": 42' in captured.out
+
+ def test_print_json_list(self, capsys):
+ """Should print list as formatted JSON."""
+ _print_json([1, 2, 3])
+ captured = capsys.readouterr()
+ assert "1" in captured.out
+ assert "2" in captured.out
+ assert "3" in captured.out
+
+
+@pytest.mark.integration
+class TestGenieConversationIntegration:
+ """Integration tests for Genie Conversation API.
+
+ Note: These tests require a Databricks workspace with Genie enabled
+ and a valid Genie Space ID configured via environment variable.
+ """
+
+ @pytest.fixture
+ def genie_space_id(self):
+ """Get Genie Space ID from environment."""
+ space_id = os.environ.get("TEST_GENIE_SPACE_ID")
+ if not space_id:
+ pytest.skip("TEST_GENIE_SPACE_ID not set - skipping Genie integration tests")
+ return space_id
+
+ def test_ask_genie_simple_question(self, workspace_client, genie_space_id):
+ """Should be able to ask a simple question to Genie."""
+ result = ask_genie(
+ space_id=genie_space_id,
+ question="How many rows are in the table?",
+ timeout_seconds=120,
+ )
+
+ # Should return a valid result
+ assert result["conversation_id"] is not None
+ assert result["status"] in ["COMPLETED", "FAILED", "TIMEOUT"]
+
+ def test_ask_genie_follow_up(self, workspace_client, genie_space_id):
+ """Should be able to ask follow-up questions."""
+ # First question
+ result1 = ask_genie(
+ space_id=genie_space_id,
+ question="Show me the first 5 rows",
+ timeout_seconds=120,
+ )
+
+ if result1["status"] != "COMPLETED":
+ pytest.skip("First query did not complete - skipping follow-up test")
+
+ # Follow-up question
+ result2 = ask_genie(
+ space_id=genie_space_id,
+ question="Now show me the count",
+ conversation_id=result1["conversation_id"],
+ timeout_seconds=120,
+ )
+
+ # Should use same conversation
+ assert result2["conversation_id"] == result1["conversation_id"]
diff --git a/databricks-skills/.tests/test_pdf_generator.py b/databricks-skills/.tests/test_pdf_generator.py
new file mode 100644
index 00000000..744cb6e7
--- /dev/null
+++ b/databricks-skills/.tests/test_pdf_generator.py
@@ -0,0 +1,450 @@
+"""
+Integration tests for databricks-unstructured-pdf-generation/scripts/pdf_generator.py
+
+Tests the HTML to PDF conversion functionality.
+Requires plutoprint for PDF conversion.
+"""
+
+import sys
+import time
+from pathlib import Path
+
+import pytest
+
+# Add the skills directory to the path
+SKILLS_DIR = Path(__file__).parent.parent
+sys.path.insert(0, str(SKILLS_DIR / "databricks-unstructured-pdf-generation" / "scripts"))
+
+from pdf_generator import (
+ convert_html_to_pdf,
+ convert_folder,
+ _needs_conversion,
+ ConversionResult,
+ BatchResult,
+)
+
+
+@pytest.fixture
+def sample_html_content():
+ """Sample HTML document for testing."""
+ return """
+
+
+
+
+
+ Test Document
+ This is a simple test paragraph.
+
+
This is highlighted content.
+
+
+ Item 1
+ Item 2
+ Item 3
+
+
+"""
+
+
+@pytest.fixture
+def html_file(tmp_path, sample_html_content):
+ """Create a temporary HTML file for testing."""
+ html_path = tmp_path / "test.html"
+ html_path.write_text(sample_html_content)
+ return html_path
+
+
+@pytest.fixture
+def html_folder(tmp_path, sample_html_content):
+ """Create a folder with multiple HTML files for testing."""
+ html_dir = tmp_path / "html"
+ html_dir.mkdir()
+
+ # Create multiple HTML files
+ (html_dir / "doc1.html").write_text(sample_html_content.replace("Test Document", "Document 1"))
+ (html_dir / "doc2.html").write_text(sample_html_content.replace("Test Document", "Document 2"))
+
+ # Create a subfolder with HTML files
+ subdir = html_dir / "subdir"
+ subdir.mkdir()
+ (subdir / "doc3.html").write_text(sample_html_content.replace("Test Document", "Document 3"))
+
+ return html_dir
+
+
+class TestConversionResult:
+ """Tests for the ConversionResult dataclass."""
+
+ def test_success_result(self):
+ """Should create a successful result."""
+ result = ConversionResult(
+ html_path="/path/to/test.html",
+ pdf_path="/path/to/test.pdf",
+ success=True,
+ )
+ assert result.success is True
+ assert result.pdf_path == "/path/to/test.pdf"
+ assert result.error is None
+ assert result.skipped is False
+
+ def test_skipped_result(self):
+ """Should create a skipped result."""
+ result = ConversionResult(
+ html_path="/path/to/test.html",
+ pdf_path="/path/to/test.pdf",
+ success=True,
+ skipped=True,
+ )
+ assert result.success is True
+ assert result.skipped is True
+
+ def test_failure_result(self):
+ """Should create a failure result."""
+ result = ConversionResult(
+ html_path="/path/to/test.html",
+ error="Something went wrong",
+ )
+ assert result.success is False
+ assert result.pdf_path is None
+ assert result.error == "Something went wrong"
+
+ def test_to_dict(self):
+ """Should convert to dictionary."""
+ result = ConversionResult(
+ html_path="/path/to/test.html",
+ pdf_path="/path/to/test.pdf",
+ success=True,
+ )
+ d = result.to_dict()
+ assert d == {
+ "html_path": "/path/to/test.html",
+ "pdf_path": "/path/to/test.pdf",
+ "success": True,
+ "skipped": False,
+ "error": None,
+ }
+
+
+class TestBatchResult:
+ """Tests for the BatchResult dataclass."""
+
+ def test_empty_result(self):
+ """Should create an empty batch result."""
+ result = BatchResult()
+ assert result.total == 0
+ assert result.converted == 0
+ assert result.skipped == 0
+ assert result.failed == 0
+ assert result.results == []
+
+ def test_to_dict(self):
+ """Should convert to dictionary."""
+ result = BatchResult(total=3, converted=2, skipped=1, failed=0)
+ d = result.to_dict()
+ assert d["total"] == 3
+ assert d["converted"] == 2
+ assert d["skipped"] == 1
+ assert d["failed"] == 0
+
+
+class TestNeedsConversion:
+ """Tests for the _needs_conversion function."""
+
+ def test_needs_conversion_pdf_missing(self, html_file, tmp_path):
+ """Should return True when PDF doesn't exist."""
+ pdf_path = tmp_path / "output" / "test.pdf"
+ assert _needs_conversion(html_file, pdf_path) is True
+
+ def test_needs_conversion_pdf_older(self, html_file, tmp_path):
+ """Should return True when PDF is older than HTML."""
+ pdf_path = tmp_path / "test.pdf"
+ pdf_path.write_bytes(b"fake pdf content")
+
+ # Make PDF older by setting mtime in the past
+ import os
+ old_time = time.time() - 3600 # 1 hour ago
+ os.utime(pdf_path, (old_time, old_time))
+
+ # Touch HTML to make it newer
+ html_file.touch()
+
+ assert _needs_conversion(html_file, pdf_path) is True
+
+ def test_needs_conversion_pdf_newer(self, html_file, tmp_path):
+ """Should return False when PDF is newer than HTML."""
+ pdf_path = tmp_path / "test.pdf"
+
+ # Create PDF after HTML
+ time.sleep(0.1) # Ensure time difference
+ pdf_path.write_bytes(b"fake pdf content")
+
+ assert _needs_conversion(html_file, pdf_path) is False
+
+
+class TestConvertHtmlToPdf:
+ """Tests for single file HTML to PDF conversion."""
+
+ def test_convert_simple_html(self, html_file, tmp_path):
+ """Test converting HTML to PDF."""
+ pdf_path = tmp_path / "output" / "test.pdf"
+ result = convert_html_to_pdf(html_file, pdf_path)
+
+ assert result.success, f"Conversion failed: {result.error}"
+ assert result.skipped is False
+ assert pdf_path.exists()
+ assert pdf_path.stat().st_size > 0
+ assert result.pdf_path == str(pdf_path)
+
+ def test_convert_minimal_html(self, tmp_path):
+ """Test converting minimal HTML."""
+ html_path = tmp_path / "minimal.html"
+ html_path.write_text("Hello ")
+
+ pdf_path = tmp_path / "minimal.pdf"
+ result = convert_html_to_pdf(html_path, pdf_path)
+
+ assert result.success
+ assert pdf_path.exists()
+
+ def test_convert_with_css(self, tmp_path):
+ """Test converting HTML with CSS styling."""
+ html_content = """
+
+
+
+
+
+ Styled Document
+ Content in a box
+
+"""
+ html_path = tmp_path / "styled.html"
+ html_path.write_text(html_content)
+
+ pdf_path = tmp_path / "styled.pdf"
+ result = convert_html_to_pdf(html_path, pdf_path)
+
+ assert result.success
+ assert pdf_path.exists()
+
+ def test_convert_creates_parent_directory(self, html_file, tmp_path):
+ """Test that conversion creates parent directories."""
+ pdf_path = tmp_path / "nested" / "dir" / "test.pdf"
+ result = convert_html_to_pdf(html_file, pdf_path)
+
+ assert result.success
+ assert pdf_path.exists()
+
+ def test_convert_skips_up_to_date(self, html_file, tmp_path):
+ """Test that conversion skips when PDF is up-to-date."""
+ pdf_path = tmp_path / "test.pdf"
+
+ # First conversion
+ result1 = convert_html_to_pdf(html_file, pdf_path)
+ assert result1.success
+ assert result1.skipped is False
+
+ # Second conversion should be skipped
+ result2 = convert_html_to_pdf(html_file, pdf_path)
+ assert result2.success
+ assert result2.skipped is True
+
+ def test_convert_force_reconvert(self, html_file, tmp_path):
+ """Test force reconversion."""
+ pdf_path = tmp_path / "test.pdf"
+
+ # First conversion
+ result1 = convert_html_to_pdf(html_file, pdf_path)
+ assert result1.success
+
+ # Force reconversion
+ result2 = convert_html_to_pdf(html_file, pdf_path, force=True)
+ assert result2.success
+ assert result2.skipped is False
+
+
+class TestConvertFolder:
+ """Tests for folder-based HTML to PDF conversion."""
+
+ def test_convert_folder_basic(self, html_folder, tmp_path):
+ """Test converting an entire folder."""
+ pdf_dir = tmp_path / "pdf"
+ result = convert_folder(html_folder, pdf_dir)
+
+ assert result.total == 3
+ assert result.converted == 3
+ assert result.skipped == 0
+ assert result.failed == 0
+
+ # Check PDFs exist
+ assert (pdf_dir / "doc1.pdf").exists()
+ assert (pdf_dir / "doc2.pdf").exists()
+ assert (pdf_dir / "subdir" / "doc3.pdf").exists()
+
+ def test_convert_folder_preserves_structure(self, html_folder, tmp_path):
+ """Test that folder structure is preserved."""
+ pdf_dir = tmp_path / "pdf"
+ convert_folder(html_folder, pdf_dir)
+
+ # Subfolder structure should be preserved
+ assert (pdf_dir / "subdir").is_dir()
+ assert (pdf_dir / "subdir" / "doc3.pdf").exists()
+
+ def test_convert_folder_skips_up_to_date(self, html_folder, tmp_path):
+ """Test that folder conversion skips up-to-date files."""
+ pdf_dir = tmp_path / "pdf"
+
+ # First conversion
+ result1 = convert_folder(html_folder, pdf_dir)
+ assert result1.converted == 3
+ assert result1.skipped == 0
+
+ # Second conversion should skip all
+ result2 = convert_folder(html_folder, pdf_dir)
+ assert result2.converted == 0
+ assert result2.skipped == 3
+
+ def test_convert_folder_force(self, html_folder, tmp_path):
+ """Test force reconversion of folder."""
+ pdf_dir = tmp_path / "pdf"
+
+ # First conversion
+ convert_folder(html_folder, pdf_dir)
+
+ # Force reconversion
+ result = convert_folder(html_folder, pdf_dir, force=True)
+ assert result.converted == 3
+ assert result.skipped == 0
+
+ def test_convert_folder_empty(self, tmp_path):
+ """Test converting empty folder."""
+ empty_dir = tmp_path / "empty"
+ empty_dir.mkdir()
+
+ pdf_dir = tmp_path / "pdf"
+ result = convert_folder(empty_dir, pdf_dir)
+
+ assert result.total == 0
+ assert result.converted == 0
+
+ def test_convert_folder_parallel(self, tmp_path):
+ """Test that parallel conversion works correctly."""
+ # Create many HTML files
+ html_dir = tmp_path / "html"
+ html_dir.mkdir()
+
+ for i in range(10):
+ (html_dir / f"doc{i}.html").write_text(
+ f"Document {i} "
+ )
+
+ pdf_dir = tmp_path / "pdf"
+ result = convert_folder(html_dir, pdf_dir, max_workers=4)
+
+ assert result.total == 10
+ assert result.converted == 10
+ assert result.failed == 0
+
+ # All PDFs should exist
+ for i in range(10):
+ assert (pdf_dir / f"doc{i}.pdf").exists()
+
+
+class TestComplexDocuments:
+ """Tests for complex document conversion."""
+
+ def test_convert_table_document(self, tmp_path):
+ """Test converting HTML with tables."""
+ html_content = """
+
+
+
+
+
+ Data Report
+
+ Name Value Status
+ Item A 100 Active
+ Item B 200 Pending
+ Item C 300 Complete
+
+
+"""
+ html_path = tmp_path / "table.html"
+ html_path.write_text(html_content)
+
+ pdf_path = tmp_path / "table.pdf"
+ result = convert_html_to_pdf(html_path, pdf_path)
+
+ assert result.success
+ assert pdf_path.exists()
+ assert pdf_path.stat().st_size > 1000 # Should be non-trivial size
+
+ def test_convert_css_variables(self, tmp_path):
+ """Test converting HTML with CSS variables."""
+ html_content = """
+
+
+
+
+
+ CSS Variables Test
+ This uses CSS variables for styling.
+
+"""
+ html_path = tmp_path / "css_vars.html"
+ html_path.write_text(html_content)
+
+ pdf_path = tmp_path / "css_vars.pdf"
+ result = convert_html_to_pdf(html_path, pdf_path)
+
+ assert result.success
+ assert pdf_path.exists()
+
+ def test_convert_flexbox_layout(self, tmp_path):
+ """Test converting HTML with flexbox layout."""
+ html_content = """
+
+
+
+
+
+ Flexbox Layout
+
+
Box 1
+
Box 2
+
Box 3
+
+
+"""
+ html_path = tmp_path / "flexbox.html"
+ html_path.write_text(html_content)
+
+ pdf_path = tmp_path / "flexbox.pdf"
+ result = convert_html_to_pdf(html_path, pdf_path)
+
+ assert result.success
+ assert pdf_path.exists()
diff --git a/databricks-skills/README.md b/databricks-skills/README.md
index a81730a2..95e9a3f2 100644
--- a/databricks-skills/README.md
+++ b/databricks-skills/README.md
@@ -1,6 +1,6 @@
# Databricks Skills for Claude Code
-Skills that teach Claude Code how to work effectively with Databricks - providing patterns, best practices, and code examples that work with Databricks MCP tools.
+Skills that teach Claude Code how to work effectively with Databricks - providing patterns, best practices, and code examples using the Databricks CLI, Python SDK, and REST APIs.
## Installation
@@ -113,22 +113,21 @@ cp -r ai-dev-kit/databricks-skills/databricks-agent-bricks .claude/skills/
## How It Works
```
-┌────────────────────────────────────────────────┐
-│ .claude/skills/ + .claude/mcp.json │
-│ (Knowledge) (Actions) │
-│ │
-│ Skills teach HOW + MCP does it │
-│ ↓ ↓ │
-│ Claude Code learns patterns and executes │
-└────────────────────────────────────────────────┘
+┌─────────────────────────────────────────────────────┐
+│ .claude/skills/ + Databricks CLI/SDK │
+│ (Knowledge) (Actions) │
+│ │
+│ Skills teach HOW + CLI/SDK executes │
+│ ↓ ↓ │
+│ Claude Code learns patterns and executes │
+└─────────────────────────────────────────────────────┘
```
**Example:** User says "Create a sales dashboard"
1. Claude loads `databricks-aibi-dashboards` skill → learns validation workflow
-2. Calls `get_table_stats_and_schema()` → gets schemas
-3. Calls `execute_sql()` → tests queries
-4. Calls `manage_dashboard(action="create_or_update")` → deploys
-5. Returns working dashboard URL
+2. Runs `databricks experimental aitools tools query` → tests queries
+3. Uses Python SDK to create dashboard via REST API
+4. Returns working dashboard URL
## Custom Skills
@@ -149,6 +148,26 @@ description: "What this teaches"
...
```
+## Testing
+
+Run tests for skill scripts (requires `pytest`):
+
+```bash
+cd databricks-skills/.tests
+
+# Run all tests (unit tests are mocked, no Databricks connection needed)
+python run_tests.py
+
+# Run only unit tests
+python run_tests.py --unit
+
+# Run integration tests (requires Databricks connection)
+python run_tests.py --integration
+
+# Verbose output
+python run_tests.py -v
+```
+
## Troubleshooting
**Skills not loading?** Check `.claude/skills/` exists and each skill has `SKILL.md`
@@ -158,6 +177,7 @@ description: "What this teaches"
## Related
- [databricks-tools-core](../databricks-tools-core/) - Python library
-- [databricks-mcp-server](../databricks-mcp-server/) - MCP server
+- [Databricks CLI](https://docs.databricks.com/dev-tools/cli/index.html) - Official CLI
+- [Databricks SDK](https://docs.databricks.com/en/dev-tools/sdk-python.html) - Python SDK
- [Databricks Docs](https://docs.databricks.com/) - Official documentation
- [MLflow Skills](https://github.com/mlflow/skills) - Upstream MLflow skills repository
diff --git a/databricks-skills/databricks-agent-bricks/1-knowledge-assistants.md b/databricks-skills/databricks-agent-bricks/1-knowledge-assistants.md
index 3adff469..f7d0a942 100644
--- a/databricks-skills/databricks-agent-bricks/1-knowledge-assistants.md
+++ b/databricks-skills/databricks-agent-bricks/1-knowledge-assistants.md
@@ -1,183 +1,68 @@
-# Knowledge Assistants (KA)
+# Knowledge Assistants - Details
-Knowledge Assistants are document-based Q&A systems that use RAG (Retrieval-Augmented Generation) to answer questions from indexed documents.
+For commands, see [SKILL.md](SKILL.md).
-## What is a Knowledge Assistant?
+## Source Types
-A KA connects to documents stored in a Unity Catalog Volume and allows users to ask natural language questions. The system:
+### Files (Volume)
-1. **Indexes** all documents in the volume (PDFs, text files, etc.)
-2. **Retrieves** relevant chunks when a question is asked
-3. **Generates** an answer using the retrieved context
-
-## When to Use
-
-Use a Knowledge Assistant when:
-- You have a collection of documents (policies, manuals, guides, reports)
-- Users need to find specific information without reading entire documents
-- You want to provide a conversational interface to documentation
-
-## Prerequisites
-
-Before creating a KA, you need documents in a Unity Catalog Volume:
-
-**Option 1: Use existing documents**
-- Upload PDFs/text files to a Volume manually or via SDK
-
-**Option 2: Generate synthetic documents**
-- Use the `databricks-unstructured-pdf-generation` skill to create realistic PDF documents
-- Each PDF gets a companion JSON file with question/guideline pairs for evaluation
-
-## Creating a Knowledge Assistant
-
-Use the `manage_ka` tool with `action="create_or_update"`:
-
-- `name`: "HR Policy Assistant"
-- `volume_path`: "/Volumes/my_catalog/my_schema/raw_data/hr_docs"
-- `description`: "Answers questions about HR policies and procedures"
-- `instructions`: "Be helpful and always cite the specific policy document when answering. If you're unsure, say so."
-
-The tool will:
-1. Create the KA with the specified volume as a knowledge source
-2. Scan the volume for JSON files with example questions (from PDF generation)
-3. Queue examples to be added once the endpoint is ready
-
-## Provisioning Timeline
-
-After creation, the KA endpoint needs to provision:
-
-| Status | Meaning | Duration |
-|--------|---------|----------|
-| `PROVISIONING` | Creating the endpoint | 2-5 minutes |
-| `ONLINE` | Ready to use | - |
-| `OFFLINE` | Not currently running | - |
-
-Use `manage_ka` with `action="get"` to check the status:
-
-- `tile_id`: ""
-
-## Adding Example Questions
-
-Example questions help with:
-- **Evaluation**: Test if the KA answers correctly
-- **User onboarding**: Show users what to ask
-
-### Automatic (from PDF generation)
-
-If you used `generate_pdf_documents`, each PDF has a companion JSON with:
```json
{
- "question": "What is the company's remote work policy?",
- "guideline": "Should mention the 3-day minimum in-office requirement"
+ "source_type": "files",
+ "files": {"path": "/Volumes/catalog/schema/volume/folder/"}
}
```
-These are automatically added when `add_examples_from_volume=true` (default).
+Supported formats: PDF, TXT, MD, DOCX
-### Manual
+### Vector Search Index
-Examples can also be specified in the `manage_ka` create_or_update call if needed.
+Use existing index instead of auto-indexing:
-## Best Practices
-
-### Document Organization
-
-- **One volume per topic**: e.g., `/Volumes/catalog/schema/raw_data/hr_docs`, `/Volumes/catalog/schema/raw_data/tech_docs`
-- **Clear naming**: Name files descriptively so chunks are identifiable
-
-### Instructions
-
-Good instructions improve answer quality:
-
-```
-Be helpful and professional. When answering:
-1. Always cite the specific document and section
-2. If multiple documents are relevant, mention all of them
-3. If the information isn't in the documents, clearly say so
-4. Use bullet points for multi-part answers
+```json
+{
+ "source_type": "index",
+ "index": {
+ "index_name": "catalog.schema.my_index",
+ "text_col": "content",
+ "doc_uri_col": "source_url"
+ }
+}
```
-### Updating Content
-
-To update the indexed documents:
-1. Add/remove/modify files in the volume
-2. Call `manage_ka` with `action="create_or_update"`, the same name and `tile_id`
-3. The KA will re-index the updated content
-
-## Example Workflow
-
-1. **Generate PDF documents** using `databricks-unstructured-pdf-generation` skill:
- - Creates PDFs in `/Volumes/catalog/schema/raw_data/pdf_documents`
- - Creates JSON files with question/guideline pairs
-
-2. **Create the Knowledge Assistant**:
- - `name`: "My Document Assistant"
- - `volume_path`: "/Volumes/catalog/schema/raw_data/pdf_documents"
+## Updating Content
-3. **Wait for ONLINE status** (2-5 minutes)
+1. Add/modify/remove files in the Volume
+2. Re-sync: `databricks knowledge-assistants sync-knowledge-sources "knowledge-assistants/{ka_id}"`
-4. **Examples are automatically added** from the JSON files
-
-5. **Test the KA** in the Databricks UI
-
-## Using KA in Supervisor Agents
-
-Knowledge Assistants can be used as agents in a Supervisor Agent (formerly Multi-Agent Supervisor, MAS). Each KA has an associated model serving endpoint.
+## Troubleshooting
-### Finding the Endpoint Name
+**KA stays in CREATING:**
+- Wait up to 10 minutes
+- Check workspace quotas
+- Verify volume path exists
-Use `manage_ka` with `action="get"` to retrieve the KA details. The response includes:
-- `tile_id`: The unique identifier for the KA
-- `name`: The KA name (sanitized)
-- `endpoint_status`: Current status (ONLINE, PROVISIONING, etc.)
+**Documents not indexed:**
+- Check file format (PDF, TXT, MD, DOCX)
+- Verify volume path (trailing slash matters)
+- Check file permissions
-The endpoint name follows this pattern: `ka-{tile_id}-endpoint`
+**Poor answer quality:**
+- Ensure documents are well-structured
+- Break large documents into smaller files
+- Add clear headings and sections
-### Finding a KA by Name
+## Evaluation Questions
-If you know the KA name but not the tile_id, use `manage_ka` with `action="find_by_name"`:
+When testing a KA, check if the volume or project contains a `pdf_eval_questions.json` file with test questions:
-```python
-manage_ka(action="find_by_name", name="HR_Policy_Assistant")
-# Returns: {"found": True, "tile_id": "01abc...", "name": "HR_Policy_Assistant", "endpoint_name": "ka-01abc...-endpoint"}
-```
-
-### Example: Adding KA to Supervisor Agent
-
-```python
-# First, find the KA
-manage_ka(action="find_by_name", name="HR_Policy_Assistant")
-
-# Then use the tile_id in a Supervisor Agent
-manage_mas(
- action="create_or_update",
- name="Support_MAS",
- agents=[
- {
- "name": "hr_agent",
- "ka_tile_id": "",
- "description": "Answers HR policy questions from the employee handbook"
- }
- ]
-)
+```json
+{
+ "api_errors_guide.pdf": {
+ "question": "What is the solution for error ERR-4521?",
+ "expected_fact": "Call /api/v2/auth/refresh with refresh_token before the 3600s TTL expires"
+ }
+}
```
-## Troubleshooting
-
-### Endpoint stays in PROVISIONING
-
-- Check workspace capacity and quotas
-- Verify the volume path is accessible
-- Wait up to 10 minutes before investigating further
-
-### Documents not indexed
-
-- Ensure files are in a supported format (PDF, TXT, MD)
-- Check file permissions in the volume
-- Verify the volume path is correct
-
-### Poor answer quality
-
-- Add more specific instructions
-- Ensure documents are well-structured
-- Consider breaking large documents into smaller files
+Use these questions to validate retrieval accuracy. See [databricks-unstructured-pdf-generation](../databricks-unstructured-pdf-generation/SKILL.md) for generating test PDFs with eval questions.
diff --git a/databricks-skills/databricks-agent-bricks/2-supervisor-agents.md b/databricks-skills/databricks-agent-bricks/2-supervisor-agents.md
index 7121bfcf..bbb296d0 100644
--- a/databricks-skills/databricks-agent-bricks/2-supervisor-agents.md
+++ b/databricks-skills/databricks-agent-bricks/2-supervisor-agents.md
@@ -1,394 +1,92 @@
-# Supervisor Agents (MAS)
+# Supervisor Agents - Details
-Supervisor Agents orchestrate multiple specialized agents, routing user queries to the most appropriate agent based on the query content.
-
-## What is a Supervisor Agent?
-
-A Supervisor Agent (formerly Multi-Agent Supervisor, MAS) acts as a traffic controller for multiple AI agents, routing user queries to the most appropriate agent. It supports five types of agents:
-
-1. **Knowledge Assistants (KA)**: Document-based Q&A from PDFs/files in Volumes
-2. **Genie Spaces**: Natural language to SQL for data exploration
-3. **Model Serving Endpoints**: Custom LLM agents, fine-tuned models, RAG applications
-4. **Unity Catalog Functions**: Callable UC functions for data operations
-5. **External MCP Servers**: JSON-RPC endpoints via UC HTTP Connections for external system integration
-
-When a user asks a question:
-1. **Analyzes** the query to understand the intent
-2. **Routes** to the most appropriate specialized agent
-3. **Returns** the agent's response to the user
-
-This allows you to combine multiple specialized agents into a single unified interface.
-
-## When to Use
-
-Use a Supervisor Agent when:
-- You have multiple specialized agents (billing, technical support, HR, etc.)
-- Users shouldn't need to know which agent to ask
-- You want to provide a unified conversational experience
-
-## Prerequisites
-
-Before creating a Supervisor Agent, you need agents of one or both types:
-
-**Model Serving Endpoints** (`endpoint_name`):
-- Knowledge Assistant (KA) endpoints (e.g., `ka-abc123-endpoint`)
-- Custom agents built with LangChain, LlamaIndex, etc.
-- Fine-tuned models
-- RAG applications
-
-**Genie Spaces** (`genie_space_id`):
-- Existing Genie spaces for SQL-based data exploration
-- Great for analytics, metrics, and data-driven questions
-- No separate endpoint deployment required - reference the space directly
-- To find a Genie space by name, use `find_genie_by_name(display_name="My Genie")`
-- **Note**: There is NO system table for Genie spaces - do not try to query `system.ai.genie_spaces`
+For commands, see [SKILL.md](SKILL.md).
## Unity Catalog Functions
-Unity Catalog Functions allow Supervisor Agents to call registered UC functions for data operations.
+Call registered UC functions from the Supervisor Agent.
-### Prerequisites
-
-- UC Function already exists (use SQL `CREATE FUNCTION` or Python UDF)
-- Agent service principal has `EXECUTE` privilege:
- ```sql
- GRANT EXECUTE ON FUNCTION catalog.schema.function_name TO ``;
- ```
-
-### Configuration
+**Prerequisites:**
+- UC Function exists (`CREATE FUNCTION` or Python UDF)
+- Grant execute: `GRANT EXECUTE ON FUNCTION catalog.schema.func TO \`\`;`
+**Config:**
```json
-{
- "name": "data_enrichment",
- "uc_function_name": "sales_analytics.utils.enrich_customer_data",
- "description": "Enriches customer records with demographic and purchase history data"
-}
+{"name": "enricher", "uc_function_name": "catalog.schema.enrich_data", "description": "Enriches customer records"}
```
-**Field**: `uc_function_name` - Fully-qualified function name in format `catalog.schema.function_name`
-
## External MCP Servers
-External MCP Servers enable Supervisor Agents to interact with external systems (ERP, CRM, etc.) via UC HTTP Connections. The MCP server implements a JSON-RPC 2.0 endpoint that exposes tools for the Supervisor Agent to call.
-
-### Prerequisites
-
-**1. MCP Server Endpoint**: Your external system must provide a JSON-RPC 2.0 endpoint (e.g., `/api/mcp`) that implements the MCP protocol:
-
-```python
-# Example MCP server tool definition
-TOOLS = [
- {
- "name": "approve_invoice",
- "description": "Approve a specific invoice",
- "inputSchema": {
- "type": "object",
- "properties": {
- "invoice_number": {"type": "string", "description": "Invoice number to approve"},
- "approver": {"type": "string", "description": "Name/email of approver"},
- },
- "required": ["invoice_number"],
- },
- },
-]
-
-# JSON-RPC methods: initialize, tools/list, tools/call
-```
-
-**2. UC HTTP Connection**: Create a Unity Catalog HTTP Connection that points to your MCP endpoint:
+Connect to external systems (ERP, CRM) via UC HTTP Connection implementing MCP protocol.
+**1. Create UC HTTP Connection:**
```sql
-CREATE CONNECTION my_mcp_connection TYPE HTTP
+CREATE CONNECTION my_mcp TYPE HTTP
OPTIONS (
- host 'https://my-app.databricksapps.com', -- Your MCP server URL
+ host 'https://my-app.databricksapps.com',
port '443',
- base_path '/api/mcp', -- Path to JSON-RPC endpoint
- client_id '', -- OAuth M2M credentials
- client_secret '',
+ base_path '/api/mcp',
+ client_id '',
+ client_secret '',
oauth_scope 'all-apis',
token_endpoint 'https://.azuredatabricks.net/oidc/v1/token',
- is_mcp_connection 'true' -- REQUIRED: Identifies as MCP connection
+ is_mcp_connection 'true'
);
```
-**3. Grant Permissions**: Agent service principal needs access to the connection:
-
+**2. Grant access:**
```sql
-GRANT USE CONNECTION ON my_mcp_connection TO ``;
+GRANT USE CONNECTION ON my_mcp TO ``;
```
-### Configuration
-
-Reference the UC Connection using the `connection_name` field:
-
-```python
-{
- "name": "external_operations",
- "connection_name": "my_mcp_connection",
- "description": "Execute external system operations: approve invoices, create records, trigger workflows"
-}
-```
-
-**Field**: `connection_name` - the name of the Unity Catalog HTTP Connection configured as an MCP server
-
-**Important**: Make the description comprehensive - it guides the Supervisor Agent's routing decisions for when to call this agent.
-
-### Complete Example: Multi-System Supervisor
-
-Example showing integration of Genie, KA, and external MCP:
-
-```python
-manage_mas(
- action="create_or_update",
- name="AP_Invoice_Supervisor",
- agents=[
- {
- "name": "billing_analyst",
- "genie_space_id": "01abc123...",
- "description": "SQL analytics on AP invoice data: spending trends, vendor analysis, aging reports"
- },
- {
- "name": "policy_expert",
- "ka_tile_id": "f32c5f73...",
- "description": "Answers questions about AP policies, approval workflows, and compliance requirements from policy documents"
- },
- {
- "name": "ap_operations",
- "connection_name": "ap_invoice_mcp",
- "description": (
- "Execute AP operations: approve/reject/flag invoices, search invoice details, "
- "get vendor summaries, trigger batch workflows. Use for ANY action or write operation."
- )
- }
- ],
- description="AP automation assistant with analytics, policy guidance, and operational actions",
- instructions="""
- Route queries as follows:
- - Data questions (invoice counts, spend analysis, vendor metrics) → billing_analyst
- - Policy questions (thresholds, SLAs, compliance rules) → policy_expert
- - Actions (approve, reject, flag, search, workflows) → ap_operations
-
- When a user asks to approve, reject, or flag an invoice, ALWAYS use ap_operations.
- """
-)
+**3. Config:**
+```json
+{"name": "operations", "connection_name": "my_mcp", "description": "Execute operations: approve invoices, trigger workflows"}
```
-### MCP Connection Testing
-
-Verify your connection before adding to MAS:
-
+**Test connection:**
```sql
--- Test tools/list method
-SELECT http_request(
- conn => 'my_mcp_connection',
- method => 'POST',
- path => '',
- json => '{"jsonrpc":"2.0","method":"tools/list","id":1}'
-);
+SELECT http_request(conn => 'my_mcp', method => 'POST', path => '', json => '{"jsonrpc":"2.0","method":"tools/list","id":1}');
```
-### Resources
-
-- **MCP Protocol Spec**: [Model Context Protocol](https://modelcontextprotocol.io)
-
-## Creating a Supervisor Agent
-
-Use the `manage_mas` tool with `action="create_or_update"`:
-
-- `name`: "Customer Support MAS"
-- `agents`:
- ```json
- [
- {
- "name": "policy_agent",
- "ka_tile_id": "f32c5f73-466b-4798-b3a0-5396b5ece2a5",
- "description": "Answers questions about company policies and procedures from indexed documents"
- },
- {
- "name": "usage_analytics",
- "genie_space_id": "01abc123-def4-5678-90ab-cdef12345678",
- "description": "Answers data questions about usage metrics, trends, and statistics"
- },
- {
- "name": "custom_agent",
- "endpoint_name": "my-custom-endpoint",
- "description": "Handles specialized queries via custom model endpoint"
- }
- ]
- ```
-- `description`: "Routes customer queries to specialized support agents"
-- `instructions`: "Analyze the user's question and route to the most appropriate agent. If unclear, ask for clarification."
-
-This example shows mixing Knowledge Assistants (policy_agent), Genie spaces (usage_analytics), and custom endpoints (custom_agent).
-
-## Agent Configuration
-
-Each agent in the `agents` list needs:
-
-| Field | Required | Description |
-|-------|----------|-------------|
-| `name` | Yes | Internal identifier for the agent |
-| `description` | Yes | What this agent handles (critical for routing) |
-| `ka_tile_id` | One of these | Knowledge Assistant tile ID (for document Q&A agents) |
-| `genie_space_id` | One of these | Genie space ID (for SQL-based data agents) |
-| `endpoint_name` | One of these | Model serving endpoint name (for custom agents) |
-| `uc_function_name` | One of these | Unity Catalog function name in format `catalog.schema.function_name` |
-| `connection_name` | One of these | Unity Catalog connection name (for external MCP servers) |
+## Writing Good Descriptions
-**Note**: Provide exactly one of: `ka_tile_id`, `genie_space_id`, `endpoint_name`, `uc_function_name`, or `connection_name`.
+The `description` field drives routing. Be specific:
-To find a KA tile_id, use `manage_ka(action="find_by_name", name="Your KA Name")`.
-To find a Genie space_id, use `find_genie_by_name(display_name="Your Genie Name")`.
+| Good | Bad |
+|------|-----|
+| "Handles billing: invoices, payments, refunds, subscriptions" | "Billing agent" |
+| "Answers API errors, integration issues, product bugs" | "Technical" |
+| "HR policies, PTO, benefits, employee handbook" | "Handles stuff" |
-### Writing Good Descriptions
+## Adding Examples
-The `description` field is critical for routing. Make it specific:
+Examples help evaluation and routing optimization. MAS must be ONLINE.
-**Good descriptions:**
-- "Handles billing questions including invoices, payments, refunds, and subscription changes"
-- "Answers technical questions about API errors, integration issues, and product bugs"
-- "Provides information about HR policies, PTO, benefits, and employee handbook"
+```bash
+python scripts/mas_manager.py add_examples TILE_ID '[
+ {"question": "I need my invoice for March", "guideline": "Route to billing_agent"},
+ {"question": "API returns 500 error", "guideline": "Route to tech_agent"}
+]'
-**Bad descriptions:**
-- "Billing agent" (too vague)
-- "Handles stuff" (not helpful)
-- "Technical" (not specific)
-
-## Provisioning Timeline
-
-After creation, the Supervisor Agent endpoint needs to provision:
-
-| Status | Meaning | Duration |
-|--------|---------|----------|
-| `PROVISIONING` | Creating the supervisor | 2-5 minutes |
-| `ONLINE` | Ready to route queries | - |
-| `OFFLINE` | Not currently running | - |
-
-Use `manage_mas` with `action="get"` to check the status.
-
-## Adding Example Questions
-
-Example questions help with evaluation and can guide routing optimization:
-
-```json
-{
- "examples": [
- {
- "question": "I haven't received my invoice for this month",
- "guideline": "Should be routed to billing_agent"
- },
- {
- "question": "The API is returning a 500 error",
- "guideline": "Should be routed to technical_agent"
- },
- {
- "question": "How many vacation days do I have?",
- "guideline": "Should be routed to hr_agent"
- }
- ]
-}
+python scripts/mas_manager.py list_examples TILE_ID
```
-If the Supervisor Agent is not yet `ONLINE`, examples are queued and added automatically when ready.
-
-## Best Practices
-
-### Agent Design
-
-1. **Specialized agents**: Each agent should have a clear, distinct purpose
-2. **Non-overlapping domains**: Avoid agents with similar descriptions
-3. **Clear boundaries**: Define what each agent does and doesn't handle
-
-### Instructions
-
-Provide routing instructions:
-
-```
-You are a customer support supervisor. Your job is to route user queries to the right specialist:
-
-1. For billing, payments, or subscription questions → billing_agent
-2. For technical issues, bugs, or API problems → technical_agent
-3. For HR, benefits, or policy questions → hr_agent
-
-If the query is unclear or spans multiple domains, ask the user to clarify.
+**In automated jobs** (waits for ONLINE):
+```bash
+python scripts/mas_manager.py add_examples_wait TILE_ID '[...]'
```
-### Fallback Handling
-
-Consider adding a general-purpose agent for queries that don't fit elsewhere:
-
-```json
-{
- "name": "general_agent",
- "endpoint_name": "general-support-endpoint",
- "description": "Handles general inquiries that don't fit other categories, provides navigation help"
-}
-```
-
-## Example Workflow
-
-1. **Deploy specialized agents** as model serving endpoints:
- - `billing-assistant-endpoint`
- - `tech-support-endpoint`
- - `hr-assistant-endpoint`
-
-2. **Create the MAS**:
- - Configure agents with clear descriptions
- - Add routing instructions
-
-3. **Wait for ONLINE status** (2-5 minutes)
-
-4. **Add example questions** for evaluation
-
-5. **Test routing** with various query types
-
-## Updating a Supervisor Agent
-
-To update an existing Supervisor Agent:
-
-1. **Add/remove agents**: Call `manage_mas` with `action="create_or_update"` and updated `agents` list
-2. **Update descriptions**: Change agent descriptions to improve routing
-3. **Modify instructions**: Update routing rules
-
-The tool finds the existing Supervisor Agent by name and updates it.
-
## Troubleshooting
-### Queries routed to wrong agent
-
-- Review and improve agent descriptions
-- Make descriptions more specific and distinct
-- Add examples that demonstrate correct routing
-
-### Endpoint not responding
-
-- Verify each underlying model serving endpoint is running
-- Check endpoint logs for errors
-- Ensure endpoints accept the expected input format
-
-### Slow responses
+**Wrong routing:**
+- Improve agent descriptions (more specific, less overlap)
+- Add examples demonstrating correct routing
-- Check latency of underlying endpoints
-- Consider endpoint scaling settings
-- Monitor for cold start issues
-
-## Advanced: Hierarchical Routing
-
-For complex scenarios, you can create multiple levels of Supervisor Agents:
-
-```
-Top-level Supervisor
-├── Customer Support Supervisor
-│ ├── billing_agent
-│ ├── technical_agent
-│ └── general_agent
-├── Sales Supervisor
-│ ├── pricing_agent
-│ ├── demo_agent
-│ └── contract_agent
-└── Internal Supervisor
- ├── hr_agent
- └── it_helpdesk_agent
-```
+**Endpoint not responding:**
+- Verify underlying endpoints are running
+- Check endpoint logs
-Each sub-supervisor is deployed as an endpoint and configured as an agent in the top-level supervisor.
+**Slow responses:**
+- Check underlying endpoint latency
+- Review endpoint scaling settings
diff --git a/databricks-skills/databricks-agent-bricks/SKILL.md b/databricks-skills/databricks-agent-bricks/SKILL.md
index 026f204a..1245a54d 100644
--- a/databricks-skills/databricks-agent-bricks/SKILL.md
+++ b/databricks-skills/databricks-agent-bricks/SKILL.md
@@ -1,212 +1,95 @@
---
name: databricks-agent-bricks
-description: "Create and manage Databricks Agent Bricks: Knowledge Assistants (KA) for document Q&A, Genie Spaces for SQL exploration, and Supervisor Agents (MAS) for multi-agent orchestration. Use when building conversational AI applications on Databricks."
+description: "Create Agent Bricks: Knowledge Assistants (KA) for document Q&A and Supervisor Agents for multi-agent orchestration (MAS). For Genie Spaces, see databricks-genie skill."
---
# Agent Bricks
-Create and manage Databricks Agent Bricks - pre-built AI components for building conversational applications.
-
-## Overview
-
-Agent Bricks are three types of pre-built AI tiles in Databricks:
+Agent Bricks are pre-built AI tiles in Databricks that provide conversational interfaces. This skill covers **Knowledge Assistants** and **Supervisor Agents**. For Genie Spaces, use the `databricks-genie` skill.
-| Brick | Purpose | Data Source |
-|-------|---------|-------------|
-| **Knowledge Assistant (KA)** | Document-based Q&A using RAG | PDF/text files in Volumes |
-| **Genie Space** | Natural language to SQL | Unity Catalog tables |
-| **Supervisor Agent (MAS)** | Multi-agent orchestration | Model serving endpoints |
+| Brick | Purpose | This Skill |
+|-------|---------|------------|
+| **Knowledge Assistant (KA)** | Document Q&A using RAG on PDFs/text in Volumes | ✓ |
+| **Supervisor Agent** | Orchestrates multiple agents (KA, Genie, endpoints, UC functions, MCP) | ✓ |
+| **Genie Space** | Natural language to SQL on Unity Catalog tables | `databricks-genie` |
-## Prerequisites
-
-Before creating Agent Bricks, ensure you have the required data:
-
-### For Knowledge Assistants
-- **Documents in a Volume**: PDF, text, or other files stored in a Unity Catalog volume
-- Generate synthetic documents using the `databricks-unstructured-pdf-generation` skill if needed
-
-### For Genie Spaces
-- **See the `databricks-genie` skill** for comprehensive Genie Space guidance
-- Tables in Unity Catalog with the data to explore
-- Generate raw data using the `databricks-synthetic-data-gen` skill
-- Create tables using the `databricks-spark-declarative-pipelines` skill
-
-### For Supervisor Agents
-- **Model Serving Endpoints**: Deployed agent endpoints (KA endpoints, custom agents, fine-tuned models)
-- **Genie Spaces**: Existing Genie spaces can be used directly as agents for SQL-based queries
-- Mix and match endpoint-based and Genie-based agents in the same Supervisor Agent
-
-### For Unity Catalog Functions
-- **Existing UC Function**: Function already registered in Unity Catalog
-- Agent service principal has `EXECUTE` privilege on the function
-
-### For External MCP Servers
-- **Existing UC HTTP Connection**: Connection configured with `is_mcp_connection: 'true'`
-- Agent service principal has `USE CONNECTION` privilege on the connection
-
-## MCP Tools
-
-### Knowledge Assistant Tool
+---
-**manage_ka** - Manage Knowledge Assistants (KA)
-- `action`: "create_or_update", "get", "find_by_name", or "delete"
-- `name`: Name for the KA (for create_or_update, find_by_name)
-- `volume_path`: Path to documents (e.g., `/Volumes/catalog/schema/volume/folder`) (for create_or_update)
-- `description`: (optional) What the KA does (for create_or_update)
-- `instructions`: (optional) How the KA should answer (for create_or_update)
-- `tile_id`: The KA tile ID (for get, delete, or update via create_or_update)
-- `add_examples_from_volume`: (optional, default: true) Auto-add examples from JSON files (for create_or_update)
+## Knowledge Assistant
-Actions:
-- **create_or_update**: Requires `name`, `volume_path`. Optionally pass `tile_id` to update.
-- **get**: Requires `tile_id`. Returns tile_id, name, description, endpoint_status, knowledge_sources, examples_count.
-- **find_by_name**: Requires `name` (exact match). Returns found, tile_id, name, endpoint_name, endpoint_status. Use this to look up an existing KA when you know the name but not the tile_id.
-- **delete**: Requires `tile_id`.
+```bash
+# Find volumes
+databricks volumes list CATALOG SCHEMA
+databricks experimental aitools tools query --warehouse WH "LIST '/Volumes/catalog/schema/volume/'"
-### Genie Space Tools
+# Create KA
+databricks knowledge-assistants create-knowledge-assistant "Name" "Description"
-**For comprehensive Genie guidance, use the `databricks-genie` skill.**
+# Add knowledge source
+databricks knowledge-assistants create-knowledge-source "knowledge-assistants/{ka_id}" \
+ --json '{"display_name": "Docs", "description": "...", "source_type": "files", "files": {"path": "/Volumes/catalog/schema/volume/"}}'
-Use `manage_genie` with actions:
-- `create_or_update` - Create or update a Genie Space
-- `get` - Get Genie Space details
-- `list` - List all Genie Spaces
-- `delete` - Delete a Genie Space
-- `export` / `import` - For migration
+# Sync and check status
+databricks knowledge-assistants sync-knowledge-sources "knowledge-assistants/{ka_id}"
+databricks knowledge-assistants get-knowledge-assistant "knowledge-assistants/{ka_id}"
-See `databricks-genie` skill for:
-- Table inspection workflow
-- Sample question best practices
-- Curation (instructions, certified queries)
+# List/manage
+databricks knowledge-assistants list-knowledge-assistants
+databricks knowledge-assistants delete-knowledge-assistant "knowledge-assistants/{ka_id}"
+```
-**IMPORTANT**: There is NO system table for Genie spaces (e.g., `system.ai.genie_spaces` does not exist). Use `manage_genie(action="list")` to find spaces.
+**Source types:** `files` (Volume path) or `index` (Vector Search: `index.index_name`, `index.text_col`, `index.doc_uri_col`)
-### Supervisor Agent Tool
+**Status:** `CREATING` (2-5 min) → `ONLINE` → `OFFLINE`
-**manage_mas** - Manage Supervisor Agents (MAS)
-- `action`: "create_or_update", "get", "find_by_name", or "delete"
-- `name`: Name for the Supervisor Agent (for create_or_update, find_by_name)
-- `agents`: List of agent configurations (for create_or_update), each with:
- - `name`: Agent identifier (required)
- - `description`: What this agent handles - critical for routing (required)
- - `ka_tile_id`: Knowledge Assistant tile ID (use for document Q&A agents - recommended for KAs)
- - `genie_space_id`: Genie space ID (use for SQL-based data agents)
- - `endpoint_name`: Model serving endpoint name (for custom agents)
- - `uc_function_name`: Unity Catalog function name in format `catalog.schema.function_name`
- - `connection_name`: Unity Catalog connection name (for external MCP servers)
- - Note: Provide exactly one of: `ka_tile_id`, `genie_space_id`, `endpoint_name`, `uc_function_name`, or `connection_name`
-- `description`: (optional) What the Supervisor Agent does (for create_or_update)
-- `instructions`: (optional) Routing instructions for the supervisor (for create_or_update)
-- `tile_id`: The Supervisor Agent tile ID (for get, delete, or update via create_or_update)
-- `examples`: (optional) List of example questions with `question` and `guideline` fields (for create_or_update)
+---
-Actions:
-- **create_or_update**: Requires `name`, `agents`. Optionally pass `tile_id` to update.
-- **get**: Requires `tile_id`. Returns tile_id, name, description, endpoint_status, agents, examples_count.
-- **find_by_name**: Requires `name` (exact match). Returns found, tile_id, name, endpoint_status, agents_count. Use this to look up an existing Supervisor Agent when you know the name but not the tile_id.
-- **delete**: Requires `tile_id`.
-
-## Typical Workflow
-
-### 1. Generate Source Data
-
-Before creating Agent Bricks, generate the required source data:
-
-**For KA (document Q&A)**:
-```
-1. Use `databricks-unstructured-pdf-generation` skill to generate PDFs
-2. PDFs are saved to a Volume with companion JSON files (question/guideline pairs)
+## Supervisor Agent
+
+**No CLI** - use `scripts/mas_manager.py` (run from skill folder):
+
+```bash
+# Create MAS
+python scripts/mas_manager.py create_mas "My Supervisor" '{
+ "description": "Routes queries to specialized agents",
+ "instructions": "Route data questions to analyst, document questions to docs_agent.",
+ "agents": [
+ {"name": "analyst", "genie_space_id": "01abc...", "description": "SQL analytics"},
+ {"name": "docs_agent", "ka_tile_id": "dab408a2-...", "description": "Answers from documents"}
+ ]
+}'
+
+# Check status and manage
+python scripts/mas_manager.py get_mas TILE_ID
+python scripts/mas_manager.py list_mas
+python scripts/mas_manager.py update_mas TILE_ID '{"agents": [...]}'
+python scripts/mas_manager.py delete_mas TILE_ID
+
+# Add examples (requires ONLINE)
+python scripts/mas_manager.py add_examples TILE_ID '[{"question": "...", "guideline": "..."}]'
+
+# Find IDs
+databricks knowledge-assistants list-knowledge-assistants --output json | jq '.[].id'
+databricks genie list-spaces --output json | jq '.[].space_id'
```
-**For Genie (SQL exploration)**:
-```
-1. Use `databricks-synthetic-data-gen` skill to create raw parquet data
-2. Use `databricks-spark-declarative-pipelines` skill to create bronze/silver/gold tables
-```
+**Agent types** (use exactly ONE per agent):
-### 2. Create the Agent Brick
-
-Use `manage_ka(action="create_or_update", ...)` or `manage_mas(action="create_or_update", ...)` with your data sources.
-
-### 3. Wait for Provisioning
-
-Newly created KA and MAS tiles need time to provision. The endpoint status will progress:
-- `PROVISIONING` - Being created (can take 2-5 minutes)
-- `ONLINE` - Ready to use
-- `OFFLINE` - Not running
-
-### 4. Add Examples (Automatic)
-
-For KA, if `add_examples_from_volume=true`, examples are automatically extracted from JSON files in the volume and added once the endpoint is `ONLINE`.
-
-## Best Practices
-
-1. **Use meaningful names**: Names are sanitized automatically (spaces become underscores)
-2. **Provide descriptions**: Helps users understand what the brick does
-3. **Add instructions**: Guide the AI's behavior and tone
-4. **Include sample questions**: Shows users how to interact with the brick
-5. **Use the workflow**: Generate data first, then create the brick
-
-## Example: Multi-Modal Supervisor Agent
-
-```python
-manage_mas(
- action="create_or_update",
- name="Enterprise Support Supervisor",
- agents=[
- {
- "name": "knowledge_base",
- "ka_tile_id": "f32c5f73-466b-...",
- "description": "Answers questions about company policies, procedures, and documentation from indexed files"
- },
- {
- "name": "analytics_engine",
- "genie_space_id": "01abc123...",
- "description": "Runs SQL analytics on usage metrics, performance stats, and operational data"
- },
- {
- "name": "ml_classifier",
- "endpoint_name": "custom-classification-endpoint",
- "description": "Classifies support tickets and predicts resolution time using custom ML model"
- },
- {
- "name": "data_enrichment",
- "uc_function_name": "support.utils.enrich_ticket_data",
- "description": "Enriches support ticket data with customer history and context"
- },
- {
- "name": "ticket_operations",
- "connection_name": "ticket_system_mcp",
- "description": "Creates, updates, assigns, and closes support tickets in external ticketing system"
- }
- ],
- description="Comprehensive enterprise support agent with knowledge retrieval, analytics, ML, data enrichment, and ticketing operations",
- instructions="""
- Route queries as follows:
- 1. Policy/procedure questions → knowledge_base
- 2. Data analysis requests → analytics_engine
- 3. Ticket classification → ml_classifier
- 4. Customer context lookups → data_enrichment
- 5. Ticket creation/updates → ticket_operations
-
- If a query spans multiple domains, chain agents:
- - First gather information (analytics_engine or knowledge_base)
- - Then take action (ticket_operations)
- """
-)
-```
+| Field | Type |
+|-------|------|
+| `ka_tile_id` | Knowledge Assistant |
+| `genie_space_id` | Genie Space |
+| `endpoint_name` | Model serving endpoint |
+| `uc_function_name` | UC function (`catalog.schema.func`) |
+| `connection_name` | MCP server (UC HTTP Connection) |
-## Related Skills
+**Status:** `NOT_READY` (2-5 min) → `ONLINE` → `OFFLINE`
-- **[databricks-genie](../databricks-genie/SKILL.md)** - Comprehensive Genie Space creation, curation, and Conversation API guidance
-- **[databricks-unstructured-pdf-generation](../databricks-unstructured-pdf-generation/SKILL.md)** - Generate synthetic PDFs to feed into Knowledge Assistants
-- **[databricks-synthetic-data-gen](../databricks-synthetic-data-gen/SKILL.md)** - Create raw data for Genie Space tables
-- **[databricks-spark-declarative-pipelines](../databricks-spark-declarative-pipelines/SKILL.md)** - Build bronze/silver/gold tables consumed by Genie Spaces
-- **[databricks-model-serving](../databricks-model-serving/SKILL.md)** - Deploy custom agent endpoints used as MAS agents
-- **[databricks-vector-search](../databricks-vector-search/SKILL.md)** - Build vector indexes for RAG applications paired with KAs
+---
-## See Also
+## Reference
-- `1-knowledge-assistants.md` - Detailed KA patterns and examples
-- `databricks-genie` skill - Detailed Genie patterns, curation, and examples
-- `2-supervisor-agents.md` - Detailed MAS patterns and examples
+| Topic | File |
+|-------|------|
+| KA source types, index, troubleshooting | [1-knowledge-assistants.md](1-knowledge-assistants.md) |
+| UC functions, MCP servers, examples | [2-supervisor-agents.md](2-supervisor-agents.md) |
diff --git a/databricks-skills/databricks-agent-bricks/scripts/mas_manager.py b/databricks-skills/databricks-agent-bricks/scripts/mas_manager.py
new file mode 100644
index 00000000..a2e1af58
--- /dev/null
+++ b/databricks-skills/databricks-agent-bricks/scripts/mas_manager.py
@@ -0,0 +1,372 @@
+#!/usr/bin/env python3
+"""
+Supervisor Agent (MAS) Manager - CLI for MAS operations.
+
+Usage:
+ python mas_manager.py create_mas "Name" '{"agents": [...], "description": "...", "instructions": "..."}'
+ python mas_manager.py get_mas TILE_ID
+ python mas_manager.py find_mas "Name"
+ python mas_manager.py update_mas TILE_ID '{"name": ..., "agents": [...], ...}'
+ python mas_manager.py delete_mas TILE_ID
+ python mas_manager.py list_mas
+ python mas_manager.py add_examples TILE_ID '[{"question": "...", "guideline": "..."}]'
+ python mas_manager.py add_examples_wait TILE_ID '[{"question": "...", "guideline": "..."}]'
+ python mas_manager.py list_examples TILE_ID
+
+The add_examples_wait command waits for the MAS to become ONLINE before adding examples.
+This is useful in jobs where you create a MAS and immediately need to add examples.
+
+Requires: databricks-sdk, requests
+ pip install databricks-sdk requests
+"""
+
+import json
+import re
+import sys
+import time
+from typing import Any
+
+import requests
+from databricks.sdk import WorkspaceClient
+
+# Global client - initialized lazily
+_client: WorkspaceClient = None
+
+
+def _get_client() -> WorkspaceClient:
+ """Get or create WorkspaceClient."""
+ global _client
+ if _client is None:
+ _client = WorkspaceClient()
+ return _client
+
+
+def _request(method: str, path: str, body: dict = None, params: dict = None) -> dict:
+ """Make authenticated HTTP request to Databricks API."""
+ w = _get_client()
+ url = f"{w.config.host}{path}"
+ headers = w.config.authenticate()
+ if body:
+ headers["Content-Type"] = "application/json"
+
+ resp = requests.request(method, url, headers=headers, json=body, params=params, timeout=300)
+
+ if resp.status_code >= 400:
+ try:
+ err = resp.json().get("message", resp.text)
+ except Exception:
+ err = resp.text
+ raise Exception(f"{method} {path}: {err}")
+
+ return resp.json() if resp.text else {}
+
+
+def _sanitize_name(name: str) -> str:
+ """Sanitize name to alphanumeric with hyphens/underscores."""
+ name = re.sub(r"[^a-zA-Z0-9_-]", "_", name.replace(" ", "_"))
+ name = re.sub(r"_+", "_", name).strip("_")
+ return name or "supervisor_agent"
+
+
+def _build_agents(agents: list[dict]) -> list[dict]:
+ """Convert simplified agent config to API format."""
+ result = []
+ for a in agents:
+ cfg = {"name": a.get("name", ""), "description": a.get("description", "")}
+
+ if a.get("genie_space_id"):
+ cfg["agent_type"] = "genie"
+ cfg["genie_space"] = {"id": a["genie_space_id"]}
+ elif a.get("ka_tile_id"):
+ cfg["agent_type"] = "serving_endpoint"
+ cfg["serving_endpoint"] = {"name": f"ka-{a['ka_tile_id'].split('-')[0]}-endpoint"}
+ elif a.get("uc_function_name"):
+ parts = a["uc_function_name"].split(".")
+ cfg["agent_type"] = "unity_catalog_function"
+ cfg["unity_catalog_function"] = {"uc_path": {"catalog": parts[0], "schema": parts[1], "name": parts[2]}}
+ elif a.get("connection_name"):
+ cfg["agent_type"] = "external_mcp_server"
+ cfg["external_mcp_server"] = {"connection_name": a["connection_name"]}
+ else:
+ cfg["agent_type"] = "serving_endpoint"
+ cfg["serving_endpoint"] = {"name": a.get("endpoint_name")}
+
+ result.append(cfg)
+ return result
+
+
+# ============================================================================
+# MAS CRUD Operations
+# ============================================================================
+
+
+def create_mas(name: str, agents: list[dict], description: str = None, instructions: str = None) -> dict:
+ """Create a Supervisor Agent."""
+ payload = {"name": _sanitize_name(name), "agents": _build_agents(agents)}
+ if description:
+ payload["description"] = description
+ if instructions:
+ payload["instructions"] = instructions
+
+ resp = _request("POST", "/api/2.0/multi-agent-supervisors", payload)
+ mas = resp.get("multi_agent_supervisor", {})
+
+ return {
+ "tile_id": mas.get("tile", {}).get("tile_id", ""),
+ "name": mas.get("tile", {}).get("name", name),
+ "endpoint_status": mas.get("status", {}).get("endpoint_status", "UNKNOWN"),
+ "agents_count": len(agents),
+ }
+
+
+def get_mas(tile_id: str) -> dict:
+ """Get a Supervisor Agent by tile ID."""
+ try:
+ resp = _request("GET", f"/api/2.0/multi-agent-supervisors/{tile_id}")
+ except Exception as e:
+ if "not found" in str(e).lower() or "does not exist" in str(e).lower():
+ return {"error": f"Supervisor Agent {tile_id} not found"}
+ raise
+
+ mas = resp.get("multi_agent_supervisor", {})
+ tile = mas.get("tile", {})
+
+ return {
+ "tile_id": tile.get("tile_id", tile_id),
+ "name": tile.get("name", ""),
+ "description": tile.get("description", ""),
+ "endpoint_status": mas.get("status", {}).get("endpoint_status", "UNKNOWN"),
+ "agents": mas.get("agents", []),
+ "instructions": tile.get("instructions", ""),
+ }
+
+
+def find_mas(name: str) -> dict:
+ """Find a Supervisor Agent by name."""
+ sanitized = _sanitize_name(name)
+ page_token = None
+
+ while True:
+ params = {"filter": f"name_contains={sanitized}&&tile_type=MAS"}
+ if page_token:
+ params["page_token"] = page_token
+
+ resp = _request("GET", "/api/2.0/tiles", params=params)
+
+ for t in resp.get("tiles", []):
+ if t.get("name") == sanitized:
+ details = get_mas(t["tile_id"])
+ return {
+ "found": True,
+ "tile_id": t["tile_id"],
+ "name": sanitized,
+ "endpoint_status": details.get("endpoint_status", "UNKNOWN"),
+ "agents_count": len(details.get("agents", [])),
+ }
+
+ page_token = resp.get("next_page_token")
+ if not page_token:
+ break
+
+ return {"found": False, "name": name}
+
+
+def update_mas(tile_id: str, name: str = None, agents: list[dict] = None,
+ description: str = None, instructions: str = None) -> dict:
+ """Update a Supervisor Agent."""
+ existing = get_mas(tile_id)
+ if "error" in existing:
+ return existing
+
+ payload = {"tile_id": tile_id}
+ if name:
+ payload["name"] = _sanitize_name(name)
+ if description:
+ payload["description"] = description
+ if instructions:
+ payload["instructions"] = instructions
+ if agents:
+ payload["agents"] = _build_agents(agents)
+
+ resp = _request("PATCH", f"/api/2.0/multi-agent-supervisors/{tile_id}", payload)
+ mas = resp.get("multi_agent_supervisor", {})
+
+ return {
+ "tile_id": mas.get("tile", {}).get("tile_id", tile_id),
+ "name": mas.get("tile", {}).get("name", ""),
+ "endpoint_status": mas.get("status", {}).get("endpoint_status", "UNKNOWN"),
+ }
+
+
+def delete_mas(tile_id: str) -> dict:
+ """Delete a Supervisor Agent."""
+ try:
+ _request("DELETE", f"/api/2.0/tiles/{tile_id}")
+ return {"success": True, "tile_id": tile_id}
+ except Exception as e:
+ return {"success": False, "tile_id": tile_id, "error": str(e)}
+
+
+def list_mas() -> list[dict]:
+ """List all Supervisor Agents."""
+ results = []
+ page_token = None
+
+ while True:
+ params = {"page_size": 100, "filter": "tile_type=MAS"}
+ if page_token:
+ params["page_token"] = page_token
+
+ resp = _request("GET", "/api/2.0/tiles", params=params)
+
+ for tile in resp.get("tiles", []):
+ if tile.get("tile_type") in ("MAS", "5"):
+ details = get_mas(tile["tile_id"])
+ if "error" not in details:
+ results.append({
+ "tile_id": tile["tile_id"],
+ "name": details.get("name", ""),
+ "endpoint_status": details.get("endpoint_status", "UNKNOWN"),
+ "agents_count": len(details.get("agents", [])),
+ })
+
+ page_token = resp.get("next_page_token")
+ if not page_token:
+ break
+
+ return results
+
+
+# ============================================================================
+# Examples Management
+# ============================================================================
+
+
+def add_examples(tile_id: str, examples: list[dict]) -> dict:
+ """Add example questions to a Supervisor Agent (must be ONLINE)."""
+ status = get_mas(tile_id)
+ if "error" in status:
+ return status
+
+ if status.get("endpoint_status") != "ONLINE":
+ return {
+ "error": f"MAS not ONLINE (status: {status.get('endpoint_status')}). Wait and retry.",
+ "tile_id": tile_id,
+ }
+
+ added = 0
+ for ex in examples:
+ question = ex.get("question", "")
+ if not question:
+ continue
+
+ guideline = ex.get("guideline")
+ payload = {"tile_id": tile_id, "question": question}
+ if guideline:
+ payload["guidelines"] = [guideline] if isinstance(guideline, str) else guideline
+
+ try:
+ _request("POST", f"/api/2.0/multi-agent-supervisors/{tile_id}/examples", payload)
+ added += 1
+ except Exception as e:
+ print(f"Warning: Failed to add example '{question[:50]}...': {e}", file=sys.stderr)
+
+ return {"tile_id": tile_id, "added_count": added, "total_requested": len(examples)}
+
+
+def add_examples_wait(tile_id: str, examples: list[dict], timeout: int = 600, poll_interval: int = 15) -> dict:
+ """Wait for MAS to become ONLINE, then add examples.
+
+ Useful in jobs where you create a MAS and immediately need to add examples.
+ Polls the MAS status until ONLINE or timeout is reached.
+
+ Args:
+ tile_id: The MAS tile ID
+ examples: List of example dicts with 'question' and optional 'guideline'
+ timeout: Max seconds to wait for ONLINE status (default: 600 = 10 minutes)
+ poll_interval: Seconds between status checks (default: 15)
+
+ Returns:
+ Result dict with added_count, or error if timeout/failure
+ """
+ elapsed = 0
+ while elapsed < timeout:
+ status = get_mas(tile_id)
+ if "error" in status:
+ return status
+
+ endpoint_status = status.get("endpoint_status", "UNKNOWN")
+ if endpoint_status == "ONLINE":
+ return add_examples(tile_id, examples)
+
+ if endpoint_status in ("FAILED", "OFFLINE"):
+ return {
+ "error": f"MAS endpoint is {endpoint_status}, cannot add examples",
+ "tile_id": tile_id,
+ }
+
+ print(f"Waiting for MAS to become ONLINE (current: {endpoint_status}, elapsed: {elapsed}s)...", file=sys.stderr)
+ time.sleep(poll_interval)
+ elapsed += poll_interval
+
+ return {
+ "error": f"Timeout waiting for MAS to become ONLINE after {timeout}s",
+ "tile_id": tile_id,
+ "last_status": status.get("endpoint_status", "UNKNOWN"),
+ }
+
+
+def list_examples(tile_id: str) -> dict:
+ """List all examples for a Supervisor Agent."""
+ resp = _request("GET", f"/api/2.0/multi-agent-supervisors/{tile_id}/examples", params={"page_size": 100})
+ examples = resp.get("examples", [])
+ return {"tile_id": tile_id, "examples": examples, "count": len(examples)}
+
+
+# ============================================================================
+# CLI Entry Point
+# ============================================================================
+
+
+def main():
+ """CLI entry point."""
+ if len(sys.argv) < 2:
+ print(__doc__)
+ sys.exit(1)
+
+ cmd = sys.argv[1]
+
+ try:
+ if cmd == "create_mas" and len(sys.argv) >= 4:
+ cfg = json.loads(sys.argv[3])
+ result = create_mas(sys.argv[2], cfg.get("agents", []), cfg.get("description"), cfg.get("instructions"))
+ elif cmd == "get_mas" and len(sys.argv) >= 3:
+ result = get_mas(sys.argv[2])
+ elif cmd == "find_mas" and len(sys.argv) >= 3:
+ result = find_mas(sys.argv[2])
+ elif cmd == "update_mas" and len(sys.argv) >= 4:
+ cfg = json.loads(sys.argv[3])
+ result = update_mas(sys.argv[2], cfg.get("name"), cfg.get("agents"), cfg.get("description"), cfg.get("instructions"))
+ elif cmd == "delete_mas" and len(sys.argv) >= 3:
+ result = delete_mas(sys.argv[2])
+ elif cmd == "list_mas":
+ result = list_mas()
+ elif cmd == "add_examples" and len(sys.argv) >= 4:
+ result = add_examples(sys.argv[2], json.loads(sys.argv[3]))
+ elif cmd == "add_examples_wait" and len(sys.argv) >= 4:
+ result = add_examples_wait(sys.argv[2], json.loads(sys.argv[3]))
+ elif cmd == "list_examples" and len(sys.argv) >= 3:
+ result = list_examples(sys.argv[2])
+ else:
+ print(__doc__)
+ sys.exit(1)
+
+ print(json.dumps(result, indent=2))
+
+ except Exception as e:
+ print(json.dumps({"error": str(e)}), file=sys.stderr)
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/databricks-skills/databricks-aibi-dashboards/1-widget-specifications.md b/databricks-skills/databricks-aibi-dashboards/1-widget-specifications.md
index d8e03c13..9f861c1b 100644
--- a/databricks-skills/databricks-aibi-dashboards/1-widget-specifications.md
+++ b/databricks-skills/databricks-aibi-dashboards/1-widget-specifications.md
@@ -297,7 +297,7 @@ Add `format` to any encoding to display values appropriately:
## Dataset Parameters
-Use `:param` syntax in SQL for dynamic filtering:
+Use `:param` syntax in SQL for dynamic filtering. Parameters can be bound to filter widgets (see [3-filters.md](3-filters.md)):
```json
{
@@ -323,19 +323,9 @@ Use `:param` syntax in SQL for dynamic filtering:
Allowed in `query.fields` (no CAST or complex SQL):
```json
-// Aggregations
-{"name": "sum(revenue)", "expression": "SUM(`revenue`)"}
-{"name": "avg(price)", "expression": "AVG(`price`)"}
-{"name": "count(id)", "expression": "COUNT(`id`)"}
-{"name": "countdistinct(id)", "expression": "COUNT(DISTINCT `id`)"}
-
-// Date truncation
-{"name": "daily(date)", "expression": "DATE_TRUNC(\"DAY\", `date`)"}
-{"name": "weekly(date)", "expression": "DATE_TRUNC(\"WEEK\", `date`)"}
-{"name": "monthly(date)", "expression": "DATE_TRUNC(\"MONTH\", `date`)"}
-
-// Simple reference
-{"name": "category", "expression": "`category`"}
+{"name": "[sum|avg|count|countdistinct|min|max](col)", "expression": "[SUM|AVG|COUNT|COUNT(DISTINCT)|MIN|MAX](`col`)"}
+{"name": "[daily|weekly|monthly](date)", "expression": "DATE_TRUNC(\"[DAY|WEEK|MONTH]\", `date`)"}
+{"name": "field", "expression": "`field`"}
```
For conditional logic, compute in dataset SQL instead.
diff --git a/databricks-skills/databricks-aibi-dashboards/3-examples.md b/databricks-skills/databricks-aibi-dashboards/3-examples.md
deleted file mode 100644
index fe128d6b..00000000
--- a/databricks-skills/databricks-aibi-dashboards/3-examples.md
+++ /dev/null
@@ -1,305 +0,0 @@
-# Complete Dashboard Examples
-
-Production-ready templates you can adapt for your use case.
-
-## Basic Dashboard (NYC Taxi)
-
-```python
-import json
-
-# Step 1: Check table schema
-table_info = get_table_stats_and_schema(catalog="samples", schema="nyctaxi")
-
-# Step 2: Test queries
-execute_sql("SELECT COUNT(*) as trips, AVG(fare_amount) as avg_fare, AVG(trip_distance) as avg_distance FROM samples.nyctaxi.trips")
-execute_sql("""
- SELECT pickup_zip, COUNT(*) as trip_count
- FROM samples.nyctaxi.trips
- GROUP BY pickup_zip
- ORDER BY trip_count DESC
- LIMIT 10
-""")
-
-# Step 3: Build dashboard JSON
-dashboard = {
- "datasets": [
- {
- "name": "summary",
- "displayName": "Summary Stats",
- "queryLines": [
- "SELECT COUNT(*) as trips, AVG(fare_amount) as avg_fare, ",
- "AVG(trip_distance) as avg_distance ",
- "FROM samples.nyctaxi.trips "
- ]
- },
- {
- "name": "by_zip",
- "displayName": "Trips by ZIP",
- "queryLines": [
- "SELECT pickup_zip, COUNT(*) as trip_count ",
- "FROM samples.nyctaxi.trips ",
- "GROUP BY pickup_zip ",
- "ORDER BY trip_count DESC ",
- "LIMIT 10 "
- ]
- }
- ],
- "pages": [{
- "name": "overview",
- "displayName": "NYC Taxi Overview",
- "pageType": "PAGE_TYPE_CANVAS",
- "layout": [
- # Text header - NO spec block! Use SEPARATE widgets for title and subtitle!
- {
- "widget": {
- "name": "title",
- "multilineTextboxSpec": {
- "lines": ["## NYC Taxi Dashboard"]
- }
- },
- "position": {"x": 0, "y": 0, "width": 6, "height": 1}
- },
- {
- "widget": {
- "name": "subtitle",
- "multilineTextboxSpec": {
- "lines": ["Trip statistics and analysis"]
- }
- },
- "position": {"x": 0, "y": 1, "width": 6, "height": 1}
- },
- # Counter - version 2, width 2!
- {
- "widget": {
- "name": "total-trips",
- "queries": [{
- "name": "main_query",
- "query": {
- "datasetName": "summary",
- "fields": [{"name": "trips", "expression": "`trips`"}],
- "disaggregated": True
- }
- }],
- "spec": {
- "version": 2,
- "widgetType": "counter",
- "encodings": {
- "value": {"fieldName": "trips", "displayName": "Total Trips"}
- },
- "frame": {"title": "Total Trips", "showTitle": True}
- }
- },
- "position": {"x": 0, "y": 2, "width": 2, "height": 3}
- },
- {
- "widget": {
- "name": "avg-fare",
- "queries": [{
- "name": "main_query",
- "query": {
- "datasetName": "summary",
- "fields": [{"name": "avg_fare", "expression": "`avg_fare`"}],
- "disaggregated": True
- }
- }],
- "spec": {
- "version": 2,
- "widgetType": "counter",
- "encodings": {
- "value": {"fieldName": "avg_fare", "displayName": "Avg Fare"}
- },
- "frame": {"title": "Average Fare", "showTitle": True}
- }
- },
- "position": {"x": 2, "y": 2, "width": 2, "height": 3}
- },
- {
- "widget": {
- "name": "total-distance",
- "queries": [{
- "name": "main_query",
- "query": {
- "datasetName": "summary",
- "fields": [{"name": "avg_distance", "expression": "`avg_distance`"}],
- "disaggregated": True
- }
- }],
- "spec": {
- "version": 2,
- "widgetType": "counter",
- "encodings": {
- "value": {"fieldName": "avg_distance", "displayName": "Avg Distance"}
- },
- "frame": {"title": "Average Distance", "showTitle": True}
- }
- },
- "position": {"x": 4, "y": 2, "width": 2, "height": 3}
- },
- # Bar chart - version 3
- {
- "widget": {
- "name": "trips-by-zip",
- "queries": [{
- "name": "main_query",
- "query": {
- "datasetName": "by_zip",
- "fields": [
- {"name": "pickup_zip", "expression": "`pickup_zip`"},
- {"name": "trip_count", "expression": "`trip_count`"}
- ],
- "disaggregated": True
- }
- }],
- "spec": {
- "version": 3,
- "widgetType": "bar",
- "encodings": {
- "x": {"fieldName": "pickup_zip", "scale": {"type": "categorical"}, "displayName": "ZIP"},
- "y": {"fieldName": "trip_count", "scale": {"type": "quantitative"}, "displayName": "Trips"}
- },
- "frame": {"title": "Trips by Pickup ZIP", "showTitle": True}
- }
- },
- "position": {"x": 0, "y": 5, "width": 6, "height": 5}
- },
- # Table - version 2, minimal column props!
- {
- "widget": {
- "name": "zip-table",
- "queries": [{
- "name": "main_query",
- "query": {
- "datasetName": "by_zip",
- "fields": [
- {"name": "pickup_zip", "expression": "`pickup_zip`"},
- {"name": "trip_count", "expression": "`trip_count`"}
- ],
- "disaggregated": True
- }
- }],
- "spec": {
- "version": 2,
- "widgetType": "table",
- "encodings": {
- "columns": [
- {"fieldName": "pickup_zip", "displayName": "ZIP Code"},
- {"fieldName": "trip_count", "displayName": "Trip Count"}
- ]
- },
- "frame": {"title": "Top ZIP Codes", "showTitle": True}
- }
- },
- "position": {"x": 0, "y": 10, "width": 6, "height": 5}
- }
- ]
- }]
-}
-
-# Step 4: Deploy
-result = manage_dashboard(
- action="create_or_update",
- display_name="NYC Taxi Dashboard",
- parent_path="/Workspace/Users/me/dashboards",
- serialized_dashboard=json.dumps(dashboard),
- warehouse_id=manage_warehouse(action="get_best"),
-)
-print(result["url"])
-```
-
-## Dashboard with Global Filters
-
-```python
-import json
-
-# Dashboard with a global filter for region
-dashboard_with_filters = {
- "datasets": [
- {
- "name": "sales",
- "displayName": "Sales Data",
- "queryLines": [
- "SELECT region, SUM(revenue) as total_revenue ",
- "FROM catalog.schema.sales ",
- "GROUP BY region"
- ]
- }
- ],
- "pages": [
- {
- "name": "overview",
- "displayName": "Sales Overview",
- "pageType": "PAGE_TYPE_CANVAS",
- "layout": [
- {
- "widget": {
- "name": "total-revenue",
- "queries": [{
- "name": "main_query",
- "query": {
- "datasetName": "sales",
- "fields": [{"name": "total_revenue", "expression": "`total_revenue`"}],
- "disaggregated": True
- }
- }],
- "spec": {
- "version": 2, # Version 2 for counters!
- "widgetType": "counter",
- "encodings": {
- "value": {"fieldName": "total_revenue", "displayName": "Total Revenue"}
- },
- "frame": {"title": "Total Revenue", "showTitle": True}
- }
- },
- "position": {"x": 0, "y": 0, "width": 6, "height": 3}
- }
- ]
- },
- {
- "name": "filters",
- "displayName": "Filters",
- "pageType": "PAGE_TYPE_GLOBAL_FILTERS", # Required for global filter page!
- "layout": [
- {
- "widget": {
- "name": "filter_region",
- "queries": [{
- "name": "ds_sales_region",
- "query": {
- "datasetName": "sales",
- "fields": [
- {"name": "region", "expression": "`region`"}
- # DO NOT use associative_filter_predicate_group - causes SQL errors!
- ],
- "disaggregated": False # False for filters!
- }
- }],
- "spec": {
- "version": 2, # Version 2 for filters!
- "widgetType": "filter-multi-select", # NOT "filter"!
- "encodings": {
- "fields": [{
- "fieldName": "region",
- "displayName": "Region",
- "queryName": "ds_sales_region" # Must match query name!
- }]
- },
- "frame": {"showTitle": True, "title": "Region"} # Always show title!
- }
- },
- "position": {"x": 0, "y": 0, "width": 2, "height": 2}
- }
- ]
- }
- ]
-}
-
-# Deploy with filters
-result = manage_dashboard(
- action="create_or_update",
- display_name="Sales Dashboard with Filters",
- parent_path="/Workspace/Users/me/dashboards",
- serialized_dashboard=json.dumps(dashboard_with_filters),
- warehouse_id=manage_warehouse(action="get_best"),
-)
-print(result["url"])
-```
diff --git a/databricks-skills/databricks-aibi-dashboards/SKILL.md b/databricks-skills/databricks-aibi-dashboards/SKILL.md
index 99cff124..0e51d509 100644
--- a/databricks-skills/databricks-aibi-dashboards/SKILL.md
+++ b/databricks-skills/databricks-aibi-dashboards/SKILL.md
@@ -1,82 +1,207 @@
---
name: databricks-aibi-dashboards
-description: "Create Databricks AI/BI dashboards. Use when creating, updating, or deploying Lakeview dashboards. CRITICAL: You MUST test ALL SQL queries via execute_sql BEFORE deploying. Follow guidelines strictly."
+description: "Create Databricks AI/BI dashboards. Must use when creating, updating, or deploying Lakeview dashboards as Databricks Dashboard have a unique json structure. CRITICAL: You MUST test ALL SQL queries via CLI BEFORE deploying. Follow guidelines strictly."
---
# AI/BI Dashboard Skill
-Create Databricks AI/BI dashboards (formerly Lakeview dashboards). **Follow these guidelines strictly.**
+Create Databricks AI/BI dashboards (formerly Lakeview dashboards).
+A dashboard should be showing something relevant for a human, typically some KPI on the top, and based on the story, some graph (often temporal), and we see "something happens".
+**Follow these guidelines strictly.**
-## CRITICAL: MANDATORY VALIDATION WORKFLOW
+## Quick Reference
-**You MUST follow this workflow exactly. Skipping validation causes broken dashboards.**
+| Task | Command |
+|------|---------|
+| List warehouses | `databricks warehouses list` |
+| List tables | `databricks experimental aitools tools query --warehouse WH "SHOW TABLES IN catalog.schema"` |
+| Get schema | `databricks experimental aitools tools discover-schema catalog.schema.table1 catalog.schema.table2` |
+| Test query | `databricks experimental aitools tools query --warehouse WH "SELECT..."` |
+| Create dashboard | `databricks lakeview create --display-name "X" --warehouse-id "Y" --serialized-dashboard "$(cat file.json)"` |
+| Update dashboard | `databricks lakeview update DASHBOARD_ID --serialized-dashboard "$(cat file.json)"` |
+| Publish | `databricks lakeview publish DASHBOARD_ID --warehouse-id WH` |
+| Delete | `databricks lakeview trash DASHBOARD_ID` |
+---
+
+## CRITICAL: Widget Version Requirements
+
+> **Wrong version = broken widget!** This is the #1 cause of dashboard errors.
+
+| Widget Type | Version | Notes |
+|-------------|---------|-------|
+| `counter` | **2** | KPI cards |
+| `table` | **2** | Data tables |
+| `bar`, `line`, `area`, `pie`, `scatter` | **3** | Charts |
+| `combo`, `choropleth-map` | **1** | Advanced charts |
+| `filter-*` | **2** | All filter types |
+
+---
+
+## NEW DASHBOARD CREATION WORKFLOW
+
+**You MUST test ALL SQL queries via CLI BEFORE deploying. Follow the overall logic in these steps for new dashboard - Skipping validation causes broken dashboards.**
+
+### Step 1: Get Warehouse ID if not already known
+
+```bash
+# List warehouses to find one for SQL execution
+databricks warehouses list
```
-┌─────────────────────────────────────────────────────────────────────┐
-│ STEP 1: Get table schemas via get_table_stats_and_schema(catalog, schema) │
-├─────────────────────────────────────────────────────────────────────┤
-│ STEP 2: Write SQL queries for each dataset │
-├─────────────────────────────────────────────────────────────────────┤
-│ STEP 3: TEST EVERY QUERY via execute_sql() ← DO NOT SKIP! │
-│ - If query fails, FIX IT before proceeding │
-│ - Verify column names match what widgets will reference │
-│ - Verify data types are correct (dates, numbers, strings) │
-├─────────────────────────────────────────────────────────────────────┤
-│ STEP 4: Build dashboard JSON using ONLY verified queries │
-├─────────────────────────────────────────────────────────────────────┤
-│ STEP 5: Deploy via manage_dashboard(action="create_or_update") │
-└─────────────────────────────────────────────────────────────────────┘
+
+### Step 2: Discover Table Schemas and existing data pattern
+
+```bash
+# Get table schemas for designing queries
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID "SHOW TABLES IN catalog.schema" 2>&1
+# IMPORTANT: Use CATALOG.SCHEMA.TABLE format (full 3-part name required)
+databricks experimental aitools tools discover-schema catalog.schema.table1 catalog.schema.table2
+
+# Example:
+databricks experimental aitools tools discover-schema samples.nyctaxi.trips main.default.customers
+
+# Explore data patterns if needed to confirm the data tells the intended story (to understand what/how to visualize):
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID ""
```
-**WARNING: If you deploy without testing queries, widgets WILL show "Invalid widget definition" errors!**
-
-## Available MCP Tools
-
-| Tool | Description |
-|------|-------------|
-| `get_table_stats_and_schema` | **STEP 1**: Get table schemas for designing queries |
-| `execute_sql` | **STEP 3**: Test SQL queries - MANDATORY before deployment! |
-| `manage_warehouse` (action="get_best") | Get available warehouse ID |
-| `manage_dashboard` | **STEP 5**: Dashboard lifecycle management (see actions below) |
-
-### manage_dashboard Actions
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create_or_update` | Deploy dashboard JSON (only after validation!) | display_name, parent_path, serialized_dashboard, warehouse_id |
-| `get` | Get dashboard details by ID | dashboard_id |
-| `list` | List all dashboards | (none) |
-| `delete` | Move dashboard to trash | dashboard_id |
-| `publish` | Publish a dashboard | dashboard_id, warehouse_id |
-| `unpublish` | Unpublish a dashboard | dashboard_id |
-
-**Example usage:**
-```python
-# Create/update dashboard
-manage_dashboard(
- action="create_or_update",
- display_name="Sales Dashboard",
- parent_path="/Workspace/Users/me/dashboards",
- serialized_dashboard=dashboard_json,
- warehouse_id="abc123",
- publish=True # auto-publish after create
-)
-# Get dashboard details
-manage_dashboard(action="get", dashboard_id="dashboard_123")
+### Step 3: Verify Data Matches Story
+The datasets.querylines in the dashboard json (see example below) must be tested to ensure
+
+Before finalizing, run the SQL Queries you intend to add in each dataset to confirm that they run properly and that the result are valid.
+This is crucial, as the widget defined in the json will use the query field output to render the visualization. The value should also make sense at a business level.
+Remember that for the filter to work, the query should have the field available (so typically group by the filter field)
+
+If values don't match expectations, ensure the query is correct, fix the data if you can, or adjust the story before creating the dashboard.
+
+### Step 4: Plan Dashboard Structure
+
+Before writing JSON, plan your dashboard:
+
+1. You must know the expected specific JSON structure. For this, **Read reference files**: [1-widget-specifications.md](1-widget-specifications.md), [3-filters.md](3-filters.md), [4-examples.md](4-examples.md)
+
+2. Think: **What widgets?** Map each visualization to a dataset:
+ | Widget | Type | Dataset | Has filter field? |
+ |--------|------|---------|-------------------|
+ | Revenue KPI | counter | ds_sales | ✓ date, region |
+ | Trend Chart | line | ds_sales | ✓ date, region |
+ | Top Products | table | ds_products | ✗ no date |
+ ...
+
+3. **What filters?** For each filter, verify ALL datasets you want filtered contain the filter field.
+ > **Filters only affect datasets that have the filter field.** A pre-aggregated table without dates WON'T be date-filtered.
+
+4. **Write JSON locally** as a file.
+
+### Step 5: Dashboard Lifecycle
+Once created, you can edit the file as following:
+```bash
+# Create a dashboard
+# IMPORTANT: Use --display-name, --warehouse-id, and --serialized-dashboard (NOT --json @file.json with displayName in it)
+databricks lakeview create \
+ --display-name "My Dashboard" \
+ --warehouse-id "abc123def456" \
+ --serialized-dashboard "$(cat dashboard.json)"
+
+# Alternative: Use --json with the correct structure
+databricks lakeview create --json '{
+ "display_name": "My Dashboard",
+ "warehouse_id": "abc123def456",
+ "serialized_dashboard": "{\"datasets\":[...],\"pages\":[...]}"
+}'
# List all dashboards
-manage_dashboard(action="list")
+databricks lakeview list
+
+# Get dashboard details
+databricks lakeview get DASHBOARD_ID
+
+# Update a dashboard
+databricks lakeview update DASHBOARD_ID --serialized-dashboard "$(cat dashboard.json)"
+
+# Publish a dashboard
+databricks lakeview publish DASHBOARD_ID --warehouse-id WAREHOUSE_ID
+
+# Unpublish a dashboard
+databricks lakeview unpublish DASHBOARD_ID
+
+# Delete (trash) a dashboard
+databricks lakeview trash DASHBOARD_ID
```
+---
+
+## JSON Structure (Required Skeleton)
+
+Every dashboard's `serialized_dashboard` content must follow this exact structure:
+
+```json
+{
+ "datasets": [
+ {
+ "name": "ds_x",
+ "displayName": "Dataset X",
+ "queryLines": ["SELECT col1, col2 ", "FROM catalog.schema.table"]
+ }
+ ],
+ "pages": [
+ {
+ "name": "main",
+ "displayName": "Main",
+ "pageType": "PAGE_TYPE_CANVAS",
+ "layout": [
+ {"widget": {/* INLINE widget definition */}, "position": {"x":0,"y":0,"width":2,"height":3}}
+ ]
+ }
+ ]
+}
+```
+
+**Structural rules (violations cause "failed to parse serialized dashboard"):**
+- `queryLines`: Array of strings, NOT `"query": "string"`
+- Widgets: INLINE in `layout[].widget`, NOT a separate `"widgets"` array
+- `pageType`: Required on every page (`PAGE_TYPE_CANVAS` or `PAGE_TYPE_GLOBAL_FILTERS`)
+- Query binding: `query.fields[].name` must exactly match `encodings.*.fieldName`
+
+### Linking a Genie Space (Optional)
+
+To add an "Ask Genie" button to the dashboard, or to link a genie space/room with an ID, add `uiSettings.genieSpace` to the JSON:
+
+```json
+{
+ "datasets": [...],
+ "pages": [...],
+ "uiSettings": {
+ "genieSpace": {
+ "isEnabled": true,
+ "overrideId": "your-genie-space-id-here",
+ "enablementMode": "ENABLED"
+ }
+ }
+}
+```
+
+> **Genie is NOT a widget.** Link via `uiSettings.genieSpace` only. There is no `"widgetType": "assistant"`.
+
+---
+
+## Design Best Practices
+
+Apply unless user specifies otherwise:
+- **Global date filter**: When data has temporal columns, add a date range filter. Most dashboards need time-based filtering.
+- **KPI time bounds**: Use time-bounded metrics that enable period comparison (MoM, YoY). Unbounded "all-time" totals are less actionable.
+- **Value formatting**: Format values based on their meaning — currency with symbol, percentages with %, large numbers compacted (K/M/B).
+- **Chart selection**: Match cardinality to chart type. Few distinct values → pie/bar with color grouping; many values → table.
+
## Reference Files
| What are you building? | Reference |
|------------------------|-----------|
| Any widget (text, counter, table, chart) | [1-widget-specifications.md](1-widget-specifications.md) |
-| Dashboard with filters (global or page-level) | [2-filters.md](2-filters.md) |
-| Need a complete working template to adapt | [3-examples.md](3-examples.md) |
-| Debugging a broken dashboard | [4-troubleshooting.md](4-troubleshooting.md) |
+| Advanced charts (area, scatter/Bubble, combo (Line+Bar), Choropleth map) | [2-advanced-widget-specifications.md](2-advanced-widget-specifications.md) |
+| Dashboard with filters (global or page-level) | [3-filters.md](3-filters.md) |
+| Need a complete working template to adapt | [4-examples.md](4-examples.md) |
+| Debugging a broken dashboard | [5-troubleshooting.md](5-troubleshooting.md) |
---
@@ -84,12 +209,16 @@ manage_dashboard(action="list")
### 1) DATASET ARCHITECTURE
-- **One dataset per domain** (e.g., orders, customers, products)
+- **One dataset per domain** (e.g., orders, customers, products). Datasets shared across widgets benefit from the same filters.
- **Exactly ONE valid SQL query per dataset** (no multiple queries separated by `;`)
- Always use **fully-qualified table names**: `catalog.schema.table_name`
- SELECT must include all dimensions needed by widgets and all derived columns via `AS` aliases
- Put ALL business logic (CASE/WHEN, COALESCE, ratios) into the dataset SELECT with explicit aliases
- **Contract rule**: Every widget `fieldName` must exactly match a dataset column or alias
+- **Add ORDER BY** when visualization depends on data order:
+ - Time series: `ORDER BY date` for chronological display
+ - Rankings/Top-N: `ORDER BY metric DESC LIMIT 10` for "Top 10" charts
+ - Categorical charts: `ORDER BY metric DESC` to show largest values first
### 2) WIDGET FIELD EXPRESSIONS
@@ -117,26 +246,10 @@ manage_dashboard(action="list")
Allowed expressions in widget queries (you CANNOT use CAST or other SQL in expressions):
-**For numbers:**
```json
-{"name": "sum(revenue)", "expression": "SUM(`revenue`)"}
-{"name": "avg(price)", "expression": "AVG(`price`)"}
-{"name": "count(orders)", "expression": "COUNT(`order_id`)"}
-{"name": "countdistinct(customers)", "expression": "COUNT(DISTINCT `customer_id`)"}
-{"name": "min(date)", "expression": "MIN(`order_date`)"}
-{"name": "max(date)", "expression": "MAX(`order_date`)"}
-```
-
-**For dates** (use daily for timeseries, weekly/monthly for grouped comparisons):
-```json
-{"name": "daily(date)", "expression": "DATE_TRUNC(\"DAY\", `date`)"}
-{"name": "weekly(date)", "expression": "DATE_TRUNC(\"WEEK\", `date`)"}
-{"name": "monthly(date)", "expression": "DATE_TRUNC(\"MONTH\", `date`)"}
-```
-
-**Simple field reference** (for pre-aggregated data):
-```json
-{"name": "category", "expression": "`category`"}
+{"name": "[sum|avg|count|countdistinct|min|max](col)", "expression": "[SUM|AVG|COUNT|COUNT(DISTINCT)|MIN|MAX](`col`)"}
+{"name": "[daily|weekly|monthly](date)", "expression": "DATE_TRUNC(\"[DAY|WEEK|MONTH]\", `date`)"}
+{"name": "field", "expression": "`field`"}
```
If you need conditional logic or multi-field formulas, compute a derived column in the dataset SQL first.
@@ -153,13 +266,20 @@ Each widget has a position: `{"x": 0, "y": 0, "width": 2, "height": 4}`
**CRITICAL**: Each row must fill width=6 exactly. No gaps allowed.
+```
+CORRECT: WRONG:
+y=0: [w=6] y=0: [w=4]____ ← gap!
+y=1: [w=2][w=2][w=2] ← fills 6 y=1: [w=1][w=1][w=1][w=1]__ ← gap!
+y=4: [w=3][w=3] ← fills 6
+```
+
**Recommended widget sizes:**
| Widget Type | Width | Height | Notes |
|-------------|-------|--------|-------|
| Text header | 6 | 1 | Full width; use SEPARATE widgets for title and subtitle |
| Counter/KPI | 2 | **3-4** | **NEVER height=2** - too cramped! |
-| Line/Bar chart | 3 | **5-6** | Pair side-by-side to fill row |
+| Line/Bar/Area chart | 3 | **5-6** | Pair side-by-side to fill row |
| Pie chart | 3 | **5-6** | Needs space for legend |
| Full-width chart | 6 | 5-7 | For detailed time series |
| Table | 6 | 5-8 | Full width for readability |
@@ -182,11 +302,11 @@ y=12: Table (w=6, h=6) - Detailed data
| Dimension Type | Max Values | Examples |
|----------------|------------|----------|
| Chart color/groups | **3-8** | 4 regions, 5 product lines, 3 tiers |
-| Filters | 4-10 | 8 countries, 5 channels |
+| Filters | 4-15 | 8 countries, 5 channels |
| High cardinality | **Table only** | customer_id, order_id, SKU |
**Before creating any chart with color/grouping:**
-1. Check column cardinality (use `get_table_stats_and_schema` to see distinct values)
+1. Check column cardinality via discover-schema or a COUNT DISTINCT query
2. If >10 distinct values, aggregate to higher level OR use TOP-N + "Other" bucket
3. For high-cardinality dimensions, use a table widget instead of a chart
@@ -196,13 +316,29 @@ Before deploying, verify:
1. All widget names use only alphanumeric + hyphens + underscores
2. All rows sum to width=6 with no gaps
3. KPIs use height 3-4, charts use height 5-6
-4. Chart dimensions have ≤8 distinct values
+4. Chart dimensions have reasonable cardinality (≤8 for colors/groups)
5. All widget fieldNames match dataset columns exactly
6. **Field `name` in query.fields matches `fieldName` in encodings exactly** (e.g., both `"sum(spend)"`)
7. Counter datasets: use `disaggregated: true` for 1-row datasets, `disaggregated: false` with aggregation for multi-row
-8. Percent values are 0-1 (not 0-100)
+8. **Percent values must be 0-1 for `number-percent` format** (0.865 displays as "86.5%", don't forget to set the format). If data is 0-100, either divide by 100 in SQL or use `number` format instead.
9. SQL uses Spark syntax (date_sub, not INTERVAL)
-10. **All SQL queries tested via `execute_sql` and return expected data**
+10. **All SQL queries tested via CLI and return expected data**
+11. **Every dataset you want filtered MUST contain the filter field** — filters only affect datasets with that column in their query
+
+---
+
+## Data Variance Considerations
+
+Before creating trend charts, check if the metric has enough variance to visualize meaningfully:
+
+```sql
+SELECT MIN(metric), MAX(metric), MAX(metric) - MIN(metric) as range FROM dataset
+```
+
+If the range is very small relative to the scale (e.g., 83-89% on a 0-100 scale), the chart will appear nearly flat. Consider:
+- Showing as KPI with delta/comparison instead of chart
+- Using a table to display exact values
+- Adjusting the visualization to focus on the variance
---
diff --git a/databricks-skills/databricks-app-python/4-deployment.md b/databricks-skills/databricks-app-python/4-deployment.md
index b318bbdf..384c82ac 100644
--- a/databricks-skills/databricks-app-python/4-deployment.md
+++ b/databricks-skills/databricks-app-python/4-deployment.md
@@ -1,6 +1,6 @@
# Deploying Databricks Apps
-Three deployment options: Databricks CLI (simplest), Asset Bundles (multi-environment), or MCP tools (programmatic).
+Three deployment options: Databricks CLI (simplest), Asset Bundles (multi-environment), or CLI commands (programmatic).
**Cookbook deployment guide**: https://apps-cookbook.dev/docs/deploy
@@ -107,9 +107,9 @@ For complete DABs guidance, use the **databricks-bundles** skill.
---
-## Option 3: MCP Tools
+## Option 3: CLI Commands
-For programmatic app lifecycle management, see [6-mcp-approach.md](6-mcp-approach.md).
+For CLI-based app lifecycle management, see [6-cli-approach.md](6-cli-approach.md).
---
diff --git a/databricks-skills/databricks-app-python/6-cli-approach.md b/databricks-skills/databricks-app-python/6-cli-approach.md
new file mode 100644
index 00000000..01543509
--- /dev/null
+++ b/databricks-skills/databricks-app-python/6-cli-approach.md
@@ -0,0 +1,87 @@
+# CLI Commands for App Lifecycle
+
+Use the Databricks CLI to create, deploy, and manage Databricks Apps.
+
+---
+
+## databricks apps - App Lifecycle Management
+
+```bash
+# List all apps
+databricks apps list
+
+# Create an app
+databricks apps create --name my-dashboard --json '{"description": "Customer analytics dashboard"}'
+
+# Get app details
+databricks apps get my-dashboard
+
+# Deploy an app (from workspace source code)
+databricks apps deploy my-dashboard --source-code-path /Workspace/Users/user@example.com/my_app
+
+# Get app logs
+databricks apps logs my-dashboard
+
+# Delete an app
+databricks apps delete my-dashboard
+```
+
+---
+
+## Workflow
+
+### Step 1: Write App Files Locally
+
+Create your app files in a local folder:
+
+```
+my_app/
+├── app.py # Main application
+├── models.py # Pydantic models
+├── backend.py # Data access layer
+├── requirements.txt # Additional dependencies
+└── app.yaml # Databricks Apps configuration
+```
+
+### Step 2: Upload to Workspace
+
+```bash
+# Upload local folder to workspace
+databricks workspace import-dir /path/to/my_app /Workspace/Users/user@example.com/my_app
+```
+
+### Step 3: Create and Deploy App
+
+```bash
+# Create the app
+databricks apps create --name my-dashboard --json '{"description": "Customer analytics dashboard"}'
+
+# Deploy from workspace source
+databricks apps deploy my-dashboard --source-code-path /Workspace/Users/user@example.com/my_app
+```
+
+### Step 4: Verify
+
+```bash
+# Check app status
+databricks apps get my-dashboard
+
+# Check logs for errors
+databricks apps logs my-dashboard
+```
+
+### Step 5: Iterate
+
+1. Fix issues in local files
+2. Re-upload with `databricks workspace import-dir /path/to/my_app /Workspace/Users/user@example.com/my_app`
+3. Re-deploy with `databricks apps deploy my-dashboard --source-code-path ...`
+4. Check `databricks apps logs my-dashboard` for errors
+5. Repeat until app is healthy
+
+---
+
+## Notes
+
+- Add resources (SQL warehouse, Lakebase, etc.) via the Databricks Apps UI after creating the app
+- CLI uses your configured profile's credentials — ensure you have access to required resources
+- For DABs deployment, see [4-deployment.md](4-deployment.md)
diff --git a/databricks-skills/databricks-app-python/6-mcp-approach.md b/databricks-skills/databricks-app-python/6-mcp-approach.md
deleted file mode 100644
index 943c49ba..00000000
--- a/databricks-skills/databricks-app-python/6-mcp-approach.md
+++ /dev/null
@@ -1,79 +0,0 @@
-# MCP Tools for App Lifecycle
-
-Use MCP tools to create, deploy, and manage Databricks Apps programmatically. This mirrors the CLI workflow but can be invoked by AI agents.
-
----
-
-## manage_app - App Lifecycle Management
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create_or_update` | Idempotent create, deploys if source_code_path provided | name |
-| `get` | Get app details (with optional logs) | name |
-| `list` | List all apps | (none, optional name_contains filter) |
-| `delete` | Delete an app | name |
-
----
-
-## Workflow
-
-### Step 1: Write App Files Locally
-
-Create your app files in a local folder:
-
-```
-my_app/
-├── app.py # Main application
-├── models.py # Pydantic models
-├── backend.py # Data access layer
-├── requirements.txt # Additional dependencies
-└── app.yaml # Databricks Apps configuration
-```
-
-### Step 2: Upload to Workspace
-
-```python
-# MCP Tool: manage_workspace_files
-manage_workspace_files(
- action="upload",
- local_path="/path/to/my_app",
- workspace_path="/Workspace/Users/user@example.com/my_app"
-)
-```
-
-### Step 3: Create and Deploy App
-
-```python
-# MCP Tool: manage_app (creates if needed + deploys)
-result = manage_app(
- action="create_or_update",
- name="my-dashboard",
- description="Customer analytics dashboard",
- source_code_path="/Workspace/Users/user@example.com/my_app"
-)
-# Returns: {"name": "my-dashboard", "url": "...", "created": True, "deployment": {...}}
-```
-
-### Step 4: Verify
-
-```python
-# MCP Tool: manage_app (get with logs)
-app = manage_app(action="get", name="my-dashboard", include_logs=True)
-# Returns: {"name": "...", "url": "...", "status": "RUNNING", "logs": "...", ...}
-```
-
-### Step 5: Iterate
-
-1. Fix issues in local files
-2. Re-upload with `manage_workspace_files(action="upload", ...)`
-3. Re-deploy with `manage_app(action="create_or_update", ...)` (will update existing + deploy)
-4. Check `manage_app(action="get", name=..., include_logs=True)` for errors
-5. Repeat until app is healthy
-
----
-
-## Notes
-
-- Add resources (SQL warehouse, Lakebase, etc.) via the Databricks Apps UI after creating the app
-- MCP tools use the service principal's permissions — ensure it has access to required resources
-- For manual deployment, see [4-deployment.md](4-deployment.md)
diff --git a/databricks-skills/databricks-app-python/SKILL.md b/databricks-skills/databricks-app-python/SKILL.md
index 777d3377..7b34b74b 100644
--- a/databricks-skills/databricks-app-python/SKILL.md
+++ b/databricks-skills/databricks-app-python/SKILL.md
@@ -72,7 +72,7 @@ Copy this checklist and verify each item:
**Lakebase**: Use [5-lakebase.md](5-lakebase.md) when using Lakebase (PostgreSQL) as your app's data layer — covers auto-injected env vars, psycopg2/asyncpg patterns, and when to choose Lakebase vs SQL warehouse. (Keywords: Lakebase, PostgreSQL, psycopg2, asyncpg, transactional, PGHOST)
-**MCP tools**: Use [6-mcp-approach.md](6-mcp-approach.md) for managing app lifecycle via MCP tools — covers creating, deploying, monitoring, and deleting apps programmatically. (Keywords: MCP, create app, deploy app, app logs)
+**CLI commands**: Use [6-cli-approach.md](6-cli-approach.md) for managing app lifecycle via CLI — covers creating, deploying, monitoring, and deleting apps. (Keywords: CLI, create app, deploy app, app logs)
**Foundation Models**: See [examples/llm_config.py](examples/llm_config.py) for calling Databricks foundation model APIs — covers OAuth M2M auth, OpenAI-compatible client wiring, and token caching. (Keywords: foundation model, LLM, OpenAI client, chat completions)
@@ -87,7 +87,7 @@ Copy this checklist and verify each item:
**Connecting to data/resources?** → Read [2-app-resources.md](2-app-resources.md)
**Using Lakebase (PostgreSQL)?** → Read [5-lakebase.md](5-lakebase.md)
**Deploying to Databricks?** → Read [4-deployment.md](4-deployment.md)
- **Using MCP tools?** → Read [6-mcp-approach.md](6-mcp-approach.md)
+ **Using CLI for app lifecycle?** → Read [6-cli-approach.md](6-cli-approach.md)
**Calling foundation model/LLM APIs?** → See [examples/llm_config.py](examples/llm_config.py)
2. Follow the instructions in the relevant guide
diff --git a/databricks-skills/databricks-config/SKILL.md b/databricks-skills/databricks-config/SKILL.md
index 118713d1..21728f19 100644
--- a/databricks-skills/databricks-config/SKILL.md
+++ b/databricks-skills/databricks-config/SKILL.md
@@ -3,20 +3,144 @@ name: databricks-config
description: "Manage Databricks workspace connections: check current workspace, switch profiles, list available workspaces, or authenticate to a new workspace. Use when the user mentions \"switch workspace\", \"which workspace\", \"current profile\", \"databrickscfg\", \"connect to workspace\", or \"databricks auth\"."
---
-Use the `manage_workspace` MCP tool for all workspace operations. Do NOT edit `~/.databrickscfg`, use Bash, or use the Databricks CLI.
+Use the Databricks CLI for all workspace operations.
-## Steps
+## CLI Commands
-1. Call `ToolSearch` with query `select:mcp__databricks__manage_workspace` to load the tool.
+### Check Current Workspace
-2. Map user intent to action:
- - status / which workspace / current → `action="status"`
- - list / available workspaces → `action="list"`
- - switch to X → call `list` first to find the profile name, then `action="switch", profile=""` (or `host=""` if a URL was given)
- - login / connect / authenticate → `action="login", host=""`
+```bash
+# Show current configuration status
+databricks auth describe
-3. Call `mcp__databricks__manage_workspace` with the action and any parameters.
+# Show current workspace URL
+databricks config get --key host
-4. Present the result. For `status`/`switch`/`login`: show host, profile, username. For `list`: formatted table with the active profile marked.
+# Show current profile
+databricks config get --key profile
+```
-> **Note:** The switch is session-scoped — it resets on MCP server restart. For permanent profile setup, use `databricks auth login -p ` and update `~/.databrickscfg` with `cluster_id` or `serverless_compute_id = auto`.
+### List Available Profiles
+
+```bash
+# List all configured profiles from ~/.databrickscfg
+cat ~/.databrickscfg | grep '^\[' | tr -d '[]'
+```
+
+### Switch Workspace/Profile
+
+```bash
+# Use a different profile for subsequent commands
+databricks --profile auth describe
+
+# Or set environment variable for the session
+export DATABRICKS_CONFIG_PROFILE=
+```
+
+### Authenticate to New Workspace
+
+```bash
+# OAuth login (opens browser)
+databricks auth login --host https://your-workspace.cloud.databricks.com
+
+# OAuth login with profile name
+databricks auth login --host https://your-workspace.cloud.databricks.com --profile my-profile
+
+# Configure with PAT
+databricks configure --profile my-profile
+```
+
+### Verify Authentication
+
+```bash
+# Check auth status
+databricks auth describe
+
+# Test by listing clusters
+databricks clusters list
+```
+
+## ~/.databrickscfg Format
+
+```ini
+[DEFAULT]
+host = https://your-workspace.cloud.databricks.com
+cluster_id = 0123-456789-abc123
+# or
+serverless_compute_id = auto
+
+[production]
+host = https://prod-workspace.cloud.databricks.com
+token = dapi...
+
+[development]
+host = https://dev-workspace.cloud.databricks.com
+```
+
+## Python SDK
+
+```python
+from databricks.sdk import WorkspaceClient
+
+# Use default profile
+w = WorkspaceClient()
+
+# Use specific profile
+w = WorkspaceClient(profile="production")
+
+# Use specific host
+w = WorkspaceClient(host="https://your-workspace.cloud.databricks.com")
+
+# Check current user
+print(w.current_user.me().user_name)
+```
+
+> **Note:** Profile changes via environment variables or CLI flags are session-scoped. For permanent profile setup, use `databricks auth login -p ` and update `~/.databrickscfg` with `cluster_id` or `serverless_compute_id = auto`.
+
+## CLI Syntax Patterns
+
+**IMPORTANT**: Use `--json` for creating Unity Catalog objects. This is the most reliable syntax.
+
+```bash
+# ✅ CORRECT - use --json for create operations
+databricks catalogs create --json '{"name": "my_catalog"}'
+databricks schemas create --json '{"name": "my_schema", "catalog_name": "my_catalog"}'
+databricks volumes create --json '{"name": "my_volume", "catalog_name": "my_catalog", "schema_name": "my_schema", "volume_type": "MANAGED"}'
+```
+
+### Common CLI Patterns
+
+```bash
+# Get help for any command
+databricks --help
+databricks schemas create --help
+
+# List operations
+databricks catalogs list
+databricks schemas list CATALOG_NAME
+databricks volumes list CATALOG_NAME.SCHEMA_NAME
+databricks clusters list
+databricks warehouses list
+
+# Create operations (use --json)
+databricks catalogs create --json '{"name": "my_catalog"}'
+databricks schemas create --json '{"name": "my_schema", "catalog_name": "my_catalog"}'
+databricks volumes create --json '{"name": "my_volume", "catalog_name": "my_catalog", "schema_name": "my_schema", "volume_type": "MANAGED"}'
+
+# Delete operations (use full name)
+databricks catalogs delete CATALOG_NAME
+databricks schemas delete CATALOG_NAME.SCHEMA_NAME
+databricks volumes delete CATALOG_NAME.SCHEMA_NAME.VOLUME_NAME
+```
+
+### SQL Execution via CLI
+
+```bash
+# Run SQL query
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID "SELECT * FROM catalog.schema.table LIMIT 10"
+
+# Create objects via SQL (alternative approach)
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID "CREATE CATALOG my_catalog"
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID "CREATE SCHEMA my_catalog.my_schema"
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID "CREATE VOLUME my_catalog.my_schema.my_volume"
+```
diff --git a/databricks-skills/databricks-dbsql/SKILL.md b/databricks-skills/databricks-dbsql/SKILL.md
index 24bf2694..043228b9 100644
--- a/databricks-skills/databricks-dbsql/SKILL.md
+++ b/databricks-skills/databricks-dbsql/SKILL.md
@@ -297,4 +297,4 @@ Load these for detailed syntax, full parameter lists, and advanced patterns:
- **Star schema in Gold layer** for BI; OBT acceptable in Silver
- **Define PK/FK constraints** on dimensional models for query optimization
- **Use `COLLATE UTF8_LCASE`** for user-facing string columns that need case-insensitive search
-- **Use MCP tools** (`execute_sql`, `execute_sql_multi`) to test and validate all SQL before deploying
+- **Test SQL via CLI** (`databricks experimental aitools tools query`) or notebooks before deploying
diff --git a/databricks-skills/databricks-docs/SKILL.md b/databricks-skills/databricks-docs/SKILL.md
index ceca11e0..8e9d68d5 100644
--- a/databricks-skills/databricks-docs/SKILL.md
+++ b/databricks-skills/databricks-docs/SKILL.md
@@ -5,7 +5,7 @@ description: "Databricks documentation reference via llms.txt index. Use when ot
# Databricks Documentation Reference
-This skill provides access to the complete Databricks documentation index via llms.txt - use it as a **reference resource** to supplement other skills and inform your use of MCP tools.
+This skill provides access to the complete Databricks documentation index via llms.txt - use it as a **reference resource** to supplement other skills.
## Role of This Skill
@@ -13,10 +13,10 @@ This is a **reference skill**, not an action skill. Use it to:
- Look up documentation when other skills don't cover a topic
- Get authoritative guidance on Databricks concepts and APIs
-- Find detailed information to inform how you use MCP tools
+- Find detailed information to inform CLI commands and SDK usage
- Discover features and capabilities you may not know about
-**Always prefer using MCP tools for actions** (execute_sql, manage_pipeline, etc.) and **load specific skills for workflows** (databricks-python-sdk, databricks-spark-declarative-pipelines, etc.). Use this skill when you need reference documentation.
+**Always prefer using CLI/SDK for actions** and **load specific skills for workflows** (databricks-python-sdk, databricks-spark-declarative-pipelines, etc.). Use this skill when you need reference documentation.
## How to Use
@@ -28,7 +28,7 @@ Use WebFetch to retrieve this index, then:
1. Search for relevant sections/links
2. Fetch specific documentation pages for detailed guidance
-3. Apply what you learn using the appropriate MCP tools
+3. Apply what you learn using the appropriate CLI commands or SDK
## Documentation Structure
@@ -47,7 +47,7 @@ The llms.txt file is organized by category:
1. Load `databricks-spark-declarative-pipelines` skill for workflow patterns
2. Use this skill to fetch docs if you need clarification on specific DLT features
-3. Use `manage_pipeline(action="create_or_update")` MCP tool to actually create the pipeline
+3. Use `databricks pipelines create` CLI command to create the pipeline
**Scenario:** User asks about an unfamiliar Databricks feature
diff --git a/databricks-skills/databricks-execution-compute/SKILL.md b/databricks-skills/databricks-execution-compute/SKILL.md
index c3518385..151a467d 100644
--- a/databricks-skills/databricks-execution-compute/SKILL.md
+++ b/databricks-skills/databricks-execution-compute/SKILL.md
@@ -27,6 +27,7 @@ Run code on Databricks. Three execution modes—choose based on workload.
### Decision Flow
+Prefer Databricks Connect for all spark-based workload.
```
Spark-based code? → Databricks Connect (fastest)
└─ Python 3.12 missing? → Install it + databricks-connect
@@ -42,7 +43,7 @@ Scala/R? → Interactive Cluster (list and ask which one to use)
**Read the reference file for your chosen mode before proceeding.**
-### Databricks Connect (no MCP tool, run locally) → [reference](references/1-databricks-connect.md)
+### Databricks Connect (run locally) → [reference](references/1-databricks-connect.md)
```bash
python my_spark_script.py
@@ -50,30 +51,48 @@ python my_spark_script.py
### Serverless Job → [reference](references/2-serverless-job.md)
-```python
-execute_code(file_path="/path/to/script.py")
+```bash
+# Create and run a job with serverless compute
+databricks jobs create --json '{
+ "name": "my-script-job",
+ "tasks": [{
+ "task_key": "main",
+ "spark_python_task": {"python_file": "/Workspace/Users/me/script.py"},
+ "environment_key": "default"
+ }],
+ "environments": [{"environment_key": "default", "spec": {"client": "4"}}]
+}'
+
+# Run the job
+databricks jobs run-now --job-id JOB_ID
```
### Interactive Cluster → [reference](references/3-interactive-cluster.md)
-```python
-# Check for running clusters first (or use the one instructed)
-list_compute(resource="clusters")
-# Ask the customer which one to use
-
-# Run code, reuse context_id for follow-up MCP call
-result = execute_code(code="...", compute_type="cluster", cluster_id="...")
-execute_code(code="...", context_id=result["context_id"], cluster_id=result["cluster_id"])
+```bash
+# List running clusters
+databricks clusters list --output json | jq '.[] | select(.state == "RUNNING")'
+
+# Run a notebook or script on a cluster
+databricks workspace import /Workspace/Users/me/script.py --file ./script.py
+databricks jobs create --json '{
+ "name": "cluster-job",
+ "tasks": [{
+ "task_key": "main",
+ "existing_cluster_id": "CLUSTER_ID",
+ "spark_python_task": {"python_file": "/Workspace/Users/me/script.py"}
+ }]
+}'
```
-## MCP Tools
+## CLI Commands
-| Tool | For | Purpose |
-|------|-----|---------|
-| `execute_code` | Serverless, Interactive | Run code remotely |
-| `list_compute` | Interactive | List clusters, check status, auto-select running cluster |
-| `manage_cluster` | Interactive | Create, start, terminate, delete. **COSTLY:** `start` takes 3-8 min—ask user |
-| `manage_sql_warehouse` | SQL | Create, modify, delete SQL warehouses |
+| Command | For | Purpose |
+|---------|-----|---------|
+| `databricks jobs create/run-now` | Serverless, Cluster | Run code remotely |
+| `databricks clusters list` | Interactive | List clusters, check status |
+| `databricks clusters create/start/delete` | Interactive | Manage clusters. **COSTLY:** `start` takes 3-8 min |
+| `databricks warehouses create/list` | SQL | Manage SQL warehouses |
## Related Skills
diff --git a/databricks-skills/databricks-execution-compute/references/1-databricks-connect.md b/databricks-skills/databricks-execution-compute/references/1-databricks-connect.md
index 838d2a7d..39be79a4 100644
--- a/databricks-skills/databricks-execution-compute/references/1-databricks-connect.md
+++ b/databricks-skills/databricks-execution-compute/references/1-databricks-connect.md
@@ -30,16 +30,12 @@ auth_type = databricks-cli
## Usage Pattern
```python
-from databricks.connect import DatabricksSession, DatabricksEnv
-
-# Declare dependencies installed on serverless compute
-# CRITICAL: Include ALL packages used inside UDFs (pandas/numpy are there by default)
-env = DatabricksEnv().withDependencies("faker", "holidays")
+from databricks.connect import DatabricksSession
+# Install dependencies locally first: uv pip install faker holidays
spark = (
DatabricksSession.builder
- .profile("my-workspace") # optional: run on a specific profile from ~/.databrickscfg instead of default
- .withEnvironment(env)
+ .profile("my-workspace") # optional: use a specific profile from ~/.databrickscfg
.serverless(True)
.getOrCreate()
)
@@ -54,9 +50,8 @@ df.write.mode('overwrite').saveAsTable("catalog.schema.table")
| Issue | Solution |
|-------|----------|
| `Python 3.12 required` | create venv with correct python version |
-| `DatabricksEnv not found` | Upgrade to databricks-connect >= 16.4 |
| `serverless_compute_id` error | Add `serverless_compute_id = auto` to ~/.databrickscfg |
-| `ModuleNotFoundError` inside UDF | Add the package to `withDependencies()` |
+| `ModuleNotFoundError` inside UDF | Install the package locally: `uv pip install ` |
| `PERSIST TABLE not supported` | Don't use `.cache()` or `.persist()` with serverless |
| `broadcast` is used | Don't broadcast small DF using spark connect, have a small python list instead or join small DF |
@@ -68,5 +63,5 @@ Switch to **[Serverless Job](2-serverless-job.md)** when:
- Non-Spark Python code (pure sklearn, pytorch, etc.)
Switch to **[Interactive Cluster](3-interactive-cluster.md)** when:
-- Need state across multiple separate MCP tool calls
+- Need state across multiple separate tool calls
- Need Scala or R support
diff --git a/databricks-skills/databricks-execution-compute/references/2-serverless-job.md b/databricks-skills/databricks-execution-compute/references/2-serverless-job.md
index 4be8801c..6cc29fd9 100644
--- a/databricks-skills/databricks-execution-compute/references/2-serverless-job.md
+++ b/databricks-skills/databricks-execution-compute/references/2-serverless-job.md
@@ -72,5 +72,5 @@ Switch to **[Databricks Connect](1-databricks-connect.md)** when:
- Need local debugging with breakpoints
Switch to **[Interactive Cluster](3-interactive-cluster.md)** when:
-- Need state across multiple MCP tool calls
+- Need state across multiple tool calls
- Need Scala or R support
diff --git a/databricks-skills/databricks-execution-compute/references/3-interactive-cluster.md b/databricks-skills/databricks-execution-compute/references/3-interactive-cluster.md
index aa73ea90..fbff2469 100644
--- a/databricks-skills/databricks-execution-compute/references/3-interactive-cluster.md
+++ b/databricks-skills/databricks-execution-compute/references/3-interactive-cluster.md
@@ -1,6 +1,6 @@
# Interactive Cluster Execution
-**Use when:** You have an existing running cluster and need to preserve state across multiple MCP tool calls, or need Scala/R support.
+**Use when:** You have an existing running cluster and need to preserve state across multiple tool calls, or need Scala/R support.
## When to Choose Interactive Cluster
@@ -20,8 +20,8 @@
**Starting a cluster takes 3-8 minutes and costs money.** Always check first:
-```python
-list_compute(resource="clusters")
+```bash
+python scripts/compute.py list-compute --resource clusters
```
If no cluster is running, ask the user:
@@ -34,58 +34,80 @@ If no cluster is running, ask the user:
### First Command: Creates Context
-```python
-result = execute_code(
- code="import pandas as pd\ndf = pd.DataFrame({'a': [1, 2, 3]})",
- compute_type="cluster",
- cluster_id="1234-567890-abcdef"
-)
-# result contains context_id for reuse
+```bash
+python scripts/compute.py execute-code \
+ --code "import pandas as pd; df = pd.DataFrame({'a': [1, 2, 3]}); print(df)" \
+ --compute-type cluster \
+ --cluster-id "1234-567890-abcdef"
+```
+
+Response includes `context_id` for reuse:
+```json
+{
+ "success": true,
+ "output": " a\n0 1\n1 2\n2 3",
+ "context_id": "ctx_abc123",
+ "cluster_id": "1234-567890-abcdef"
+}
```
### Follow-up Commands: Reuse Context
-```python
+```bash
# Variables from first command still available
-execute_code(
- code="print(df.shape)", # df exists
- context_id=result["context_id"],
- cluster_id=result["cluster_id"]
-)
+python scripts/compute.py execute-code \
+ --code "print(df.shape)" \
+ --compute-type cluster \
+ --cluster-id "1234-567890-abcdef" \
+ --context-id "ctx_abc123"
```
### Auto-Select Best Running Cluster
-```python
-best_cluster = list_compute(resource="clusters", auto_select=True)
-execute_code(
- code="spark.range(100).show()",
- compute_type="cluster",
- cluster_id=best_cluster["cluster_id"]
-)
+```bash
+# Get best running cluster
+python scripts/compute.py list-compute --auto-select
+# Returns: {"cluster_id": "1234-567890-abcdef"}
+
+# Then execute on it
+python scripts/compute.py execute-code \
+ --code "spark.range(100).show()" \
+ --compute-type cluster \
+ --cluster-id "1234-567890-abcdef"
```
## Language Support
-```python
-execute_code(code='println("Hello")', compute_type="cluster", language="scala")
-execute_code(code="SELECT * FROM table LIMIT 10", compute_type="cluster", language="sql")
-execute_code(code='print("Hello")', compute_type="cluster", language="r")
+```bash
+# Scala
+python scripts/compute.py execute-code --code 'println("Hello")' --compute-type cluster --language scala --cluster-id ...
+
+# SQL
+python scripts/compute.py execute-code --code "SELECT * FROM table LIMIT 10" --compute-type cluster --language sql --cluster-id ...
+
+# R
+python scripts/compute.py execute-code --code 'print("Hello")' --compute-type cluster --language r --cluster-id ...
```
## Installing Libraries
-Install pip packages directly in the execution context (pandas/numpy are there by default):
-
-```python
-# Install library
-execute_code(
- code="""%pip install faker
- dbutils.library.restartPython()""", # Restart Python to pick up new packages (if needed)
- compute_type="cluster",
- cluster_id="...",
- context_id="..."
-)
+Install pip packages directly in the execution context:
+
+```bash
+python scripts/compute.py execute-code \
+ --code "%pip install faker" \
+ --compute-type cluster \
+ --cluster-id "..." \
+ --context-id "..."
+```
+
+If needed, restart Python to pick up new packages:
+```bash
+python scripts/compute.py execute-code \
+ --code "dbutils.library.restartPython()" \
+ --compute-type cluster \
+ --cluster-id "..." \
+ --context-id "..."
```
## Context Lifecycle
@@ -93,32 +115,31 @@ execute_code(
**Keep alive (default):** Context persists until cluster terminates.
**Destroy when done:**
-```python
-execute_code(
- code="print('Done!')",
- compute_type="cluster",
- destroy_context_on_completion=True
-)
+```bash
+python scripts/compute.py execute-code \
+ --code "print('Done!')" \
+ --compute-type cluster \
+ --cluster-id "..." \
+ --destroy-context
```
-## Handling No Running Cluster
+## Managing Clusters
-When no cluster is running, `execute_code` returns:
-```json
-{
- "success": false,
- "error": "No running cluster available",
- "startable_clusters": [{"cluster_id": "...", "cluster_name": "...", "state": "TERMINATED"}],
- "suggestions": ["Start a terminated cluster", "Use serverless instead"]
-}
-```
+```bash
+# List all clusters
+python scripts/compute.py list-compute --resource clusters
+
+# Get specific cluster status
+python scripts/compute.py list-compute --cluster-id "1234-567890-abcdef"
+
+# Start a cluster (WITH USER APPROVAL ONLY - costs money, 3-8min startup)
+python scripts/compute.py manage-cluster --action start --cluster-id "1234-567890-abcdef"
-### Starting a Cluster (With User Approval Only)
+# Terminate a cluster (reversible)
+python scripts/compute.py manage-cluster --action terminate --cluster-id "1234-567890-abcdef"
-```python
-manage_cluster(action="start", cluster_id="1234-567890-abcdef")
-# Poll until running (wait 20sec)
-list_compute(resource="clusters", cluster_id="1234-567890-abcdef")
+# Create a new cluster
+python scripts/compute.py manage-cluster --action create --name "my-cluster" --num-workers 2
```
## Common Issues
@@ -127,7 +148,7 @@ list_compute(resource="clusters", cluster_id="1234-567890-abcdef")
|-------|----------|
| "No running cluster" | Ask user to start or use serverless |
| Context not found | Context expired; create new one |
-| Library not found | `%pip install ` then if needed `dbutils.library.restartPython()` |
+| Library not found | `%pip install ` then restart Python if needed |
## When NOT to Use
diff --git a/databricks-skills/databricks-execution-compute/scripts/compute.py b/databricks-skills/databricks-execution-compute/scripts/compute.py
new file mode 100644
index 00000000..0e584f3c
--- /dev/null
+++ b/databricks-skills/databricks-execution-compute/scripts/compute.py
@@ -0,0 +1,668 @@
+#!/usr/bin/env python3
+"""Compute CLI - Execute code and manage compute resources on Databricks.
+
+Standalone script with no external dependencies beyond databricks-sdk.
+
+Commands:
+- execute-code: Run code on serverless or cluster compute
+- list-compute: List clusters, node types, or spark versions
+- manage-cluster: Create, start, terminate, or delete clusters
+
+Requires: pip install databricks-sdk
+"""
+
+import argparse
+import base64
+import json
+import uuid
+from dataclasses import dataclass
+from datetime import timedelta
+from typing import Any, Dict, List, Optional
+
+from databricks.sdk import WorkspaceClient
+from databricks.sdk.service.compute import (
+ ClusterSource,
+ CommandStatus,
+ ContextStatus,
+ Environment,
+ Language,
+ ListClustersFilterBy,
+ ResultType,
+ State,
+)
+from databricks.sdk.service.jobs import (
+ JobEnvironment,
+ NotebookTask,
+ RunResultState,
+ Source,
+ SubmitTask,
+)
+from databricks.sdk.service.workspace import ImportFormat, Language as WsLang
+
+
+# ---------------------------------------------------------------------------
+# Authentication
+# ---------------------------------------------------------------------------
+
+def get_workspace_client() -> WorkspaceClient:
+ """Get authenticated WorkspaceClient using standard auth chain."""
+ return WorkspaceClient()
+
+
+def get_current_username() -> str:
+ """Get the current user's username."""
+ w = get_workspace_client()
+ return w.current_user.me().user_name
+
+
+# ---------------------------------------------------------------------------
+# Exceptions
+# ---------------------------------------------------------------------------
+
+class NoRunningClusterError(Exception):
+ """Raised when no running cluster is available."""
+
+ def __init__(self, message: str, suggestions: List[str] = None, startable_clusters: List[Dict] = None):
+ super().__init__(message)
+ self.suggestions = suggestions or []
+ self.startable_clusters = startable_clusters or []
+
+
+# ---------------------------------------------------------------------------
+# Result Classes
+# ---------------------------------------------------------------------------
+
+@dataclass
+class ExecutionResult:
+ """Result from cluster command execution."""
+ success: bool
+ output: str = ""
+ error: str = ""
+ cluster_id: str = ""
+ context_id: str = ""
+ status: str = ""
+ result_type: str = ""
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "success": self.success,
+ "output": self.output,
+ "error": self.error,
+ "cluster_id": self.cluster_id,
+ "context_id": self.context_id,
+ "status": self.status,
+ "result_type": self.result_type,
+ }
+
+
+@dataclass
+class ServerlessRunResult:
+ """Result from serverless code execution."""
+ success: bool
+ output: str = ""
+ error: str = ""
+ run_id: int = 0
+ run_page_url: str = ""
+ state: str = ""
+ execution_duration_ms: int = 0
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "success": self.success,
+ "output": self.output,
+ "error": self.error,
+ "run_id": self.run_id,
+ "run_page_url": self.run_page_url,
+ "state": self.state,
+ "execution_duration_ms": self.execution_duration_ms,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Cluster Execution
+# ---------------------------------------------------------------------------
+
+def list_clusters() -> List[Dict[str, Any]]:
+ """List interactive clusters created by humans (UI/API, not jobs)."""
+ w = get_workspace_client()
+ clusters = []
+ # Filter to only UI and API created clusters (interactive, human-created)
+ # Excludes JOB clusters (created by jobs) and other system clusters
+ filter_by = ListClustersFilterBy(
+ cluster_sources=[ClusterSource.UI, ClusterSource.API]
+ )
+ for c in w.clusters.list(filter_by=filter_by, page_size=100):
+ clusters.append({
+ "cluster_id": c.cluster_id,
+ "cluster_name": c.cluster_name,
+ "state": c.state.value if c.state else "UNKNOWN",
+ "creator_user_name": c.creator_user_name,
+ "spark_version": c.spark_version,
+ "node_type_id": c.node_type_id,
+ "num_workers": c.num_workers,
+ })
+ return clusters
+
+
+def get_best_cluster() -> str:
+ """Get the best running interactive cluster ID, or raise NoRunningClusterError."""
+ w = get_workspace_client()
+ running = []
+ startable = []
+
+ # Filter to only interactive clusters (UI/API created)
+ filter_by = ListClustersFilterBy(
+ cluster_sources=[ClusterSource.UI, ClusterSource.API]
+ )
+ for c in w.clusters.list(filter_by=filter_by, page_size=100):
+ info = {
+ "cluster_id": c.cluster_id,
+ "cluster_name": c.cluster_name,
+ "state": c.state.value if c.state else "UNKNOWN",
+ }
+ if c.state == State.RUNNING:
+ running.append(info)
+ elif c.state in (State.TERMINATED, State.PENDING):
+ startable.append(info)
+
+ if running:
+ return running[0]["cluster_id"]
+
+ raise NoRunningClusterError(
+ "No running cluster available.",
+ suggestions=[
+ "Start an existing cluster with: python compute.py manage-cluster --action start --cluster-id ",
+ "Use serverless compute: python compute.py execute-code --compute-type serverless --code '...'",
+ ],
+ startable_clusters=startable,
+ )
+
+
+def start_cluster(cluster_id: str) -> Dict[str, Any]:
+ """Start a cluster and wait for it to be running."""
+ w = get_workspace_client()
+ w.clusters.start(cluster_id=cluster_id)
+ # Don't wait - just return immediately
+ return {"success": True, "cluster_id": cluster_id, "message": "Cluster start initiated"}
+
+
+def get_cluster_status(cluster_id: str) -> Dict[str, Any]:
+ """Get the status of a specific cluster."""
+ w = get_workspace_client()
+ c = w.clusters.get(cluster_id=cluster_id)
+ return {
+ "cluster_id": c.cluster_id,
+ "cluster_name": c.cluster_name,
+ "state": c.state.value if c.state else "UNKNOWN",
+ "state_message": c.state_message,
+ "creator_user_name": c.creator_user_name,
+ "spark_version": c.spark_version,
+ "node_type_id": c.node_type_id,
+ "num_workers": c.num_workers,
+ }
+
+
+def _get_or_create_context(w: WorkspaceClient, cluster_id: str, context_id: Optional[str], language: str) -> str:
+ """Get existing context or create a new one."""
+ lang_map = {"python": Language.PYTHON, "scala": Language.SCALA, "sql": Language.SQL, "r": Language.R}
+ lang = lang_map.get(language.lower(), Language.PYTHON)
+
+ if context_id:
+ # Verify context exists
+ try:
+ status = w.command_execution.context_status(cluster_id=cluster_id, context_id=context_id)
+ if status.status == ContextStatus.RUNNING:
+ return context_id
+ except Exception:
+ pass # Context doesn't exist, create new one
+
+ # Create new context
+ ctx = w.command_execution.create(cluster_id=cluster_id, language=lang).result()
+ return ctx.id
+
+
+def execute_databricks_command(
+ code: str,
+ cluster_id: Optional[str] = None,
+ context_id: Optional[str] = None,
+ language: str = "python",
+ timeout: int = 120,
+ destroy_context_on_completion: bool = False,
+) -> ExecutionResult:
+ """Execute code on a Databricks cluster using Command Execution API."""
+ w = get_workspace_client()
+
+ # Get cluster ID if not provided
+ if not cluster_id:
+ cluster_id = get_best_cluster()
+
+ # Get or create context
+ ctx_id = _get_or_create_context(w, cluster_id, context_id, language)
+
+ # Execute command
+ lang_map = {"python": Language.PYTHON, "scala": Language.SCALA, "sql": Language.SQL, "r": Language.R}
+ lang = lang_map.get(language.lower(), Language.PYTHON)
+
+ try:
+ cmd = w.command_execution.execute(
+ cluster_id=cluster_id,
+ context_id=ctx_id,
+ language=lang,
+ command=code,
+ ).result(timeout=timedelta(seconds=timeout))
+
+ # Parse results
+ output = ""
+ error = ""
+ result_type = cmd.results.result_type.value if cmd.results and cmd.results.result_type else ""
+
+ if cmd.results:
+ if cmd.results.result_type == ResultType.TEXT:
+ output = cmd.results.data or ""
+ elif cmd.results.result_type == ResultType.TABLE:
+ output = json.dumps(cmd.results.data) if cmd.results.data else ""
+ elif cmd.results.result_type == ResultType.ERROR:
+ error = cmd.results.cause or str(cmd.results.data) or "Unknown error"
+
+ success = cmd.status == CommandStatus.FINISHED and cmd.results.result_type != ResultType.ERROR
+
+ return ExecutionResult(
+ success=success,
+ output=output,
+ error=error,
+ cluster_id=cluster_id,
+ context_id=ctx_id,
+ status=cmd.status.value if cmd.status else "",
+ result_type=result_type,
+ )
+
+ finally:
+ if destroy_context_on_completion and ctx_id:
+ try:
+ w.command_execution.destroy(cluster_id=cluster_id, context_id=ctx_id)
+ except Exception:
+ pass
+
+
+# ---------------------------------------------------------------------------
+# Serverless Execution
+# ---------------------------------------------------------------------------
+
+def run_code_on_serverless(
+ code: str,
+ language: str = "python",
+ timeout: int = 1800,
+) -> ServerlessRunResult:
+ """Run code on serverless compute using Jobs API runs/submit."""
+ w = get_workspace_client()
+
+ # Create temp notebook
+ username = get_current_username()
+ notebook_name = f"_tmp_serverless_{uuid.uuid4().hex[:8]}"
+ notebook_path = f"/Workspace/Users/{username}/.tmp/{notebook_name}"
+
+ # Ensure directory exists
+ try:
+ w.workspace.mkdirs(f"/Workspace/Users/{username}/.tmp")
+ except Exception:
+ pass
+
+ # Upload notebook content
+ if language.lower() == "sql":
+ notebook_content = f"-- Databricks notebook source\n{code}"
+ else:
+ notebook_content = f"# Databricks notebook source\n{code}"
+
+ content_b64 = base64.b64encode(notebook_content.encode()).decode()
+
+ ws_lang_map = {"python": WsLang.PYTHON, "sql": WsLang.SQL}
+ ws_lang = ws_lang_map.get(language.lower(), WsLang.PYTHON)
+
+ w.workspace.import_(
+ path=notebook_path,
+ content=content_b64,
+ format=ImportFormat.SOURCE,
+ language=ws_lang,
+ overwrite=True,
+ )
+
+ try:
+ # Submit run
+ run = w.jobs.submit(
+ run_name=f"serverless-run-{uuid.uuid4().hex[:8]}",
+ tasks=[
+ SubmitTask(
+ task_key="main",
+ notebook_task=NotebookTask(
+ notebook_path=notebook_path,
+ source=Source.WORKSPACE,
+ ),
+ environment_key="default",
+ )
+ ],
+ environments=[
+ JobEnvironment(
+ environment_key="default",
+ spec=Environment(client="1"),
+ )
+ ],
+ ).result(timeout=timedelta(seconds=timeout))
+
+ # Get run output
+ run_output = w.jobs.get_run_output(run_id=run.tasks[0].run_id)
+
+ output = ""
+ error = ""
+ success = run.state.result_state == RunResultState.SUCCESS
+
+ if run_output.notebook_output and run_output.notebook_output.result:
+ output = run_output.notebook_output.result
+ if run_output.error:
+ error = run_output.error
+
+ return ServerlessRunResult(
+ success=success,
+ output=output,
+ error=error,
+ run_id=run.run_id,
+ run_page_url=run.run_page_url or "",
+ state=run.state.result_state.value if run.state and run.state.result_state else "",
+ execution_duration_ms=run.execution_duration or 0,
+ )
+
+ finally:
+ # Cleanup temp notebook
+ try:
+ w.workspace.delete(notebook_path)
+ except Exception:
+ pass
+
+
+# ---------------------------------------------------------------------------
+# Cluster Management
+# ---------------------------------------------------------------------------
+
+def create_cluster(
+ name: str,
+ num_workers: int = 1,
+ autotermination_minutes: int = 120,
+ spark_version: Optional[str] = None,
+ node_type_id: Optional[str] = None,
+) -> Dict[str, Any]:
+ """Create a new cluster."""
+ w = get_workspace_client()
+
+ # Get defaults if not provided
+ if not spark_version:
+ versions = list(w.clusters.spark_versions())
+ # Pick latest LTS
+ for v in versions:
+ if "LTS" in v.name and "ML" not in v.name:
+ spark_version = v.key
+ break
+ if not spark_version and versions:
+ spark_version = versions[0].key
+
+ if not node_type_id:
+ node_types = list(w.clusters.list_node_types().node_types)
+ # Pick smallest available
+ for nt in sorted(node_types, key=lambda x: x.memory_mb or 0):
+ if nt.is_deprecated is not True:
+ node_type_id = nt.node_type_id
+ break
+
+ cluster = w.clusters.create(
+ cluster_name=name,
+ spark_version=spark_version,
+ node_type_id=node_type_id,
+ num_workers=num_workers,
+ autotermination_minutes=autotermination_minutes,
+ ).result()
+
+ return {
+ "success": True,
+ "cluster_id": cluster.cluster_id,
+ "cluster_name": name,
+ "message": "Cluster created",
+ }
+
+
+def terminate_cluster(cluster_id: str) -> Dict[str, Any]:
+ """Terminate a cluster (can be restarted)."""
+ w = get_workspace_client()
+ w.clusters.delete(cluster_id=cluster_id)
+ return {"success": True, "cluster_id": cluster_id, "message": "Cluster terminated"}
+
+
+def delete_cluster(cluster_id: str) -> Dict[str, Any]:
+ """Permanently delete a cluster."""
+ w = get_workspace_client()
+ w.clusters.permanent_delete(cluster_id=cluster_id)
+ return {"success": True, "cluster_id": cluster_id, "message": "Cluster permanently deleted"}
+
+
+def list_node_types() -> List[Dict[str, Any]]:
+ """List available node types."""
+ w = get_workspace_client()
+ result = []
+ for nt in w.clusters.list_node_types().node_types:
+ result.append({
+ "node_type_id": nt.node_type_id,
+ "memory_mb": nt.memory_mb,
+ "num_cores": nt.num_cores,
+ "description": nt.description,
+ "is_deprecated": nt.is_deprecated,
+ })
+ return result
+
+
+def list_spark_versions() -> List[Dict[str, Any]]:
+ """List available Spark versions."""
+ w = get_workspace_client()
+ result = []
+ response = w.clusters.spark_versions()
+ for v in response.versions or []:
+ result.append({
+ "key": v.key,
+ "name": v.name,
+ })
+ return result
+
+
+# ---------------------------------------------------------------------------
+# CLI Commands
+# ---------------------------------------------------------------------------
+
+def _none_if_empty(value):
+ """Convert empty strings to None."""
+ return None if value == "" else value
+
+
+def _no_cluster_error_response(e: NoRunningClusterError) -> Dict[str, Any]:
+ """Build a structured error response when no running cluster is available."""
+ return {
+ "success": False,
+ "error": str(e),
+ "suggestions": e.suggestions,
+ "startable_clusters": e.startable_clusters,
+ }
+
+
+def cmd_execute_code(args):
+ """Execute code on Databricks via serverless or cluster compute."""
+ code = _none_if_empty(args.code)
+ file_path = _none_if_empty(args.file)
+ cluster_id = _none_if_empty(args.cluster_id)
+ context_id = _none_if_empty(args.context_id)
+ language = _none_if_empty(args.language) or "python"
+ compute_type = args.compute_type
+ timeout = args.timeout
+ destroy_context = args.destroy_context
+
+ if not code and not file_path:
+ return {"success": False, "error": "Either --code or --file must be provided."}
+
+ # Read code from file if provided
+ if file_path and not code:
+ try:
+ with open(file_path, "r", encoding="utf-8") as f:
+ code = f.read()
+ except FileNotFoundError:
+ return {"success": False, "error": f"File not found: {file_path}"}
+
+ # Resolve "auto" compute type
+ if compute_type == "auto":
+ if cluster_id or context_id:
+ compute_type = "cluster"
+ elif language.lower() in ("scala", "r"):
+ compute_type = "cluster"
+ else:
+ compute_type = "serverless"
+
+ # Serverless execution
+ if compute_type == "serverless":
+ default_timeout = timeout if timeout else 1800
+ result = run_code_on_serverless(
+ code=code,
+ language=language,
+ timeout=default_timeout,
+ )
+ return result.to_dict()
+
+ # Cluster execution
+ default_timeout = timeout if timeout else 120
+ try:
+ result = execute_databricks_command(
+ code=code,
+ cluster_id=cluster_id,
+ context_id=context_id,
+ language=language,
+ timeout=default_timeout,
+ destroy_context_on_completion=destroy_context,
+ )
+ return result.to_dict()
+ except NoRunningClusterError as e:
+ return _no_cluster_error_response(e)
+
+
+def cmd_list_compute(args):
+ """List compute resources: clusters, node types, or spark versions."""
+ resource = args.resource.lower()
+ cluster_id = _none_if_empty(args.cluster_id)
+ auto_select = args.auto_select
+
+ if resource == "clusters":
+ if cluster_id:
+ return get_cluster_status(cluster_id)
+ if auto_select:
+ try:
+ best = get_best_cluster()
+ return {"cluster_id": best}
+ except NoRunningClusterError as e:
+ return _no_cluster_error_response(e)
+ return {"clusters": list_clusters()}
+
+ elif resource == "node_types":
+ return {"node_types": list_node_types()}
+
+ elif resource == "spark_versions":
+ return {"spark_versions": list_spark_versions()}
+
+ else:
+ return {"success": False, "error": f"Unknown resource: {resource}. Use: clusters, node_types, spark_versions"}
+
+
+def cmd_manage_cluster(args):
+ """Create, start, terminate, or delete a cluster."""
+ action = args.action.lower()
+ cluster_id = _none_if_empty(args.cluster_id)
+ name = _none_if_empty(args.name)
+
+ if action == "create":
+ if not name:
+ return {"success": False, "error": "name is required for create action."}
+ return create_cluster(
+ name=name,
+ num_workers=args.num_workers or 1,
+ autotermination_minutes=args.autotermination_minutes or 120,
+ )
+
+ elif action == "start":
+ if not cluster_id:
+ return {"success": False, "error": "cluster_id is required for start action."}
+ return start_cluster(cluster_id)
+
+ elif action == "terminate":
+ if not cluster_id:
+ return {"success": False, "error": "cluster_id is required for terminate action."}
+ return terminate_cluster(cluster_id)
+
+ elif action == "delete":
+ if not cluster_id:
+ return {"success": False, "error": "cluster_id is required for delete action."}
+ return delete_cluster(cluster_id)
+
+ elif action == "get":
+ if not cluster_id:
+ return {"success": False, "error": "cluster_id is required for get action."}
+ try:
+ return get_cluster_status(cluster_id)
+ except Exception as e:
+ if "does not exist" in str(e).lower():
+ return {"success": True, "cluster_id": cluster_id, "state": "DELETED", "exists": False}
+ return {"success": False, "error": str(e)}
+
+ else:
+ return {"success": False, "error": f"Unknown action: {action}. Use: create, start, terminate, delete, get"}
+
+
+# ---------------------------------------------------------------------------
+# CLI Setup
+# ---------------------------------------------------------------------------
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Execute code and manage compute on Databricks",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ subparsers = parser.add_subparsers(dest="command", required=True)
+
+ # execute-code
+ exec_parser = subparsers.add_parser("execute-code", help="Run code on Databricks")
+ exec_parser.add_argument("--code", help="Code to execute")
+ exec_parser.add_argument("--file", help="File to execute")
+ exec_parser.add_argument("--compute-type", default="auto", choices=["auto", "serverless", "cluster"],
+ help="Compute type (default: auto)")
+ exec_parser.add_argument("--cluster-id", help="Cluster ID (for cluster compute)")
+ exec_parser.add_argument("--context-id", help="Context ID (reuse existing context)")
+ exec_parser.add_argument("--language", default="python", choices=["python", "scala", "sql", "r"],
+ help="Language (default: python)")
+ exec_parser.add_argument("--timeout", type=int, help="Timeout in seconds")
+ exec_parser.add_argument("--destroy-context", action="store_true", help="Destroy context after execution")
+ exec_parser.set_defaults(func=cmd_execute_code)
+
+ # list-compute
+ list_parser = subparsers.add_parser("list-compute", help="List compute resources")
+ list_parser.add_argument("--resource", default="clusters", choices=["clusters", "node_types", "spark_versions"],
+ help="Resource to list (default: clusters)")
+ list_parser.add_argument("--cluster-id", help="Get specific cluster status")
+ list_parser.add_argument("--auto-select", action="store_true", help="Return best running cluster")
+ list_parser.set_defaults(func=cmd_list_compute)
+
+ # manage-cluster
+ manage_parser = subparsers.add_parser("manage-cluster", help="Manage clusters")
+ manage_parser.add_argument("--action", required=True, choices=["create", "start", "terminate", "delete", "get"],
+ help="Action to perform")
+ manage_parser.add_argument("--cluster-id", help="Cluster ID")
+ manage_parser.add_argument("--name", help="Cluster name (for create)")
+ manage_parser.add_argument("--num-workers", type=int, help="Number of workers (for create)")
+ manage_parser.add_argument("--autotermination-minutes", type=int, help="Auto-termination minutes (for create)")
+ manage_parser.set_defaults(func=cmd_manage_cluster)
+
+ args = parser.parse_args()
+ result = args.func(args)
+ print(json.dumps(result, indent=2, default=str))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/databricks-skills/databricks-genie/SKILL.md b/databricks-skills/databricks-genie/SKILL.md
index 82332476..a8652ebc 100644
--- a/databricks-skills/databricks-genie/SKILL.md
+++ b/databricks-skills/databricks-genie/SKILL.md
@@ -5,196 +5,205 @@ description: "Create and query Databricks Genie Spaces for natural language SQL
# Databricks Genie
-Create, manage, and query Databricks Genie Spaces - natural language interfaces for SQL-based data exploration.
+Create, manage, and query Genie Spaces - natural language interfaces for SQL-based data exploration.
## Overview
Genie Spaces allow users to ask natural language questions about structured data in Unity Catalog. The system translates questions into SQL queries, executes them on a SQL warehouse, and presents results conversationally.
-## When to Use This Skill
-
-Use this skill when:
-- Creating a new Genie Space for data exploration
-- Adding sample questions to guide users
-- Connecting Unity Catalog tables to a conversational interface
-- Asking questions to a Genie Space programmatically (Conversation API)
-- Exporting a Genie Space configuration (serialized_space) for backup or migration
-- Importing / cloning a Genie Space from a serialized payload
-- Migrating a Genie Space between workspaces or environments (dev → staging → prod)
- - Only supports catalog remapping where catalog names differ across environments
- - Not supported for schema and/or table names that differ across environments
- - Not including migration of tables between environments (only migration of Genie Spaces)
-
-## MCP Tools
-
-| Tool | Purpose |
-|------|---------|
-| `manage_genie` | Create, get, list, delete, export, and import Genie Spaces |
-| `ask_genie` | Ask natural language questions to a Genie Space |
-| `get_table_stats_and_schema` | Inspect table schemas before creating a space |
-| `execute_sql` | Test SQL queries directly |
-
-### manage_genie - Space Management
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create_or_update` | Idempotent create/update a space | display_name, table_identifiers (or serialized_space) |
-| `get` | Get space details | space_id |
-| `list` | List all spaces | (none) |
-| `delete` | Delete a space | space_id |
-| `export` | Export space config for migration/backup | space_id |
-| `import` | Import space from serialized config | warehouse_id, serialized_space |
-
-**Example tool calls:**
-```
-# MCP Tool: manage_genie
-# Create a new space
-manage_genie(
- action="create_or_update",
- display_name="Sales Analytics",
- table_identifiers=["catalog.schema.customers", "catalog.schema.orders"],
- description="Explore sales data with natural language",
- sample_questions=["What were total sales last month?"]
-)
-
-# MCP Tool: manage_genie
-# Get space details with full config
-manage_genie(action="get", space_id="space_123", include_serialized_space=True)
-
-# MCP Tool: manage_genie
-# List all spaces
-manage_genie(action="list")
-
-# MCP Tool: manage_genie
-# Export for migration
-exported = manage_genie(action="export", space_id="space_123")
-
-# MCP Tool: manage_genie
-# Import to new workspace
-manage_genie(
- action="import",
- warehouse_id="warehouse_456",
- serialized_space=exported["serialized_space"],
- title="Sales Analytics (Prod)"
-)
-```
+## Creating a Genie Space
-### ask_genie - Conversation API (Query)
+### Step 1: Understand the Data
-Ask natural language questions to a Genie Space. Pass `conversation_id` for follow-up questions.
+Before creating a Genie Space, explore the available tables to:
+- **Select relevant tables** — typically gold layer (aggregated KPIs) and sometimes silver layer (cleaned facts) or metric views
+- **Understand the story** — what business questions can this data answer? What insights can users discover?
+- **Design meaningful sample questions** — questions should reflect real use cases and lead to actionable insights in the data
-```
-# MCP Tool: ask_genie
-# Start a new conversation
-result = ask_genie(
- space_id="space_123",
- question="What were total sales last month?"
-)
-# Returns: {question, conversation_id, message_id, status, sql, columns, data, row_count}
-
-# MCP Tool: ask_genie
-# Follow-up question in same conversation
-result = ask_genie(
- space_id="space_123",
- question="Break that down by region",
- conversation_id=result["conversation_id"]
-)
+```bash
+# Discover table schemas, columns, and sample values
+databricks experimental aitools tools discover-schema catalog.schema.gold_sales catalog.schema.gold_customers
+
+# Run SQL queries to explore the data and understand relationships
+databricks sql exec "SELECT * FROM catalog.schema.gold_sales LIMIT 10"
+databricks sql exec "DESCRIBE TABLE catalog.schema.gold_sales"
```
-## Quick Start
+### Step 2: Create the Space
-### 1. Inspect Your Tables
+Define your space in a local JSON file (e.g., `genie_space.json`) for version control and easy iteration. See "serialized_space Format" below for the full structure.
-Before creating a Genie Space, understand your data:
+```bash
+# List all Genie Spaces
+databricks genie list-spaces
-```
-# MCP Tool: get_table_stats_and_schema
-get_table_stats_and_schema(
- catalog="my_catalog",
- schema="sales",
- table_stat_level="SIMPLE"
-)
-```
+# Create a Genie Space from a local file
+# IMPORTANT: sample_questions require a 32-char hex "id" and "question" must be an array
+databricks genie create-space --json "{
+ \"warehouse_id\": \"WAREHOUSE_ID\",
+ \"title\": \"Sales Analytics\",
+ \"description\": \"Explore sales data\",
+ \"parent_path\": \"/Workspace/Users/you@company.com/genie_spaces\",
+ \"serialized_space\": $(cat genie_space.json | jq -c '.' | jq -Rs '.')
+}"
-### 2. Create the Genie Space
+# Get space details (with full config)
+databricks genie get-space SPACE_ID --include-serialized-space
-```
-# MCP Tool: manage_genie
-manage_genie(
- action="create_or_update",
- display_name="Sales Analytics",
- table_identifiers=[
- "my_catalog.sales.customers",
- "my_catalog.sales.orders"
- ],
- description="Explore sales data with natural language",
- sample_questions=[
- "What were total sales last month?",
- "Who are our top 10 customers?"
- ]
-)
+# Delete a Genie Space
+databricks genie trash-space SPACE_ID
```
-### 3. Ask Questions (Conversation API)
+### Step 3: Test and Iterate
+Use `scripts/conversation.py` (see Conversation API section below) to test questions and verify answers are accurate.
+
+If answers are inaccurate or incomplete, improve the space — see "Improving a Genie Space" below.
+
+### Export & Import
+
+```bash
+# Export space configuration
+databricks genie export-space SPACE_ID > exported.json
+
+# Import space from exported config
+databricks genie import-space --json @exported.json
```
-# MCP Tool: ask_genie
-ask_genie(
- space_id="your_space_id",
- question="What were total sales last month?"
-)
-# Returns: SQL, columns, data, row_count
-```
-### 4. Export & Import (Clone / Migrate)
+### Improving a Genie Space
+
+When Genie answers are inaccurate or incomplete, improve the space by updating questions, SQL examples, or instructions:
-Export a space (preserves all tables, instructions, SQL examples, and layout):
+```bash
+# 1. Edit your local genie_space.json (add questions, fix SQL examples, improve instructions)
+# 2. Push updates back to the space
+databricks genie update-space SPACE_ID --json "{\"serialized_space\": $(cat genie_space.json | jq -c '.' | jq -Rs '.')}"
```
-# MCP Tool: manage_genie
-exported = manage_genie(action="export", space_id="your_space_id")
-# exported["serialized_space"] contains the full config
+
+## serialized_space Format
+
+The `serialized_space` field is a JSON string containing the full space configuration.
+
+### Structure
+
+```json
+{
+ "version": 2,
+ "config": {
+ "sample_questions": [...]
+ },
+ "data_sources": {
+ "tables": [{"identifier": "catalog.schema.table"}]
+ },
+ "instructions": {
+ "example_question_sqls": [...],
+ "text_instructions": [...]
+ }
+}
```
-Clone to a new space (same catalog):
+### Field Format Requirements
+
+**IMPORTANT:** All items in `sample_questions`, `example_question_sqls`, and `text_instructions` require a unique `id` field.
+| Field | Format |
+|-------|--------|
+| `config.sample_questions[]` | `{"id": "32hexchars", "question": ["..."]}` |
+| `instructions.example_question_sqls[]` | `{"id": "32hexchars", "question": ["..."], "sql": ["..."]}` |
+| `instructions.text_instructions[]` | `{"id": "32hexchars", "content": ["..."]}` |
+
+- **ID format:** 32-character lowercase hex UUID without hyphens.
+- **Text fields are arrays:** `question`, `sql`, and `content` are arrays of strings, not plain strings.
+
+### Text Instructions
+
+`text_instructions` make the Genie Space more reliable by explaining:
+- **Where to find information** — which tables contain which metrics
+- **How to answer specific questions** — when a user asks X, use table Y with filter Z
+- **Business context** — definitions, thresholds, and domain knowledge
+
+Well-crafted instructions significantly improve answer accuracy.
+
+### Complete Example
+
+This example shows a properly formatted `serialized_space` with sample questions, SQL examples, and text instructions. Note that every item has a unique 32-char hex `id` and all text fields are arrays:
+
+```json
+{
+ "version": 2,
+ "config": {
+ "sample_questions": [
+ {"id": "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4", "question": ["What is our current on-time performance?"]},...
+ ]
+ },
+ "data_sources": {
+ "tables": [
+ {"identifier": "catalog.ops.gold_otp_summary"},...
+ ]
+ },
+ "instructions": {
+ "example_question_sqls": [
+ {
+ "id": "b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5",
+ "question": ["What is our on-time performance?"],
+ "sql": ["SELECT flight_date, ROUND(SUM(on_time_count) * 100.0 / SUM(total_flights), 1) AS otp_pct\n", "FROM catalog.ops.gold_otp_summary\n", "WHERE flight_date >= date_sub(current_date(), 7)\n", "GROUP BY flight_date ORDER BY flight_date"]
+ }
+ ],
+ "text_instructions": [
+ {
+ "id": "c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6",
+ "content": [
+ "On-time performance (OTP) questions: Use gold_otp_summary table. OTP target is 85%.\n",
+ "Delay analysis questions: Use gold_delay_analysis table. Filter by delay_code for specific delay types.\n",
+ "When asked about 'this week' or 'recent': Use flight_date >= date_sub(current_date(), 7).\n",
+ "When comparing aircraft: Join with gold_aircraft_reliability on tail_number."
+ ]
+ }
+ ]
+ }
+}
```
-# MCP Tool: manage_genie
-manage_genie(
- action="import",
- warehouse_id=exported["warehouse_id"],
- serialized_space=exported["serialized_space"],
- title=exported["title"], # override title; omit to keep original
- description=exported["description"],
-)
+
+
+## Cross-Workspace Migration
+
+When migrating between workspaces, catalog names often differ. Export the space, remap with `sed`, then import:
+
+```bash
+sed -i '' 's/source_catalog/target_catalog/g' genie_space.json
```
-> **Cross-workspace migration:** Each MCP server is workspace-scoped. Configure one server entry per workspace profile in your IDE's MCP config, then `manage_genie(action="export")` from the source server and `manage_genie(action="import")` via the target server. See [spaces.md §Migration](spaces.md#migrating-across-workspaces-with-catalog-remapping) for the full workflow.
+Use `DATABRICKS_CONFIG_PROFILE=profile_name` to target different workspaces.
-## Reference Files
+## Conversation API
-- [spaces.md](spaces.md) - Creating and managing Genie Spaces
-- [conversation.md](conversation.md) - Asking questions via the Conversation API
+Use `scripts/conversation.py` to ask questions programmatically:
-## Prerequisites
+```bash
+# Ask a question
+python scripts/conversation.py ask SPACE_ID "What were total sales last month?"
-Before creating a Genie Space:
+# Follow-up in same conversation (Genie remembers context)
+python scripts/conversation.py ask SPACE_ID "Break down by region" --conversation-id CONV_ID
-1. **Tables in Unity Catalog** - Bronze/silver/gold tables with the data
-2. **SQL Warehouse** - A warehouse to execute queries (auto-detected if not specified)
+# With timeout for complex queries
+python scripts/conversation.py ask SPACE_ID "Complex query" --timeout 120
+```
-### Creating Tables
+Start a new conversation for unrelated topics. Use `--conversation-id` only for follow-ups on the same topic.
-Use these skills in sequence:
-1. `databricks-synthetic-data-gen` - Generate raw parquet files
-2. `databricks-spark-declarative-pipelines` - Create bronze/silver/gold tables
+## Troubleshooting
-## Common Issues
+| Issue | Solution |
+|-------|----------|
+| `sample_question.id must be provided` | Add 32-char hex UUID `id` to each sample question |
+| `Expected an array for question` | Use `"question": ["text"]` not `"question": "text"` |
+| No warehouse available | Create a SQL warehouse or provide `warehouse_id` |
+| Empty `serialized_space` on export | Requires CAN EDIT permission on the space |
+| Tables not found after migration | Remap catalog name in `serialized_space` before import |
-See [spaces.md §Troubleshooting](spaces.md#troubleshooting) for a full list of issues and solutions.
## Related Skills
-- **[databricks-agent-bricks](../databricks-agent-bricks/SKILL.md)** - Use Genie Spaces as agents inside Supervisor Agents
-- **[databricks-synthetic-data-gen](../databricks-synthetic-data-gen/SKILL.md)** - Generate raw parquet data to populate tables for Genie
-- **[databricks-spark-declarative-pipelines](../databricks-spark-declarative-pipelines/SKILL.md)** - Build bronze/silver/gold tables consumed by Genie Spaces
-- **[databricks-unity-catalog](../databricks-unity-catalog/SKILL.md)** - Manage the catalogs, schemas, and tables Genie queries
+- **[databricks-synthetic-data-gen](../databricks-synthetic-data-gen/SKILL.md)** - Generate data for Genie tables
+- **[databricks-spark-declarative-pipelines](../databricks-spark-declarative-pipelines/SKILL.md)** - Build bronze/silver/gold tables
diff --git a/databricks-skills/databricks-genie/conversation.md b/databricks-skills/databricks-genie/conversation.md
deleted file mode 100644
index e4320e8b..00000000
--- a/databricks-skills/databricks-genie/conversation.md
+++ /dev/null
@@ -1,239 +0,0 @@
-# Genie Conversations
-
-Use the Genie Conversation API to ask natural language questions to a curated Genie Space.
-
-## Overview
-
-The `ask_genie` tool allows you to programmatically send questions to a Genie Space and receive SQL-generated answers. Instead of writing SQL directly, you delegate the query generation to Genie, which has been curated with business logic, instructions, and certified queries.
-
-## When to Use `ask_genie`
-
-### Use `ask_genie` When:
-
-| Scenario | Why |
-|----------|-----|
-| Genie Space has curated business logic | Genie knows rules like "active customer = ordered in 90 days" |
-| User explicitly says "ask Genie" or "use my Genie Space" | User intent to use their curated space |
-| Complex business metrics with specific definitions | Genie has certified queries for official metrics |
-| Testing a Genie Space after creating it | Validate the space works correctly |
-| User wants conversational data exploration | Genie handles context for follow-up questions |
-
-### Use Direct SQL (`execute_sql`) Instead When:
-
-| Scenario | Why |
-|----------|-----|
-| Simple ad-hoc query | Direct SQL is faster, no curation needed |
-| You already have the exact SQL | No need for Genie to regenerate |
-| Genie Space doesn't exist for this data | Can't use Genie without a space |
-| Need precise control over the query | Direct SQL gives exact control |
-
-## MCP Tools
-
-| Tool | Purpose |
-|------|---------|
-| `ask_genie` | Ask a question or follow-up (`conversation_id` optional) |
-
-## Basic Usage
-
-### Ask a Question
-
-```python
-ask_genie(
- space_id="01abc123...",
- question="What were total sales last month?"
-)
-```
-
-**Response:**
-```python
-{
- "question": "What were total sales last month?",
- "conversation_id": "conv_xyz789",
- "message_id": "msg_123",
- "status": "COMPLETED",
- "sql": "SELECT SUM(total_amount) AS total_sales FROM orders WHERE order_date >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL 1 MONTH) AND order_date < DATE_TRUNC('month', CURRENT_DATE)",
- "columns": ["total_sales"],
- "data": [[125430.50]],
- "row_count": 1
-}
-```
-
-### Ask Follow-up Questions
-
-Use the `conversation_id` from the first response to ask follow-up questions with context:
-
-```python
-# First question
-result = ask_genie(
- space_id="01abc123...",
- question="What were total sales last month?"
-)
-
-# Follow-up (uses context from first question)
-ask_genie(
- space_id="01abc123...",
- question="Break that down by region",
- conversation_id=result["conversation_id"]
-)
-```
-
-Genie remembers the context, so "that" refers to "total sales last month".
-
-## Response Fields
-
-| Field | Description |
-|-------|-------------|
-| `question` | The original question asked |
-| `conversation_id` | ID for follow-up questions |
-| `message_id` | Unique message identifier |
-| `status` | `COMPLETED`, `FAILED`, `CANCELLED`, `TIMEOUT` |
-| `sql` | The SQL query Genie generated |
-| `columns` | List of column names in result |
-| `data` | Query results as list of rows |
-| `row_count` | Number of rows returned |
-| `text_response` | Text explanation (if Genie asks for clarification) |
-| `error` | Error message (if status is not COMPLETED) |
-
-## Handling Responses
-
-### Successful Response
-
-```python
-result = ask_genie(space_id, "Who are our top 10 customers?")
-
-if result["status"] == "COMPLETED":
- print(f"SQL: {result['sql']}")
- print(f"Rows: {result['row_count']}")
- for row in result["data"]:
- print(row)
-```
-
-### Failed Response
-
-```python
-result = ask_genie(space_id, "What is the meaning of life?")
-
-if result["status"] == "FAILED":
- print(f"Error: {result['error']}")
- # Genie couldn't answer - may need to rephrase or use direct SQL
-```
-
-### Timeout
-
-```python
-result = ask_genie(space_id, question, timeout_seconds=60)
-
-if result["status"] == "TIMEOUT":
- print("Query took too long - try a simpler question or increase timeout")
-```
-
-## Example Workflows
-
-### Workflow 1: User Asks to Use Genie
-
-```
-User: "Ask my Sales Genie what the churn rate is"
-
-Claude:
-1. Identifies user wants to use Genie (explicit request)
-2. Calls ask_genie(space_id="sales_genie_id", question="What is the churn rate?")
-3. Returns: "Based on your Sales Genie, the churn rate is 4.2%.
- Genie used this SQL: SELECT ..."
-```
-
-### Workflow 2: Testing a New Genie Space
-
-```
-User: "I just created a Genie Space for HR data. Can you test it?"
-
-Claude:
-1. Gets the space_id from the user or recent manage_genie(action="create_or_update") result
-2. Calls ask_genie with test questions:
- - "How many employees do we have?"
- - "What is the average salary by department?"
-3. Reports results: "Your HR Genie is working. It correctly answered..."
-```
-
-### Workflow 3: Data Exploration with Follow-ups
-
-```
-User: "Use my analytics Genie to explore sales trends"
-
-Claude:
-1. ask_genie(space_id, "What were total sales by month this year?")
-2. User: "Which month had the highest growth?"
-3. ask_genie(space_id, "Which month had the highest growth?", conversation_id=conv_id)
-4. User: "What products drove that growth?"
-5. ask_genie(space_id, "What products drove that growth?", conversation_id=conv_id)
-```
-
-## Best Practices
-
-### Start New Conversations for New Topics
-
-Don't reuse conversations across unrelated questions:
-
-```python
-# Good: New conversation for new topic
-result1 = ask_genie(space_id, "What were sales last month?") # New conversation
-result2 = ask_genie(space_id, "How many employees do we have?") # New conversation
-
-# Good: Follow-up for related question
-result1 = ask_genie(space_id, "What were sales last month?")
-result2 = ask_genie(space_id, "Break that down by product",
- conversation_id=result1["conversation_id"]) # Related follow-up
-```
-
-### Handle Clarification Requests
-
-Genie may ask for clarification instead of returning results:
-
-```python
-result = ask_genie(space_id, "Show me the data")
-
-if result.get("text_response"):
- # Genie is asking for clarification
- print(f"Genie asks: {result['text_response']}")
- # Rephrase with more specifics
-```
-
-### Set Appropriate Timeouts
-
-- Simple aggregations: 30-60 seconds
-- Complex joins: 60-120 seconds
-- Large data scans: 120+ seconds
-
-```python
-# Quick question
-ask_genie(space_id, "How many orders today?", timeout_seconds=30)
-
-# Complex analysis
-ask_genie(space_id, "Calculate customer lifetime value for all customers",
- timeout_seconds=180)
-```
-
-## Troubleshooting
-
-### "Genie Space not found"
-
-- Verify the `space_id` is correct
-- Check you have access to the space
-- Use `manage_genie(action="get", space_id=...)` to verify it exists
-
-### "Query timed out"
-
-- Increase `timeout_seconds`
-- Simplify the question
-- Check if the SQL warehouse is running
-
-### "Failed to generate SQL"
-
-- Rephrase the question more clearly
-- Check if the question is answerable with the available tables
-- Add more instructions/curation to the Genie Space
-
-### Unexpected Results
-
-- Review the generated SQL in the response
-- Add SQL instructions to the Genie Space via the Databricks UI
-- Add sample questions that demonstrate correct patterns
diff --git a/databricks-skills/databricks-genie/scripts/conversation.py b/databricks-skills/databricks-genie/scripts/conversation.py
new file mode 100644
index 00000000..e1a670ff
--- /dev/null
+++ b/databricks-skills/databricks-genie/scripts/conversation.py
@@ -0,0 +1,171 @@
+#!/usr/bin/env python3
+"""
+Genie Conversation API - CLI interface for asking questions to Genie Spaces.
+
+Usage:
+ python conversation.py ask SPACE_ID "What were total sales last month?"
+ python conversation.py ask SPACE_ID "Break that down by region" --conversation-id CONV_ID
+ python conversation.py ask SPACE_ID "Complex query" --timeout 120
+
+Requires: databricks-sdk package
+"""
+
+import argparse
+import json
+import sys
+import time
+from typing import Any, Dict, Optional
+
+from databricks.sdk import WorkspaceClient
+from databricks.sdk.service.dashboards import GenieMessage
+
+
+def ask_genie(
+ space_id: str,
+ question: str,
+ conversation_id: Optional[str] = None,
+ timeout_seconds: int = 60,
+) -> Dict[str, Any]:
+ """Ask a question to a Genie Space.
+
+ Args:
+ space_id: The Genie Space ID
+ question: Natural language question to ask
+ conversation_id: Optional conversation ID for follow-up questions
+ timeout_seconds: Maximum time to wait for response (default: 60)
+
+ Returns:
+ Dict with question, conversation_id, message_id, status, sql, columns, data, row_count
+ """
+ client = WorkspaceClient()
+
+ # Start or continue conversation
+ if conversation_id:
+ response = client.genie.start_conversation_and_wait(
+ space_id=space_id,
+ content=question,
+ conversation_id=conversation_id,
+ )
+ else:
+ response = client.genie.start_conversation_and_wait(
+ space_id=space_id,
+ content=question,
+ )
+
+ # Extract conversation and message IDs
+ conv_id = response.conversation_id if hasattr(response, 'conversation_id') else None
+ msg_id = response.message_id if hasattr(response, 'message_id') else None
+
+ # Poll for completion
+ start_time = time.time()
+ while True:
+ if time.time() - start_time > timeout_seconds:
+ return {
+ "question": question,
+ "conversation_id": conv_id,
+ "message_id": msg_id,
+ "status": "TIMEOUT",
+ "error": f"Query timed out after {timeout_seconds} seconds",
+ }
+
+ # Get message details
+ message = client.genie.get_message(
+ space_id=space_id,
+ conversation_id=conv_id,
+ message_id=msg_id,
+ )
+
+ status = message.status.value if hasattr(message.status, 'value') else str(message.status)
+
+ if status == "COMPLETED":
+ # Extract results
+ result = {
+ "question": question,
+ "conversation_id": conv_id,
+ "message_id": msg_id,
+ "status": "COMPLETED",
+ }
+
+ # Get SQL and data from attachments
+ if message.attachments:
+ for attachment in message.attachments:
+ if hasattr(attachment, 'query') and attachment.query:
+ result["sql"] = attachment.query.query
+ if hasattr(attachment, 'text') and attachment.text:
+ result["text_response"] = attachment.text.content
+
+ # Get query result if available
+ if hasattr(message, 'query_result') and message.query_result:
+ qr = message.query_result
+ if hasattr(qr, 'columns'):
+ result["columns"] = [c.name for c in qr.columns]
+ if hasattr(qr, 'data_array'):
+ result["data"] = qr.data_array
+ result["row_count"] = len(qr.data_array)
+
+ return result
+
+ elif status in ["FAILED", "CANCELLED"]:
+ error_msg = ""
+ if message.attachments:
+ for attachment in message.attachments:
+ if hasattr(attachment, 'text') and attachment.text:
+ error_msg = attachment.text.content
+ return {
+ "question": question,
+ "conversation_id": conv_id,
+ "message_id": msg_id,
+ "status": status,
+ "error": error_msg or f"Query {status.lower()}",
+ }
+
+ # Still processing, wait and retry
+ time.sleep(2)
+
+
+def _print_json(data: Any) -> None:
+ """Print data as formatted JSON."""
+ print(json.dumps(data, indent=2, default=str))
+
+
+def main():
+ """CLI entry point."""
+ parser = argparse.ArgumentParser(
+ description="Ask questions to a Genie Space",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+ subparsers = parser.add_subparsers(dest="command", required=True)
+
+ # ask command
+ ask_parser = subparsers.add_parser("ask", help="Ask a question to a Genie Space")
+ ask_parser.add_argument("space_id", help="The Genie Space ID")
+ ask_parser.add_argument("question", help="Natural language question to ask")
+ ask_parser.add_argument(
+ "--conversation-id", "-c",
+ help="Conversation ID for follow-up questions",
+ )
+ ask_parser.add_argument(
+ "--timeout", "-t",
+ type=int,
+ default=60,
+ help="Timeout in seconds (default: 60)",
+ )
+
+ args = parser.parse_args()
+
+ if args.command == "ask":
+ result = ask_genie(
+ space_id=args.space_id,
+ question=args.question,
+ conversation_id=args.conversation_id,
+ timeout_seconds=args.timeout,
+ )
+ _print_json(result)
+ else:
+ parser.print_help()
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/databricks-skills/databricks-genie/spaces.md b/databricks-skills/databricks-genie/spaces.md
deleted file mode 100644
index ff8acb60..00000000
--- a/databricks-skills/databricks-genie/spaces.md
+++ /dev/null
@@ -1,395 +0,0 @@
-# Creating Genie Spaces
-
-This guide covers creating and managing Genie Spaces for SQL-based data exploration.
-
-## What is a Genie Space?
-
-A Genie Space connects to Unity Catalog tables and translates natural language questions into SQL — understanding schemas, generating queries, executing them on a SQL warehouse, and presenting results conversationally.
-
-## Creation Workflow
-
-### Step 1: Inspect Table Schemas (Required)
-
-**Before creating a Genie Space, you MUST inspect the table schemas** to understand what data is available:
-
-```python
-get_table_stats_and_schema(
- catalog="my_catalog",
- schema="sales",
- table_stat_level="SIMPLE"
-)
-```
-
-This returns:
-- Table names and row counts
-- Column names and data types
-- Sample values and cardinality
-- Null counts and statistics
-
-### Step 2: Analyze and Plan
-
-Based on the schema information:
-
-1. **Select relevant tables** - Choose tables that support the user's use case
-2. **Identify key columns** - Note date columns, metrics, dimensions, and foreign keys
-3. **Understand relationships** - How do tables join together?
-4. **Plan sample questions** - What questions can this data answer?
-
-### Step 3: Create the Genie Space
-
-Create the space with content tailored to the actual data:
-
-```python
-manage_genie(
- action="create_or_update",
- display_name="Sales Analytics",
- table_identifiers=[
- "my_catalog.sales.customers",
- "my_catalog.sales.orders",
- "my_catalog.sales.products"
- ],
- description="""Explore retail sales data with three related tables:
-- customers: Customer demographics including region, segment, and signup date
-- orders: Transaction history with order_date, total_amount, and status
-- products: Product catalog with category, price, and inventory
-
-Tables join on customer_id and product_id.""",
- sample_questions=[
- "What were total sales last month?",
- "Who are our top 10 customers by total_amount?",
- "How many orders were placed in Q4 by region?",
- "What's the average order value by customer segment?",
- "Which product categories have the highest revenue?",
- "Show me customers who haven't ordered in 90 days"
- ]
-)
-```
-
-## Why This Workflow Matters
-
-**Sample questions that reference actual column names** help Genie:
-- Learn the vocabulary of your data
-- Generate more accurate SQL queries
-- Provide better autocomplete suggestions
-
-**A description that explains table relationships** helps Genie:
-- Understand how to join tables correctly
-- Know which table contains which information
-- Provide more relevant answers
-
-## Auto-Detection of Warehouse
-
-When `warehouse_id` is not specified, the tool:
-
-1. Lists all SQL warehouses in the workspace
-2. Prioritizes by:
- - **Running** warehouses first (already available)
- - **Starting** warehouses second
- - **Smaller sizes** preferred (cost-efficient)
-3. Returns an error if no warehouses exist
-
-To use a specific warehouse, provide the `warehouse_id` explicitly.
-
-## Table Selection
-
-Choose tables carefully for best results:
-
-| Layer | Recommended | Why |
-|-------|-------------|-----|
-| Bronze | No | Raw data, may have quality issues |
-| Silver | Yes | Cleaned and validated |
-| Gold | Yes | Aggregated, optimized for analytics |
-
-### Tips for Table Selection
-
-- **Include related tables**: If users ask about customers and orders, include both
-- **Use descriptive column names**: `customer_name` is better than `cust_nm`
-- **Add table comments**: Genie uses metadata to understand the data
-
-## Sample Questions
-
-Sample questions help users understand what they can ask:
-
-**Good sample questions:**
-- "What were total sales last month?"
-- "Who are our top 10 customers by revenue?"
-- "How many orders were placed in Q4?"
-- "What's the average order value by region?"
-
-These appear in the Genie UI to guide users.
-
-## Best Practices
-
-### Table Design for Genie
-
-1. **Descriptive names**: Use `customer_lifetime_value` not `clv`
-2. **Add comments**: `COMMENT ON TABLE sales.customers IS 'Customer master data'`
-3. **Primary keys**: Define relationships clearly
-4. **Date columns**: Include proper date/timestamp columns for time-based queries
-
-### Description and Context
-
-Provide context in the description:
-
-```
-Explore retail sales data from our e-commerce platform. Includes:
-- Customers: demographics, segments, and account status
-- Orders: transaction history with amounts and dates
-- Products: catalog with categories and pricing
-
-Time range: Last 6 months of data
-```
-
-### Sample Questions
-
-Write sample questions that:
-- Cover common use cases
-- Demonstrate the data's capabilities
-- Use natural language (not SQL terms)
-
-## Updating a Genie Space
-
-`manage_genie(action="create_or_update")` handles both create and update automatically. There are two ways it locates an existing space to update:
-
-- **By `space_id`** (explicit, preferred): pass `space_id=` to target a specific space.
-- **By `display_name`** (implicit fallback): if `space_id` is omitted, the tool searches for a space with a matching name and updates it if found; otherwise it creates a new one.
-
-### Simple field updates (tables, questions, warehouse)
-
-To update metadata without a serialized config:
-
-```python
-manage_genie(
- action="create_or_update",
- display_name="Sales Analytics",
- space_id="01abc123...", # omit to match by name instead
- table_identifiers=[ # updated table list
- "my_catalog.sales.customers",
- "my_catalog.sales.orders",
- "my_catalog.sales.products",
- ],
- sample_questions=[ # updated sample questions
- "What were total sales last month?",
- "Who are our top 10 customers by revenue?",
- ],
- warehouse_id="abc123def456", # omit to keep current / auto-detect
- description="Updated description.",
-)
-```
-
-### Full config update via `serialized_space`
-
-To push a complete serialized configuration to an existing space (the dict contains all regular table metadata, plus it preserves all instructions, SQL examples, join specs, etc.):
-
-```python
-manage_genie(
- action="create_or_update",
- display_name="Sales Analytics", # overrides title embedded in serialized_space
- table_identifiers=[], # ignored when serialized_space is provided
- space_id="01abc123...", # target space to overwrite
- warehouse_id="abc123def456", # overrides warehouse embedded in serialized_space
- description="Updated description.", # overrides description embedded in serialized_space; omit to keep the one in the payload
- serialized_space=remapped_config, # JSON string from manage_genie(action="export") (after catalog remap if needed)
-)
-```
-
-> **Note:** When `serialized_space` is provided, `table_identifiers` and `sample_questions` are ignored — the full config comes from the serialized payload. However, `display_name`, `warehouse_id`, and `description` are still applied as top-level overrides on top of the serialized payload. Omit any of them to keep the values embedded in `serialized_space`.
-
-## Export, Import & Migration
-
-`manage_genie(action="export")` returns a dictionary with four top-level keys:
-
-| Key | Description |
-|-----|-------------|
-| `space_id` | ID of the exported space |
-| `title` | Display name of the space |
-| `description` | Description of the space |
-| `warehouse_id` | SQL warehouse associated with the space (workspace-specific — do **not** reuse across workspaces) |
-| `serialized_space` | JSON-encoded string with the full space configuration (see below) |
-
-This envelope enables cloning, backup, and cross-workspace migration. Use `manage_genie(action="export")` and `manage_genie(action="import")` for all export/import operations — no direct REST calls needed.
-
-### What is `serialized_space`?
-
-`serialized_space` is a JSON string (version 2) embedded inside the export envelope. Its top-level keys are:
-
-| Key | Contents |
-|-----|----------|
-| `version` | Schema version (currently `2`) |
-| `config` | Space-level config: `sample_questions` shown in the UI |
-| `data_sources` | `tables` array — each entry has a fully-qualified `identifier` (`catalog.schema.table`) and optional `column_configs` (format assistance, entity matching per column) |
-| `instructions` | `example_question_sqls` (certified Q&A pairs), `join_specs` (join relationships between tables), `sql_snippets` (`filters` and `measures` with display names and usage instructions) |
-| `benchmarks` | Evaluation Q&A pairs used to measure space quality |
-
-Catalog names appear **everywhere** inside `serialized_space` — in `data_sources.tables[].identifier`, SQL strings in `example_question_sqls`, `join_specs`, and `sql_snippets`. A single `.replace(src_catalog, tgt_catalog)` on the whole string is sufficient for catalog remapping.
-
-Minimum structure:
-```json
-{"version": 2, "data_sources": {"tables": [{"identifier": "catalog.schema.table"}]}}
-```
-
-### Exporting a Space
-
-Use `manage_genie(action="export")` to export the full configuration (requires CAN EDIT permission):
-
-```python
-exported = manage_genie(action="export", space_id="01abc123...")
-# Returns:
-# {
-# "space_id": "01abc123...",
-# "title": "Sales Analytics",
-# "description": "Explore sales data...",
-# "warehouse_id": "abc123def456",
-# "serialized_space": "{\"version\":2,\"data_sources\":{...},\"instructions\":{...}}"
-# }
-```
-
-You can also get `serialized_space` inline via `manage_genie(action="get")`:
-
-```python
-details = manage_genie(action="get", space_id="01abc123...", include_serialized_space=True)
-serialized = details["serialized_space"]
-```
-
-### Cloning a Space (Same Workspace)
-
-```python
-# Step 1: Export the source space
-source = manage_genie(action="export", space_id="01abc123...")
-
-# Step 2: Import as a new space
-manage_genie(
- action="import",
- warehouse_id=source["warehouse_id"],
- serialized_space=source["serialized_space"],
- title=source["title"], # override title; omit to keep original
- description=source["description"],
-)
-# Returns: {"space_id": "01def456...", "title": "Sales Analytics (Dev Copy)", "operation": "imported"}
-```
-
-### Migrating Across Workspaces with Catalog Remapping
-
-When migrating between environments (e.g. prod → dev), Unity Catalog names are often different. The `serialized_space` string contains the source catalog name **everywhere** — in table identifiers, SQL queries, join specs, and filter snippets. You must remap it before importing.
-
-**Agent workflow (3 steps):**
-
-**Step 1 — Export from source workspace:**
-```python
-exported = manage_genie(action="export", space_id="01f106e1239d14b28d6ab46f9c15e540")
-# exported keys: warehouse_id, title, description, serialized_space
-# exported["serialized_space"] contains all references to source catalog
-```
-
-**Step 2 — Remap catalog name in `serialized_space`:**
-
-The agent does this as an inline string substitution between the two MCP calls:
-```python
-modified_serialized = exported["serialized_space"].replace(
- "source_catalog_name", # e.g. "healthverity_claims_sample_patient_dataset"
- "target_catalog_name" # e.g. "healthverity_claims_sample_patient_dataset_dev"
-)
-```
-This replaces all occurrences — table identifiers, SQL FROM clauses, join specs, and filter snippets.
-
-**Step 3 — Import to target workspace:**
-```python
-manage_genie(
- action="import",
- warehouse_id="", # from manage_warehouse(action="list") on target
- serialized_space=modified_serialized,
- title=exported["title"],
- description=exported["description"]
-)
-```
-
-### Batch Migration of Multiple Spaces
-
-To migrate several spaces at once, loop through space IDs. The agent exports, remaps the catalog, then imports each:
-
-```
-For each space_id in [id1, id2, id3]:
- 1. exported = manage_genie(action="export", space_id=space_id)
- 2. modified = exported["serialized_space"].replace(src_catalog, tgt_catalog)
- 3. result = manage_genie(action="import", warehouse_id=wh_id, serialized_space=modified, title=exported["title"], description=exported["description"])
- 4. record result["space_id"] for updating databricks.yml
-```
-
-After migration, update `databricks.yml` with the new dev `space_id` values under the `dev` target's `genie_space_ids` variable.
-
-### Updating an Existing Space with New Config
-
-To push a serialized config to an already-existing space (rather than creating a new one), use `manage_genie(action="create_or_update")` with `space_id=` and `serialized_space=`. The export → remap → push pattern is identical to the migration steps above; just replace `manage_genie(action="import")` with `manage_genie(action="create_or_update", space_id=TARGET_SPACE_ID, ...)` as the final call.
-
-### Permissions Required
-
-| Operation | Required Permission |
-|-----------|-------------------|
-| `manage_genie(action="export")` / `manage_genie(action="get", include_serialized_space=True)` | CAN EDIT on source space |
-| `manage_genie(action="import")` | Can create items in target workspace folder |
-| `manage_genie(action="create_or_update")` with `serialized_space` (update) | CAN EDIT on target space |
-
-## Example End-to-End Workflow
-
-1. **Generate synthetic data** using `databricks-synthetic-data-gen` skill:
- - Creates parquet files in `/Volumes/catalog/schema/raw_data/`
-
-2. **Create tables** using `databricks-spark-declarative-pipelines` skill:
- - Creates `catalog.schema.bronze_*` → `catalog.schema.silver_*` → `catalog.schema.gold_*`
-
-3. **Inspect the tables**:
- ```python
- get_table_stats_and_schema(catalog="catalog", schema="schema")
- ```
-
-4. **Create the Genie Space**:
- - `display_name`: "My Data Explorer"
- - `table_identifiers`: `["catalog.schema.silver_customers", "catalog.schema.silver_orders"]`
-
-5. **Add sample questions** based on actual column names
-
-6. **Test** in the Databricks UI
-
-## Troubleshooting
-
-### No warehouse available
-
-- Create a SQL warehouse in the Databricks workspace
-- Or provide a specific `warehouse_id`
-
-### Queries are slow
-
-- Ensure the warehouse is running (not stopped)
-- Consider using a larger warehouse size
-- Check if tables are optimized (OPTIMIZE, Z-ORDER)
-
-### Poor query generation
-
-- Use descriptive column names
-- Add table and column comments
-- Include sample questions that demonstrate the vocabulary
-- Add instructions via the Databricks Genie UI
-
-### `manage_genie(action="export")` returns empty `serialized_space`
-
-Requires at least **CAN EDIT** permission on the space.
-
-### `manage_genie(action="import")` fails with permission error
-
-Ensure you have CREATE privileges in the target workspace folder.
-
-### Tables not found after migration
-
-Catalog name was not remapped — replace the source catalog name in `serialized_space` before calling `manage_genie(action="import")`. The catalog appears in table identifiers, SQL FROM clauses, join specs, and filter snippets; a single `.replace(src_catalog, tgt_catalog)` on the whole string covers all occurrences.
-
-### `manage_genie` lands in the wrong workspace
-
-Each MCP server is workspace-scoped. Set up two named MCP server entries (one per profile) in your IDE's MCP config instead of switching a single server's profile mid-session.
-
-### MCP server doesn't pick up profile change
-
-The MCP process reads `DATABRICKS_CONFIG_PROFILE` once at startup — editing the config file requires an IDE reload to take effect.
-
-### `manage_genie(action="import")` fails with JSON parse error
-
-The `serialized_space` string may contain multi-line SQL arrays with `\n` escape sequences. Flatten SQL arrays to single-line strings before passing to avoid double-escaping issues.
diff --git a/databricks-skills/databricks-jobs/task-types.md b/databricks-skills/databricks-jobs/task-types.md
index c5b06fbe..f7c3e043 100644
--- a/databricks-skills/databricks-jobs/task-types.md
+++ b/databricks-skills/databricks-jobs/task-types.md
@@ -618,7 +618,6 @@ Define reusable Python environments for serverless tasks with custom pip depende
> **IMPORTANT:** The `client` field is **required** in the environment `spec`. It specifies the
> base serverless environment version. Use `"4"` as the value. Without it, the API returns:
> `"Either base environment or version must be provided for environment"`.
-> The MCP `manage_jobs` tool (action="create") auto-injects `client: "4"` if omitted, but CLI/SDK calls require it explicitly.
### DABs YAML
diff --git a/databricks-skills/databricks-lakebase-autoscale/SKILL.md b/databricks-skills/databricks-lakebase-autoscale/SKILL.md
index f471765c..848e6e67 100644
--- a/databricks-skills/databricks-lakebase-autoscale/SKILL.md
+++ b/databricks-skills/databricks-lakebase-autoscale/SKILL.md
@@ -169,71 +169,6 @@ w.postgres.update_endpoint(
).wait()
```
-## MCP Tools
-
-The following MCP tools are available for managing Lakebase infrastructure. Use `type="autoscale"` for Lakebase Autoscaling.
-
-### manage_lakebase_database - Project Management
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create_or_update` | Create or update a project | name |
-| `get` | Get project details (includes branches/endpoints) | name |
-| `list` | List all projects | (none, optional type filter) |
-| `delete` | Delete project and all branches/computes/data | name |
-
-**Example usage:**
-```python
-# Create an autoscale project
-manage_lakebase_database(
- action="create_or_update",
- name="my-app",
- type="autoscale",
- display_name="My Application",
- pg_version="17"
-)
-
-# Get project with branches
-manage_lakebase_database(action="get", name="my-app", type="autoscale")
-
-# Delete project
-manage_lakebase_database(action="delete", name="my-app", type="autoscale")
-```
-
-### manage_lakebase_branch - Branch Management
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create_or_update` | Create/update branch with compute endpoint | project_name, branch_id |
-| `delete` | Delete branch and endpoints | name (full branch name) |
-
-**Example usage:**
-```python
-# Create a dev branch with 7-day TTL
-manage_lakebase_branch(
- action="create_or_update",
- project_name="my-app",
- branch_id="development",
- source_branch="production",
- ttl_seconds=604800, # 7 days
- autoscaling_limit_min_cu=0.5,
- autoscaling_limit_max_cu=4.0,
- scale_to_zero_seconds=300
-)
-
-# Delete branch
-manage_lakebase_branch(action="delete", name="projects/my-app/branches/development")
-```
-
-### generate_lakebase_credential - OAuth Tokens
-
-Generate OAuth token (~1hr) for PostgreSQL connections. Use as password with `sslmode=require`.
-
-```python
-# For autoscale endpoints
-generate_lakebase_credential(endpoint="projects/my-app/branches/production/endpoints/ep-primary")
-```
-
## Reference Files
- [projects.md](projects.md) - Project management patterns and settings
@@ -242,7 +177,9 @@ generate_lakebase_credential(endpoint="projects/my-app/branches/production/endpo
- [connection-patterns.md](connection-patterns.md) - Connection patterns for different use cases
- [reverse-etl.md](reverse-etl.md) - Synced tables from Delta Lake to Lakebase
-## CLI Quick Reference
+## CLI Commands
+
+### Project Management
```bash
# Create a project
@@ -256,18 +193,45 @@ databricks postgres list-projects
# Get project details
databricks postgres get-project projects/my-app
-# Create a branch
+# Delete a project
+databricks postgres delete-project projects/my-app
+```
+
+### Branch Management
+
+```bash
+# Create a branch with TTL
+databricks postgres create-branch projects/my-app development \
+ --json '{"spec": {"source_branch": "projects/my-app/branches/production", "ttl": {"seconds": 604800}}}'
+
+# Create a branch with no expiry
databricks postgres create-branch projects/my-app development \
--json '{"spec": {"source_branch": "projects/my-app/branches/production", "no_expiry": true}}'
# List branches
databricks postgres list-branches projects/my-app
+# Delete a branch
+databricks postgres delete-branch projects/my-app/branches/development
+```
+
+### Endpoint Management
+
+```bash
# Get endpoint details
databricks postgres get-endpoint projects/my-app/branches/production/endpoints/ep-primary
-# Delete a project
-databricks postgres delete-project projects/my-app
+# Update endpoint autoscaling limits
+databricks postgres update-endpoint projects/my-app/branches/production/endpoints/ep-primary \
+ --json '{"spec": {"autoscaling_limit_min_cu": 2.0, "autoscaling_limit_max_cu": 8.0}}'
+```
+
+### OAuth Credentials
+
+```bash
+# Generate database credential (for connections)
+databricks postgres generate-database-credential \
+ --endpoint projects/my-app/branches/production/endpoints/ep-primary
```
## Key Differences from Lakebase Provisioned
diff --git a/databricks-skills/databricks-lakebase-provisioned/SKILL.md b/databricks-skills/databricks-lakebase-provisioned/SKILL.md
index 7548219c..2dacbaa2 100644
--- a/databricks-skills/databricks-lakebase-provisioned/SKILL.md
+++ b/databricks-skills/databricks-lakebase-provisioned/SKILL.md
@@ -221,76 +221,14 @@ mlflow.langchain.log_model(
)
```
-## MCP Tools
-
-The following MCP tools are available for managing Lakebase infrastructure. Use `type="provisioned"` for Lakebase Provisioned.
-
-### manage_lakebase_database - Database Management
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create_or_update` | Create or update a database | name |
-| `get` | Get database details | name |
-| `list` | List all databases | (none, optional type filter) |
-| `delete` | Delete database and resources | name |
-
-**Example usage:**
-```python
-# Create a provisioned database
-manage_lakebase_database(
- action="create_or_update",
- name="my-lakebase-instance",
- type="provisioned",
- capacity="CU_1"
-)
-
-# Get database details
-manage_lakebase_database(action="get", name="my-lakebase-instance", type="provisioned")
-
-# List all databases
-manage_lakebase_database(action="list")
-
-# Delete with cascade
-manage_lakebase_database(action="delete", name="my-lakebase-instance", type="provisioned", force=True)
-```
-
-### manage_lakebase_sync - Reverse ETL
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create_or_update` | Set up reverse ETL from Delta to Lakebase | instance_name, source_table_name, target_table_name |
-| `delete` | Remove synced table (and optionally catalog) | table_name |
-
-**Example usage:**
-```python
-# Set up reverse ETL
-manage_lakebase_sync(
- action="create_or_update",
- instance_name="my-lakebase-instance",
- source_table_name="catalog.schema.delta_table",
- target_table_name="lakebase_catalog.schema.postgres_table",
- scheduling_policy="TRIGGERED" # or SNAPSHOT, CONTINUOUS
-)
-
-# Delete synced table
-manage_lakebase_sync(action="delete", table_name="lakebase_catalog.schema.postgres_table")
-```
-
-### generate_lakebase_credential - OAuth Tokens
-
-Generate OAuth token (~1hr) for PostgreSQL connections. Use as password with `sslmode=require`.
-
-```python
-# For provisioned instances
-generate_lakebase_credential(instance_names=["my-lakebase-instance"])
-```
-
## Reference Files
- [connection-patterns.md](connection-patterns.md) - Detailed connection patterns for different use cases
- [reverse-etl.md](reverse-etl.md) - Syncing data from Delta Lake to Lakebase
-## CLI Quick Reference
+## CLI Commands
+
+### Instance Management
```bash
# Create instance
@@ -301,11 +239,6 @@ databricks database create-database-instance \
# Get instance details
databricks database get-database-instance --name my-lakebase-instance
-# Generate credentials
-databricks database generate-database-credential \
- --request-id $(uuidgen) \
- --json '{"instance_names": ["my-lakebase-instance"]}'
-
# List instances
databricks database list-database-instances
@@ -314,6 +247,35 @@ databricks database stop-database-instance --name my-lakebase-instance
# Start instance
databricks database start-database-instance --name my-lakebase-instance
+
+# Delete instance
+databricks database delete-database-instance --name my-lakebase-instance
+```
+
+### OAuth Credentials
+
+```bash
+# Generate credentials for connection
+databricks database generate-database-credential \
+ --request-id $(uuidgen) \
+ --json '{"instance_names": ["my-lakebase-instance"]}'
+```
+
+### Reverse ETL (Synced Tables)
+
+Synced tables are managed via Unity Catalog SQL commands:
+
+```sql
+-- Create synced table from Delta to Lakebase
+CREATE TABLE lakebase_catalog.schema.target_table
+SYNC FROM catalog.schema.source_delta_table
+SCHEDULE TRIGGERED;
+
+-- List synced tables
+SHOW TABLES IN lakebase_catalog.schema;
+
+-- Drop synced table
+DROP TABLE lakebase_catalog.schema.target_table;
```
## Common Issues
diff --git a/databricks-skills/databricks-metric-views/SKILL.md b/databricks-skills/databricks-metric-views/SKILL.md
index bddc74ad..c020c2c5 100644
--- a/databricks-skills/databricks-metric-views/SKILL.md
+++ b/databricks-skills/databricks-metric-views/SKILL.md
@@ -95,72 +95,88 @@ ORDER BY ALL
| YAML Syntax | [yaml-reference.md](yaml-reference.md) | Complete YAML spec: dimensions, measures, joins, materialization |
| Patterns & Examples | [patterns.md](patterns.md) | Common patterns: star schema, snowflake, filtered measures, window measures, ratios |
-## MCP Tools
-
-Use the `manage_metric_views` tool for all metric view operations:
-
-| Action | Description |
-|--------|-------------|
-| `create` | Create a metric view with dimensions and measures |
-| `alter` | Update a metric view's YAML definition |
-| `describe` | Get the full definition and metadata |
-| `query` | Query measures grouped by dimensions |
-| `drop` | Drop a metric view |
-| `grant` | Grant SELECT privileges to users/groups |
-
-### Create via MCP
-
-```python
-manage_metric_views(
- action="create",
- full_name="catalog.schema.orders_metrics",
- source="catalog.schema.orders",
- or_replace=True,
- comment="Orders KPIs for sales analysis",
- filter_expr="order_date > '2020-01-01'",
- dimensions=[
- {"name": "Order Month", "expr": "DATE_TRUNC('MONTH', order_date)", "comment": "Month of order"},
- {"name": "Order Status", "expr": "status"},
- ],
- measures=[
- {"name": "Order Count", "expr": "COUNT(1)"},
- {"name": "Total Revenue", "expr": "SUM(total_price)", "comment": "Sum of total price"},
- ],
-)
+## SQL Operations
+
+### Create Metric View
+
+```sql
+CREATE OR REPLACE VIEW catalog.schema.orders_metrics
+WITH METRICS
+LANGUAGE YAML
+AS $$
+ version: 1.1
+ comment: "Orders KPIs for sales analysis"
+ source: catalog.schema.orders
+ filter: order_date > '2020-01-01'
+ dimensions:
+ - name: Order Month
+ expr: DATE_TRUNC('MONTH', order_date)
+ comment: "Month of order"
+ - name: Order Status
+ expr: status
+ measures:
+ - name: Order Count
+ expr: COUNT(1)
+ - name: Total Revenue
+ expr: SUM(total_price)
+ comment: "Sum of total price"
+$$;
```
-### Query via MCP
-
-```python
-manage_metric_views(
- action="query",
- full_name="catalog.schema.orders_metrics",
- query_measures=["Total Revenue", "Order Count"],
- query_dimensions=["Order Month"],
- where="extract(year FROM `Order Month`) = 2024",
- order_by="ALL",
- limit=100,
-)
+### Query Metric View
+
+```sql
+SELECT
+ `Order Month`,
+ MEASURE(`Total Revenue`) AS total_revenue,
+ MEASURE(`Order Count`) AS order_count
+FROM catalog.schema.orders_metrics
+WHERE extract(year FROM `Order Month`) = 2024
+GROUP BY ALL
+ORDER BY ALL
+LIMIT 100;
```
-### Describe via MCP
+### Describe Metric View
-```python
-manage_metric_views(
- action="describe",
- full_name="catalog.schema.orders_metrics",
-)
+```sql
+DESCRIBE TABLE EXTENDED catalog.schema.orders_metrics;
+
+-- Or get YAML definition
+SHOW CREATE TABLE catalog.schema.orders_metrics;
```
### Grant Access
-```python
-manage_metric_views(
- action="grant",
- full_name="catalog.schema.orders_metrics",
- principal="data-consumers",
- privileges=["SELECT"],
-)
+```sql
+GRANT SELECT ON VIEW catalog.schema.orders_metrics TO `data-consumers`;
+```
+
+### Drop Metric View
+
+```sql
+DROP VIEW IF EXISTS catalog.schema.orders_metrics;
+```
+
+### CLI Execution
+
+```bash
+# Execute SQL via CLI
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID "
+CREATE OR REPLACE VIEW catalog.schema.orders_metrics
+WITH METRICS
+LANGUAGE YAML
+AS \$\$
+ version: 1.1
+ source: catalog.schema.orders
+ dimensions:
+ - name: Order Month
+ expr: DATE_TRUNC('MONTH', order_date)
+ measures:
+ - name: Total Revenue
+ expr: SUM(total_price)
+\$\$
+"
```
## YAML Spec Quick Reference
diff --git a/databricks-skills/databricks-metric-views/patterns.md b/databricks-skills/databricks-metric-views/patterns.md
index 48c7f9e3..430c1691 100644
--- a/databricks-skills/databricks-metric-views/patterns.md
+++ b/databricks-skills/databricks-metric-views/patterns.md
@@ -579,73 +579,81 @@ GROUP BY ALL
ORDER BY ALL
```
-## MCP Tool Examples
+## SQL Examples
### Create with joins
-```python
-manage_metric_views(
- action="create",
- full_name="catalog.schema.sales_metrics",
- source="catalog.schema.fact_sales",
- or_replace=True,
- joins=[
- {
- "name": "customer",
- "source": "catalog.schema.dim_customer",
- "on": "source.customer_id = customer.id"
- },
- {
- "name": "product",
- "source": "catalog.schema.dim_product",
- "on": "source.product_id = product.id"
- }
- ],
- dimensions=[
- {"name": "Customer Segment", "expr": "customer.segment"},
- {"name": "Product Category", "expr": "product.category"},
- {"name": "Sale Month", "expr": "DATE_TRUNC('MONTH', source.sale_date)"},
- ],
- measures=[
- {"name": "Total Revenue", "expr": "SUM(source.amount)"},
- {"name": "Order Count", "expr": "COUNT(1)"},
- {"name": "Unique Customers", "expr": "COUNT(DISTINCT source.customer_id)"},
- ],
-)
+```sql
+CREATE OR REPLACE VIEW catalog.schema.sales_metrics
+WITH METRICS
+LANGUAGE YAML
+AS $$
+ version: 1.1
+ source: catalog.schema.fact_sales
+ joins:
+ - name: customer
+ source: catalog.schema.dim_customer
+ on: source.customer_id = customer.id
+ - name: product
+ source: catalog.schema.dim_product
+ on: source.product_id = product.id
+ dimensions:
+ - name: Customer Segment
+ expr: customer.segment
+ - name: Product Category
+ expr: product.category
+ - name: Sale Month
+ expr: DATE_TRUNC('MONTH', source.sale_date)
+ measures:
+ - name: Total Revenue
+ expr: SUM(source.amount)
+ - name: Order Count
+ expr: COUNT(1)
+ - name: Unique Customers
+ expr: COUNT(DISTINCT source.customer_id)
+$$;
```
### Alter to add a new measure
-```python
-manage_metric_views(
- action="alter",
- full_name="catalog.schema.sales_metrics",
- source="catalog.schema.fact_sales",
- joins=[
- {"name": "customer", "source": "catalog.schema.dim_customer", "on": "source.customer_id = customer.id"},
- ],
- dimensions=[
- {"name": "Customer Segment", "expr": "customer.segment"},
- {"name": "Sale Month", "expr": "DATE_TRUNC('MONTH', source.sale_date)"},
- ],
- measures=[
- {"name": "Total Revenue", "expr": "SUM(source.amount)"},
- {"name": "Order Count", "expr": "COUNT(1)"},
- {"name": "Average Order Value", "expr": "AVG(source.amount)"}, # New measure
- ],
-)
+```sql
+-- Use CREATE OR REPLACE to update the metric view
+CREATE OR REPLACE VIEW catalog.schema.sales_metrics
+WITH METRICS
+LANGUAGE YAML
+AS $$
+ version: 1.1
+ source: catalog.schema.fact_sales
+ joins:
+ - name: customer
+ source: catalog.schema.dim_customer
+ on: source.customer_id = customer.id
+ dimensions:
+ - name: Customer Segment
+ expr: customer.segment
+ - name: Sale Month
+ expr: DATE_TRUNC('MONTH', source.sale_date)
+ measures:
+ - name: Total Revenue
+ expr: SUM(source.amount)
+ - name: Order Count
+ expr: COUNT(1)
+ - name: Average Order Value
+ expr: AVG(source.amount)
+$$;
```
### Query with filters
-```python
-manage_metric_views(
- action="query",
- full_name="catalog.schema.sales_metrics",
- query_measures=["Total Revenue", "Order Count"],
- query_dimensions=["Customer Segment", "Sale Month"],
- where="`Customer Segment` = 'Enterprise'",
- order_by="ALL",
- limit=50,
-)
+```sql
+SELECT
+ `Customer Segment`,
+ `Sale Month`,
+ MEASURE(`Total Revenue`) AS total_revenue,
+ MEASURE(`Order Count`) AS order_count
+FROM catalog.schema.sales_metrics
+WHERE `Customer Segment` = 'Enterprise'
+GROUP BY ALL
+ORDER BY ALL
+LIMIT 50;
```
diff --git a/databricks-skills/databricks-model-serving/1-classical-ml.md b/databricks-skills/databricks-model-serving/1-classical-ml.md
index 4b973e0a..42b6a016 100644
--- a/databricks-skills/databricks-model-serving/1-classical-ml.md
+++ b/databricks-skills/databricks-model-serving/1-classical-ml.md
@@ -140,16 +140,14 @@ endpoint = w.serving_endpoints.create_and_wait(
## Query the Endpoint
-### Via MCP Tool
-
-```
-manage_serving_endpoint(
- action="query",
- name="diabetes-predictor",
- dataframe_records=[
- {"age": 45, "bmi": 25.3, "bp": 120, "s1": 200}
- ]
-)
+### Via CLI
+
+```bash
+databricks serving-endpoints query diabetes-predictor --json '{
+ "dataframe_records": [
+ {"age": 45, "bmi": 25.3, "bp": 120, "s1": 200}
+ ]
+}'
```
### Via Python SDK
diff --git a/databricks-skills/databricks-model-serving/3-genai-agents.md b/databricks-skills/databricks-model-serving/3-genai-agents.md
index 4061dbab..66647687 100644
--- a/databricks-skills/databricks-model-serving/3-genai-agents.md
+++ b/databricks-skills/databricks-model-serving/3-genai-agents.md
@@ -221,10 +221,12 @@ for event in AGENT.predict_stream(request):
print(event)
```
-Run via MCP:
+Run via CLI:
-```
-execute_code(file_path="./my_agent/test_agent.py")
+```bash
+# Upload and run on Databricks
+databricks workspace import-dir ./my_agent /Workspace/Users//my_agent
+databricks jobs run-now --job-id # Job configured to run test_agent.py
```
## Logging the Agent
@@ -267,18 +269,16 @@ from databricks import agents
agents.deploy(
"main.agents.my_agent",
version="1",
- tags={"source": "mcp"}
+ tags={"source": "cli"}
)
# Takes ~15 minutes
```
## Query Deployed Agent
-```
-manage_serving_endpoint(
- action="query",
- name="my-agent-endpoint",
- messages=[{"role": "user", "content": "What is Databricks?"}],
- max_tokens=500
-)
+```bash
+databricks serving-endpoints query my-agent-endpoint --json '{
+ "messages": [{"role": "user", "content": "What is Databricks?"}],
+ "max_tokens": 500
+}'
```
diff --git a/databricks-skills/databricks-model-serving/5-development-testing.md b/databricks-skills/databricks-model-serving/5-development-testing.md
index 2a3806cf..71970aa9 100644
--- a/databricks-skills/databricks-model-serving/5-development-testing.md
+++ b/databricks-skills/databricks-model-serving/5-development-testing.md
@@ -1,8 +1,6 @@
# Development & Testing Workflow
-MCP-based workflow for developing and testing agents on Databricks.
-
-> **If MCP tools are not available**, use Databricks CLI or the Python SDK directly. See [Databricks CLI docs](https://docs.databricks.com/dev-tools/cli/) for `databricks workspace import` and `databricks clusters spark-submit` commands.
+CLI-based workflow for developing and testing agents on Databricks.
## Overview
@@ -13,17 +11,17 @@ MCP-based workflow for developing and testing agents on Databricks.
▼
┌─────────────────────────────────────────────────────────────┐
│ Step 2: Upload to workspace │
-│ → manage_workspace_files MCP tool │
+│ → databricks workspace import-dir │
└─────────────────────────────────────────────────────────────┘
▼
┌─────────────────────────────────────────────────────────────┐
│ Step 3: Install packages │
-│ → execute_code MCP tool │
+│ → databricks jobs (serverless with pip requirements) │
└─────────────────────────────────────────────────────────────┘
▼
┌─────────────────────────────────────────────────────────────┐
│ Step 4: Test agent (iterate) │
-│ → execute_code MCP tool (with file_path) │
+│ → databricks jobs run-now │
│ → If error: fix locally, re-upload, re-run │
└─────────────────────────────────────────────────────────────┘
```
@@ -85,17 +83,13 @@ print("Response:", result.model_dump(exclude_none=True))
## Step 2: Upload to Workspace
-Use the `manage_workspace_files` MCP tool:
+Use the Databricks CLI:
-```
-manage_workspace_files(
- action="upload",
- local_path="./my_agent",
- workspace_path="/Workspace/Users/you@company.com/my_agent"
-)
+```bash
+databricks workspace import-dir ./my_agent /Workspace/Users/you@company.com/my_agent
```
-This uploads all files in parallel.
+This uploads all files recursively.
## Step 3: Install Packages
@@ -135,8 +129,8 @@ execute_code(
1. Read the error from the output
2. Fix the local file (`agent.py` or `test_agent.py`)
-3. Re-upload: `manage_workspace_files(action="upload", ...)`
-4. Re-run: `execute_code(file_path=...)`
+3. Re-upload: `databricks workspace import-dir ./my_agent /Workspace/.../my_agent`
+4. Re-run the job
### Iteration Tips
@@ -188,13 +182,12 @@ print(response.content)
## Workflow Summary
-| Step | MCP Tool | Purpose |
-|------|----------|---------|
-| Upload files | `manage_workspace_files` (action="upload") | Sync local files to workspace |
-| Install packages | `execute_code` | Set up dependencies |
-| Restart Python | `execute_code` | Apply package changes |
-| Test agent | `execute_code` (with `file_path`) | Run test script |
-| Debug | `execute_code` | Quick checks |
+| Step | CLI Command | Purpose |
+|------|-------------|---------|
+| Upload files | `databricks workspace import-dir` | Sync local files to workspace |
+| Install packages | Job with pip requirements | Set up dependencies |
+| Test agent | `databricks jobs run-now` | Run test script |
+| Debug | Run notebook or script | Quick checks |
## Next Steps
diff --git a/databricks-skills/databricks-model-serving/6-logging-registration.md b/databricks-skills/databricks-model-serving/6-logging-registration.md
index cd687358..bfa643b9 100644
--- a/databricks-skills/databricks-model-serving/6-logging-registration.md
+++ b/databricks-skills/databricks-model-serving/6-logging-registration.md
@@ -60,10 +60,12 @@ uc_model_info = mlflow.register_model(
print(f"Registered: {uc_model_info.name} version {uc_model_info.version}")
```
-Run via MCP:
+Run via CLI:
-```
-execute_code(file_path="./my_agent/log_model.py")
+```bash
+# Upload and run on Databricks
+databricks workspace import-dir ./my_agent /Workspace/Users//my_agent
+databricks jobs run-now --job-id # Job configured to run log_model.py
```
## Resources for Auto Authentication
@@ -141,7 +143,7 @@ mlflow.models.predict(
)
```
-Run via MCP (in log_model.py or separate file):
+Run validation (in log_model.py or separate file):
```python
# validate_model.py
diff --git a/databricks-skills/databricks-model-serving/7-deployment.md b/databricks-skills/databricks-model-serving/7-deployment.md
index 666cb168..2f503112 100644
--- a/databricks-skills/databricks-model-serving/7-deployment.md
+++ b/databricks-skills/databricks-model-serving/7-deployment.md
@@ -2,7 +2,7 @@
Deploy models to serving endpoints. Uses async job-based approach for agents (deployment takes ~15 min).
-> **If MCP tools are not available**, use `databricks.agents.deploy()` directly in a notebook, or create jobs via CLI: `databricks jobs create --json @job.json`
+> Use `databricks.agents.deploy()` directly in a notebook, or create jobs via CLI: `databricks jobs create --json @job.json`
## Deployment Options
@@ -13,7 +13,7 @@ Deploy models to serving endpoints. Uses async job-based approach for agents (de
## GenAI Agent Deployment (Job-Based)
-Since agent deployment takes ~15 minutes, use a job to avoid MCP timeouts.
+Since agent deployment takes ~15 minutes, use a job for async deployment.
### Step 1: Create Deployment Script
@@ -32,7 +32,7 @@ print(f"Deploying {model_name} version {version}...")
deployment = agents.deploy(
model_name,
version,
- tags={"source": "mcp", "environment": "dev"}
+ tags={"source": "cli", "environment": "dev"}
)
print(f"Deployment complete!")
@@ -41,40 +41,39 @@ print(f"Endpoint: {deployment.endpoint_name}")
### Step 2: Create Deployment Job (One-Time)
-Use the `manage_jobs` MCP tool with action="create":
+Use the Databricks CLI:
-```
-manage_jobs(
- action="create",
- name="deploy-agent-job",
- tasks=[
- {
- "task_key": "deploy",
- "spark_python_task": {
- "python_file": "/Workspace/Users/you@company.com/my_agent/deploy_agent.py",
- "parameters": ["{{job.parameters.model_name}}", "{{job.parameters.version}}"]
- }
- }
- ],
- parameters=[
- {"name": "model_name", "default": "main.agents.my_agent"},
- {"name": "version", "default": "1"}
- ]
-)
+```bash
+databricks jobs create --json '{
+ "name": "deploy-agent-job",
+ "tasks": [{
+ "task_key": "deploy",
+ "spark_python_task": {
+ "python_file": "/Workspace/Users/you@company.com/my_agent/deploy_agent.py",
+ "parameters": ["{{job.parameters.model_name}}", "{{job.parameters.version}}"]
+ },
+ "new_cluster": {
+ "spark_version": "16.1.x-scala2.12",
+ "node_type_id": "i3.xlarge",
+ "num_workers": 0
+ }
+ }],
+ "parameters": [
+ {"name": "model_name", "default": "main.agents.my_agent"},
+ {"name": "version", "default": "1"}
+ ]
+}'
```
Save the returned `job_id`.
### Step 3: Run Deployment (Async)
-Use `manage_job_runs` with action="run_now" - returns immediately:
+Run the job - returns immediately:
-```
-manage_job_runs(
- action="run_now",
- job_id="",
- job_parameters={"model_name": "main.agents.my_agent", "version": "1"}
-)
+```bash
+databricks jobs run-now --job-id \
+ --params '{"model_name": "main.agents.my_agent", "version": "1"}'
```
Save the returned `run_id`.
@@ -83,14 +82,14 @@ Save the returned `run_id`.
Check job run status:
-```
-manage_job_runs(action="get", run_id="")
+```bash
+databricks jobs get-run --run-id
```
Or check endpoint directly:
-```
-manage_serving_endpoint(action="get", name="")
+```bash
+databricks serving-endpoints get
```
## Classical ML Deployment
@@ -163,7 +162,7 @@ deployment = agents.deploy(
"main.agents.my_agent",
"1",
endpoint_name="my-agent-endpoint", # Control the name
- tags={"source": "mcp", "environment": "dev"}
+ tags={"source": "cli", "environment": "dev"}
)
```
@@ -172,7 +171,7 @@ deployment = agents.deploy(
Endpoints created via `agents.deploy()` appear under **Serving** in the Databricks UI. If you don't see your endpoint:
1. **Check the filter** - The Serving page defaults to "Owned by me". If the deployment ran as a service principal (e.g., via a job), switch to "All" to see it.
-2. **Verify via API** - Use `manage_serving_endpoint(action="list")` or `manage_serving_endpoint(action="get", name="...")` to confirm the endpoint exists and check its state.
+2. **Verify via CLI** - Use `databricks serving-endpoints list` or `databricks serving-endpoints get ` to confirm the endpoint exists and check its state.
3. **Check the name** - The auto-generated name may not be what you expect. Print `deployment.endpoint_name` in the deploy script or check the job run output.
### Deployment Script with Explicit Naming
@@ -261,18 +260,18 @@ client.update_endpoint(
## Workflow Summary
-| Step | MCP Tool | Waits? |
-|------|----------|--------|
-| Upload deploy script | `manage_workspace_files` (action="upload") | Yes |
-| Create job (one-time) | `manage_jobs` (action="create") | Yes |
-| Run deployment | `manage_job_runs` (action="run_now") | **No** - returns immediately |
-| Check job status | `manage_job_runs` (action="get") | Yes |
-| Check endpoint status | `manage_serving_endpoint` (action="get") | Yes |
+| Step | CLI Command | Waits? |
+|------|-------------|--------|
+| Upload deploy script | `databricks workspace import-dir` | Yes |
+| Create job (one-time) | `databricks jobs create` | Yes |
+| Run deployment | `databricks jobs run-now` | **No** - returns immediately |
+| Check job status | `databricks jobs get-run` | Yes |
+| Check endpoint status | `databricks serving-endpoints get` | Yes |
## After Deployment
Once endpoint is READY:
-1. **Test with MCP**: `manage_serving_endpoint(action="query", name="...", messages=[...])`
+1. **Test with CLI**: `databricks serving-endpoints query --json '{"messages": [...]}'`
2. **Share with team**: Endpoint URL in Databricks UI
3. **Integrate in apps**: Use REST API or SDK
diff --git a/databricks-skills/databricks-model-serving/8-querying-endpoints.md b/databricks-skills/databricks-model-serving/8-querying-endpoints.md
index 4dfa2f91..2cebb0c1 100644
--- a/databricks-skills/databricks-model-serving/8-querying-endpoints.md
+++ b/databricks-skills/databricks-model-serving/8-querying-endpoints.md
@@ -2,87 +2,41 @@
Send requests to deployed Model Serving endpoints.
-> **If MCP tools are not available**, use the Python SDK or REST API examples below.
-
-## MCP Tools
+## CLI Commands
### Check Endpoint Status
Before querying, verify the endpoint is ready:
-```
-manage_serving_endpoint(action="get", name="my-agent-endpoint")
-```
-
-Response:
-```json
-{
- "name": "my-agent-endpoint",
- "state": "READY",
- "served_entities": [
- {"name": "my_agent-1", "entity_name": "main.agents.my_agent", "deployment_state": "READY"}
- ]
-}
+```bash
+databricks serving-endpoints get my-agent-endpoint
```
### Query Chat/Agent Endpoint
-```
-manage_serving_endpoint(
- action="query",
- name="my-agent-endpoint",
- messages=[
- {"role": "user", "content": "What is Databricks?"}
- ],
- max_tokens=500,
- temperature=0.7
-)
-```
-
-Response:
-```json
-{
- "choices": [
- {
- "message": {
- "role": "assistant",
- "content": "Databricks is a unified data intelligence platform..."
- },
- "finish_reason": "stop"
- }
- ],
- "usage": {
- "prompt_tokens": 10,
- "completion_tokens": 150,
- "total_tokens": 160
- }
-}
+```bash
+databricks serving-endpoints query my-agent-endpoint --json '{
+ "messages": [{"role": "user", "content": "What is Databricks?"}],
+ "max_tokens": 500,
+ "temperature": 0.7
+}'
```
### Query ML Model Endpoint
-```
-manage_serving_endpoint(
- action="query",
- name="sklearn-classifier",
- dataframe_records=[
- {"age": 25, "income": 50000, "credit_score": 720},
- {"age": 35, "income": 75000, "credit_score": 680}
- ]
-)
-```
-
-Response:
-```json
-{
- "predictions": [0.85, 0.72]
-}
+```bash
+databricks serving-endpoints query sklearn-classifier --json '{
+ "dataframe_records": [
+ {"age": 25, "income": 50000, "credit_score": 720},
+ {"age": 35, "income": 75000, "credit_score": 680}
+ ]
+}'
```
### List All Endpoints
-```
-manage_serving_endpoint(action="list", limit=20)
+```bash
+databricks serving-endpoints list
```
## Python SDK
diff --git a/databricks-skills/databricks-model-serving/9-package-requirements.md b/databricks-skills/databricks-model-serving/9-package-requirements.md
index f9ceb7a9..e5508b6a 100644
--- a/databricks-skills/databricks-model-serving/9-package-requirements.md
+++ b/databricks-skills/databricks-model-serving/9-package-requirements.md
@@ -137,24 +137,23 @@ export DATABRICKS_TOKEN="your-token"
export DATABRICKS_CONFIG_PROFILE="your-profile"
```
-## Installing Packages via MCP
+## Installing Packages
-Use `execute_code`:
+In a notebook or Python script on Databricks:
-```
-execute_code(
- code="%pip install -U mlflow==3.6.0 databricks-langchain langgraph==0.3.4 databricks-agents pydantic"
-)
+```python
+%pip install -U mlflow==3.6.0 databricks-langchain langgraph==0.3.4 databricks-agents pydantic
+dbutils.library.restartPython()
```
-Then restart Python:
+Or via job libraries configuration:
-```
-execute_code(
- code="dbutils.library.restartPython()",
- cluster_id="",
- context_id=""
-)
+```json
+"libraries": [
+ {"pypi": {"package": "mlflow==3.6.0"}},
+ {"pypi": {"package": "databricks-langchain"}},
+ {"pypi": {"package": "langgraph==0.3.4"}}
+]
```
## Checking Installed Versions
@@ -171,17 +170,13 @@ for pkg in packages:
print(f"{pkg}: NOT INSTALLED")
```
-Via MCP:
+In a notebook:
-```
-execute_code(
- code="""
+```python
import pkg_resources
for pkg in ['mlflow', 'langchain', 'langgraph', 'pydantic', 'databricks-langchain']:
try:
print(f"{pkg}: {pkg_resources.get_distribution(pkg).version}")
except:
print(f"{pkg}: NOT INSTALLED")
- """
-)
```
diff --git a/databricks-skills/databricks-model-serving/SKILL.md b/databricks-skills/databricks-model-serving/SKILL.md
index 74160298..bf520b5a 100644
--- a/databricks-skills/databricks-model-serving/SKILL.md
+++ b/databricks-skills/databricks-model-serving/SKILL.md
@@ -82,59 +82,40 @@ ALWAYS use exact endpoint names from this table. NEVER guess or abbreviate.
| Custom PyFunc | [2-custom-pyfunc.md](2-custom-pyfunc.md) | Custom preprocessing, signatures |
| GenAI Agents | [3-genai-agents.md](3-genai-agents.md) | ResponsesAgent, LangGraph |
| Tools Integration | [4-tools-integration.md](4-tools-integration.md) | UC Functions, Vector Search |
-| Development & Testing | [5-development-testing.md](5-development-testing.md) | MCP workflow, iteration |
+| Development & Testing | [5-development-testing.md](5-development-testing.md) | CLI workflow, iteration |
| Logging & Registration | [6-logging-registration.md](6-logging-registration.md) | mlflow.pyfunc.log_model |
| Deployment | [7-deployment.md](7-deployment.md) | Job-based async deployment |
-| Querying Endpoints | [8-querying-endpoints.md](8-querying-endpoints.md) | SDK, REST, MCP tools |
+| Querying Endpoints | [8-querying-endpoints.md](8-querying-endpoints.md) | CLI, SDK, REST |
| Package Requirements | [9-package-requirements.md](9-package-requirements.md) | DBR versions, pip |
---
## Quick Start: Deploy a GenAI Agent
-### Step 1: Install Packages (in notebook or via MCP)
+### Step 1: Install Packages (in notebook)
```python
%pip install -U mlflow==3.6.0 databricks-langchain langgraph==0.3.4 databricks-agents pydantic
dbutils.library.restartPython()
```
-Or via MCP:
-```
-execute_code(code="%pip install -U mlflow==3.6.0 databricks-langchain langgraph==0.3.4 databricks-agents pydantic")
-```
-
### Step 2: Create Agent File
Create `agent.py` locally with `ResponsesAgent` pattern (see [3-genai-agents.md](3-genai-agents.md)).
### Step 3: Upload to Workspace
-```
-manage_workspace_files(
- action="upload",
- local_path="./my_agent",
- workspace_path="/Workspace/Users/you@company.com/my_agent"
-)
+```bash
+databricks workspace import-dir ./my_agent /Workspace/Users/you@company.com/my_agent
```
### Step 4: Test Agent
-```
-execute_code(
- file_path="./my_agent/test_agent.py",
- cluster_id=""
-)
-```
+Run `test_agent.py` on a cluster to validate the agent works.
### Step 5: Log Model
-```
-execute_code(
- file_path="./my_agent/log_model.py",
- cluster_id=""
-)
-```
+Run `log_model.py` on a cluster to register the model in Unity Catalog.
### Step 6: Deploy (Async via Job)
@@ -142,12 +123,10 @@ See [7-deployment.md](7-deployment.md) for job-based deployment that doesn't tim
### Step 7: Query Endpoint
-```
-manage_serving_endpoint(
- action="query",
- name="my-agent-endpoint",
- messages=[{"role": "user", "content": "Hello!"}]
-)
+```bash
+databricks serving-endpoints query my-agent-endpoint --json '{
+ "messages": [{"role": "user", "content": "Hello!"}]
+}'
```
---
@@ -174,55 +153,50 @@ Then deploy via UI or SDK. See [1-classical-ml.md](1-classical-ml.md).
---
-## MCP Tools
+## CLI Commands
-> **If MCP tools are not available**, use the SDK/CLI examples in the reference files below.
+### Endpoint Management
-### Development & Testing
+```bash
+# List all serving endpoints
+databricks serving-endpoints list
-| Tool | Purpose |
-|------|---------|
-| `manage_workspace_files` (action="upload") | Upload agent files to workspace |
-| `execute_code` | Install packages, test agent, log model |
+# Get endpoint details and status
+databricks serving-endpoints get my-agent-endpoint
-### Deployment
+# Query a chat/agent endpoint
+databricks serving-endpoints query my-agent-endpoint --json '{
+ "messages": [{"role": "user", "content": "Hello!"}],
+ "max_tokens": 500
+}'
-| Tool | Purpose |
-|------|---------|
-| `manage_jobs` (action="create") | Create deployment job (one-time) |
-| `manage_job_runs` (action="run_now") | Kick off deployment (async) |
-| `manage_job_runs` (action="get") | Check deployment job status |
+# Query a traditional ML endpoint
+databricks serving-endpoints query sklearn-classifier --json '{
+ "dataframe_records": [{"age": 25, "income": 50000, "credit_score": 720}]
+}'
+```
-### manage_serving_endpoint - Querying
+### Workspace File Operations
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `get` | Check endpoint status (READY/NOT_READY/NOT_FOUND) | name |
-| `list` | List all endpoints | (none, optional limit) |
-| `query` | Send requests to endpoint | name + one of: messages, inputs, dataframe_records |
+```bash
+# Upload agent files to workspace
+databricks workspace import-dir ./my_agent /Workspace/Users/you@company.com/my_agent
-**Example usage:**
-```python
-# Check endpoint status
-manage_serving_endpoint(action="get", name="my-agent-endpoint")
+# List workspace files
+databricks workspace list /Workspace/Users/you@company.com/my_agent
+```
-# List all endpoints
-manage_serving_endpoint(action="list")
+### Jobs for Deployment
-# Query a chat/agent endpoint
-manage_serving_endpoint(
- action="query",
- name="my-agent-endpoint",
- messages=[{"role": "user", "content": "Hello!"}],
- max_tokens=500
-)
+```bash
+# Create a deployment job
+databricks jobs create --json @deploy_job.json
-# Query a traditional ML endpoint
-manage_serving_endpoint(
- action="query",
- name="sklearn-classifier",
- dataframe_records=[{"age": 25, "income": 50000, "credit_score": 720}]
-)
+# Run the deployment job
+databricks jobs run-now --job-id JOB_ID
+
+# Check job run status
+databricks jobs get-run --run-id RUN_ID
```
---
@@ -231,42 +205,27 @@ manage_serving_endpoint(
### Check Endpoint Status After Deployment
-```
-manage_serving_endpoint(action="get", name="my-agent-endpoint")
+```bash
+databricks serving-endpoints get my-agent-endpoint
```
-Returns:
-```json
-{
- "name": "my-agent-endpoint",
- "state": "READY",
- "served_entities": [...]
-}
-```
+Returns JSON with endpoint status (`READY`, `NOT_READY`, etc.).
### Query a Chat/Agent Endpoint
-```
-manage_serving_endpoint(
- action="query",
- name="my-agent-endpoint",
- messages=[
- {"role": "user", "content": "What is Databricks?"}
- ],
- max_tokens=500
-)
+```bash
+databricks serving-endpoints query my-agent-endpoint --json '{
+ "messages": [{"role": "user", "content": "What is Databricks?"}],
+ "max_tokens": 500
+}'
```
### Query a Traditional ML Endpoint
-```
-manage_serving_endpoint(
- action="query",
- name="sklearn-classifier",
- dataframe_records=[
- {"age": 25, "income": 50000, "credit_score": 720}
- ]
-)
+```bash
+databricks serving-endpoints query sklearn-classifier --json '{
+ "dataframe_records": [{"age": 25, "income": 50000, "credit_score": 720}]
+}'
```
---
diff --git a/databricks-skills/databricks-python-sdk/SKILL.md b/databricks-skills/databricks-python-sdk/SKILL.md
index eaf7cd66..4d03b5ce 100644
--- a/databricks-skills/databricks-python-sdk/SKILL.md
+++ b/databricks-skills/databricks-python-sdk/SKILL.md
@@ -91,7 +91,7 @@ databricks --profile MY_PROFILE clusters list
# Common commands
databricks clusters list
databricks jobs list
-databricks workspace ls /Users/me
+databricks workspace list /Users/me
```
---
diff --git a/databricks-skills/databricks-spark-declarative-pipelines/SKILL.md b/databricks-skills/databricks-spark-declarative-pipelines/SKILL.md
index a1bdd7c3..0efeb017 100644
--- a/databricks-skills/databricks-spark-declarative-pipelines/SKILL.md
+++ b/databricks-skills/databricks-spark-declarative-pipelines/SKILL.md
@@ -39,10 +39,10 @@ description: "Creates, configures, and updates Databricks Lakeflow Spark Declara
- When the user provides table schema and asks for code, respond directly with the code. Don't ask clarifying questions if the request is clear.
## Tools
-- List files in volume: `databricks fs ls dbfs:/Volumes/{catalog}/{schema}/{volume}/{path} --profile {PROFILE}`
-- Query data: `databricks experimental aitools tools query --profile {PROFILE} --warehouse abc123 "SELECT 1 FROM catalog.schema.table"`
-- Discover schema: `databricks experimental aitools tools discover-schema --profile {PROFILE} catalog.schema.table1 catalog.schema.table2`
-- Pipelines CLI: `databricks pipelines init|deploy|run|logs|stop` or use `databricks pipelines --help` for more options
+- List files in volume: `databricks fs ls /Volumes/{catalog}/{schema}/{volume}/{path}`
+- Query data: `databricks experimental aitools tools query --warehouse abc123 "SELECT 1 FROM catalog.schema.table"`
+- Discover schema: `databricks experimental aitools tools discover-schema catalog.schema.table1 catalog.schema.table2`
+- Pipelines CLI: `databricks pipelines create|get|delete|start-update|list-pipelines` or use `databricks pipelines --help` for more options
## Choose Your Workflow
@@ -83,15 +83,14 @@ Use this when the pipeline is **part of an existing DAB project**:
→ See [1-project-initialization.md](references/1-project-initialization.md) for adding pipelines to existing bundles
-### Option C: Rapid Iteration with MCP Tools (no bundle management)
+### Option C: Rapid Iteration with CLI (no bundle management, or you'll create the DAB at the end)
Use this when you need to **quickly create, test, and iterate** on a pipeline without managing bundle files:
- User wants to "just run a pipeline and see if it works"
- Part of a larger demo where bundle is managed separately, or the DAB bundle will be created at the end as you want to quickly test the project first
- Prototyping or experimenting with pipeline logic
-- User explicitly asks to use MCP tools
-→ See [2-mcp-approach.md](references/2-mcp-approach.md) for MCP-based workflow
+→ See [2-cli-approach.md](references/2-cli-approach.md) for CLI-based workflow
---
@@ -101,7 +100,7 @@ Before writing pipeline code, make sure you have:
```
- [ ] Language selected: Python or SQL
- [ ] Read the syntax basics: **SQL**: Always Read [sql/1-syntax-basics.md](references/sql/1-syntax-basics.md), **Python**: Always Read [python/1-syntax-basics.md](references/python/1-syntax-basics.md)
-- [ ] Workflow chosen: Standalone DAB / Existing DAB / MCP iteration
+- [ ] Workflow chosen: Standalone DAB / Existing DAB / CLI iteration
- [ ] Compute type: serverless (default) or classic
- [ ] Schema strategy: single schema with prefixes vs. multi-schema
- [ ] Consider [Multi-Schema Patterns](#multi-schema-patterns) and [Modern Defaults](#modern-defaults)
@@ -179,7 +178,7 @@ After choosing your workflow (see [Choose Your Workflow](#choose-your-workflow))
| Task | Guide |
|------|-------|
| **Setting up standalone pipeline project** | [1-project-initialization.md](references/1-project-initialization.md) |
-| **Rapid iteration with MCP tools** | [2-mcp-approach.md](references/2-mcp-approach.md) |
+| **Rapid iteration with CLI** | [2-cli-approach.md](references/2-cli-approach.md) |
| **Advanced configuration** | [3-advanced-configuration.md](references/3-advanced-configuration.md) |
| **Migrating from DLT** | [4-dlt-migration.md](references/4-dlt-migration.md) |
@@ -248,7 +247,7 @@ For detailed syntax, see [sql/1-syntax-basics.md](references/sql/1-syntax-basics
### Project Structure
- **Standalone pipeline projects**: Use `databricks pipelines init` for Asset Bundle with multi-environment support
- **Pipeline in existing bundle**: Add to `resources/*.pipeline.yml`
-- **Rapid iteration/prototyping**: Use MCP tools, formalize in bundle later
+- **Rapid iteration/prototyping**: Use CLI/SDK, formalize in bundle later
- See **[1-project-initialization.md](references/1-project-initialization.md)** for project setup details
### Minimal pipeline config pointers
@@ -278,30 +277,35 @@ For detailed examples, see **[3-advanced-configuration.md](references/3-advanced
## Post-Run Validation (Required)
-After running a pipeline (via DAB or MCP), you **MUST** validate both the execution status AND the actual data.
+After running a pipeline (via DAB or CLI), you **MUST** validate both the execution status AND the actual data.
### Step 1: Check Pipeline Execution Status
-**From MCP (`manage_pipeline(action="run")` or `manage_pipeline(action="create_or_update")`):**
-- Check `result["success"]` and `result["state"]`
-- If failed, check `result["message"]` and `result["errors"]` for details
+```bash
+# Get pipeline status and details (pipeline_id is positional)
+databricks pipelines get
+
+# Get recent events/logs
+databricks pipelines list-pipeline-events
+```
**From DAB (`databricks bundle run`):**
- Check the command output for success/failure
-- Use `manage_pipeline(action="get", pipeline_id=...)` to get detailed status and recent events
+- Use `databricks pipelines get ` to get detailed status and recent events
### Step 2: Validate Output Data
Even if the pipeline reports SUCCESS, you **MUST** verify the data is correct:
+```bash
+# Check schema, row counts, sample data, and null counts for all tables
+databricks experimental aitools tools discover-schema \
+ my_catalog.my_schema.bronze_orders \
+ my_catalog.my_schema.silver_orders \
+ my_catalog.my_schema.gold_summary
```
-# MCP Tool: get_table_stats_and_schema - validates schema, row counts, and stats
-get_table_stats_and_schema(
- catalog="my_catalog",
- schema="my_schema",
- table_names=["bronze_*", "silver_*", "gold_*"] # Use glob patterns
-)
-```
+
+This returns per table: columns/types, 5 sample rows, total_rows count, and null counts per column.
**Check for:**
- Empty tables (row_count = 0) - indicates ingestion or filtering issues
@@ -314,7 +318,7 @@ get_table_stats_and_schema(
If validation reveals problems, trace upstream to find the root cause:
1. **Start from the problematic table** - identify what's wrong (empty, wrong counts, bad data)
-2. **Check its source table** - use `get_table_stats_and_schema` on the upstream table
+2. **Check its source table** - run `DESCRIBE` and `COUNT(*)` on the upstream table
3. **Trace back to bronze** - continue until you find where the issue originates
4. **Common causes:**
- Bronze empty → source files missing or path incorrect
@@ -324,7 +328,7 @@ If validation reveals problems, trace upstream to find the root cause:
5. **Fix the SQL/Python code**, re-upload, and re-run the pipeline
-**Do NOT use `execute_sql` with COUNT queries for validation** - `get_table_stats_and_schema` is faster and returns more information in a single call.
+**Use `discover-schema` for validation** - it returns schema, row counts, sample data, and null counts in a single call.
---
@@ -332,17 +336,18 @@ If validation reveals problems, trace upstream to find the root cause:
| Issue | Solution |
|-------|----------|
-| **Empty output tables** | Use `get_table_stats_and_schema` to check upstream sources. Verify source files exist and paths are correct. |
+| **"Only SQL, Scala and Python notebooks are supported"** | Use `{"file": {"path": "..."}}` instead of `{"notebook": {"path": "..."}}` for raw SQL files. `notebook` is for Databricks notebook format only. |
+| **Empty output tables** | Use `discover-schema` to check upstream tables. Verify source files exist and paths are correct. |
| **Pipeline stuck INITIALIZING** | Normal for serverless, wait a few minutes |
| **"Column not found"** | Check `schemaHints` match actual data |
| **Streaming reads fail** | For file ingestion in a streaming table, you must use the `STREAM` keyword with `read_files`: `FROM STREAM read_files(...)`. For table streams use `FROM stream(table)`. See [read_files — Usage in streaming tables](https://docs.databricks.com/aws/en/sql/language-manual/functions/read_files#usage-in-streaming-tables). |
-| **Timeout during run** | Increase `timeout`, or use `wait_for_completion=False` and check status with `manage_pipeline(action="get")` |
+| **Timeout during run** | Use `databricks pipelines get ` to check status |
| **MV doesn't refresh** | Enable row tracking on source tables |
| **SCD2: query column not found** | Lakeflow uses `__START_AT` and `__END_AT` (double underscore), not `START_AT`/`END_AT`. Use `WHERE __END_AT IS NULL` for current rows. See [sql/4-cdc-patterns.md](references/sql/4-cdc-patterns.md). |
| **AUTO CDC parse error at APPLY/SEQUENCE** | Put `APPLY AS DELETE WHEN` **before** `SEQUENCE BY`. Only list columns in `COLUMNS * EXCEPT (...)` that exist in the source (omit `_rescued_data` unless bronze uses rescue data). Omit `TRACK HISTORY ON *` if it causes "end of input" errors; default is equivalent. See [sql/4-cdc-patterns.md](references/sql/4-cdc-patterns.md). |
| **"Cannot create streaming table from batch query"** | In a streaming table query, use `FROM STREAM read_files(...)` so `read_files` leverages Auto Loader; `FROM read_files(...)` alone is batch. See [sql/2-ingestion.md](references/sql/2-ingestion.md) and [read_files — Usage in streaming tables](https://docs.databricks.com/aws/en/sql/language-manual/functions/read_files#usage-in-streaming-tables). |
-**For detailed errors**, the `result["message"]` from `manage_pipeline(action="create_or_update")` includes suggested next steps. Use `manage_pipeline(action="get", pipeline_id=...)` which includes recent events and error details.
+**For detailed errors**, use `databricks pipelines get ` which includes recent events, or `databricks pipelines list-pipeline-events ` for full event history.
---
diff --git a/databricks-skills/databricks-spark-declarative-pipelines/references/1-project-initialization.md b/databricks-skills/databricks-spark-declarative-pipelines/references/1-project-initialization.md
index fbab69b3..40cc8d4a 100644
--- a/databricks-skills/databricks-spark-declarative-pipelines/references/1-project-initialization.md
+++ b/databricks-skills/databricks-spark-declarative-pipelines/references/1-project-initialization.md
@@ -232,8 +232,8 @@ databricks bundle run customer_pipeline_etl
# Run specific target
databricks bundle run customer_pipeline_etl --target prod
-# Or use Pipeline API directly
-databricks pipelines start-update --pipeline-id
+# Or use Pipeline API directly (pipeline_id is positional)
+databricks pipelines start-update
```
---
@@ -429,7 +429,7 @@ pip install --upgrade databricks-cli
databricks catalogs list
# Create catalog if needed
-databricks catalogs create --name my_catalog
+databricks catalogs create --json '{"name": "my_catalog"}'
```
### "Language option not recognized"
@@ -576,7 +576,7 @@ For technical best practices (Liquid Clustering, serverless, etc.), see **[SKILL
## References
-- **[SKILL.md](../SKILL.md)** - Main development workflow and MCP tools
+- **[SKILL.md](../SKILL.md)** - Main development workflow and CLI commands
- **[Declarative Automation Bundles (DABs) Documentation](https://docs.databricks.com/dev-tools/bundles/)** - Official bundle reference
- **[Pipeline Configuration Reference](https://docs.databricks.com/aws/en/ldp/configure-pipeline)** - Pipeline settings
- **[Databricks CLI Reference](https://docs.databricks.com/dev-tools/cli/)** - CLI commands and options
diff --git a/databricks-skills/databricks-spark-declarative-pipelines/references/2-cli-approach.md b/databricks-skills/databricks-spark-declarative-pipelines/references/2-cli-approach.md
new file mode 100644
index 00000000..4fdd27df
--- /dev/null
+++ b/databricks-skills/databricks-spark-declarative-pipelines/references/2-cli-approach.md
@@ -0,0 +1,174 @@
+# Rapid Pipeline Iteration with CLI
+
+Use CLI commands to create, run, and iterate on **SDP pipelines**. This is the fastest approach for prototyping without managing bundle files.
+
+**IMPORTANT: Default to serverless pipelines.** Only use classic clusters if user explicitly requires R language, Spark RDD APIs, or JAR libraries.
+
+### Step 1: Write Pipeline Files Locally
+
+Create `.sql` or `.py` files in a local folder. For syntax examples, see:
+- [sql/1-syntax-basics.md](sql/1-syntax-basics.md) for SQL syntax
+- [python/1-syntax-basics.md](python/1-syntax-basics.md) for Python syntax
+
+### Step 2: Upload to Databricks Workspace
+
+```bash
+# Upload local folder to workspace
+databricks workspace import-dir ./my_pipeline /Workspace/Users/user@example.com/my_pipeline
+```
+
+### Step 3: Create Pipeline
+
+```bash
+# Create pipeline with JSON config
+# Use "file" - can point to a single .sql/.py file OR a directory (includes all files)
+databricks pipelines create --json '{
+ "name": "my_orders_pipeline",
+ "catalog": "my_catalog",
+ "schema": "my_schema",
+ "serverless": true,
+ "libraries": [
+ {"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}
+ ],
+ "development": true
+}'
+
+# Or specify individual files:
+# "libraries": [
+# {"file": {"path": "/Workspace/.../bronze/ingest_orders.sql"}},
+# {"file": {"path": "/Workspace/.../silver/clean_orders.sql"}}
+# ]
+#
+# Legacy (avoid): {"notebook": {"path": "..."}} - use "file" instead
+```
+
+Save the returned `pipeline_id` for subsequent operations.
+
+### Step 4: Run Pipeline
+
+```bash
+# Start a full refresh run (pipeline_id is a positional argument)
+databricks pipelines start-update --full-refresh
+
+# Check run status
+databricks pipelines get
+```
+
+### Step 5: Validate Results
+
+**On Success** - Verify tables were created with correct data:
+
+```bash
+# Check schema, row counts, sample data, and null counts for all tables
+databricks experimental aitools tools discover-schema \
+ my_catalog.my_schema.bronze_orders \
+ my_catalog.my_schema.silver_orders \
+ my_catalog.my_schema.gold_summary
+```
+
+This returns per table: columns/types, 5 sample rows, total_rows count, and null counts.
+
+Or use Python for detailed stats:
+```python
+from databricks.sdk import WorkspaceClient
+
+w = WorkspaceClient()
+
+# Get table info
+table = w.tables.get("my_catalog.my_schema.bronze_orders")
+print(f"Columns: {len(table.columns)}")
+print(f"Created: {table.created_at}")
+```
+
+**On Failure** - Get pipeline events and errors:
+
+```bash
+# Get pipeline details with recent events (pipeline_id is positional)
+databricks pipelines get
+
+# Get specific run events
+databricks pipelines list-pipeline-events
+```
+
+### Step 6: Iterate Until Working
+
+1. Review errors from pipeline status or events
+2. Fix issues in local files
+3. Re-upload: `databricks workspace import-dir ./my_pipeline /Workspace/Users/user@example.com/my_pipeline --overwrite`
+4. Update and run: `databricks pipelines update --json '...'` then `databricks pipelines start-update `
+5. Repeat until pipeline completes successfully
+
+---
+
+## Quick Reference: CLI Commands
+
+### Pipeline Lifecycle
+
+| Command | Description |
+|---------|-------------|
+| `databricks pipelines create --json '{...}'` | Create new pipeline |
+| `databricks pipelines get PIPELINE_ID` | Get pipeline details and status |
+| `databricks pipelines update PIPELINE_ID --json '{...}'` | Update pipeline config |
+| `databricks pipelines delete PIPELINE_ID` | Delete a pipeline |
+| `databricks pipelines list-pipelines` | List all pipelines |
+
+### Run Management
+
+| Command | Description |
+|---------|-------------|
+| `databricks pipelines start-update PIPELINE_ID` | Start pipeline update |
+| `databricks pipelines start-update PIPELINE_ID --full-refresh` | Start with full refresh |
+| `databricks pipelines stop PIPELINE_ID` | Stop running pipeline |
+| `databricks pipelines list-pipeline-events PIPELINE_ID` | Get events/logs |
+| `databricks pipelines list-updates PIPELINE_ID` | List recent runs |
+
+### Supporting Commands
+
+| Command | Description |
+|---------|-------------|
+| `databricks workspace import-dir` | Upload files/folders to workspace |
+| `databricks workspace list` | List workspace files |
+| `databricks experimental aitools tools discover-schema` | Get schema, row counts, sample data, null counts |
+| `databricks experimental aitools tools query` | Run ad-hoc SQL queries |
+
+---
+
+## Python SDK Alternative
+
+For more programmatic control, use the Databricks SDK:
+
+```python
+from databricks.sdk import WorkspaceClient
+
+w = WorkspaceClient()
+
+# Create pipeline - use "file" to include all .sql/.py files in a directory
+pipeline = w.pipelines.create(
+ name="my_orders_pipeline",
+ catalog="my_catalog",
+ schema="my_schema",
+ serverless=True,
+ libraries=[
+ {"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}
+ ],
+ development=True
+)
+print(f"Created pipeline: {pipeline.pipeline_id}")
+
+# Start update
+update = w.pipelines.start_update(
+ pipeline_id=pipeline.pipeline_id,
+ full_refresh=True
+)
+
+# Poll for completion
+import time
+while True:
+ status = w.pipelines.get(pipeline_id=pipeline.pipeline_id)
+ if status.state in ["IDLE", "FAILED"]:
+ print(f"Pipeline state: {status.state}")
+ break
+ time.sleep(10)
+```
+
+---
diff --git a/databricks-skills/databricks-spark-declarative-pipelines/references/2-mcp-approach.md b/databricks-skills/databricks-spark-declarative-pipelines/references/2-mcp-approach.md
deleted file mode 100644
index 87e0ed70..00000000
--- a/databricks-skills/databricks-spark-declarative-pipelines/references/2-mcp-approach.md
+++ /dev/null
@@ -1,163 +0,0 @@
-Use MCP tools to create, run, and iterate on **SDP pipelines**. The **primary tool is `manage_pipeline`** which handles the entire lifecycle.
-
-**IMPORTANT: Default to serverless pipelines.** Only use classic clusters if user explicitly requires R language, Spark RDD APIs, or JAR libraries.
-
-### Step 1: Write Pipeline Files Locally
-
-Create `.sql` or `.py` files in a local folder. For syntax examples, see:
-- [sql/1-syntax-basics.md](sql/1-syntax-basics.md) for SQL syntax
-- [python/1-syntax-basics.md](python/1-syntax-basics.md) for Python syntax
-
-### Step 2: Upload to Databricks Workspace
-
-```
-# MCP Tool: manage_workspace_files
-manage_workspace_files(
- action="upload",
- local_path="/path/to/my_pipeline",
- workspace_path="/Workspace/Users/user@example.com/my_pipeline"
-)
-```
-
-### Step 3: Create/Update and Run Pipeline
-
-Use **`manage_pipeline`** with `action="create_or_update"` to manage the resource:
-
-```
-# MCP Tool: manage_pipeline
-manage_pipeline(
- action="create_or_update",
- name="my_orders_pipeline",
- root_path="/Workspace/Users/user@example.com/my_pipeline",
- catalog="my_catalog",
- schema="my_schema",
- workspace_file_paths=[
- "/Workspace/Users/user@example.com/my_pipeline/bronze/ingest_orders.sql",
- "/Workspace/Users/user@example.com/my_pipeline/silver/clean_orders.sql",
- "/Workspace/Users/user@example.com/my_pipeline/gold/daily_summary.sql"
- ],
- start_run=True, # Automatically run after create/update
- wait_for_completion=True, # Wait for run to finish
- full_refresh=True # Reprocess all data
-)
-```
-
-**Result contains actionable information:**
-```json
-{
- "success": true,
- "pipeline_id": "abc-123",
- "pipeline_name": "my_orders_pipeline",
- "created": true,
- "state": "COMPLETED",
- "catalog": "my_catalog",
- "schema": "my_schema",
- "duration_seconds": 45.2,
- "message": "Pipeline created and completed successfully in 45.2s. Tables written to my_catalog.my_schema",
- "error_message": null,
- "errors": []
-}
-```
-
-### Alternative: Run Pipeline Separately
-
-If you want to run an existing pipeline or control the run separately:
-
-```
-# MCP Tool: manage_pipeline_run
-manage_pipeline_run(
- action="start",
- pipeline_id="",
- full_refresh=True,
- wait=True, # Wait for completion
- timeout=1800 # 30 minute timeout
-)
-```
-
-### Step 4: Validate Results
-
-**On Success** - Use `get_table_stats_and_schema` to verify tables (NOT manual SQL COUNT queries):
-```
-# MCP Tool: get_table_stats_and_schema
-get_table_stats_and_schema(
- catalog="my_catalog",
- schema="my_schema",
- table_names=["bronze_orders", "silver_orders", "gold_daily_summary"]
-)
-# Returns schema, row counts, and column stats for all tables in one call
-```
-
-**On Failure** - Check `run_result["message"]` for suggested next steps, then get detailed errors:
-```
-# MCP Tool: manage_pipeline
-manage_pipeline(action="get", pipeline_id="")
-# Returns pipeline details enriched with recent events and error messages
-
-# Or get events/logs directly:
-# MCP Tool: manage_pipeline_run
-manage_pipeline_run(
- action="get_events",
- pipeline_id="",
- event_log_level="ERROR", # ERROR, WARN, or INFO
- max_results=10
-)
-```
-
-### Step 5: Iterate Until Working
-
-1. Review errors from run result or `manage_pipeline(action="get")`
-2. Fix issues in local files
-3. Re-upload with `manage_workspace_files(action="upload")`
-4. Run `manage_pipeline(action="create_or_update", start_run=True)` again (it will update, not recreate)
-5. Repeat until `result["success"] == True`
-
----
-
-## Quick Reference: MCP Tools
-
-### manage_pipeline - Pipeline Lifecycle
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create` | Create new pipeline | name, root_path, catalog, schema, workspace_file_paths |
-| `create_or_update` | **Main entry point.** Idempotent create/update, optionally run | name, root_path, catalog, schema, workspace_file_paths |
-| `get` | Get pipeline details by ID | pipeline_id |
-| `update` | Update pipeline config | pipeline_id + fields to change |
-| `delete` | Delete a pipeline | pipeline_id |
-| `find_by_name` | Find pipeline by name | name |
-
-**create_or_update options:**
-- `start_run=True`: Automatically run after create/update
-- `wait_for_completion=True`: Block until run finishes
-- `full_refresh=True`: Reprocess all data (default)
-- `timeout=1800`: Max wait time in seconds
-
-### manage_pipeline_run - Run Management
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `start` | Start pipeline update | pipeline_id |
-| `get` | Get run status | pipeline_id, update_id |
-| `stop` | Stop running pipeline | pipeline_id |
-| `get_events` | Get events/logs for debugging | pipeline_id |
-
-**start options:**
-- `wait=True`: Block until complete (default)
-- `full_refresh=True`: Reprocess all data
-- `validate_only=True`: Dry run without writing data
-- `refresh_selection=["table1", "table2"]`: Refresh specific tables only
-
-**get_events options:**
-- `event_log_level`: "ERROR", "WARN" (default), "INFO"
-- `max_results`: Number of events (default 5)
-- `update_id`: Filter to specific run
-
-### Supporting Tools
-
-| Tool | Description |
-|------|-------------|
-| `manage_workspace_files(action="upload")` | Upload files/folders to workspace |
-| `get_table_stats_and_schema` | **Use this to validate tables** - returns schema, row counts, and stats in one call |
-| `execute_sql` | Run ad-hoc SQL to inspect actual data content (not for row counts) |
-
----
diff --git a/databricks-skills/databricks-spark-declarative-pipelines/references/3-advanced-configuration.md b/databricks-skills/databricks-spark-declarative-pipelines/references/3-advanced-configuration.md
index b637f469..6a349f78 100644
--- a/databricks-skills/databricks-spark-declarative-pipelines/references/3-advanced-configuration.md
+++ b/databricks-skills/databricks-spark-declarative-pipelines/references/3-advanced-configuration.md
@@ -1,13 +1,13 @@
-# Advanced Pipeline Configuration (`extra_settings`)
+# Advanced Pipeline Configuration
-By default, pipelines are created with **serverless compute and Unity Catalog**. Use the `extra_settings` parameter only for advanced use cases.
+By default, pipelines are created with **serverless compute and Unity Catalog**. Use advanced configuration options only when needed.
-**CRITICAL: Do NOT use `extra_settings` to set `serverless=false` unless the user explicitly requires:**
+**CRITICAL: Do NOT set `serverless=false` unless the user explicitly requires:**
- R language support
- Spark RDD APIs
- JAR libraries or Maven coordinates
-## When to Use `extra_settings`
+## When to Use Advanced Configuration
- **Development mode**: Faster iteration with relaxed validation
- **Continuous pipelines**: Real-time streaming instead of triggered runs
@@ -16,7 +16,9 @@ By default, pipelines are created with **serverless compute and Unity Catalog**.
- **Python dependencies**: Install pip packages for serverless pipelines
- **Classic clusters** (rare): Only if user explicitly needs R, RDD APIs, or JARs
-## `extra_settings` Parameter Reference
+## Pipeline JSON Configuration Reference
+
+These fields can be passed to `databricks pipelines create --json '{...}'` or `databricks pipelines update --json '{...}'`.
### Top-Level Fields
@@ -157,198 +159,159 @@ Install pip dependencies for serverless pipelines:
## Configuration Examples
+All examples use `databricks pipelines create --json '{...}'`. For updates, use `databricks pipelines update --json '{...}'`.
+
### Development Mode Pipeline
-Use `manage_pipeline(action="create_or_update")` tool with:
-- `name`: "my_dev_pipeline"
-- `root_path`: "/Workspace/Users/user@example.com/my_pipeline"
-- `catalog`: "dev_catalog"
-- `schema`: "dev_schema"
-- `workspace_file_paths`: [...]
-- `start_run`: true
-- `extra_settings`:
-```json
-{
- "development": true,
- "tags": {"environment": "development", "owner": "data-team"}
-}
+```bash
+databricks pipelines create --json '{
+ "name": "my_dev_pipeline",
+ "catalog": "dev_catalog",
+ "schema": "dev_schema",
+ "serverless": true,
+ "development": true,
+ "libraries": [{"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}],
+ "tags": {"environment": "development", "owner": "data-team"}
+}'
```
### Non-Serverless with Dedicated Cluster
-Use `manage_pipeline(action="create_or_update")` tool with `extra_settings`:
-```json
-{
- "serverless": false,
- "clusters": [{
- "label": "default",
- "num_workers": 4,
- "node_type_id": "i3.xlarge",
- "custom_tags": {"cost_center": "analytics"}
- }],
- "photon": true,
- "edition": "ADVANCED"
-}
+```bash
+databricks pipelines create --json '{
+ "name": "my_pipeline",
+ "catalog": "my_catalog",
+ "schema": "my_schema",
+ "serverless": false,
+ "photon": true,
+ "edition": "ADVANCED",
+ "clusters": [{
+ "label": "default",
+ "num_workers": 4,
+ "node_type_id": "i3.xlarge",
+ "custom_tags": {"cost_center": "analytics"}
+ }],
+ "libraries": [{"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}]
+}'
```
### Continuous Streaming Pipeline
-Use `manage_pipeline(action="create_or_update")` tool with `extra_settings`:
-```json
-{
- "continuous": true,
- "configuration": {
- "spark.sql.shuffle.partitions": "auto"
- }
-}
-```
-
-### Using Instance Pool
-
-Use `manage_pipeline(action="create_or_update")` tool with `extra_settings`:
-```json
-{
- "serverless": false,
- "clusters": [{
- "label": "default",
- "instance_pool_id": "0727-104344-hauls13-pool-xyz",
- "num_workers": 2,
- "custom_tags": {"project": "analytics"}
- }]
-}
-```
-
-### Custom Event Log Location
-
-Use `manage_pipeline(action="create_or_update")` tool with `extra_settings`:
-```json
-{
- "event_log": {
- "catalog": "audit_catalog",
- "schema": "pipeline_logs",
- "name": "my_pipeline_events"
- }
-}
+```bash
+databricks pipelines create --json '{
+ "name": "my_streaming_pipeline",
+ "catalog": "my_catalog",
+ "schema": "my_schema",
+ "serverless": true,
+ "continuous": true,
+ "configuration": {"spark.sql.shuffle.partitions": "auto"},
+ "libraries": [{"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}]
+}'
```
### Pipeline with Email Notifications
-Use `manage_pipeline(action="create_or_update")` tool with `extra_settings`:
-```json
-{
- "notifications": [{
- "email_recipients": ["team@example.com", "oncall@example.com"],
- "alerts": ["on-update-failure", "on-update-fatal-failure", "on-flow-failure"]
- }]
-}
+```bash
+databricks pipelines create --json '{
+ "name": "my_pipeline",
+ "catalog": "my_catalog",
+ "schema": "my_schema",
+ "serverless": true,
+ "libraries": [{"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}],
+ "notifications": [{
+ "email_recipients": ["team@example.com", "oncall@example.com"],
+ "alerts": ["on-update-failure", "on-update-fatal-failure", "on-flow-failure"]
+ }]
+}'
```
### Production Pipeline with Autoscaling
-Use `manage_pipeline(action="create_or_update")` tool with `extra_settings`:
-```json
-{
- "serverless": false,
- "development": false,
- "photon": true,
- "edition": "ADVANCED",
- "clusters": [{
- "label": "default",
- "autoscale": {
- "min_workers": 2,
- "max_workers": 8,
- "mode": "ENHANCED"
- },
- "node_type_id": "i3.xlarge",
- "spark_conf": {
- "spark.sql.adaptive.enabled": "true"
- },
- "custom_tags": {"environment": "production"}
- }],
- "notifications": [{
- "email_recipients": ["data-team@example.com"],
- "alerts": ["on-update-failure"]
- }]
-}
+```bash
+databricks pipelines create --json '{
+ "name": "prod_pipeline",
+ "catalog": "prod_catalog",
+ "schema": "prod_schema",
+ "serverless": false,
+ "development": false,
+ "photon": true,
+ "edition": "ADVANCED",
+ "clusters": [{
+ "label": "default",
+ "autoscale": {"min_workers": 2, "max_workers": 8, "mode": "ENHANCED"},
+ "node_type_id": "i3.xlarge",
+ "spark_conf": {"spark.sql.adaptive.enabled": "true"},
+ "custom_tags": {"environment": "production"}
+ }],
+ "libraries": [{"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}],
+ "notifications": [{"email_recipients": ["data-team@example.com"], "alerts": ["on-update-failure"]}]
+}'
```
-### Run as Service Principal
+### Serverless with Python Dependencies
-Use `manage_pipeline(action="create_or_update")` tool with `extra_settings`:
-```json
-{
- "run_as": {
- "service_principal_name": "00000000-0000-0000-0000-000000000000"
- }
-}
+```bash
+databricks pipelines create --json '{
+ "name": "ml_pipeline",
+ "catalog": "my_catalog",
+ "schema": "my_schema",
+ "serverless": true,
+ "libraries": [{"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}],
+ "environment": {
+ "dependencies": ["scikit-learn==1.3.0", "pandas>=2.0.0", "requests"]
+ }
+}'
```
### Continuous Pipeline with Restart Window
-Use `manage_pipeline(action="create_or_update")` tool with `extra_settings`:
-```json
-{
- "continuous": true,
- "restart_window": {
- "start_hour": 2,
- "days_of_week": ["SATURDAY", "SUNDAY"],
- "time_zone_id": "America/Los_Angeles"
- }
-}
+```bash
+databricks pipelines create --json '{
+ "name": "realtime_pipeline",
+ "catalog": "my_catalog",
+ "schema": "my_schema",
+ "serverless": true,
+ "continuous": true,
+ "libraries": [{"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}],
+ "restart_window": {
+ "start_hour": 2,
+ "days_of_week": ["SATURDAY", "SUNDAY"],
+ "time_zone_id": "America/Los_Angeles"
+ }
+}'
```
-### Serverless with Python Dependencies
+### Custom Event Log Location
-Use `manage_pipeline(action="create_or_update")` tool with `extra_settings`:
-```json
-{
- "serverless": true,
- "environment": {
- "dependencies": [
- "scikit-learn==1.3.0",
- "pandas>=2.0.0",
- "requests"
- ]
- }
-}
+```bash
+databricks pipelines create --json '{
+ "name": "my_pipeline",
+ "catalog": "my_catalog",
+ "schema": "my_schema",
+ "serverless": true,
+ "libraries": [{"file": {"path": "/Workspace/Users/user@example.com/my_pipeline"}}],
+ "event_log": {
+ "catalog": "audit_catalog",
+ "schema": "pipeline_logs",
+ "name": "my_pipeline_events"
+ }
+}'
```
-### Update Existing Pipeline by ID
+### Update Existing Pipeline
-If you have a pipeline ID from the Databricks UI, you can force an update by including `id` in `extra_settings`:
-```json
-{
- "id": "554f4497-4807-4182-bff0-ffac4bb4f0ce"
-}
-```
+```bash
+# Update pipeline configuration
+databricks pipelines update --json '{
+ "name": "updated_pipeline_name",
+ "development": false,
+ "notifications": [{"email_recipients": ["team@example.com"], "alerts": ["on-update-failure"]}]
+}'
-### Full JSON Export from Databricks UI
-
-You can copy pipeline settings from the Databricks UI (Pipeline Settings > JSON) and pass them directly as `extra_settings`. Invalid fields like `pipeline_type` are automatically filtered:
-
-```json
-{
- "id": "554f4497-4807-4182-bff0-ffac4bb4f0ce",
- "pipeline_type": "WORKSPACE",
- "continuous": false,
- "development": true,
- "photon": false,
- "edition": "ADVANCED",
- "channel": "CURRENT",
- "clusters": [{
- "label": "default",
- "num_workers": 1,
- "instance_pool_id": "0727-104344-pool-xyz"
- }],
- "configuration": {
- "catalog": "main",
- "schema": "my_schema"
- }
-}
+# Then run it
+databricks pipelines start-update --full-refresh
```
-**Note**: Explicit tool parameters (`name`, `root_path`, `catalog`, `schema`, `workspace_file_paths`) always take precedence over values in `extra_settings`.
-
---
## Multi-Schema Patterns
diff --git a/databricks-skills/databricks-synthetic-data-gen/SKILL.md b/databricks-skills/databricks-synthetic-data-gen/SKILL.md
index c046e488..3e6f5b71 100644
--- a/databricks-skills/databricks-synthetic-data-gen/SKILL.md
+++ b/databricks-skills/databricks-synthetic-data-gen/SKILL.md
@@ -126,26 +126,30 @@ Show a clear specification with **the business story and your assumptions surfac
**Do NOT proceed to code generation until user approves the plan, including the catalog.**
-### Post-Generation Checklist
+### Post-Generation Validation
-After generating data, use `get_volume_folder_details` to validate the output matches requirements:
-- Row counts match the plan
-- Schema matches expected columns and types
-- Data distributions look reasonable (check column stats)
+Use `databricks experimental aitools tools query` to validate generated data (row counts, distributions, referential integrity). Query parquet files directly:
-## Use Databricks Connect Spark + Faker Pattern
+```bash
+databricks experimental aitools tools query --warehouse $WAREHOUSE_ID "
+SELECT COUNT(*) FROM parquet.\`/Volumes/CATALOG/SCHEMA/raw_data/customers\`
+"
+```
+
+See [references/2-troubleshooting.md](references/2-troubleshooting.md) for full validation examples.
+
+## Use Databricks Connect Spark + Faker Pattern
```python
-from databricks.connect import DatabricksSession, DatabricksEnv
+from databricks.connect import DatabricksSession
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
import pandas as pd
-# Setup serverless with dependencies (MUST list all libs used in UDFs)
-env = DatabricksEnv().withDependencies("faker", "holidays")
-spark = DatabricksSession.builder.withEnvironment(env).serverless(True).getOrCreate()
+# Setup serverless Spark session
+spark = DatabricksSession.builder.serverless(True).getOrCreate()
-# Pandas UDF pattern - import lib INSIDE the function
+# Pandas UDF pattern - import lib INSIDE the function (libs must be installed locally)
@F.pandas_udf(StringType())
def fake_name(ids: pd.Series) -> pd.Series:
from faker import Faker # Import inside UDF
@@ -248,9 +252,7 @@ uv pip install "databricks-connect>=16.4,<17.4" faker numpy pandas holidays
| Issue | Solution |
|-------|----------|
-| `ImportError: cannot import name 'DatabricksEnv'` | Upgrade: `uv pip install "databricks-connect>=16.4"` |
-| Python 3.11 instead of 3.12 | Python 3.12 required. Use `uv` to create env with correct version |
-| `ModuleNotFoundError: faker` | Add to `withDependencies()`, import inside UDF |
+| `ModuleNotFoundError: faker` | Install locally: `uv pip install faker`, import inside UDF |
| Faker UDF is slow | Use `pandas_udf` for batch processing |
| Out of memory | Increase `numPartitions` in `spark.range()` |
| Referential integrity errors | Write master table to Delta first, read back for FK joins |
diff --git a/databricks-skills/databricks-synthetic-data-gen/references/2-troubleshooting.md b/databricks-skills/databricks-synthetic-data-gen/references/2-troubleshooting.md
index 420b3500..793b64f7 100644
--- a/databricks-skills/databricks-synthetic-data-gen/references/2-troubleshooting.md
+++ b/databricks-skills/databricks-synthetic-data-gen/references/2-troubleshooting.md
@@ -12,31 +12,16 @@ Common issues and solutions for synthetic data generation.
| Mode | Solution |
|------|----------|
-| **DB Connect 16.4+** | Use `DatabricksEnv().withDependencies("faker", "pandas", ...)` |
-| **Older DB Connect with Serverless** | Create job with `environments` parameter |
-| **Databricks Runtime** | Use Databricks CLI to install `faker holidays` |
+| **DB Connect with Serverless** | Install libs locally (`uv pip install faker`), use `DatabricksSession.builder.serverless(True)` |
+| **Databricks Runtime** | Use Databricks CLI to install `faker holidays` |
| **Classic cluster** | Use Databricks CLI to install libraries. `databricks libraries install --json '{"cluster_id": "", "libraries": [{"pypi": {"package": "faker"}}, {"pypi": {"package": "holidays"}}]}'` |
```python
-# For DB Connect 16.4+
-from databricks.connect import DatabricksSession, DatabricksEnv
+# For DB Connect with serverless
+from databricks.connect import DatabricksSession
-env = DatabricksEnv().withDependencies("faker", "pandas", "numpy", "holidays")
-spark = DatabricksSession.builder.withEnvironment(env).serverless(True).getOrCreate()
-```
-
-### DatabricksEnv not found
-
-**Problem:** Using older databricks-connect version.
-
-**Solution:** Upgrade to 16.4+ or use job-based approach:
-
-```bash
-# Upgrade (prefer uv, fall back to pip)
-uv pip install "databricks-connect>=16.4,<17.4"
-# or: pip install "databricks-connect>=16.4,<17.4"
-
-# Or use job with environments parameter instead
+# Install dependencies locally first: uv pip install faker pandas numpy holidays
+spark = DatabricksSession.builder.serverless(True).getOrCreate()
```
### serverless_compute_id error
@@ -300,25 +285,57 @@ resolution_hours = np.random.exponential(scale=resolution_scale[priority])
## Validation Steps
-After generation, verify your data:
+After generation, validate using SQL queries via Databricks CLI:
-```python
-# 1. Check row counts
-print(f"Customers: {customers_df.count():,}")
-print(f"Orders: {orders_df.count():,}")
-
-# 2. Verify distributions
-customers_df.groupBy("tier").count().show()
-orders_df.describe("amount").show()
-
-# 3. Check referential integrity
-orphans = orders_df.join(
- customers_df,
- orders_df.customer_id == customers_df.customer_id,
- "left_anti"
-)
-print(f"Orphan orders: {orphans.count()}")
+```bash
+# Set your warehouse ID
+WAREHOUSE_ID="your-warehouse-id"
+VOLUME_PATH="/Volumes/CATALOG/SCHEMA/raw_data"
-# 4. Verify date range
-orders_df.select(F.min("order_date"), F.max("order_date")).show()
+# 1. Check row counts
+databricks experimental aitools tools query --warehouse $WAREHOUSE_ID "
+SELECT 'customers' as table_name, COUNT(*) as row_count FROM parquet.\`${VOLUME_PATH}/customers\`
+UNION ALL
+SELECT 'orders', COUNT(*) FROM parquet.\`${VOLUME_PATH}/orders\`
+"
+
+# 2. Preview schema and sample data
+databricks experimental aitools tools query --warehouse $WAREHOUSE_ID "
+DESCRIBE SELECT * FROM parquet.\`${VOLUME_PATH}/customers\`
+"
+
+databricks experimental aitools tools query --warehouse $WAREHOUSE_ID "
+SELECT * FROM parquet.\`${VOLUME_PATH}/customers\` LIMIT 5
+"
+
+# 3. Verify distributions
+databricks experimental aitools tools query --warehouse $WAREHOUSE_ID "
+SELECT tier, COUNT(*) as count, ROUND(COUNT(*) * 100.0 / SUM(COUNT(*)) OVER(), 1) as pct
+FROM parquet.\`${VOLUME_PATH}/customers\`
+GROUP BY tier ORDER BY tier
+"
+
+# 4. Check amount statistics
+databricks experimental aitools tools query --warehouse $WAREHOUSE_ID "
+SELECT
+ MIN(amount) as min_amount,
+ MAX(amount) as max_amount,
+ ROUND(AVG(amount), 2) as avg_amount,
+ ROUND(STDDEV(amount), 2) as stddev_amount
+FROM parquet.\`${VOLUME_PATH}/orders\`
+"
+
+# 5. Check referential integrity
+databricks experimental aitools tools query --warehouse $WAREHOUSE_ID "
+SELECT COUNT(*) as orphan_orders
+FROM parquet.\`${VOLUME_PATH}/orders\` o
+LEFT JOIN parquet.\`${VOLUME_PATH}/customers\` c ON o.customer_id = c.customer_id
+WHERE c.customer_id IS NULL
+"
+
+# 6. Verify date range
+databricks experimental aitools tools query --warehouse $WAREHOUSE_ID "
+SELECT MIN(order_date) as min_date, MAX(order_date) as max_date
+FROM parquet.\`${VOLUME_PATH}/orders\`
+"
```
diff --git a/databricks-skills/databricks-synthetic-data-gen/scripts/generate_synthetic_data.py b/databricks-skills/databricks-synthetic-data-gen/scripts/generate_synthetic_data.py
index b9f953fa..b36edb8e 100644
--- a/databricks-skills/databricks-synthetic-data-gen/scripts/generate_synthetic_data.py
+++ b/databricks-skills/databricks-synthetic-data-gen/scripts/generate_synthetic_data.py
@@ -6,9 +6,9 @@
- Direct write to Unity Catalog
- Works with serverless and classic compute
-Auto-detects environment and uses:
-- DatabricksEnv with managed dependencies if databricks-connect >= 16.4 (local)
-- Standard session if running on Databricks Runtime or older databricks-connect
+Prerequisites:
+- Install dependencies locally: uv pip install faker pandas numpy holidays databricks-connect
+- Configure ~/.databrickscfg with serverless_compute_id = auto
"""
import sys
import os
@@ -61,105 +61,23 @@
REGION_PROBS = [0.4, 0.25, 0.2, 0.15]
# =============================================================================
-# ENVIRONMENT DETECTION AND SESSION CREATION
+# SESSION CREATION
# =============================================================================
-def is_databricks_runtime():
- """Check if running on Databricks Runtime vs locally."""
- return "DATABRICKS_RUNTIME_VERSION" in os.environ
-
-def get_databricks_connect_version():
- """Get databricks-connect version as (major, minor) tuple or None."""
- try:
- import importlib.metadata
- version_str = importlib.metadata.version('databricks-connect')
- parts = version_str.split('.')
- return (int(parts[0]), int(parts[1]))
- except Exception:
- return None
-
-# Detect environment
-on_runtime = is_databricks_runtime()
-db_version = get_databricks_connect_version()
+from databricks.connect import DatabricksSession
print("=" * 80)
-print("ENVIRONMENT DETECTION")
+print("CREATING SPARK SESSION")
print("=" * 80)
-print(f"Running on Databricks Runtime: {on_runtime}")
-if db_version:
- print(f"databricks-connect version: {db_version[0]}.{db_version[1]}")
-else:
- print("databricks-connect: not available")
-
-# Use DatabricksEnv with managed dependencies if:
-# - Running locally (not on Databricks Runtime)
-# - databricks-connect >= 16.4
-use_managed_deps = (not on_runtime) and db_version and db_version >= (16, 4)
-
-if use_managed_deps:
- print("Using DatabricksEnv with managed dependencies")
- print("=" * 80)
- from databricks.connect import DatabricksSession, DatabricksEnv
-
- env = DatabricksEnv().withDependencies("faker", "pandas", "numpy", "holidays")
-
- if USE_SERVERLESS:
- spark = DatabricksSession.builder.withEnvironment(env).serverless(True).getOrCreate()
- print("Connected to serverless compute with managed dependencies!")
- else:
- if not CLUSTER_ID:
- raise ValueError("CLUSTER_ID must be set when USE_SERVERLESS=False")
- spark = DatabricksSession.builder.withEnvironment(env).clusterId(CLUSTER_ID).getOrCreate()
- print(f"Connected to cluster with managed dependencies!")
-else:
- print("Using standard session (dependencies must be pre-installed)")
- print("=" * 80)
-
- # Check that UDF dependencies are available
- print("\nChecking UDF dependencies...")
- missing_deps = []
-
- try:
- from faker import Faker
- print(" faker: OK")
- except ImportError:
- missing_deps.append("faker")
- print(" faker: MISSING")
-
- try:
- import pandas as pd
- print(" pandas: OK")
- except ImportError:
- missing_deps.append("pandas")
- print(" pandas: MISSING")
-
- if missing_deps:
- print("\n" + "=" * 80)
- print("ERROR: Missing dependencies for UDFs")
- print("=" * 80)
- print(f"Missing: {', '.join(missing_deps)}")
- if on_runtime:
- print('\nSolution: Install libraries via Databricks CLI:')
- print(' databricks libraries install --json \'{"cluster_id": "", "libraries": [{"pypi": {"package": "faker"}}, {"pypi": {"package": "holidays"}}]}\'')
- else:
- print("\nSolution: Upgrade to databricks-connect >= 16.4 for managed deps")
- print(" Or create a job with environment settings")
- print("=" * 80)
- sys.exit(1)
- print("\nAll dependencies available")
- print("=" * 80)
-
- from databricks.connect import DatabricksSession
-
- if USE_SERVERLESS:
- spark = DatabricksSession.builder.serverless(True).getOrCreate()
- print("Connected to serverless compute")
- else:
- if not CLUSTER_ID:
- raise ValueError("CLUSTER_ID must be set when USE_SERVERLESS=False")
- spark = DatabricksSession.builder.clusterId(CLUSTER_ID).getOrCreate()
- print(f"Connected to cluster ")
+if USE_SERVERLESS:
+ spark = DatabricksSession.builder.serverless(True).getOrCreate()
+ print("Connected to serverless compute")
+else:
+ if not CLUSTER_ID:
+ raise ValueError("CLUSTER_ID must be set when USE_SERVERLESS=False")
+ spark = DatabricksSession.builder.clusterId(CLUSTER_ID).getOrCreate()
+ print(f"Connected to cluster {CLUSTER_ID}")
# Import Faker for UDF definitions
from faker import Faker
@@ -260,10 +178,6 @@ def generate_lognormal_amount(tiers: pd.Series) -> pd.Series:
customers_df.write.mode(WRITE_MODE).parquet(f"{VOLUME_PATH}/customers")
print(f" Saved customers to {VOLUME_PATH}/customers")
-# Show tier distribution
-print("\n Tier distribution:")
-customers_df.groupBy("tier").count().orderBy("tier").show()
-
# =============================================================================
# GENERATE ORDERS (Child Table with Referential Integrity)
# =============================================================================
@@ -366,10 +280,6 @@ def generate_lognormal_amount(tiers: pd.Series) -> pd.Series:
orders_final.write.mode(WRITE_MODE).parquet(f"{VOLUME_PATH}/orders")
print(f" Saved orders to {VOLUME_PATH}/orders")
-# Show status distribution
-print("\n Status distribution:")
-orders_final.groupBy("status").count().orderBy("status").show()
-
# =============================================================================
# CLEANUP AND SUMMARY
# =============================================================================
diff --git a/databricks-skills/databricks-unity-catalog/6-volumes.md b/databricks-skills/databricks-unity-catalog/6-volumes.md
index 497b6090..179baa67 100644
--- a/databricks-skills/databricks-unity-catalog/6-volumes.md
+++ b/databricks-skills/databricks-unity-catalog/6-volumes.md
@@ -37,18 +37,16 @@ All volume operations use the path format:
---
-## MCP Tools
-
-| Tool | Usage |
-|------|-------|
-| `list_volume_files` | `list_volume_files(volume_path="/Volumes/catalog/schema/volume/path/")` |
-| `get_volume_folder_details` | `get_volume_folder_details(volume_path="catalog/schema/volume/path", format="parquet")` - schema, row counts, stats |
-| `upload_to_volume` | `upload_to_volume(local_path="/tmp/data/*", volume_path="/Volumes/.../dest")` - supports files, folders, globs |
-| `download_from_volume` | `download_from_volume(volume_path="/Volumes/.../file.csv", local_path="/tmp/file.csv")` |
-| `create_volume_directory` | `create_volume_directory(volume_path="/Volumes/.../new_folder")` - creates parents like `mkdir -p` |
-| `delete_volume_file` | `delete_volume_file(volume_path="/Volumes/.../file.csv")` |
-| `delete_volume_directory` | `delete_volume_directory(volume_path="/Volumes/.../folder")` - directory must be empty |
-| `get_volume_file_info` | `get_volume_file_info(volume_path="/Volumes/.../file.csv")` - returns size, modified date |
+## CLI Commands
+
+| Command | Description |
+|---------|-------------|
+| `databricks fs ls /Volumes/catalog/schema/volume/path/` | List files in a volume |
+| `databricks fs cp /tmp/data/* /Volumes/.../dest --recursive` | Upload files/folders to volume |
+| `databricks fs cp /Volumes/.../file.csv /tmp/file.csv` | Download files from volume |
+| `databricks fs mkdirs /Volumes/.../new_folder` | Create directory (like `mkdir -p`) |
+| `databricks fs rm /Volumes/.../file.csv` | Delete file |
+| `databricks fs rm /Volumes/.../folder --recursive` | Delete directory recursively |
---
diff --git a/databricks-skills/databricks-unity-catalog/7-data-profiling.md b/databricks-skills/databricks-unity-catalog/7-data-profiling.md
index 23a2b62f..cf6c3ec1 100644
--- a/databricks-skills/databricks-unity-catalog/7-data-profiling.md
+++ b/databricks-skills/databricks-unity-catalog/7-data-profiling.md
@@ -36,55 +36,42 @@ Supported `AggregationGranularity` values: `AGGREGATION_GRANULARITY_5_MINUTES`,
---
-## MCP Tools
+## CLI & SQL Commands
-Use the `manage_uc_monitors` tool for all monitor operations:
+### Create a Monitor (SQL)
-| Action | Description |
-|--------|-------------|
-| `create` | Create a quality monitor on a table |
-| `get` | Get monitor details and status |
-| `run_refresh` | Trigger a metric refresh |
-| `list_refreshes` | List refresh history |
-| `delete` | Delete the monitor (assets are not deleted) |
-
-### Create a Monitor
+```sql
+CREATE OR REPLACE QUALITY MONITOR catalog.schema.my_table
+OPTIONS (
+ OUTPUT_SCHEMA 'catalog.schema'
+);
+```
-> **Note:** The MCP tool currently only creates **snapshot** monitors. For TimeSeries or InferenceLog monitors, use the Python SDK directly (see below).
+### Get Monitor Status (SQL)
-```python
-manage_uc_monitors(
- action="create",
- table_name="catalog.schema.my_table",
- output_schema_name="catalog.schema",
-)
+```sql
+DESCRIBE QUALITY MONITOR catalog.schema.my_table;
```
-### Get Monitor Status
+### Trigger a Refresh (SQL)
-```python
-manage_uc_monitors(
- action="get",
- table_name="catalog.schema.my_table",
-)
+```sql
+REFRESH QUALITY MONITOR catalog.schema.my_table;
```
-### Trigger a Refresh
+### Delete a Monitor (SQL)
-```python
-manage_uc_monitors(
- action="run_refresh",
- table_name="catalog.schema.my_table",
-)
+```sql
+DROP QUALITY MONITOR catalog.schema.my_table;
```
-### Delete a Monitor
+### Execute via CLI
-```python
-manage_uc_monitors(
- action="delete",
- table_name="catalog.schema.my_table",
-)
+```bash
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID "
+CREATE OR REPLACE QUALITY MONITOR catalog.schema.my_table
+OPTIONS (OUTPUT_SCHEMA 'catalog.schema')
+"
```
---
@@ -300,7 +287,7 @@ LIMIT 100;
---
> **Note:** Data profiling was formerly known as Lakehouse Monitoring. The legacy SDK accessor
-> `w.lakehouse_monitors` and the MCP tool `manage_uc_monitors` still use the previous API.
+> `w.lakehouse_monitors` still uses the previous API. Use `w.data_quality` for the new API.
## Resources
diff --git a/databricks-skills/databricks-unity-catalog/SKILL.md b/databricks-skills/databricks-unity-catalog/SKILL.md
index 2e3d05fa..bbc77a6f 100644
--- a/databricks-skills/databricks-unity-catalog/SKILL.md
+++ b/databricks-skills/databricks-unity-catalog/SKILL.md
@@ -29,15 +29,41 @@ Use this skill when:
## Quick Start
-### Volume File Operations (MCP Tools)
+### Create Unity Catalog Objects (CLI)
-| Tool | Usage |
-|------|-------|
-| `list_volume_files` | `list_volume_files(volume_path="/Volumes/catalog/schema/volume/path/")` |
-| `get_volume_folder_details` | `get_volume_folder_details(volume_path="catalog/schema/volume/path", format="parquet")` - schema, row counts, stats |
-| `upload_to_volume` | `upload_to_volume(local_path="/tmp/data/*", volume_path="/Volumes/.../dest")` |
-| `download_from_volume` | `download_from_volume(volume_path="/Volumes/.../file.csv", local_path="/tmp/file.csv")` |
-| `create_volume_directory` | `create_volume_directory(volume_path="/Volumes/.../new_folder")` |
+**IMPORTANT**: Use `--json` for creating UC objects. Positional args vary by command and version.
+
+```bash
+# Create a catalog
+databricks catalogs create --json '{"name": "my_catalog"}'
+
+# Create a schema
+databricks schemas create --json '{"name": "my_schema", "catalog_name": "my_catalog"}'
+
+# Create a volume
+databricks volumes create --json '{"name": "my_volume", "catalog_name": "my_catalog", "schema_name": "my_schema", "volume_type": "MANAGED"}'
+
+# List catalogs, schemas, volumes
+databricks catalogs list
+databricks schemas list my_catalog
+databricks volumes list my_catalog.my_schema
+```
+
+### Volume File Operations (CLI)
+
+```bash
+# List files in a volume
+databricks fs ls /Volumes/catalog/schema/volume/path/
+
+# Upload files to a volume
+databricks fs cp /tmp/data/* /Volumes/catalog/schema/volume/dest/ --recursive
+
+# Download files from a volume
+databricks fs cp /Volumes/catalog/schema/volume/file.csv /tmp/file.csv
+
+# Create a directory in a volume
+databricks fs mkdirs /Volumes/catalog/schema/volume/new_folder
+```
### Enable System Tables Access
@@ -71,20 +97,17 @@ WHERE usage_date >= current_date() - 30
GROUP BY workspace_id, sku_name;
```
-## MCP Tool Integration
+## SQL Queries via CLI
-Use `mcp__databricks__execute_sql` for system table queries:
+Use `databricks experimental aitools tools query` for system table queries:
-```python
-# Query lineage
-mcp__databricks__execute_sql(
- sql_query="""
- SELECT source_table_full_name, target_table_full_name
- FROM system.access.table_lineage
- WHERE event_date >= current_date() - 7
- """,
- catalog="system"
-)
+```bash
+# Query lineage via CLI
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID "
+ SELECT source_table_full_name, target_table_full_name
+ FROM system.access.table_lineage
+ WHERE event_date >= current_date() - 7
+"
```
## Best Practices
diff --git a/databricks-skills/databricks-unstructured-pdf-generation/SKILL.md b/databricks-skills/databricks-unstructured-pdf-generation/SKILL.md
index 92322fd0..5b10479d 100644
--- a/databricks-skills/databricks-unstructured-pdf-generation/SKILL.md
+++ b/databricks-skills/databricks-unstructured-pdf-generation/SKILL.md
@@ -7,331 +7,108 @@ description: "Generate PDF documents from HTML and upload to Unity Catalog volum
Convert HTML content to PDF documents and upload them to Unity Catalog Volumes.
-## Overview
+## Workflow
-The `generate_and_upload_pdf` MCP tool converts HTML to PDF and uploads to a Unity Catalog Volume. You (the LLM) generate the HTML content, and the tool handles conversion and upload.
+1. Write HTML files to `./raw_data/html/` (write multiple files in parallel for speed)
+2. Convert HTML → PDF using `scripts/pdf_generator.py` (parallel conversion)
+3. Upload PDFs to Unity Catalog volume using `databricks fs cp`
+4. Generate `doc_questions.json` with test questions for each document
-## Tool Signature
+## Dependencies
-```
-generate_and_upload_pdf(
- html_content: str, # Complete HTML document
- filename: str, # PDF filename (e.g., "report.pdf")
- catalog: str, # Unity Catalog name
- schema: str, # Schema name
- volume: str = "raw_data", # Volume name (default: "raw_data")
- folder: str = None, # Optional subfolder
-)
-```
-
-**Returns:**
-```json
-{
- "success": true,
- "volume_path": "/Volumes/catalog/schema/volume/filename.pdf",
- "error": null
-}
+```bash
+uv pip install plutoprint
```
-## Quick Start
+## Step 1: Write HTML Files
-Generate a simple PDF:
-
-```
-generate_and_upload_pdf(
- html_content='''
-
-
-
-
-
- Quarterly Report Q1 2024
-
-
Executive Summary
-
Revenue increased 15% year-over-year...
-
-
-''',
- filename="q1_report.pdf",
- catalog="my_catalog",
- schema="my_schema"
-)
+```bash
+mkdir -p ./raw_data/html
```
-## Performance: Generate Multiple PDFs in Parallel
+Write HTML documents to `./raw_data/html/filename.html`. Use subdirectories to organize (structure is preserved).
-**IMPORTANT**: PDF generation and upload can take 2-5 seconds per document. When generating multiple PDFs, **call the tool in parallel** to maximize throughput.
-
-### Example: Generate 5 PDFs in Parallel
-
-Make 5 simultaneous `generate_and_upload_pdf` calls:
+## Step 2: Convert to PDF
+```bash
+# Convert entire folder (parallel, 4 workers)
+python scripts/pdf_generator.py convert --input ./raw_data/html --output ./raw_data/pdf
```
-# Call 1
-generate_and_upload_pdf(
- html_content="...Employee Handbook content...",
- filename="employee_handbook.pdf",
- catalog="hr_catalog", schema="policies", folder="2024"
-)
-# Call 2 (parallel)
-generate_and_upload_pdf(
- html_content="...Leave Policy content...",
- filename="leave_policy.pdf",
- catalog="hr_catalog", schema="policies", folder="2024"
-)
+Skips files where PDF exists and is newer than HTML. Use `--force` to reconvert all.
-# Call 3 (parallel)
-generate_and_upload_pdf(
- html_content="...Code of Conduct content...",
- filename="code_of_conduct.pdf",
- catalog="hr_catalog", schema="policies", folder="2024"
-)
+## Step 3: Upload to Volume
-# Call 4 (parallel)
-generate_and_upload_pdf(
- html_content="...Benefits Guide content...",
- filename="benefits_guide.pdf",
- catalog="hr_catalog", schema="policies", folder="2024"
-)
-
-# Call 5 (parallel)
-generate_and_upload_pdf(
- html_content="...Remote Work Policy content...",
- filename="remote_work_policy.pdf",
- catalog="hr_catalog", schema="policies", folder="2024"
-)
+```bash
+databricks fs cp -r ./raw_data/pdf /Volumes/my_catalog/my_schema/raw_data/
```
-By calling these in parallel (not sequentially), 5 PDFs that would take 15-25 seconds sequentially complete in 3-5 seconds total.
-
-## HTML Best Practices
+## Step 4: Generate Test Questions
-### Use Complete HTML5 Structure
+Create `./raw_data/pdf/pdf_eval_questions.json` with questions for Knowledge Assistant evaluation or MAS:
-Always include the full HTML structure:
-
-```html
-
-
-
-
-
-
-
-
-
+```json
+{
+ "api_errors_guide.pdf": {
+ "question": "What is the solution for error ERR-4521?",
+ "expected_fact": "Call /api/v2/auth/refresh with refresh_token before the 3600s TTL expires"
+ },
+ "installation_manual.pdf": {
+ "question": "What port does the service use by default?",
+ "expected_fact": "Port 8443 for HTTPS, configurable via CONFIG_PORT environment variable"
+ }
+}
```
-### CSS Features Supported
+This JSON can be used to build KA test cases and validate retrieval accuracy.
-PlutoPrint supports modern CSS3:
-- Flexbox and Grid layouts
-- CSS variables (`--var-name`)
-- Web fonts (system fonts recommended)
-- Colors, backgrounds, borders
-- Tables with styling
+## Document Content Guidelines
-### CSS to Avoid
+When generating documents for Knowledge Assistant testing or demos:
-- Animations and transitions (static PDF)
-- Interactive elements (forms, hover effects)
-- External resources (images via URL) - use embedded base64 if needed
+- **Multi-page documents**: Each PDF should be several pages with substantial content
+- **Specific error codes and solutions**: Include product-specific error codes, causes, and resolution steps
+- **Technical details**: API endpoints, configuration parameters, version numbers, specific commands
+- **Simple CSS**: Keep styling minimal for fast HTML creation and reliable PDF conversion
+- **Queryable facts**: Include details a KA must read the document to answer (not general knowledge)
-### Professional Document Template
+**Good document types:**
+- Product user manuals with troubleshooting sections
+- API error reference guides (error codes, causes, solutions)
+- Installation/configuration guides with specific steps
+- Technical specifications with version-specific details
-```html
-
-
-
-
-
-
- Document Title
+**Example content:** Instead of generic "Connection failed" errors, write:
+- "Error ERR-4521: OAuth token expired. Cause: Token TTL exceeded 3600s default. Solution: Call `/api/v2/auth/refresh` with your refresh_token before expiration. See Section 4.2 for token lifecycle management."
- Section 1
- Content here...
+## CLI Reference
-
- Important: Key information highlighted here.
-
-
- Data Table
-
- Column 1 Column 2 Column 3
- Data Data Data
-
-
-
-
-
```
+python scripts/pdf_generator.py convert [OPTIONS]
-## Common Patterns
-
-### Pattern 1: Technical Documentation
-
-Generate API documentation, user guides, or technical specs:
-
-```
-generate_and_upload_pdf(
- html_content='''
-
-
-
- API Reference
-
-
GET /api/v1/users
-
Returns a list of all users.
-
- Request Headers
- Authorization: Bearer {token}
-Content-Type: application/json
-
-''',
- filename="api_reference.pdf",
- catalog="docs_catalog",
- schema="api_docs"
-)
+ --input, -i Input HTML file or folder (required)
+ --output, -o Output folder for PDFs (required)
+ --force, -f Force reconvert (ignore timestamps)
+ --workers, -w Parallel workers (default: 4)
```
-### Pattern 2: Business Reports
-
-```
-generate_and_upload_pdf(
- html_content='''
-
-
-
- Q1 2024 Performance Report
-
-
-
-''',
- filename="q1_2024_report.pdf",
- catalog="finance",
- schema="reports",
- folder="quarterly"
-)
-```
+## Folder Structure
-### Pattern 3: HR Policies
+Subfolder structure is preserved:
```
-generate_and_upload_pdf(
- html_content='''
-
-
-
- Employee Leave Policy
- Effective: January 1, 2024
-
-
-
1. Annual Leave
-
All full-time employees are entitled to 20 days of paid annual leave per calendar year.
-
-
-
- Note: Leave requests must be submitted at least 2 weeks in advance.
-
-
-''',
- filename="leave_policy.pdf",
- catalog="hr_catalog",
- schema="policies"
-)
+./raw_data/html/ ./raw_data/pdf/
+├── report.html → ├── report.pdf
+├── quarterly/ ├── quarterly/
+│ └── q1.html → │ └── q1.pdf
+└── legal/ └── legal/
+ └── terms.html → └── terms.pdf
```
-## Workflow for Multiple Documents
-
-When asked to generate multiple PDFs:
-
-1. **Plan the documents**: Determine titles, content structure for each
-2. **Generate HTML for each**: Create complete HTML documents
-3. **Call tool in parallel**: Make multiple simultaneous `generate_and_upload_pdf` calls
-4. **Report results**: Summarize successful uploads and any errors
-
-## Prerequisites
-
-- Unity Catalog schema must exist
-- Volume must exist (default: `raw_data`)
-- User must have WRITE permission on the volume
-
## Troubleshooting
| Issue | Solution |
|-------|----------|
-| "Volume does not exist" | Create the volume first or use an existing one |
-| "Schema does not exist" | Create the schema or check the name |
-| PDF looks wrong | Check HTML/CSS syntax, use supported CSS features |
-| Slow generation | Call multiple PDFs in parallel, not sequentially |
+| "plutoprint not installed" | `uv pip install plutoprint` |
+| PDF looks wrong | Check HTML/CSS syntax |
+| "Volume does not exist" | `databricks volumes create catalog.schema.volume MANAGED` |
diff --git a/databricks-skills/databricks-unstructured-pdf-generation/scripts/pdf_generator.py b/databricks-skills/databricks-unstructured-pdf-generation/scripts/pdf_generator.py
new file mode 100644
index 00000000..e7808d13
--- /dev/null
+++ b/databricks-skills/databricks-unstructured-pdf-generation/scripts/pdf_generator.py
@@ -0,0 +1,276 @@
+#!/usr/bin/env python3
+"""
+PDF Generator - Convert HTML files to PDF locally.
+
+Usage:
+ # Convert single file
+ python pdf_generator.py convert --input ./raw_data/html/report.html --output ./raw_data/pdf
+
+ # Convert entire folder (parallel)
+ python pdf_generator.py convert --input ./raw_data/html --output ./raw_data/pdf
+
+ # Force reconvert (ignore timestamps)
+ python pdf_generator.py convert --input ./raw_data/html --output ./raw_data/pdf --force
+
+Requires: plutoprint
+ uv / pip install plutoprint
+"""
+
+import argparse
+import logging
+import sys
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Optional
+
+logging.basicConfig(level=logging.INFO, format="%(message)s")
+logger = logging.getLogger(__name__)
+
+MAX_WORKERS = 4
+
+
+@dataclass
+class ConversionResult:
+ """Result from converting HTML to PDF."""
+ html_path: str
+ pdf_path: Optional[str] = None
+ success: bool = False
+ skipped: bool = False
+ error: Optional[str] = None
+
+ def to_dict(self) -> dict:
+ return {
+ "html_path": self.html_path,
+ "pdf_path": self.pdf_path,
+ "success": self.success,
+ "skipped": self.skipped,
+ "error": self.error,
+ }
+
+
+@dataclass
+class BatchResult:
+ """Result from batch conversion."""
+ total: int = 0
+ converted: int = 0
+ skipped: int = 0
+ failed: int = 0
+ results: list = field(default_factory=list)
+
+ def to_dict(self) -> dict:
+ return {
+ "total": self.total,
+ "converted": self.converted,
+ "skipped": self.skipped,
+ "failed": self.failed,
+ "results": [r.to_dict() for r in self.results],
+ }
+
+
+def _needs_conversion(html_path: Path, pdf_path: Path) -> bool:
+ """Check if HTML needs to be converted (PDF missing or older than HTML).
+
+ Args:
+ html_path: Path to HTML file
+ pdf_path: Path to output PDF file
+
+ Returns:
+ True if conversion needed, False if PDF is up-to-date
+ """
+ if not pdf_path.exists():
+ return True
+
+ html_mtime = html_path.stat().st_mtime
+ pdf_mtime = pdf_path.stat().st_mtime
+
+ return html_mtime > pdf_mtime
+
+
+def convert_html_to_pdf(
+ html_path: Path,
+ pdf_path: Path,
+ force: bool = False,
+) -> ConversionResult:
+ """Convert a single HTML file to PDF.
+
+ Args:
+ html_path: Path to HTML file
+ pdf_path: Path to output PDF file
+ force: If True, convert even if PDF is up-to-date
+
+ Returns:
+ ConversionResult with success/skip/error status
+ """
+ result = ConversionResult(html_path=str(html_path))
+
+ # Check if conversion needed
+ if not force and not _needs_conversion(html_path, pdf_path):
+ result.skipped = True
+ result.success = True
+ result.pdf_path = str(pdf_path)
+ logger.debug(f"Skipped (up-to-date): {html_path.name}")
+ return result
+
+ # Ensure output directory exists
+ pdf_path.parent.mkdir(parents=True, exist_ok=True)
+
+ try:
+ import plutoprint
+
+ # Read HTML content
+ html_content = html_path.read_text(encoding="utf-8")
+
+ # Convert to PDF
+ book = plutoprint.Book(plutoprint.PAGE_SIZE_A4)
+ book.load_html(html_content)
+ book.write_to_pdf(str(pdf_path))
+
+ if pdf_path.exists():
+ result.success = True
+ result.pdf_path = str(pdf_path)
+ logger.info(f"Converted: {html_path.name} -> {pdf_path.name}")
+ else:
+ result.error = "PDF file not created"
+ logger.error(f"Failed: {html_path.name} - PDF not created")
+
+ except ImportError:
+ result.error = "plutoprint not installed. Run: pip install plutoprint"
+ logger.error(result.error)
+ except Exception as e:
+ result.error = str(e)
+ logger.error(f"Failed: {html_path.name} - {e}")
+
+ return result
+
+
+def convert_folder(
+ input_dir: Path,
+ output_dir: Path,
+ force: bool = False,
+ max_workers: int = MAX_WORKERS,
+) -> BatchResult:
+ """Convert all HTML files in a folder to PDF (parallel).
+
+ Preserves subfolder structure from input to output.
+
+ Args:
+ input_dir: Directory containing HTML files
+ output_dir: Directory for output PDF files
+ force: If True, convert even if PDFs are up-to-date
+ max_workers: Number of parallel workers (default: 4)
+
+ Returns:
+ BatchResult with counts and per-file results
+ """
+ batch = BatchResult()
+
+ # Find all HTML files
+ html_files = list(input_dir.rglob("*.html"))
+ batch.total = len(html_files)
+
+ if batch.total == 0:
+ logger.warning(f"No HTML files found in {input_dir}")
+ return batch
+
+ logger.info(f"Found {batch.total} HTML file(s) in {input_dir}")
+
+ def process_file(html_path: Path) -> ConversionResult:
+ # Compute relative path to preserve folder structure
+ relative_path = html_path.relative_to(input_dir)
+ pdf_relative = relative_path.with_suffix(".pdf")
+ pdf_path = output_dir / pdf_relative
+
+ return convert_html_to_pdf(html_path, pdf_path, force=force)
+
+ # Process files in parallel
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = {executor.submit(process_file, f): f for f in html_files}
+
+ for future in as_completed(futures):
+ result = future.result()
+ batch.results.append(result)
+
+ if result.skipped:
+ batch.skipped += 1
+ elif result.success:
+ batch.converted += 1
+ else:
+ batch.failed += 1
+
+ logger.info(f"Done: {batch.converted} converted, {batch.skipped} skipped, {batch.failed} failed")
+ return batch
+
+
+def main():
+ """CLI entry point."""
+ parser = argparse.ArgumentParser(
+ description="Convert HTML files to PDF",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Convert single file
+ python pdf_generator.py convert --input ./raw_data/html/report.html --output ./raw_data/pdf
+
+ # Convert entire folder (parallel)
+ python pdf_generator.py convert --input ./raw_data/html --output ./raw_data/pdf
+
+ # Force reconvert all
+ python pdf_generator.py convert --input ./raw_data/html --output ./raw_data/pdf --force
+ """,
+ )
+
+ subparsers = parser.add_subparsers(dest="command", help="Commands")
+
+ # Convert command
+ conv_parser = subparsers.add_parser("convert", help="Convert HTML to PDF")
+ conv_parser.add_argument("--input", "-i", required=True, help="Input HTML file or folder")
+ conv_parser.add_argument("--output", "-o", required=True, help="Output folder for PDFs")
+ conv_parser.add_argument("--force", "-f", action="store_true", help="Force reconvert (ignore timestamps)")
+ conv_parser.add_argument("--workers", "-w", type=int, default=MAX_WORKERS, help=f"Parallel workers (default: {MAX_WORKERS})")
+
+ args = parser.parse_args()
+
+ if args.command == "convert":
+ input_path = Path(args.input)
+ output_path = Path(args.output)
+
+ if not input_path.exists():
+ print(f"Error: Input path does not exist: {input_path}")
+ sys.exit(1)
+
+ if input_path.is_file():
+ # Single file conversion
+ if not input_path.suffix.lower() == ".html":
+ print(f"Error: Input file must be .html: {input_path}")
+ sys.exit(1)
+
+ pdf_path = output_path / input_path.with_suffix(".pdf").name
+ result = convert_html_to_pdf(input_path, pdf_path, force=args.force)
+
+ if result.skipped:
+ print(f"Skipped (up-to-date): {result.pdf_path}")
+ elif result.success:
+ print(f"Converted: {result.pdf_path}")
+ else:
+ print(f"Error: {result.error}")
+ sys.exit(1)
+ else:
+ # Folder conversion
+ batch = convert_folder(
+ input_path,
+ output_path,
+ force=args.force,
+ max_workers=args.workers,
+ )
+
+ print(f"\nSummary: {batch.converted} converted, {batch.skipped} skipped, {batch.failed} failed")
+ if batch.failed > 0:
+ sys.exit(1)
+ else:
+ parser.print_help()
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/databricks-skills/databricks-vector-search/SKILL.md b/databricks-skills/databricks-vector-search/SKILL.md
index 72068ec5..18cc7679 100644
--- a/databricks-skills/databricks-vector-search/SKILL.md
+++ b/databricks-skills/databricks-vector-search/SKILL.md
@@ -302,7 +302,7 @@ databricks vector-search indexes delete-index \
| **Embedding dimension mismatch** | Ensure query and index dimensions match |
| **Index not updating** | Check pipeline_type; use sync_index() for TRIGGERED |
| **Out of capacity** | Upgrade to Storage-Optimized (1B+ vectors) |
-| **`query_vector` truncated by MCP tool** | MCP tool calls serialize arrays as JSON and can truncate large vectors (e.g. 1024-dim). Use `query_text` instead (for managed embedding indexes), or use the Databricks SDK/CLI to pass raw vectors |
+| **`query_vector` truncated** | Large vectors (e.g. 1024-dim) can be truncated when serialized as JSON. Use `query_text` instead (for managed embedding indexes), or use the Databricks SDK to pass raw vectors |
## Embedding Models
@@ -323,112 +323,6 @@ embedding_source_columns=[
]
```
-## MCP Tools
-
-The following MCP tools are available for managing Vector Search infrastructure. For a full end-to-end walkthrough, see [end-to-end-rag.md](end-to-end-rag.md).
-
-### manage_vs_endpoint - Endpoint Management
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create_or_update` | Create endpoint (STANDARD or STORAGE_OPTIMIZED). Idempotent | name |
-| `get` | Get endpoint details | name |
-| `list` | List all endpoints | (none) |
-| `delete` | Delete endpoint (indexes must be deleted first) | name |
-
-```python
-# Create or update an endpoint
-result = manage_vs_endpoint(action="create_or_update", name="my-vs-endpoint", endpoint_type="STANDARD")
-# Returns {"name": "my-vs-endpoint", "endpoint_type": "STANDARD", "created": True}
-
-# List all endpoints
-endpoints = manage_vs_endpoint(action="list")
-
-# Get specific endpoint
-endpoint = manage_vs_endpoint(action="get", name="my-vs-endpoint")
-```
-
-### manage_vs_index - Index Management
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `create_or_update` | Create index. Idempotent, auto-triggers sync for DELTA_SYNC | name, endpoint_name, primary_key |
-| `get` | Get index details | name |
-| `list` | List indexes. Optional endpoint_name filter | (none) |
-| `delete` | Delete index | name |
-
-```python
-# Create a Delta Sync index with managed embeddings
-result = manage_vs_index(
- action="create_or_update",
- name="catalog.schema.my_index",
- endpoint_name="my-vs-endpoint",
- primary_key="id",
- index_type="DELTA_SYNC",
- delta_sync_index_spec={
- "source_table": "catalog.schema.docs",
- "embedding_source_columns": [{"name": "content", "embedding_model_endpoint_name": "databricks-gte-large-en"}],
- "pipeline_type": "TRIGGERED"
- }
-)
-
-# Get a specific index
-index = manage_vs_index(action="get", name="catalog.schema.my_index")
-
-# List all indexes on an endpoint
-indexes = manage_vs_index(action="list", endpoint_name="my-vs-endpoint")
-
-# List all indexes across all endpoints
-all_indexes = manage_vs_index(action="list")
-```
-
-### query_vs_index - Query (Hot Path)
-
-Query index with `query_text`, `query_vector`, or hybrid (`query_type="HYBRID"`). Prefer `query_text` over `query_vector` — MCP tool calls can truncate large embedding arrays (1024-dim).
-
-```python
-# Query an index
-results = query_vs_index(
- index_name="catalog.schema.my_index",
- columns=["id", "content"],
- query_text="machine learning best practices",
- num_results=5
-)
-
-# Hybrid search (combines vector + keyword)
-results = query_vs_index(
- index_name="catalog.schema.my_index",
- columns=["id", "content"],
- query_text="SPARK-12345 memory error",
- query_type="HYBRID",
- num_results=10
-)
-```
-
-### manage_vs_data - Data Operations
-
-| Action | Description | Required Params |
-|--------|-------------|-----------------|
-| `upsert` | Insert/update records | index_name, inputs_json |
-| `delete` | Delete by primary key | index_name, primary_keys |
-| `scan` | Scan index contents | index_name |
-| `sync` | Trigger sync for TRIGGERED indexes | index_name |
-
-```python
-# Upsert data into a Direct Access index
-manage_vs_data(
- action="upsert",
- index_name="catalog.schema.my_index",
- inputs_json=[{"id": "doc1", "content": "...", "embedding": [0.1, 0.2, ...]}]
-)
-
-# Trigger manual sync for a TRIGGERED pipeline index
-manage_vs_data(action="sync", index_name="catalog.schema.my_index")
-
-# Scan index contents
-manage_vs_data(action="scan", index_name="catalog.schema.my_index", num_results=100)
-```
-
## Notes
- **Storage-Optimized is newer** — better for most use cases unless you need <100ms latency
@@ -436,7 +330,7 @@ manage_vs_data(action="scan", index_name="catalog.schema.my_index", num_results=
- **Hybrid search** — available for both Delta Sync and Direct Access indexes
- **`columns_to_sync` matters** — only synced columns are available in query results; include all columns you need
- **Filter syntax differs by endpoint** — Standard uses dict-format filters, Storage-Optimized uses SQL-like string filters. Use the `databricks-vectorsearch` package's `filters` parameter which accepts both formats
-- **Management vs runtime** — MCP tools above handle lifecycle management; for agent tool-calling at runtime, use `VectorSearchRetrieverTool` or the Databricks managed Vector Search MCP server
+- **Management vs runtime** — CLI and SDK handle lifecycle management; for agent tool-calling at runtime, use `VectorSearchRetrieverTool`
## Related Skills
diff --git a/databricks-skills/databricks-vector-search/end-to-end-rag.md b/databricks-skills/databricks-vector-search/end-to-end-rag.md
index a3808d1b..60691a2a 100644
--- a/databricks-skills/databricks-vector-search/end-to-end-rag.md
+++ b/databricks-skills/databricks-vector-search/end-to-end-rag.md
@@ -2,16 +2,16 @@
Build a complete Retrieval-Augmented Generation pipeline: prepare documents, create a vector index, query it, and wire it into an agent.
-## MCP Tools Used
+## CLI Commands Used
-| Tool | Step |
-|------|------|
-| `execute_sql` | Create source table, insert documents |
-| `manage_vs_endpoint(action="create")` | Create compute endpoint |
-| `manage_vs_index(action="create")` | Create Delta Sync index with managed embeddings |
-| `manage_vs_index(action="sync")` | Trigger index sync |
-| `manage_vs_index(action="get")` | Check index status |
-| `query_vs_index` | Test similarity search |
+| Command | Step |
+|---------|------|
+| `databricks experimental aitools tools query` | Create source table, insert documents |
+| `databricks vector-search endpoints create` | Create compute endpoint |
+| `databricks vector-search indexes create-index` | Create Delta Sync index with managed embeddings |
+| `databricks vector-search indexes sync-index` | Trigger index sync |
+| `databricks vector-search indexes get-index` | Check index status |
+| `databricks vector-search indexes query-index` | Test similarity search |
---
@@ -34,10 +34,10 @@ INSERT INTO catalog.schema.knowledge_base VALUES
('doc-003', 'Delta Lake', 'Delta Lake is an open-source storage layer...', 'storage', current_timestamp());
```
-Or via MCP:
+Or via CLI:
-```python
-execute_sql(sql_query="""
+```bash
+databricks experimental aitools tools query --warehouse WAREHOUSE_ID "
CREATE TABLE IF NOT EXISTS catalog.schema.knowledge_base (
doc_id STRING,
title STRING,
@@ -45,7 +45,7 @@ execute_sql(sql_query="""
category STRING,
updated_at TIMESTAMP DEFAULT current_timestamp()
)
-""")
+"
```
## Step 2: Create Vector Search Endpoint
diff --git a/databricks-skills/databricks-vector-search/troubleshooting-and-operations.md b/databricks-skills/databricks-vector-search/troubleshooting-and-operations.md
index 7dc4b8c9..23385adc 100644
--- a/databricks-skills/databricks-vector-search/troubleshooting-and-operations.md
+++ b/databricks-skills/databricks-vector-search/troubleshooting-and-operations.md
@@ -4,7 +4,7 @@ Operational guidance for monitoring, cost optimization, capacity planning, and m
## Monitoring Endpoint Status
-Use `manage_vs_endpoint(action="get")` (MCP tool) or `w.vector_search_endpoints.get_endpoint()` (SDK) to check endpoint health.
+Use `databricks vector-search endpoints get` (CLI) or `w.vector_search_endpoints.get_endpoint()` (SDK) to check endpoint health.
### Endpoint fields
@@ -34,7 +34,7 @@ print(f"Indexes: {endpoint.num_indexes}")
## Monitoring Index Status
-Use `manage_vs_index(action="get")` (MCP tool) or `w.vector_search_indexes.get_index()` (SDK) to check index health.
+Use `databricks vector-search indexes get-index` (CLI) or `w.vector_search_indexes.get_index()` (SDK) to check index health.
### Index fields
diff --git a/databricks-skills/databricks-zerobus-ingest/SKILL.md b/databricks-skills/databricks-zerobus-ingest/SKILL.md
index 22f90c55..668a4be2 100644
--- a/databricks-skills/databricks-zerobus-ingest/SKILL.md
+++ b/databricks-skills/databricks-zerobus-ingest/SKILL.md
@@ -120,54 +120,54 @@ You must always follow all the steps in the Workflow
## Workflow
0. **Display the plan of your execution**
-1. **Determinate the type of client**
-2. **Get schema** Always use 4-protobuf-schema.md. Execute using the `execute_code` MCP tool
-3. **Write Python code to a local file follow the instructions in the relevant guide to ingest with zerobus** in the project (e.g., `scripts/zerobus_ingest.py`).
-4. **Execute on Databricks** using the `execute_code` MCP tool (with `file_path` parameter)
-5. **If execution fails**: Edit the local file to fix the error, then re-execute
-6. **Reuse the context** for follow-up executions by passing the returned `cluster_id` and `context_id`
+1. **Determine the type of client**
+2. **Get schema** Always use 4-protobuf-schema.md
+3. **Write Python code to a local file** following the instructions in the relevant guide (e.g., `scripts/zerobus_ingest.py`)
+4. **Upload to workspace**: `databricks workspace import-dir ./scripts /Workspace/Users//scripts`
+5. **Execute on Databricks** using a job or notebook
+6. **If execution fails**: Edit the local file, re-upload, and re-execute
---
## Important
- Never install local packages
-- Always validate MCP server requirement before execution
- **Serverless limitation**: The Zerobus SDK cannot pip-install on serverless compute. Use classic compute clusters, or use the [Zerobus REST API](https://docs.databricks.com/aws/en/ingestion/zerobus-rest-api) (Beta) for notebook-based ingestion without the SDK.
- **Explicit table grants**: Service principals need explicit `MODIFY` and `SELECT` grants on the target table. Schema-level inherited permissions may not be sufficient for the `authorization_details` OAuth flow.
---
-### Context Reuse Pattern
+### Execution Workflow
-The first execution auto-selects a running cluster and creates an execution context. **Reuse this context for follow-up calls** - it's much faster (~1s vs ~15s) and shares variables/imports:
-
-**First execution** - use `execute_code` tool:
-- `file_path`: "scripts/zerobus_ingest.py"
-
-Returns: `{ success, output, error, cluster_id, context_id, ... }`
+**Step 1: Upload code to workspace**
+```bash
+databricks workspace import-dir ./scripts /Workspace/Users//scripts
+```
-Save `cluster_id` and `context_id` for follow-up calls.
+**Step 2: Create and run a job**
+```bash
+databricks jobs create --json '{
+ "name": "zerobus-ingest",
+ "tasks": [{
+ "task_key": "ingest",
+ "spark_python_task": {
+ "python_file": "/Workspace/Users//scripts/zerobus_ingest.py"
+ },
+ "new_cluster": {
+ "spark_version": "16.1.x-scala2.12",
+ "node_type_id": "i3.xlarge",
+ "num_workers": 0
+ }
+ }]
+}'
+
+databricks jobs run-now --job-id
+```
**If execution fails:**
-1. Read the error from the result
+1. Read the error from the job run output
2. Edit the local Python file to fix the issue
-3. Re-execute with same context using `execute_code` tool:
- - `file_path`: "scripts/zerobus_ingest.py"
- - `cluster_id`: ""
- - `context_id`: ""
-
-**Follow-up executions** reuse the context (faster, shares state):
-- `file_path`: "scripts/validate_ingestion.py"
-- `cluster_id`: ""
-- `context_id`: ""
-
-### Handling Failures
-
-When execution fails:
-1. Read the error from the result
-2. **Edit the local Python file** to fix the issue
-3. Re-execute using the same `cluster_id` and `context_id` (faster, keeps installed libraries)
-4. If the context is corrupted, omit `context_id` to create a fresh one
+3. Re-upload: `databricks workspace import-dir ./scripts /Workspace/Users//scripts`
+4. Re-run: `databricks jobs run-now --job-id `
---
@@ -175,14 +175,14 @@ When execution fails:
Databricks provides Spark, pandas, numpy, and common data libraries by default. **Only install a library if you get an import error.**
-Use `execute_code` tool:
-- `code`: "%pip install databricks-zerobus-ingest-sdk>=1.0.0"
-- `cluster_id`: ""
-- `context_id`: ""
-
-The library is immediately available in the same context.
+Add to the job configuration:
+```json
+"libraries": [
+ {"pypi": {"package": "databricks-zerobus-ingest-sdk>=1.0.0"}}
+]
+```
-**Note:** Keeping the same `context_id` means installed libraries persist across calls.
+Or use init scripts in the cluster configuration.
## 🚨 Critical Learning: Timestamp Format Fix
diff --git a/databricks-skills/install_skills.sh b/databricks-skills/install_skills.sh
index 7630615c..3613b00c 100755
--- a/databricks-skills/install_skills.sh
+++ b/databricks-skills/install_skills.sh
@@ -3,7 +3,7 @@
# Databricks Skills Installer
#
# Installs Databricks skills for Claude Code into your project.
-# These skills teach Claude how to work with Databricks using MCP tools.
+# These skills teach Claude how to work with Databricks using CLI, SDK, and REST APIs.
#
# Usage:
# # Install all skills (Databricks + MLflow + APX)
@@ -119,7 +119,7 @@ get_skill_extra_files() {
"databricks-bundles") echo "alerts_guidance.md SDP_guidance.md" ;;
"databricks-iceberg") echo "1-managed-iceberg-tables.md 2-uniform-and-compatibility.md 3-iceberg-rest-catalog.md 4-snowflake-interop.md 5-external-engine-interop.md" ;;
"databricks-app-apx") echo "backend-patterns.md best-practices.md frontend-patterns.md" ;;
- "databricks-app-python") echo "1-authorization.md 2-app-resources.md 3-frameworks.md 4-deployment.md 5-lakebase.md 6-mcp-approach.md examples/llm_config.py examples/fm-minimal-chat.py examples/fm-parallel-calls.py examples/fm-structured-outputs.py" ;;
+ "databricks-app-python") echo "1-authorization.md 2-app-resources.md 3-frameworks.md 4-deployment.md 5-lakebase.md 6-cli-approach.md examples/llm_config.py examples/fm-minimal-chat.py examples/fm-parallel-calls.py examples/fm-structured-outputs.py" ;;
"databricks-jobs") echo "task-types.md triggers-schedules.md notifications-monitoring.md examples.md" ;;
"databricks-python-sdk") echo "doc-index.md examples/1-authentication.py examples/2-clusters-and-jobs.py examples/3-sql-and-warehouses.py examples/4-unity-catalog.py examples/5-serving-and-vector-search.py" ;;
"databricks-unity-catalog") echo "5-system-tables.md" ;;
diff --git a/databricks-tools-core/.env.test.template b/databricks-tools-core/.env.test.template
deleted file mode 100644
index 4daac3d9..00000000
--- a/databricks-tools-core/.env.test.template
+++ /dev/null
@@ -1,22 +0,0 @@
-# Test environment configuration
-# Copy to .env.test and fill in your values
-
-# Databricks workspace configuration
-DATABRICKS_HOST=https://your-workspace.cloud.databricks.com
-DATABRICKS_TOKEN=your-databricks-token
-
-# LLM Provider Configuration
-LLM_PROVIDER=DATABRICKS
-
-# Databricks Foundation Models
-# Available endpoints: databricks-gpt-5-mini, databricks-gpt-5-2, databricks-claude-sonnet-4-5
-DATABRICKS_MODEL=databricks-gpt-5-mini
-DATABRICKS_MODEL_MINI=databricks-gpt-5-mini
-
-# Azure OpenAI (uncomment if using Azure)
-# LLM_PROVIDER=AZURE
-# AZURE_OPENAI_API_KEY=your-api-key
-# AZURE_OPENAI_ENDPOINT=https://your-resource.cognitiveservices.azure.com/
-# AZURE_OPENAI_API_VERSION=2024-08-01-preview
-# AZURE_OPENAI_DEPLOYMENT=gpt-4o
-# AZURE_OPENAI_DEPLOYMENT_MINI=gpt-4o-mini
diff --git a/databricks-tools-core/README.md b/databricks-tools-core/README.md
deleted file mode 100644
index 9d975487..00000000
--- a/databricks-tools-core/README.md
+++ /dev/null
@@ -1,445 +0,0 @@
-# Databricks Tools Core
-
-High-level, AI-assistant-friendly Python functions for building Databricks projects.
-
-## Overview
-
-The `databricks-tools-core` package provides reusable, opinionated functions for interacting with the Databricks platform. It is designed to be used by AI coding assistants (Claude Code, Cursor, etc.) and developers who want simple, high-level APIs for common Databricks operations.
-
-### Modules
-
-| Module | Description |
-|--------|-------------|
-| **sql/** | SQL execution, warehouse management, and table statistics |
-| **jobs/** | Job management and run operations (serverless by default) |
-| **unity_catalog/** | Unity Catalog operations (catalogs, schemas, tables) |
-| **compute/** | Compute and execution context operations |
-| **spark_declarative_pipelines/** | Spark Declarative Pipeline management |
-| **synthetic_data_generation/** | Test data generation utilities |
-
-## Installation
-
-### Using uv (recommended)
-
-```bash
-# Install the package
-uv pip install -e .
-
-# Install with dev dependencies
-uv pip install -e ".[dev]"
-```
-
-### Using pip
-
-```bash
-pip install -e .
-```
-
-## Authentication
-
-All functions use `get_workspace_client()` from the `auth` module, which supports multiple authentication methods:
-
-### Authentication Priority
-
-1. **Context variables** (for multi-user apps) - Set via `set_databricks_auth()`
-2. **Environment variables** - `DATABRICKS_HOST` and `DATABRICKS_TOKEN`
-3. **Config profile** - `DATABRICKS_CONFIG_PROFILE` or `~/.databrickscfg`
-
-### Single-User Mode (CLI, Scripts, Notebooks)
-
-```bash
-# Option 1: Environment variables
-export DATABRICKS_HOST="https://your-workspace.cloud.databricks.com"
-export DATABRICKS_TOKEN="your-token"
-
-# Option 2: Config profile
-export DATABRICKS_CONFIG_PROFILE="my-profile"
-```
-
-Then use functions directly:
-
-```python
-from databricks_tools_core.sql import execute_sql
-
-result = execute_sql("SELECT 1") # Uses env vars or config
-```
-
-### Multi-User Mode (Web Apps, APIs)
-
-For applications serving multiple users, use contextvars to set per-request credentials:
-
-```python
-from databricks_tools_core.auth import (
- set_databricks_auth,
- clear_databricks_auth,
- get_workspace_client,
-)
-
-# In your request handler
-async def handle_request(user_host: str, user_token: str):
- set_databricks_auth(user_host, user_token)
- try:
- # All functions now use this user's credentials
- result = execute_sql("SELECT current_user()")
-
- # Or get client directly
- client = get_workspace_client()
- warehouses = client.warehouses.list()
- finally:
- clear_databricks_auth()
-```
-
-**How it works:**
-
-```
-┌─────────────────────────────────────────────────────────────────┐
-│ Authentication Flow │
-├─────────────────────────────────────────────────────────────────┤
-│ │
-│ set_databricks_auth(host, token) │
-│ │ │
-│ ▼ │
-│ ┌─────────────────────────┐ │
-│ │ contextvars │ (async-safe, per-request) │
-│ │ _host_ctx = host │ │
-│ │ _token_ctx = token │ │
-│ └───────────┬─────────────┘ │
-│ │ │
-│ ▼ │
-│ get_workspace_client() │
-│ │ │
-│ ├─── Has context? ──► WorkspaceClient(host, token) │
-│ │ │
-│ └─── No context? ───► WorkspaceClient() (uses env/cfg) │
-│ │
-└─────────────────────────────────────────────────────────────────┘
-```
-
-This pattern is used by `databricks-mcp-app` to handle per-user authentication when deployed as a Databricks App, where each request includes the user's access token in headers.
-
-## Usage
-
-### SQL Execution
-
-Execute SQL queries on Databricks SQL Warehouses:
-
-```python
-from databricks_tools_core.sql import execute_sql, execute_sql_multi
-
-# Simple query (auto-selects warehouse if not specified)
-result = execute_sql("SELECT * FROM my_catalog.my_schema.customers LIMIT 10")
-# Returns: [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}, ...]
-
-# Query with specific warehouse and catalog/schema context
-result = execute_sql(
- sql_query="SELECT COUNT(*) as cnt FROM customers",
- warehouse_id="abc123def456",
- catalog="my_catalog",
- schema="my_schema",
-)
-
-# Execute multiple statements with dependency-aware parallelism
-result = execute_sql_multi(
- sql_content="""
- CREATE TABLE t1 AS SELECT 1 as id;
- CREATE TABLE t2 AS SELECT 2 as id;
- CREATE TABLE t3 AS SELECT * FROM t1 JOIN t2;
- """,
- catalog="my_catalog",
- schema="my_schema",
-)
-# t1 and t2 run in parallel, t3 waits for both
-```
-
-### Warehouse Management
-
-List and select SQL warehouses:
-
-```python
-from databricks_tools_core.sql import list_warehouses, get_best_warehouse
-
-# List warehouses (running ones first)
-warehouses = list_warehouses(limit=20)
-# Returns: [{"id": "...", "name": "...", "state": "RUNNING", ...}, ...]
-
-# Auto-select best available warehouse
-warehouse_id = get_best_warehouse()
-# Prefers: running shared endpoints > running warehouses > stopped warehouses
-```
-
-### Table Statistics
-
-Get detailed table information and column statistics:
-
-```python
-from databricks_tools_core.sql import get_table_stats_and_schema, TableStatLevel
-
-# Get all tables in a schema with basic stats
-result = get_table_stats_and_schema(
- catalog="my_catalog",
- schema="my_schema",
-)
-
-# Get specific tables (faster - no listing required)
-result = get_table_stats_and_schema(
- catalog="my_catalog",
- schema="my_schema",
- table_names=["customers", "orders"],
-)
-
-# Use glob patterns
-result = get_table_stats_and_schema(
- catalog="my_catalog",
- schema="my_schema",
- table_names=["raw_*", "gold_customers"], # Mix of patterns and exact names
-)
-
-# Control stat level
-result = get_table_stats_and_schema(
- catalog="my_catalog",
- schema="my_schema",
- table_names=["customers"],
- table_stat_level=TableStatLevel.DETAILED, # NONE, SIMPLE, or DETAILED
-)
-```
-
-**TableStatLevel options:**
-
-| Level | Description | Use Case |
-|-------|-------------|----------|
-| `NONE` | DDL only, no stats | Quick schema lookup |
-| `SIMPLE` | Basic stats (samples, min/max, cardinality) | Default, cached |
-| `DETAILED` | Full stats (histograms, percentiles, value counts) | Data profiling |
-
-### Jobs
-
-Create, manage, and run Databricks jobs. Uses serverless compute by default for optimal performance and cost.
-
-```python
-from databricks_tools_core.jobs import (
- create_job, get_job, list_jobs, update_job, delete_job,
- run_job_now, get_run, list_runs, cancel_run, wait_for_run,
-)
-
-# Create a job with notebook task (serverless by default)
-tasks = [
- {
- "task_key": "etl_task",
- "notebook_task": {
- "notebook_path": "/Workspace/ETL/process_data",
- "source": "WORKSPACE",
- },
- }
-]
-job = create_job(name="my_etl_job", tasks=tasks)
-print(f"Created job: {job['job_id']}")
-
-# Run the job immediately
-run_id = run_job_now(job_id=job["job_id"])
-
-# Wait for completion (with timeout)
-result = wait_for_run(run_id=run_id, timeout=3600)
-if result.success:
- print(f"Job completed in {result.duration_seconds}s")
-else:
- print(f"Job failed: {result.error_message}")
-
-# List recent runs
-runs = list_runs(job_id=job["job_id"], limit=10)
-
-# Cancel a running job
-cancel_run(run_id=run_id)
-```
-
-**Job Functions:**
-
-| Function | Description |
-|----------|-------------|
-| `create_job()` | Create a new job with tasks and settings |
-| `get_job()` | Get detailed job configuration |
-| `list_jobs()` | List jobs with optional name filter |
-| `find_job_by_name()` | Find job by exact name, returns job ID |
-| `update_job()` | Update job configuration |
-| `delete_job()` | Delete a job |
-
-**Run Functions:**
-
-| Function | Description |
-|----------|-------------|
-| `run_job_now()` | Trigger a job run, returns run ID |
-| `get_run()` | Get run status and details |
-| `get_run_output()` | Get run output and logs |
-| `list_runs()` | List runs with filters |
-| `cancel_run()` | Cancel a running job |
-| `wait_for_run()` | Wait for run completion, returns `JobRunResult` |
-
-**JobRunResult fields:**
-
-| Field | Type | Description |
-|-------|------|-------------|
-| `success` | bool | True if run completed successfully |
-| `lifecycle_state` | str | PENDING, RUNNING, TERMINATING, TERMINATED, SKIPPED, INTERNAL_ERROR |
-| `result_state` | str | SUCCESS, FAILED, TIMEDOUT, CANCELED |
-| `duration_seconds` | float | Total execution time |
-| `error_message` | str | Error message if failed |
-| `run_page_url` | str | Link to run in Databricks UI |
-
-### Unity Catalog Operations
-
-```python
-from databricks_tools_core.unity_catalog import catalogs, schemas, tables
-
-# List catalogs
-all_catalogs = catalogs.list_catalogs()
-
-# Create schema
-schema = schemas.create_schema(
- catalog_name="main",
- schema_name="my_schema",
- comment="Example schema"
-)
-
-# Create table
-from databricks.sdk.service.catalog import ColumnInfo, TableType
-
-table = tables.create_table(
- catalog_name="main",
- schema_name="my_schema",
- table_name="my_table",
- columns=[
- ColumnInfo(name="id", type_name="INT"),
- ColumnInfo(name="value", type_name="STRING")
- ],
- table_type=TableType.MANAGED
-)
-```
-
-## Architecture
-
-```
-databricks-tools-core/
-├── databricks_tools_core/
-│ ├── auth.py # Authentication (contextvars + env vars)
-│ ├── sql/ # SQL operations
-│ │ ├── sql.py # execute_sql, execute_sql_multi
-│ │ ├── warehouse.py # list_warehouses, get_best_warehouse
-│ │ ├── table_stats.py # get_table_stats_and_schema
-│ │ └── sql_utils/ # Internal utilities
-│ │ ├── executor.py # SQLExecutor class
-│ │ ├── parallel_executor.py # Multi-statement execution
-│ │ ├── dependency_analyzer.py # SQL dependency analysis
-│ │ ├── table_stats_collector.py # Stats collection with caching
-│ │ └── models.py # Pydantic models
-│ ├── jobs/ # Job operations
-│ │ ├── jobs.py # create_job, get_job, list_jobs, etc.
-│ │ ├── runs.py # run_job_now, get_run, wait_for_run, etc.
-│ │ └── models.py # JobRunResult, JobError, enums
-│ ├── unity_catalog/ # Unity Catalog operations
-│ ├── compute/ # Compute operations
-│ ├── spark_declarative_pipelines/ # SDP operations
-│ └── client.py # REST API client
-└── tests/ # Integration tests
-```
-
-This is a **pure Python library** with no MCP protocol dependencies. It can be used standalone in notebooks, scripts, or other Python projects.
-
-For MCP server functionality, see the `databricks-mcp-server` package which wraps these functions as MCP tools.
-
-## Testing
-
-The project includes comprehensive integration tests that run against a real Databricks workspace.
-
-### Prerequisites
-
-- A Databricks workspace with valid authentication configured
-- At least one running SQL warehouse
-- Permission to create catalogs/schemas/tables
-
-### Running Tests
-
-```bash
-# Install dev dependencies
-uv pip install -e ".[dev]"
-
-# Run all integration tests
-uv run pytest tests/integration/ -v
-
-# Run specific test file
-uv run pytest tests/integration/sql/test_sql.py -v
-
-# Run specific test class
-uv run pytest tests/integration/sql/test_table_stats.py::TestTableStatLevelDetailed -v
-
-# Run with more verbose output
-uv run pytest tests/integration/ -v --tb=long
-```
-
-### Test Structure
-
-```
-tests/
-├── conftest.py # Shared fixtures
-│ ├── workspace_client # WorkspaceClient fixture
-│ ├── test_catalog # Creates ai_dev_kit_test catalog
-│ ├── test_schema # Creates fresh test_schema (drops if exists)
-│ ├── warehouse_id # Gets best running warehouse
-│ └── test_tables # Creates sample tables with data
-└── integration/
- ├── sql/
- │ ├── test_warehouse.py # Warehouse listing tests
- │ ├── test_sql.py # SQL execution tests
- │ └── test_table_stats.py # Table statistics tests
- └── jobs/
- ├── conftest.py # Jobs-specific fixtures (test notebook, cleanup)
- ├── test_jobs.py # Job CRUD tests
- └── test_runs.py # Run operation tests
-```
-
-### Test Coverage
-
-| Test File | Coverage |
-|-----------|----------|
-| `test_warehouse.py` | `list_warehouses`, `get_best_warehouse` |
-| `test_sql.py` | `execute_sql`, `execute_sql_multi`, error handling, parallel execution |
-| `test_table_stats.py` | `get_table_stats_and_schema`, all stat levels, glob patterns, caching |
-| `test_jobs.py` | `list_jobs`, `find_job_by_name`, `create_job`, `get_job`, `update_job`, `delete_job` |
-| `test_runs.py` | `run_job_now`, `get_run`, `cancel_run`, `list_runs`, `wait_for_run` |
-
-### Test Fixtures
-
-The test suite uses session-scoped fixtures to minimize setup overhead:
-
-- **`test_catalog`**: Creates `ai_dev_kit_test` catalog (reuses if exists)
-- **`test_schema`**: Drops and recreates `test_schema` for clean state
-- **`test_tables`**: Creates `customers`, `orders`, `products` tables with sample data
-
-## Development
-
-```bash
-# Install dev dependencies
-uv pip install -e ".[dev]"
-
-# Run tests
-uv run pytest
-
-# Format code
-uv run black databricks_tools_core/
-
-# Lint code
-uv run ruff check databricks_tools_core/
-```
-
-## Dependencies
-
-### Core
-- `databricks-sdk>=0.20.0` - Official Databricks Python SDK
-- `requests>=2.31.0` - HTTP client
-- `pydantic>=2.0.0` - Data validation
-- `sqlglot>=20.0.0` - SQL parsing
-- `sqlfluff>=3.0.0` - SQL linting and formatting
-
-### Development
-- `pytest>=7.0.0` - Testing framework
-- `pytest-timeout>=2.0.0` - Test timeouts
-- `black>=23.0.0` - Code formatting
-- `ruff>=0.1.0` - Linting
diff --git a/databricks-tools-core/databricks_tools_core/__init__.py b/databricks-tools-core/databricks_tools_core/__init__.py
deleted file mode 100644
index ae00930d..00000000
--- a/databricks-tools-core/databricks_tools_core/__init__.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""
-Databricks Tools Core Library
-
-High-level, AI-assistant-friendly functions for building Databricks projects.
-Organized by product line for scalability.
-"""
-
-__version__ = "0.1.0"
-
-# Auth utilities
-from .auth import get_workspace_client, set_databricks_auth, clear_databricks_auth, get_current_username
-
-__all__ = [
- "get_workspace_client",
- "set_databricks_auth",
- "clear_databricks_auth",
- "get_current_username",
-]
diff --git a/databricks-tools-core/databricks_tools_core/agent_bricks/__init__.py b/databricks-tools-core/databricks_tools_core/agent_bricks/__init__.py
deleted file mode 100644
index 95f2de86..00000000
--- a/databricks-tools-core/databricks_tools_core/agent_bricks/__init__.py
+++ /dev/null
@@ -1,78 +0,0 @@
-"""
-Agent Bricks - Manage Genie Spaces, Knowledge Assistants, and Supervisor Agents.
-
-This module provides a unified interface for managing Agent Bricks resources:
-- Knowledge Assistants (KA): Document-based Q&A systems
-- Supervisor Agents (MAS): Multi-agent orchestration
-- Genie Spaces: SQL-based data exploration
-"""
-
-from .manager import AgentBricksManager, TileExampleQueue, get_tile_example_queue
-from .models import (
- # Enums
- EndpointStatus,
- Permission,
- TileType,
- # Data classes
- GenieIds,
- KAIds,
- MASIds,
- # TypedDicts
- BaseAgentDict,
- CuratedQuestionDict,
- EvaluationRunDict,
- GenieListInstructionsResponseDict,
- GenieListQuestionsResponseDict,
- GenieSpaceDict,
- InstructionDict,
- KnowledgeAssistantDict,
- KnowledgeAssistantExampleDict,
- KnowledgeAssistantListExamplesResponseDict,
- KnowledgeAssistantResponseDict,
- KnowledgeAssistantStatusDict,
- KnowledgeSourceDict,
- ListEvaluationRunsResponseDict,
- MultiAgentSupervisorDict,
- MultiAgentSupervisorExampleDict,
- MultiAgentSupervisorListExamplesResponseDict,
- MultiAgentSupervisorResponseDict,
- MultiAgentSupervisorStatusDict,
- TileDict,
-)
-
-__all__ = [
- # Main class
- "AgentBricksManager",
- # Background queue
- "TileExampleQueue",
- "get_tile_example_queue",
- # Enums
- "EndpointStatus",
- "Permission",
- "TileType",
- # Data classes
- "GenieIds",
- "KAIds",
- "MASIds",
- # TypedDicts
- "BaseAgentDict",
- "CuratedQuestionDict",
- "EvaluationRunDict",
- "GenieListInstructionsResponseDict",
- "GenieListQuestionsResponseDict",
- "GenieSpaceDict",
- "InstructionDict",
- "KnowledgeAssistantDict",
- "KnowledgeAssistantExampleDict",
- "KnowledgeAssistantListExamplesResponseDict",
- "KnowledgeAssistantResponseDict",
- "KnowledgeAssistantStatusDict",
- "KnowledgeSourceDict",
- "ListEvaluationRunsResponseDict",
- "MultiAgentSupervisorDict",
- "MultiAgentSupervisorExampleDict",
- "MultiAgentSupervisorListExamplesResponseDict",
- "MultiAgentSupervisorResponseDict",
- "MultiAgentSupervisorStatusDict",
- "TileDict",
-]
diff --git a/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py b/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py
deleted file mode 100644
index c9e279bc..00000000
--- a/databricks-tools-core/databricks_tools_core/agent_bricks/manager.py
+++ /dev/null
@@ -1,1609 +0,0 @@
-"""
-Agent Bricks Manager - Manage Genie Spaces, Knowledge Assistants, and Supervisor Agents.
-
-Unified wrapper for Agent Bricks tiles with operations for:
-- Knowledge Assistants (KA): Document-based Q&A systems
-- Supervisor Agents (MAS): Multi-agent orchestration
-- Genie Spaces: SQL-based data exploration
-"""
-
-import json
-import logging
-import re
-import threading
-import time
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from typing import Any, Dict, List, Optional, Tuple
-
-import requests
-from databricks.sdk import WorkspaceClient
-
-from ..auth import get_workspace_client, get_current_username
-from .models import (
- EndpointStatus,
- GenieIds,
- GenieListInstructionsResponseDict,
- GenieListQuestionsResponseDict,
- GenieSpaceDict,
- KAIds,
- KnowledgeAssistantListExamplesResponseDict,
- KnowledgeAssistantResponseDict,
- ListEvaluationRunsResponseDict,
- MASIds,
- MultiAgentSupervisorListExamplesResponseDict,
- MultiAgentSupervisorResponseDict,
- Permission,
- TileType,
-)
-
-logger = logging.getLogger(__name__)
-
-
-class AgentBricksManager:
- """Unified wrapper for Agent Bricks tiles.
-
- Works with:
- - /2.0/knowledge-assistants* (KA)
- - /2.0/multi-agent-supervisors* (MAS)
- - /2.0/data-rooms* (Genie)
- - /2.0/tiles* (common operations)
-
- Key operations:
- Common tile ops:
- - delete(tile_id): Delete any tile
- - share(tile_id, changes): Share tile with users/groups
-
- KA-specific ops (prefixed with ka_):
- - ka_create(): Create KA with knowledge sources
- - ka_get(): Get KA by tile_id
- - ka_update(): Update KA configuration
- - ka_create_or_update(): Create or update a KA
- - ka_sync_sources(): Trigger re-index
- - ka_get_endpoint_status(): Get endpoint status
- - ka_add_examples_batch(): Add example questions
-
- MAS-specific ops (prefixed with mas_):
- - mas_create(): Create MAS with agents
- - mas_get(): Get MAS by tile_id
- - mas_update(): Update MAS configuration
- - mas_get_endpoint_status(): Get endpoint status
- - mas_add_examples_batch(): Add example questions
-
- Genie-specific ops (prefixed with genie_):
- - genie_create(): Create Genie space
- - genie_get(): Get Genie space
- - genie_update(): Update Genie space
- - genie_delete(): Delete Genie space
- - genie_export(): Export space with full serialized config
- - genie_import(): Create new space from serialized payload
- - genie_update_with_serialized_space(): Full update via serialized payload
- - genie_add_sample_questions_batch(): Add sample questions
- - genie_add_sql_instructions_batch(): Add SQL examples
- - genie_add_benchmarks_batch(): Add benchmarks
- """
-
- def __init__(
- self,
- client: Optional[WorkspaceClient] = None,
- default_timeout_s: int = 600,
- default_poll_s: float = 2.0,
- ):
- """
- Initialize the Agent Bricks Manager.
-
- Args:
- client: Optional WorkspaceClient (creates new one if not provided)
- default_timeout_s: Default timeout for polling operations
- default_poll_s: Default poll interval in seconds
- """
- self.w: WorkspaceClient = client or get_workspace_client()
- self.default_timeout_s = default_timeout_s
- self.default_poll_s = default_poll_s
-
- @staticmethod
- def sanitize_name(name: str) -> str:
- """Sanitize a name to ensure it's alphanumeric with only hyphens and underscores.
-
- Args:
- name: The original name
-
- Returns:
- Sanitized name that complies with Databricks naming requirements
- """
- # Replace spaces with underscores
- sanitized = name.replace(" ", "_")
-
- # Replace any character that is not alphanumeric, hyphen, or underscore
- sanitized = re.sub(r"[^a-zA-Z0-9_-]", "_", sanitized)
-
- # Remove consecutive underscores or hyphens
- sanitized = re.sub(r"[_-]{2,}", "_", sanitized)
-
- # Remove leading/trailing underscores or hyphens
- sanitized = sanitized.strip("_-")
-
- # If the name is empty after sanitization, use a default
- if not sanitized:
- sanitized = "knowledge_assistant"
-
- logger.debug(f"Sanitized name: '{name}' -> '{sanitized}'")
- return sanitized
-
- # ========================================================================
- # Common Tile Operations
- # ========================================================================
-
- def delete(self, tile_id: str) -> None:
- """Delete any tile (KA or MAS) by ID."""
- self._delete(f"/api/2.0/tiles/{tile_id}")
-
- def share(self, tile_id: str, changes: List[Dict[str, Any]]) -> None:
- """Share a tile with specified permissions.
-
- Args:
- tile_id: The tile ID
- changes: List of permission changes, each containing:
- - principal: User/group (e.g., "users:email@company.com")
- - add: List of permissions to grant
- - remove: List of permissions to revoke
-
- Example:
- >>> manager.share(
- ... tile_id,
- ... changes=[
- ... {
- ... "principal": "users:john@company.com",
- ... "add": [Permission.CAN_READ, Permission.CAN_RUN],
- ... "remove": []
- ... }
- ... ]
- ... )
- """
- # Convert Permission enums to strings
- processed_changes = []
- for change in changes:
- processed_change = {
- "principal": change["principal"],
- "add": [p.value if isinstance(p, Permission) else p for p in change.get("add", [])],
- "remove": [p.value if isinstance(p, Permission) else p for p in change.get("remove", [])],
- }
- processed_changes.append(processed_change)
-
- self._post(
- f"/api/2.0/knowledge-assistants/{tile_id}/share",
- {"changes": processed_changes},
- )
-
- # ========================================================================
- # Discovery & Listing
- # ========================================================================
-
- def list_all_agent_bricks(self, tile_type: Optional[TileType] = None, page_size: int = 100) -> List[Dict[str, Any]]:
- """List all agent bricks (tiles) in the workspace.
-
- Args:
- tile_type: Specific tile type to filter for. If None, returns all.
- page_size: Number of results per page.
-
- Returns:
- List of Tile dictionaries.
- """
- all_tiles = []
-
- # Build filter query
- filter_q = f"tile_type={tile_type.name}" if tile_type else None
- page_token = None
-
- while True:
- params = {"page_size": page_size}
- if filter_q:
- params["filter"] = filter_q
- if page_token:
- params["page_token"] = page_token
-
- resp = self._get("/api/2.0/tiles", params=params)
-
- for tile in resp.get("tiles", []):
- if tile_type:
- tile_type_value = tile.get("tile_type")
- if tile_type_value == tile_type.value or tile_type_value == tile_type.name:
- all_tiles.append(tile)
- else:
- all_tiles.append(tile)
-
- page_token = resp.get("next_page_token")
- if not page_token:
- break
-
- return all_tiles
-
- def find_by_name(self, name: str) -> Optional[KAIds]:
- """Find a KA by exact display name.
-
- Note: Names are sanitized (spaces→underscores) before lookup to match
- how the API stores them.
- """
- sanitized_name = self.sanitize_name(name)
- filter_q = f"name_contains={sanitized_name}&&tile_type=KA"
- page_token = None
- while True:
- params = {"filter": filter_q}
- if page_token:
- params["page_token"] = page_token
- resp = self._get("/api/2.0/tiles", params=params)
- for t in resp.get("tiles", []):
- if t.get("name") == sanitized_name:
- return KAIds(tile_id=t["tile_id"], name=sanitized_name)
- page_token = resp.get("next_page_token")
- if not page_token:
- break
- return None
-
- def mas_find_by_name(self, name: str) -> Optional[MASIds]:
- """Find a MAS by exact display name.
-
- Note: Names are sanitized (spaces→underscores) before lookup to match
- how the API stores them.
- """
- sanitized_name = self.sanitize_name(name)
- filter_q = f"name_contains={sanitized_name}&&tile_type=MAS"
- page_token = None
- while True:
- params = {"filter": filter_q}
- if page_token:
- params["page_token"] = page_token
- resp = self._get("/api/2.0/tiles", params=params)
- for t in resp.get("tiles", []):
- if t.get("name") == sanitized_name:
- return MASIds(tile_id=t["tile_id"], name=sanitized_name)
- page_token = resp.get("next_page_token")
- if not page_token:
- break
- return None
-
- def genie_find_by_name(self, display_name: str) -> Optional[GenieIds]:
- """Find a Genie space by exact display name."""
- page_token = None
- while True:
- params = {}
- if page_token:
- params["page_token"] = page_token
- resp = self._get("/api/2.0/data-rooms", params=params)
- for space in resp.get("data_rooms", []):
- if space.get("display_name") == display_name:
- return GenieIds(space_id=space["space_id"], display_name=display_name)
- page_token = resp.get("next_page_token")
- if not page_token:
- break
- return None
-
- # ========================================================================
- # Knowledge Assistant (KA) Operations
- # ========================================================================
-
- def ka_create(
- self,
- name: str,
- knowledge_sources: List[Dict[str, Any]],
- description: Optional[str] = None,
- instructions: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Create a Knowledge Assistant with specified knowledge sources.
-
- Uses SDK's create_knowledge_assistant and create_knowledge_source.
-
- Args:
- name: Name for the KA
- knowledge_sources: List of knowledge source dictionaries:
- {
- "files_source": {
- "name": "source_name",
- "type": "files",
- "files": {"path": "/Volumes/catalog/schema/path"}
- }
- }
- description: Optional description
- instructions: Optional instructions
-
- Returns:
- Dict with tile_id, name, endpoint_name, and created sources
- """
- from databricks.sdk.service.knowledgeassistants import (
- KnowledgeAssistant as SDKKnowledgeAssistant,
- KnowledgeSource as SDKKnowledgeSource,
- FilesSpec,
- )
-
- sanitized_name = self.sanitize_name(name)
-
- # Create KA via SDK
- ka_obj = SDKKnowledgeAssistant(
- display_name=sanitized_name,
- description=description or "",
- instructions=instructions,
- )
- logger.debug(f"Creating KA with SDK: {ka_obj}")
- created_ka = self.w.knowledge_assistants.create_knowledge_assistant(ka_obj)
-
- # Add knowledge sources
- created_sources = []
- for source_dict in knowledge_sources:
- files_source = source_dict.get("files_source", {})
- if files_source:
- source_name = files_source.get("name", f"source_{sanitized_name}")
- source_path = files_source.get("files", {}).get("path", "")
- # API requires non-empty description - provide default if not specified
- source_description = files_source.get("description") or f"Knowledge source from {source_path}"
- source_obj = SDKKnowledgeSource(
- display_name=source_name,
- description=source_description,
- source_type="files",
- files=FilesSpec(path=source_path),
- )
- created_source = self.w.knowledge_assistants.create_knowledge_source(
- parent=created_ka.name,
- knowledge_source=source_obj,
- )
- created_sources.append(created_source)
-
- return {
- "tile_id": created_ka.id,
- "name": created_ka.display_name,
- "endpoint_name": created_ka.endpoint_name,
- "description": created_ka.description,
- "instructions": created_ka.instructions,
- "state": created_ka.state.value if created_ka.state else None,
- "sources": [{"id": s.id, "name": s.display_name, "path": s.files.path if s.files else None} for s in created_sources],
- }
-
- def ka_get(self, tile_id: str) -> Optional[Dict[str, Any]]:
- """Get KA by tile_id using SDK.
-
- Returns:
- Dict with KA info or None if not found.
- """
- try:
- ka = self.w.knowledge_assistants.get_knowledge_assistant(f"knowledge-assistants/{tile_id}")
- sources = list(self.w.knowledge_assistants.list_knowledge_sources(f"knowledge-assistants/{tile_id}"))
-
- return {
- "tile_id": ka.id,
- "name": ka.display_name,
- "endpoint_name": ka.endpoint_name,
- "description": ka.description,
- "instructions": ka.instructions,
- "state": ka.state.value if ka.state else None,
- "creator": ka.creator,
- "experiment_id": ka.experiment_id,
- "sources": [
- {
- "id": s.id,
- "name": s.display_name,
- "source_type": s.source_type,
- "path": s.files.path if s.files else None,
- "state": s.state.value if s.state else None,
- }
- for s in sources
- ],
- }
- except Exception as e:
- if "does not exist" in str(e).lower() or "not found" in str(e).lower():
- return None
- raise
-
- def ka_get_endpoint_status(self, tile_id: str) -> Optional[str]:
- """Get the state of a Knowledge Assistant.
-
- Returns:
- State string (ACTIVE, CREATING, FAILED) or None
- """
- ka = self.ka_get(tile_id)
- if not ka:
- return None
- return ka.get("state")
-
- def ka_is_ready_for_update(self, tile_id: str) -> bool:
- """Check if a KA is ready to be updated (state is ACTIVE)."""
- status = self.ka_get_endpoint_status(tile_id)
- return status == "ACTIVE"
-
- def ka_wait_for_ready_status(self, tile_id: str, timeout: int = 60, poll_interval: int = 5) -> bool:
- """Wait for a KA to be ready for updates.
-
- Returns:
- True if ready within timeout, False otherwise.
- """
- start_time = time.time()
- while time.time() - start_time < timeout:
- if self.ka_is_ready_for_update(tile_id):
- logger.info(f"KA {tile_id} is ready (status: {EndpointStatus.ONLINE.value})")
- return True
- current_status = self.ka_get_endpoint_status(tile_id)
- logger.info(f"KA {tile_id} status: {current_status}, waiting...")
- time.sleep(poll_interval)
-
- logger.warning(f"Timeout waiting for KA {tile_id} to be ready")
- return False
-
- def ka_update(
- self,
- tile_id: str,
- name: Optional[str] = None,
- description: Optional[str] = None,
- instructions: Optional[str] = None,
- knowledge_sources: Optional[List[Dict[str, Any]]] = None,
- ) -> Dict[str, Any]:
- """Update KA metadata and/or knowledge sources.
-
- Uses SDK's update_knowledge_assistant for metadata updates (API 2.1).
- Knowledge source updates require separate create/delete operations.
-
- Args:
- tile_id: The KA tile ID
- name: Optional new display name
- description: Optional new description
- instructions: Optional new instructions
- knowledge_sources: Optional new sources (currently ignored on update)
-
- Returns:
- Updated KA data
- """
- # Update metadata if provided using API 2.1 endpoint
- # Note: We use raw API call because SDK has a bug where FieldMask converts
- # display_name to displayName (camelCase), but the API expects snake_case.
- if name is not None or description is not None or instructions is not None:
- update_fields = []
- body: Dict[str, Any] = {}
-
- if name is not None:
- body["display_name"] = name
- update_fields.append("display_name")
- if description is not None:
- body["description"] = description
- update_fields.append("description")
- if instructions is not None:
- body["instructions"] = instructions
- update_fields.append("instructions")
-
- if update_fields:
- self._patch(
- f"/api/2.1/knowledge-assistants/{tile_id}",
- body,
- params={"update_mask": ",".join(update_fields)},
- )
-
- # Note: Knowledge source updates require separate SDK calls
- # (create_knowledge_source / delete_knowledge_source)
- # For now, we skip source updates on existing KAs
- if knowledge_sources is not None:
- logger.debug(
- "Knowledge source updates on existing KAs not yet implemented via SDK. "
- "Sources will remain unchanged."
- )
-
- return self.ka_get(tile_id)
-
- def ka_create_or_update(
- self,
- name: str,
- knowledge_sources: List[Dict[str, Any]],
- description: Optional[str] = None,
- instructions: Optional[str] = None,
- tile_id: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Create or update a Knowledge Assistant.
-
- Args:
- name: Name for the KA
- knowledge_sources: List of knowledge source dictionaries
- description: Optional description
- instructions: Optional instructions
- tile_id: Optional existing tile_id to update
-
- Returns:
- KA data with 'operation' field ('created' or 'updated')
- """
- sanitized_name = self.sanitize_name(name)
- existing_ka = None
- operation = "created"
-
- if tile_id:
- existing_ka = self.ka_get(tile_id)
- if existing_ka:
- operation = "updated"
- else:
- # No tile_id provided - check if KA exists by name
- found = self.find_by_name(name)
- if found:
- tile_id = found.tile_id
- existing_ka = self.ka_get(tile_id)
- if existing_ka:
- operation = "updated"
-
- if existing_ka:
- if not self.ka_is_ready_for_update(tile_id):
- current_status = self.ka_get_endpoint_status(tile_id)
- raise Exception(
- f"Knowledge Assistant {tile_id} is not ready for update "
- f"(status: {current_status}). Wait or delete and create new."
- )
-
- result = self.ka_update(
- tile_id,
- name=sanitized_name,
- description=description,
- instructions=instructions,
- knowledge_sources=knowledge_sources,
- )
- else:
- result = self.ka_create(
- name=sanitized_name,
- knowledge_sources=knowledge_sources,
- description=description,
- instructions=instructions,
- )
-
- if result:
- result["operation"] = operation
- return result
-
- def ka_sync_sources(self, tile_id: str) -> None:
- """Trigger indexing/sync of all knowledge sources."""
- self.w.knowledge_assistants.sync_knowledge_sources(f"knowledge-assistants/{tile_id}")
-
- def ka_reconcile_model(self, tile_id: str) -> None:
- """Reconcile KA to latest model."""
- self._patch(f"/api/2.0/knowledge-assistants/{tile_id}/reconcile-model", {})
-
- def ka_wait_until_ready(
- self, tile_id: str, timeout_s: Optional[int] = None, poll_s: Optional[float] = None
- ) -> Dict[str, Any]:
- """Wait until KA is ready (not in CREATING state)."""
- timeout_s = timeout_s or self.default_timeout_s
- poll_s = poll_s or self.default_poll_s
- deadline = time.time() + timeout_s
-
- while True:
- ka = self.ka_get(tile_id)
- status = ka.get("state") if ka else None
- if status and status != "CREATING":
- return ka
- if time.time() >= deadline:
- return ka
- time.sleep(poll_s)
-
- def ka_wait_until_active(
- self, tile_id: str, timeout_s: Optional[int] = None, poll_s: Optional[float] = None
- ) -> Dict[str, Any]:
- """Wait for state==ACTIVE."""
- timeout_s = timeout_s or self.default_timeout_s
- poll_s = poll_s or self.default_poll_s
- deadline = time.time() + timeout_s
- start_time = time.time()
- last_status = None
- ka = None
-
- while True:
- try:
- ka = self.ka_get(tile_id)
- status = ka.get("state") if ka else None
-
- if status != last_status:
- elapsed = int(time.time() - start_time)
- logger.info(f"[{elapsed}s] KA state: {last_status} -> {status}")
- last_status = status
-
- if status == "ACTIVE":
- return ka
- except Exception as e:
- elapsed = int(time.time() - start_time)
- if "does not exist" in str(e) and elapsed < 60:
- logger.debug(f"[{elapsed}s] KA not yet available, waiting...")
- else:
- raise
-
- if time.time() >= deadline:
- if ka:
- return ka
- raise TimeoutError(f"KA {tile_id} was not found within {timeout_s} seconds")
- time.sleep(poll_s)
-
- # Alias for backward compatibility
- def ka_wait_until_endpoint_online(
- self, tile_id: str, timeout_s: Optional[int] = None, poll_s: Optional[float] = None
- ) -> Dict[str, Any]:
- """Wait for KA to be active. Alias for ka_wait_until_active."""
- return self.ka_wait_until_active(tile_id, timeout_s, poll_s)
-
- # ========================================================================
- # KA Examples Management
- # ========================================================================
-
- def ka_create_example(self, tile_id: str, question: str, guidelines: Optional[List[str]] = None) -> Dict[str, Any]:
- """Create an example question for the KA."""
- payload = {"tile_id": tile_id, "question": question}
- if guidelines:
- payload["guidelines"] = guidelines
- return self._post(f"/api/2.0/knowledge-assistants/{tile_id}/examples", payload)
-
- def ka_list_examples(
- self, tile_id: str, page_size: int = 100, page_token: Optional[str] = None
- ) -> KnowledgeAssistantListExamplesResponseDict:
- """List all examples for a KA."""
- params = {"page_size": page_size}
- if page_token:
- params["page_token"] = page_token
- return self._get(f"/api/2.0/knowledge-assistants/{tile_id}/examples", params=params)
-
- def ka_delete_example(self, tile_id: str, example_id: str) -> None:
- """Delete an example from the KA."""
- self._delete(f"/api/2.0/knowledge-assistants/{tile_id}/examples/{example_id}")
-
- def ka_add_examples_batch(self, tile_id: str, questions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Add multiple example questions in parallel.
-
- Args:
- tile_id: The KA tile ID
- questions: List of {'question': str, 'guideline': Optional[str]}
-
- Returns:
- List of created examples
- """
- created_examples = []
-
- def create_example(q: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- question_text = q.get("question", "")
- guideline = q.get("guideline")
- guidelines = [guideline] if guideline else None
-
- if not question_text:
- return None
- try:
- example = self.ka_create_example(tile_id, question_text, guidelines)
- logger.info(f"Added example: {question_text[:50]}...")
- return example
- except Exception as e:
- logger.error(f"Failed to add example '{question_text[:50]}...': {e}")
- return None
-
- max_workers = min(2, len(questions))
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_q = {executor.submit(create_example, q): q for q in questions}
- for future in as_completed(future_to_q):
- result = future.result()
- if result:
- created_examples.append(result)
-
- return created_examples
-
- def ka_list_evaluation_runs(
- self, tile_id: str, page_size: int = 100, page_token: Optional[str] = None
- ) -> ListEvaluationRunsResponseDict:
- """List all evaluation runs for a KA."""
- params = {"page_size": page_size}
- if page_token:
- params["page_token"] = page_token
- return self._get(f"/api/2.0/tiles/{tile_id}/evaluation-runs", params=params)
-
- # ========================================================================
- # KA Helper Methods
- # ========================================================================
-
- @staticmethod
- def ka_get_knowledge_sources_from_volumes(
- volume_paths: List[Tuple[str, Optional[str]]],
- ) -> List[Dict[str, Any]]:
- """Convert volume paths to knowledge source dictionaries.
-
- Args:
- volume_paths: List of (volume_path, description) tuples
-
- Returns:
- List of knowledge source dictionaries for KA API
-
- Example:
- >>> paths = [
- ... ('/Volumes/main/default/docs', 'Documentation'),
- ... ('/Volumes/main/default/guides', None)
- ... ]
- >>> sources = AgentBricksManager.ka_get_knowledge_sources_from_volumes(paths)
- """
- knowledge_sources = []
-
- for idx, (volume_path, _description) in enumerate(volume_paths):
- path_parts = volume_path.rstrip("/").split("/")
- source_name = path_parts[-1] if path_parts else f"source_{idx + 1}"
- source_name = source_name.replace(" ", "_").replace(".", "_")
-
- knowledge_source = {
- "files_source": {
- "name": source_name,
- "type": "files",
- "files": {"path": volume_path},
- }
- }
- knowledge_sources.append(knowledge_source)
-
- return knowledge_sources
-
- # ========================================================================
- # Supervisor Agent (MAS) Operations
- # ========================================================================
-
- def mas_create(
- self,
- name: str,
- agents: List[Dict[str, Any]],
- description: Optional[str] = None,
- instructions: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Create a Supervisor Agent with specified agents.
-
- Args:
- name: Name for the Supervisor Agent
- agents: List of agent configurations (BaseAgent format):
- {
- "name": "Agent Name",
- "description": "Description",
- "agent_type": "genie", # or "ka", "app", etc.
- "genie_space": {"id": "space_id"} # or serving_endpoint, app
- }
- description: Optional description
- instructions: Optional instructions
-
- Returns:
- Supervisor Agent creation response
- """
- payload = {"name": self.sanitize_name(name), "agents": agents}
- if description:
- payload["description"] = description
- if instructions:
- payload["instructions"] = instructions
-
- logger.info(f"Creating MAS with name={name}, {len(agents)} agents")
- return self._post("/api/2.0/multi-agent-supervisors", payload)
-
- def mas_get(self, tile_id: str) -> Optional[MultiAgentSupervisorResponseDict]:
- """Get MAS by tile_id."""
- try:
- return self._get(f"/api/2.0/multi-agent-supervisors/{tile_id}")
- except Exception as e:
- if "does not exist" in str(e).lower() or "not found" in str(e).lower():
- return None
- raise
-
- def mas_update(
- self,
- tile_id: str,
- name: Optional[str] = None,
- description: Optional[str] = None,
- instructions: Optional[str] = None,
- agents: Optional[List[Dict[str, Any]]] = None,
- ) -> Dict[str, Any]:
- """Update a Supervisor Agent."""
- payload = {"tile_id": tile_id}
- if name:
- payload["name"] = self.sanitize_name(name)
- if description:
- payload["description"] = description
- if instructions:
- payload["instructions"] = instructions
- if agents:
- payload["agents"] = agents
-
- logger.info(f"Updating MAS {tile_id}")
- return self._patch(f"/api/2.0/multi-agent-supervisors/{tile_id}", payload)
-
- def mas_get_endpoint_status(self, tile_id: str) -> Optional[str]:
- """Get the endpoint status of a MAS."""
- mas = self.mas_get(tile_id)
- if not mas:
- return None
- return mas.get("multi_agent_supervisor", {}).get("status", {}).get("endpoint_status")
-
- # ========================================================================
- # MAS Examples Management
- # ========================================================================
-
- def mas_create_example(self, tile_id: str, question: str, guidelines: Optional[List[str]] = None) -> Dict[str, Any]:
- """Create an example question for the MAS."""
- payload = {"tile_id": tile_id, "question": question}
- if guidelines:
- payload["guidelines"] = guidelines
- return self._post(f"/api/2.0/multi-agent-supervisors/{tile_id}/examples", payload)
-
- def mas_list_examples(
- self, tile_id: str, page_size: int = 100, page_token: Optional[str] = None
- ) -> MultiAgentSupervisorListExamplesResponseDict:
- """List all examples for a MAS."""
- params = {"page_size": page_size}
- if page_token:
- params["page_token"] = page_token
- return self._get(f"/api/2.0/multi-agent-supervisors/{tile_id}/examples", params=params)
-
- def mas_update_example(
- self,
- tile_id: str,
- example_id: str,
- question: Optional[str] = None,
- guidelines: Optional[List[str]] = None,
- ) -> Dict[str, Any]:
- """Update an example in a MAS."""
- payload = {"tile_id": tile_id, "example_id": example_id}
- if question:
- payload["question"] = question
- if guidelines:
- payload["guidelines"] = guidelines
- return self._patch(f"/api/2.0/multi-agent-supervisors/{tile_id}/examples/{example_id}", payload)
-
- def mas_delete_example(self, tile_id: str, example_id: str) -> None:
- """Delete an example from the MAS."""
- self._delete(f"/api/2.0/multi-agent-supervisors/{tile_id}/examples/{example_id}")
-
- def mas_add_examples_batch(self, tile_id: str, questions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Add multiple example questions in parallel."""
- created_examples = []
-
- def create_example(q: Dict[str, Any]) -> Optional[Dict[str, Any]]:
- question_text = q.get("question", "")
- guidelines = q.get("guideline")
- if guidelines and isinstance(guidelines, str):
- guidelines = [guidelines]
-
- if not question_text:
- return None
- try:
- example = self.mas_create_example(tile_id, question_text, guidelines)
- logger.info(f"Added MAS example: {question_text[:50]}...")
- return example
- except Exception as e:
- logger.error(f"Failed to add MAS example '{question_text[:50]}...': {e}")
- return None
-
- max_workers = min(2, len(questions))
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_q = {executor.submit(create_example, q): q for q in questions}
- for future in as_completed(future_to_q):
- result = future.result()
- if result:
- created_examples.append(result)
-
- return created_examples
-
- def mas_list_evaluation_runs(
- self, tile_id: str, page_size: int = 100, page_token: Optional[str] = None
- ) -> ListEvaluationRunsResponseDict:
- """List all evaluation runs for a MAS."""
- params = {"page_size": page_size}
- if page_token:
- params["page_token"] = page_token
- return self._get(f"/api/2.0/tiles/{tile_id}/evaluation-runs", params=params)
-
- # ========================================================================
- # Genie Space Operations
- # ========================================================================
-
- def genie_get(self, space_id: str) -> Optional[GenieSpaceDict]:
- """Get Genie space by ID."""
- try:
- return self._get(f"/api/2.0/data-rooms/{space_id}")
- except Exception as e:
- if "does not exist" in str(e).lower() or "not found" in str(e).lower():
- return None
- raise
-
- def genie_create(
- self,
- display_name: str,
- warehouse_id: str,
- table_identifiers: List[str],
- description: Optional[str] = None,
- parent_folder_path: Optional[str] = None,
- parent_folder_id: Optional[str] = None,
- create_dir: bool = True,
- run_as_type: str = "VIEWER",
- ) -> Dict[str, Any]:
- """Create a Genie space.
-
- Args:
- display_name: Display name for the space
- warehouse_id: SQL warehouse ID to use
- table_identifiers: List of tables (e.g., ["catalog.schema.table"])
- description: Optional description
- parent_folder_path: Optional workspace folder path
- parent_folder_id: Optional parent folder ID
- create_dir: Whether to create parent folder if missing
- run_as_type: Run as type (default: "VIEWER")
-
- Returns:
- Created Genie space data
- """
- if parent_folder_path and parent_folder_id:
- raise ValueError("Cannot specify both parent_folder_path and parent_folder_id")
-
- room_payload = {
- "display_name": display_name,
- "warehouse_id": warehouse_id,
- "table_identifiers": table_identifiers,
- "run_as_type": run_as_type,
- }
-
- if description:
- room_payload["description"] = description
-
- # Resolve parent folder
- if parent_folder_path:
- if create_dir:
- try:
- self.w.workspace.mkdirs(parent_folder_path)
- except Exception as e:
- logger.warning(f"Could not create directory {parent_folder_path}: {e}")
- raise
-
- try:
- folder_status = self._get("/api/2.0/workspace/get-status", params={"path": parent_folder_path})
- parent_folder_id = folder_status["object_id"]
- except Exception as e:
- raise ValueError(f"Failed to get folder ID for path '{parent_folder_path}': {str(e)}")
-
- if parent_folder_id:
- room_payload["parent_folder"] = f"folders/{parent_folder_id}"
-
- return self._post("/api/2.0/data-rooms/", room_payload)
-
- def genie_update(
- self,
- space_id: str,
- display_name: Optional[str] = None,
- description: Optional[str] = None,
- warehouse_id: Optional[str] = None,
- table_identifiers: Optional[List[str]] = None,
- sample_questions: Optional[List[str]] = None,
- ) -> Dict[str, Any]:
- """Update a Genie space.
-
- Args:
- space_id: The Genie space ID
- display_name: Optional new display name
- description: Optional new description
- warehouse_id: Optional new warehouse ID
- table_identifiers: Optional new table identifiers
- sample_questions: Optional sample questions (replaces all existing)
-
- Returns:
- Updated Genie space data
- """
- current_space = self.genie_get(space_id)
- if not current_space:
- raise ValueError(f"Genie space {space_id} not found")
-
- update_payload = {
- "id": space_id,
- "space_id": current_space.get("space_id", space_id),
- "display_name": display_name or current_space.get("display_name"),
- "warehouse_id": warehouse_id or current_space.get("warehouse_id"),
- "table_identifiers": table_identifiers
- if table_identifiers is not None
- else current_space.get("table_identifiers", []),
- "run_as_type": current_space.get("run_as_type", "VIEWER"),
- }
-
- if description is not None:
- update_payload["description"] = description
- elif current_space.get("description"):
- update_payload["description"] = current_space["description"]
-
- # Preserve timestamps and user info
- for field in [
- "created_timestamp",
- "last_updated_timestamp",
- "user_id",
- "folder_node_internal_name",
- ]:
- if current_space.get(field):
- update_payload[field] = current_space[field]
-
- result = self._patch(f"/api/2.0/data-rooms/{space_id}", update_payload)
-
- if sample_questions is not None:
- self.genie_update_sample_questions(space_id, sample_questions)
-
- return result
-
- def genie_delete(self, space_id: str) -> None:
- """Delete a Genie space."""
- self._delete(f"/api/2.0/data-rooms/{space_id}")
-
- def genie_export(self, space_id: str) -> Dict[str, Any]:
- """Export a Genie space with its full serialized configuration.
-
- Uses the public /api/2.0/genie/spaces endpoint with include_serialized_space=true.
- Requires at least CAN EDIT permission on the space.
-
- Args:
- space_id: The Genie space ID to export
-
- Returns:
- Dictionary with space metadata including:
- - space_id: The space ID
- - title: The space title
- - description: The description
- - warehouse_id: The SQL warehouse ID
- - serialized_space: JSON string with full space config (tables, instructions,
- SQL queries, layout). Pass this to genie_import() to clone/migrate the space.
- """
- return self._get(
- f"/api/2.0/genie/spaces/{space_id}",
- params={"include_serialized_space": "true"},
- )
-
- def genie_import(
- self,
- warehouse_id: str,
- serialized_space: str,
- title: Optional[str] = None,
- description: Optional[str] = None,
- parent_path: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Create a new Genie space from a serialized payload (import/clone).
-
- Uses the public /api/2.0/genie/spaces endpoint with serialized_space in the body.
- The serialized_space string is obtained from genie_export().
-
- Args:
- warehouse_id: SQL warehouse ID to associate with the new space
- serialized_space: The JSON string from genie_export() containing the full
- space configuration (tables, instructions, SQL queries, layout)
- title: Optional title override (defaults to the exported space's title)
- description: Optional description override
- parent_path: Optional workspace folder path where the space will be registered
- (e.g., "/Workspace/Users/you@company.com/Genie Spaces")
-
- Returns:
- Dictionary with the newly created space details including space_id
- """
- payload: Dict[str, Any] = {
- "warehouse_id": warehouse_id,
- "serialized_space": serialized_space,
- }
- if title:
- payload["title"] = title
- if description:
- payload["description"] = description
- if parent_path:
- payload["parent_path"] = parent_path
- return self._post("/api/2.0/genie/spaces", payload)
-
- def genie_update_with_serialized_space(
- self,
- space_id: str,
- serialized_space: str,
- title: Optional[str] = None,
- description: Optional[str] = None,
- warehouse_id: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Update a Genie space using a serialized payload (full replacement).
-
- Uses the public /api/2.0/genie/spaces/{space_id} endpoint (PATCH) with
- serialized_space in the body. This replaces the entire space configuration.
-
- Args:
- space_id: The Genie space ID to update
- serialized_space: The JSON string containing the new space configuration.
- Obtain from genie_export() or construct manually:
- '{"version":2,"data_sources":{"tables":[{"identifier":"cat.schema.table"}]}}'
- title: Optional title override
- description: Optional description override
- warehouse_id: Optional warehouse override
-
- Returns:
- Dictionary with the updated space details
- """
- payload: Dict[str, Any] = {"serialized_space": serialized_space}
- if title:
- payload["title"] = title
- if description:
- payload["description"] = description
- if warehouse_id:
- payload["warehouse_id"] = warehouse_id
- return self._patch(f"/api/2.0/genie/spaces/{space_id}", payload)
-
- def genie_list_questions(
- self, space_id: str, question_type: str = "SAMPLE_QUESTION"
- ) -> GenieListQuestionsResponseDict:
- """List curated questions for a Genie space."""
- return self._get(
- f"/api/2.0/data-rooms/{space_id}/curated-questions",
- params={"question_type": question_type},
- )
-
- def genie_list_instructions(self, space_id: str) -> GenieListInstructionsResponseDict:
- """List all instructions for a Genie space."""
- return self._get(f"/api/2.0/data-rooms/{space_id}/instructions")
-
- def genie_update_sample_questions(self, space_id: str, questions: List[str]) -> Dict[str, Any]:
- """Replace all sample questions for a Genie space.
-
- Args:
- space_id: The Genie space ID
- questions: New list of questions (replaces ALL existing)
-
- Returns:
- Batch action response
- """
- existing = self.genie_list_questions(space_id, question_type="SAMPLE_QUESTION")
- existing_ids = [
- q.get("curated_question_id") or q.get("id")
- for q in existing.get("curated_questions", [])
- if q.get("curated_question_id") or q.get("id")
- ]
-
- actions = []
-
- # Delete existing
- for question_id in existing_ids:
- actions.append({"action_type": "DELETE", "curated_question": {"id": question_id}})
-
- # Create new
- for question_text in questions:
- actions.append(
- {
- "action_type": "CREATE",
- "curated_question": {
- "data_room_id": space_id,
- "question_text": question_text,
- "question_type": "SAMPLE_QUESTION",
- },
- }
- )
-
- return self._post(
- f"/api/2.0/data-rooms/{space_id}/curated-questions/batch-actions",
- {"actions": actions},
- )
-
- def genie_add_sample_questions_batch(self, space_id: str, questions: List[str]) -> Dict[str, Any]:
- """Add multiple sample questions (without replacing existing)."""
- actions = [
- {
- "action_type": "CREATE",
- "curated_question": {
- "data_space_id": space_id,
- "question_text": q,
- "question_type": "SAMPLE_QUESTION",
- },
- }
- for q in questions
- ]
- return self._post(
- f"/api/2.0/data-rooms/{space_id}/curated-questions/batch-actions",
- {"actions": actions},
- )
-
- def genie_add_curated_question(
- self,
- space_id: str,
- question_text: str,
- question_type: str,
- answer_text: Optional[str] = None,
- ) -> Dict[str, Any]:
- """Add a curated question (low-level)."""
- curated_question = {
- "data_space_id": space_id,
- "question_text": question_text,
- "question_type": question_type,
- "is_deprecated": False,
- }
- if answer_text:
- curated_question["answer_text"] = answer_text
-
- return self._post(
- f"/api/2.0/data-rooms/{space_id}/curated-questions",
- {"curated_question": curated_question, "data_space_id": space_id},
- )
-
- def genie_add_sample_question(self, space_id: str, question_text: str) -> Dict[str, Any]:
- """Add a single sample question."""
- return self.genie_add_curated_question(space_id, question_text, "SAMPLE_QUESTION")
-
- def genie_add_instruction(self, space_id: str, title: str, content: str, instruction_type: str) -> Dict[str, Any]:
- """Add an instruction (low-level)."""
- payload = {"title": title, "content": content, "instruction_type": instruction_type}
- return self._post(f"/api/2.0/data-rooms/{space_id}/instructions", payload)
-
- def genie_add_text_instruction(self, space_id: str, content: str, title: str = "Notes") -> Dict[str, Any]:
- """Add general text instruction/notes."""
- return self.genie_add_instruction(space_id, title, content, "TEXT_INSTRUCTION")
-
- def genie_add_sql_instruction(self, space_id: str, title: str, content: str) -> Dict[str, Any]:
- """Add a SQL query example instruction."""
- return self.genie_add_instruction(space_id, title, content, "SQL_INSTRUCTION")
-
- def genie_add_sql_function(self, space_id: str, function_name: str) -> Dict[str, Any]:
- """Add a SQL function (certified answer)."""
- return self.genie_add_instruction(space_id, "SQL Function", function_name, "CERTIFIED_ANSWER")
-
- def genie_add_sql_instructions_batch(
- self, space_id: str, sql_instructions: List[Dict[str, str]]
- ) -> List[Dict[str, Any]]:
- """Add multiple SQL instructions.
-
- Args:
- space_id: The Genie space ID
- sql_instructions: List of {'title': str, 'content': str}
-
- Returns:
- List of created instructions
- """
- results = []
- for instr in sql_instructions:
- try:
- result = self.genie_add_sql_instruction(space_id, instr["title"], instr["content"])
- results.append(result)
- logger.info(f"Added SQL instruction: {instr['title'][:50]}...")
- except Exception as e:
- logger.error(f"Failed to add SQL instruction '{instr['title']}': {e}")
- return results
-
- def genie_add_sql_functions_batch(self, space_id: str, function_names: List[str]) -> List[Dict[str, Any]]:
- """Add multiple SQL functions (certified answers)."""
- results = []
- for func_name in function_names:
- try:
- result = self.genie_add_sql_function(space_id, func_name)
- results.append(result)
- logger.info(f"Added SQL function: {func_name}")
- except Exception as e:
- logger.error(f"Failed to add SQL function '{func_name}': {e}")
- return results
-
- def genie_add_benchmark(self, space_id: str, question_text: str, answer_text: str) -> Dict[str, Any]:
- """Add a benchmark question with expected answer."""
- return self.genie_add_curated_question(space_id, question_text, "BENCHMARK", answer_text)
-
- def genie_add_benchmarks_batch(self, space_id: str, benchmarks: List[Dict[str, str]]) -> List[Dict[str, Any]]:
- """Add multiple benchmarks.
-
- Args:
- space_id: The Genie space ID
- benchmarks: List of {'question_text': str, 'answer_text': str}
-
- Returns:
- List of created benchmarks
- """
- results = []
- for bm in benchmarks:
- try:
- result = self.genie_add_benchmark(space_id, bm["question_text"], bm["answer_text"])
- results.append(result)
- logger.info(f"Added benchmark: {bm['question_text'][:50]}...")
- except Exception as e:
- logger.error(f"Failed to add benchmark '{bm['question_text'][:50]}...': {e}")
- return results
-
- # ========================================================================
- # Low-level HTTP Wrappers
- # ========================================================================
-
- def _handle_response_error(self, response: requests.Response, method: str, path: str) -> None:
- """Extract detailed error from response and raise."""
- if response.status_code >= 400:
- try:
- error_data = response.json()
- error_msg = error_data.get("message", error_data.get("error", str(error_data)))
- detailed_error = f"{method} {path} failed: {error_msg}"
- logger.error(f"API Error: {detailed_error}\nFull response: {json.dumps(error_data, indent=2)}")
- raise Exception(detailed_error)
- except ValueError:
- error_text = response.text
- detailed_error = f"{method} {path} failed with status {response.status_code}: {error_text}"
- logger.error(f"API Error: {detailed_error}")
- raise Exception(detailed_error)
-
- def _get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
- headers = self.w.config.authenticate()
- url = f"{self.w.config.host}{path}"
- response = requests.get(url, headers=headers, params=params or {}, timeout=20)
- if response.status_code >= 400:
- self._handle_response_error(response, "GET", path)
- return response.json()
-
- def _post(self, path: str, body: Dict[str, Any], timeout: int = 300) -> Dict[str, Any]:
- headers = self.w.config.authenticate()
- headers["Content-Type"] = "application/json"
- url = f"{self.w.config.host}{path}"
- response = requests.post(url, headers=headers, json=body, timeout=timeout)
- if response.status_code >= 400:
- self._handle_response_error(response, "POST", path)
- return response.json()
-
- def _patch(
- self, path: str, body: Dict[str, Any], params: Optional[Dict[str, str]] = None
- ) -> Dict[str, Any]:
- headers = self.w.config.authenticate()
- headers["Content-Type"] = "application/json"
- url = f"{self.w.config.host}{path}"
- response = requests.patch(url, headers=headers, json=body, params=params, timeout=20)
- if response.status_code >= 400:
- self._handle_response_error(response, "PATCH", path)
- return response.json()
-
- def _put(self, path: str, body: Dict[str, Any]) -> Dict[str, Any]:
- headers = self.w.config.authenticate()
- headers["Content-Type"] = "application/json"
- url = f"{self.w.config.host}{path}"
- response = requests.put(url, headers=headers, json=body, timeout=20)
- if response.status_code >= 400:
- self._handle_response_error(response, "PUT", path)
- return response.json()
-
- def _delete(self, path: str) -> Dict[str, Any]:
- headers = self.w.config.authenticate()
- url = f"{self.w.config.host}{path}"
- response = requests.delete(url, headers=headers, timeout=20)
- if response.status_code >= 400:
- self._handle_response_error(response, "DELETE", path)
- return response.json()
-
- # ========================================================================
- # Warehouse Auto-Detection (for Genie)
- # ========================================================================
-
- def get_best_warehouse_id(self) -> Optional[str]:
- """Get the best available SQL warehouse ID for Genie spaces.
-
- Prioritizes running warehouses, then starting ones, preferring smaller sizes.
- Within the same state/size tier, warehouses owned by the current user are
- preferred (soft preference — no warehouses are excluded).
-
- Returns:
- Warehouse ID string, or None if no warehouses available.
- """
- try:
- warehouses = list(self.w.warehouses.list())
- if not warehouses:
- return None
-
- current_user = get_current_username()
-
- # Sort by state (RUNNING first) and size (smaller first),
- # with a soft preference for user-owned warehouses within each tier
- size_order = [
- "2X-Small",
- "X-Small",
- "Small",
- "Medium",
- "Large",
- "X-Large",
- "2X-Large",
- "3X-Large",
- "4X-Large",
- ]
-
- def sort_key(wh):
- state_priority = 0 if wh.state.value == "RUNNING" else (1 if wh.state.value == "STARTING" else 2)
- try:
- size_priority = size_order.index(wh.cluster_size)
- except ValueError:
- size_priority = 99
- # Soft preference: user-owned warehouses sort first (0) within same tier
- owner_priority = 0 if (current_user and (wh.creator_name or "").lower() == current_user.lower()) else 1
- return (state_priority, size_priority, owner_priority)
-
- warehouses_sorted = sorted(warehouses, key=sort_key)
- return warehouses_sorted[0].id if warehouses_sorted else None
- except Exception as e:
- logger.warning(f"Failed to get warehouses: {e}")
- return None
-
- # ========================================================================
- # Volume Scanning (for KA examples from PDF JSON files)
- # ========================================================================
-
- def scan_volume_for_examples(self, volume_path: str) -> List[Dict[str, Any]]:
- """Scan a volume folder for JSON files containing question/guideline pairs.
-
- These JSON files are typically created by the PDF generation tool and contain:
- - question: A question that can be answered by the document
- - guideline: How to evaluate if the answer is correct
-
- Args:
- volume_path: Path to the volume folder (e.g., "/Volumes/catalog/schema/volume/folder")
-
- Returns:
- List of dicts with 'question' and optionally 'guideline' keys
- """
- examples = []
- try:
- # List files in the volume
- files = list(self.w.files.list_directory_contents(volume_path))
-
- for file_info in files:
- if file_info.path and file_info.path.endswith(".json"):
- try:
- # Read the JSON file
- response = self.w.files.download(file_info.path)
- content = response.read().decode("utf-8")
- data = json.loads(content)
-
- # Extract question and guideline if present
- if "question" in data:
- example = {"question": data["question"]}
- if "guideline" in data:
- example["guideline"] = data["guideline"]
- examples.append(example)
- logger.debug(f"Found example in {file_info.path}: {data['question'][:50]}...")
- except Exception as e:
- logger.warning(f"Failed to read JSON file {file_info.path}: {e}")
- continue
-
- logger.info(f"Found {len(examples)} examples in {volume_path}")
- except Exception as e:
- logger.warning(f"Failed to scan volume {volume_path} for examples: {e}")
-
- return examples
-
-
-# ============================================================================
-# TileExampleQueue - Background queue for adding examples to tiles
-# ============================================================================
-
-
-class TileExampleQueue:
- """Background queue for adding examples to tiles (KA/MAS) that aren't ready yet.
-
- This queue polls tiles periodically and attempts to add examples once
- the endpoint status is ONLINE.
- """
-
- def __init__(self, poll_interval: float = 30.0, max_attempts: int = 120):
- """Initialize the queue.
-
- Args:
- poll_interval: Seconds between status checks (default: 30)
- max_attempts: Maximum poll attempts before giving up (default: 120 = 1 hour)
- """
- self.queue: Dict[str, Tuple[AgentBricksManager, List[Dict[str, Any]], str, float, int]] = {}
- self.lock = threading.Lock()
- self.running = False
- self.thread: Optional[threading.Thread] = None
- self.poll_interval = poll_interval
- self.max_attempts = max_attempts
-
- def enqueue(
- self,
- tile_id: str,
- manager: AgentBricksManager,
- questions: List[Dict[str, Any]],
- tile_type: str = "KA",
- ) -> None:
- """Add a tile and its questions to the processing queue.
-
- Args:
- tile_id: The tile ID
- manager: AgentBricksManager instance
- questions: List of question dictionaries
- tile_type: Type of tile ('KA' or 'MAS')
- """
- with self.lock:
- self.queue[tile_id] = (manager, questions, tile_type, time.time(), 0)
- logger.info(
- f"Enqueued {len(questions)} examples for {tile_type} {tile_id} (will add when endpoint is ready)"
- )
-
- # Start background thread if not running
- if not self.running:
- self.start()
-
- def start(self) -> None:
- """Start the background processing thread."""
- if not self.running:
- self.running = True
- self.thread = threading.Thread(target=self._process_loop, daemon=True)
- self.thread.start()
- logger.info("Started tile example queue background processor")
-
- def stop(self) -> None:
- """Stop the background processing thread."""
- self.running = False
- if self.thread:
- self.thread.join(timeout=5)
- logger.info("Stopped tile example queue background processor")
-
- def _process_loop(self) -> None:
- """Background loop that checks tile status and adds examples when ready."""
- while self.running:
- try:
- # Get snapshot of queue to process
- with self.lock:
- items_to_process = list(self.queue.items())
-
- # Process each tile
- for tile_id, (
- manager,
- questions,
- tile_type,
- enqueue_time,
- attempt_count,
- ) in items_to_process:
- try:
- # Check if max attempts exceeded
- if attempt_count >= self.max_attempts:
- elapsed_time = time.time() - enqueue_time
- logger.error(
- f"{tile_type} {tile_id} exceeded max attempts ({self.max_attempts}). "
- f"Elapsed: {elapsed_time:.0f}s. Removing from queue. "
- f"Failed to add {len(questions)} examples."
- )
- with self.lock:
- self.queue.pop(tile_id, None)
- continue
-
- # Increment attempt count
- with self.lock:
- if tile_id in self.queue:
- self.queue[tile_id] = (
- manager,
- questions,
- tile_type,
- enqueue_time,
- attempt_count + 1,
- )
-
- # Check endpoint status
- if tile_type == "KA":
- status = manager.ka_get_endpoint_status(tile_id)
- elif tile_type == "MAS":
- status = manager.mas_get_endpoint_status(tile_id)
- else:
- logger.error(f"Unknown tile type: {tile_type}")
- with self.lock:
- self.queue.pop(tile_id, None)
- continue
-
- logger.debug(
- f"{tile_type} {tile_id} status: {status} (attempt {attempt_count + 1}/{self.max_attempts})"
- )
-
- # Add examples if ONLINE
- if status == EndpointStatus.ONLINE.value:
- logger.info(f"{tile_type} {tile_id} is ONLINE, adding {len(questions)} examples...")
-
- if tile_type == "KA":
- created = manager.ka_add_examples_batch(tile_id, questions)
- else:
- created = manager.mas_add_examples_batch(tile_id, questions)
-
- elapsed_time = time.time() - enqueue_time
- logger.info(
- f"Added {len(created)} examples to {tile_type} {tile_id} "
- f"after {attempt_count + 1} attempts ({elapsed_time:.0f}s)"
- )
-
- with self.lock:
- self.queue.pop(tile_id, None)
-
- except Exception as e:
- logger.error(f"Error processing {tile_type} {tile_id}: {e}")
- with self.lock:
- self.queue.pop(tile_id, None)
-
- except Exception as e:
- logger.error(f"Error in queue processor: {e}")
-
- time.sleep(self.poll_interval)
-
-
-# Global singleton queue instance
-_tile_example_queue: Optional[TileExampleQueue] = None
-_queue_lock = threading.Lock()
-
-
-def get_tile_example_queue() -> TileExampleQueue:
- """Get or create the global tile example queue instance."""
- global _tile_example_queue
- if _tile_example_queue is None:
- with _queue_lock:
- if _tile_example_queue is None:
- _tile_example_queue = TileExampleQueue()
- return _tile_example_queue
diff --git a/databricks-tools-core/databricks_tools_core/agent_bricks/models.py b/databricks-tools-core/databricks_tools_core/agent_bricks/models.py
deleted file mode 100644
index 7f322742..00000000
--- a/databricks-tools-core/databricks_tools_core/agent_bricks/models.py
+++ /dev/null
@@ -1,254 +0,0 @@
-"""
-Agent Bricks Models - Type definitions for Agent Bricks API responses.
-
-Based on api_ka.proto and api_mas.proto definitions.
-"""
-
-from dataclasses import dataclass
-from enum import Enum
-from typing import Any, Dict, List, Optional, TypedDict
-
-
-# ============================================================================
-# Enums
-# ============================================================================
-
-
-class TileType(Enum):
- """Tile types from the protobuf definition."""
-
- UNSPECIFIED = 0
- KIE = 1 # Knowledge Indexing Engine
- T2T = 2 # Text to Text
- KA = 3 # Knowledge Assistant
- MAO = 4 # Deprecated
- MAS = 5 # Supervisor Agent (formerly Multi-Agent Supervisor)
-
-
-class EndpointStatus(Enum):
- """Vector Search Endpoint status values."""
-
- ONLINE = "ONLINE"
- OFFLINE = "OFFLINE"
- PROVISIONING = "PROVISIONING"
- NOT_READY = "NOT_READY"
-
-
-class Permission(Enum):
- """Standard Databricks permissions for sharing resources."""
-
- CAN_READ = "CAN_READ" # View/read access
- CAN_WRITE = "CAN_WRITE" # Write/edit access
- CAN_RUN = "CAN_RUN" # Execute/run access
- CAN_MANAGE = "CAN_MANAGE" # Full management access including ACLs
- CAN_VIEW = "CAN_VIEW" # View metadata only
-
-
-# ============================================================================
-# Data Classes
-# ============================================================================
-
-
-@dataclass(frozen=True)
-class KAIds:
- """Knowledge Assistant identifiers."""
-
- tile_id: str
- name: str
-
-
-@dataclass(frozen=True)
-class GenieIds:
- """Genie Space identifiers."""
-
- space_id: str
- display_name: str
-
-
-@dataclass(frozen=True)
-class MASIds:
- """Supervisor Agent identifiers."""
-
- tile_id: str
- name: str
-
-
-# ============================================================================
-# TypedDict Definitions (from api_ka.proto and api_mas.proto)
-# ============================================================================
-
-
-class TileDict(TypedDict, total=False):
- """Tile metadata common to KA and MAS."""
-
- tile_id: str
- name: str
- description: Optional[str]
- instructions: Optional[str]
- tile_type: str
- created_timestamp_ms: int
- last_updated_timestamp_ms: int
- user_id: str
-
-
-class KnowledgeSourceDict(TypedDict, total=False):
- """Knowledge source configuration for KA."""
-
- knowledge_source_id: str
- files_source: Dict[str, Any] # Contains: name, type, files: {path: ...}
-
-
-class KnowledgeAssistantStatusDict(TypedDict):
- """KA endpoint status."""
-
- endpoint_status: str # ONLINE, OFFLINE, PROVISIONING, NOT_READY
-
-
-class KnowledgeAssistantDict(TypedDict, total=False):
- """Complete Knowledge Assistant response."""
-
- tile: TileDict
- knowledge_sources: List[KnowledgeSourceDict]
- status: KnowledgeAssistantStatusDict
-
-
-class KnowledgeAssistantResponseDict(TypedDict):
- """GET /knowledge-assistants/{tile_id} response."""
-
- knowledge_assistant: KnowledgeAssistantDict
-
-
-class KnowledgeAssistantExampleDict(TypedDict, total=False):
- """KA example question."""
-
- example_id: str
- question: str
- guidelines: List[str]
- feedback_records: List[Dict[str, Any]]
-
-
-class KnowledgeAssistantListExamplesResponseDict(TypedDict, total=False):
- """List examples response."""
-
- examples: List[KnowledgeAssistantExampleDict]
- tile_id: str
- next_page_token: Optional[str]
-
-
-class BaseAgentDict(TypedDict, total=False):
- """Agent configuration for MAS."""
-
- name: str
- description: str
- agent_type: str # genie, serving_endpoint, unity_catalog_function, external_mcp_server
- genie_space: Optional[Dict[str, str]] # {id: ...}
- serving_endpoint: Optional[Dict[str, str]] # {name: ...}
- app: Optional[Dict[str, str]] # {name: ...}
- unity_catalog_function: Optional[Dict[str, Any]] # {uc_path: {catalog, schema, name}}
- external_mcp_server: Optional[Dict[str, str]] # {connection_name: ...}
-
-
-class MultiAgentSupervisorStatusDict(TypedDict):
- """MAS endpoint status."""
-
- endpoint_status: str # ONLINE, OFFLINE, PROVISIONING, NOT_READY
-
-
-class MultiAgentSupervisorDict(TypedDict, total=False):
- """Complete Supervisor Agent response."""
-
- tile: TileDict
- agents: List[BaseAgentDict]
- status: MultiAgentSupervisorStatusDict
-
-
-class MultiAgentSupervisorResponseDict(TypedDict):
- """GET /multi-agent-supervisors/{tile_id} response."""
-
- multi_agent_supervisor: MultiAgentSupervisorDict
-
-
-class MultiAgentSupervisorExampleDict(TypedDict, total=False):
- """MAS example question."""
-
- example_id: str
- question: str
- guidelines: List[str]
- feedback_records: List[Dict[str, Any]]
-
-
-class MultiAgentSupervisorListExamplesResponseDict(TypedDict, total=False):
- """List examples response."""
-
- examples: List[MultiAgentSupervisorExampleDict]
- tile_id: str
- next_page_token: Optional[str]
-
-
-class GenieSpaceDict(TypedDict, total=False):
- """Genie Space (Data Room) response.
-
- Note: Genie uses /api/2.0/data-rooms endpoint (not defined in KA/MAS protos).
- """
-
- space_id: str
- id: str # Same as space_id
- display_name: str
- description: Optional[str]
- warehouse_id: str
- table_identifiers: List[str]
- run_as_type: str # VIEWER, OWNER, etc.
- created_timestamp: int
- last_updated_timestamp: int
- user_id: str
- folder_node_internal_name: Optional[str]
- sample_questions: Optional[List[str]]
-
-
-class CuratedQuestionDict(TypedDict, total=False):
- """Curated question for Genie space."""
-
- question_id: str
- data_space_id: str
- question_text: str
- question_type: str # SAMPLE_QUESTION, BENCHMARK
- answer_text: Optional[str]
- is_deprecated: bool
-
-
-class GenieListQuestionsResponseDict(TypedDict, total=False):
- """List curated questions response."""
-
- curated_questions: List[CuratedQuestionDict]
-
-
-class InstructionDict(TypedDict, total=False):
- """Genie instruction (text, SQL, or certified answer)."""
-
- instruction_id: str
- title: str
- content: str
- instruction_type: str # TEXT_INSTRUCTION, SQL_INSTRUCTION, CERTIFIED_ANSWER
-
-
-class GenieListInstructionsResponseDict(TypedDict, total=False):
- """List instructions response."""
-
- instructions: List[InstructionDict]
-
-
-class EvaluationRunDict(TypedDict, total=False):
- """Evaluation run metadata."""
-
- mlflow_run_id: str
- tile_id: str
- name: Optional[str]
- created_timestamp_ms: int
- last_updated_timestamp_ms: int
-
-
-class ListEvaluationRunsResponseDict(TypedDict, total=False):
- """List evaluation runs response."""
-
- evaluation_runs: List[EvaluationRunDict]
- next_page_token: Optional[str]
diff --git a/databricks-tools-core/databricks_tools_core/aibi_dashboards/__init__.py b/databricks-tools-core/databricks_tools_core/aibi_dashboards/__init__.py
deleted file mode 100644
index bba5ef6e..00000000
--- a/databricks-tools-core/databricks_tools_core/aibi_dashboards/__init__.py
+++ /dev/null
@@ -1,54 +0,0 @@
-"""AI/BI Dashboard tools for Databricks.
-
-This module provides functions for creating, managing, and deploying
-AI/BI dashboards (formerly known as Lakeview dashboards) in Databricks.
-
-The main entry points are CRUD operations for dashboard management.
-Dashboard JSON content should be generated by the AI assistant using
-the databricks-aibi-dashboards skill documentation.
-
-Example:
- >>> from databricks_tools_core.aibi_dashboards import create_or_update_dashboard
- >>> result = create_or_update_dashboard(
- ... display_name="Sales Overview",
- ... parent_path="/Workspace/Users/me/dashboards",
- ... serialized_dashboard=dashboard_json, # Generated by AI assistant
- ... warehouse_id="abc123",
- ... )
- >>> print(result["url"])
-"""
-
-from .dashboards import (
- create_dashboard,
- create_or_update_dashboard,
- deploy_dashboard,
- find_dashboard_by_path,
- get_dashboard,
- get_dashboard_by_name,
- list_dashboards,
- publish_dashboard,
- trash_dashboard,
- unpublish_dashboard,
- update_dashboard,
-)
-from .models import (
- DashboardDeploymentResult,
-)
-
-__all__ = [
- # CRUD operations
- "create_dashboard",
- "get_dashboard",
- "get_dashboard_by_name",
- "list_dashboards",
- "find_dashboard_by_path",
- "update_dashboard",
- "trash_dashboard",
- "publish_dashboard",
- "unpublish_dashboard",
- # High-level deploy
- "create_or_update_dashboard",
- "deploy_dashboard",
- # Models
- "DashboardDeploymentResult",
-]
diff --git a/databricks-tools-core/databricks_tools_core/aibi_dashboards/dashboards.py b/databricks-tools-core/databricks_tools_core/aibi_dashboards/dashboards.py
deleted file mode 100644
index 9dc869ff..00000000
--- a/databricks-tools-core/databricks_tools_core/aibi_dashboards/dashboards.py
+++ /dev/null
@@ -1,491 +0,0 @@
-"""AI/BI Dashboard CRUD Operations.
-
-Functions for managing AI/BI dashboards using the Databricks Lakeview API.
-Note: AI/BI dashboards were previously known as Lakeview dashboards.
-The SDK/API still uses the 'lakeview' name internally.
-"""
-
-import json
-import logging
-from typing import Any, Dict, Optional, Union
-
-from databricks.sdk.service.dashboards import Dashboard
-
-from ..auth import get_workspace_client
-from .models import DashboardDeploymentResult
-
-logger = logging.getLogger(__name__)
-
-
-def get_dashboard(dashboard_id: str) -> Dict[str, Any]:
- """Get dashboard details by ID.
-
- Args:
- dashboard_id: The dashboard ID
-
- Returns:
- Dictionary with dashboard details including:
- - dashboard_id: The dashboard ID
- - display_name: Dashboard display name
- - warehouse_id: Associated SQL warehouse
- - parent_path: Workspace path
- - serialized_dashboard: Dashboard JSON content (if available)
- """
- w = get_workspace_client()
- dashboard = w.lakeview.get(dashboard_id=dashboard_id)
-
- return {
- "dashboard_id": dashboard.dashboard_id,
- "display_name": dashboard.display_name,
- "warehouse_id": dashboard.warehouse_id,
- "parent_path": dashboard.parent_path,
- "path": dashboard.path,
- "create_time": dashboard.create_time,
- "update_time": dashboard.update_time,
- "lifecycle_state": dashboard.lifecycle_state.value if dashboard.lifecycle_state else None,
- "serialized_dashboard": dashboard.serialized_dashboard,
- }
-
-
-def list_dashboards(
- page_size: int = 100,
- page_token: Optional[str] = None,
-) -> Dict[str, Any]:
- """List AI/BI dashboards in the workspace.
-
- Args:
- page_size: Number of dashboards per page (default: 100)
- page_token: Token for pagination
-
- Returns:
- Dictionary with:
- - dashboards: List of dashboard summaries
- - next_page_token: Token for next page (if available)
- """
- w = get_workspace_client()
-
- dashboards = []
- for dashboard in w.lakeview.list(page_size=page_size, page_token=page_token):
- dashboards.append(
- {
- "dashboard_id": dashboard.dashboard_id,
- "display_name": dashboard.display_name,
- "warehouse_id": dashboard.warehouse_id,
- "parent_path": dashboard.parent_path,
- "path": dashboard.path,
- "lifecycle_state": dashboard.lifecycle_state.value if dashboard.lifecycle_state else None,
- }
- )
-
- return {"dashboards": dashboards}
-
-
-def find_dashboard_by_path(dashboard_path: str) -> Optional[str]:
- """Find a dashboard by its workspace path and return its ID.
-
- Args:
- dashboard_path: Full workspace path (e.g., /Workspace/Users/.../MyDash.lvdash.json)
-
- Returns:
- Dashboard ID if found, None otherwise
- """
- w = get_workspace_client()
-
- try:
- from databricks.sdk.errors.platform import ResourceDoesNotExist
-
- existing = w.workspace.get_status(path=dashboard_path)
- return existing.resource_id
- except ResourceDoesNotExist:
- return None
- except Exception as e:
- logger.warning(f"Error checking dashboard path {dashboard_path}: {e}")
- return None
-
-
-def get_dashboard_by_name(parent_path: str, display_name: str) -> Optional[Dict[str, Any]]:
- """Get dashboard details by parent path and display name.
-
- The dashboard file path is constructed as: {parent_path}/{display_name}.lvdash.json
- This matches the naming convention used by create_dashboard and deploy_dashboard.
-
- Args:
- parent_path: Workspace folder path (e.g., /Workspace/Users/me/dashboards)
- display_name: Dashboard display name (used as filename without .lvdash.json)
-
- Returns:
- Dictionary with dashboard details if found, None otherwise
- """
- dashboard_path = f"{parent_path}/{display_name}.lvdash.json"
- dashboard_id = find_dashboard_by_path(dashboard_path)
-
- if dashboard_id:
- return get_dashboard(dashboard_id)
- return None
-
-
-def create_dashboard(
- display_name: str,
- parent_path: str,
- serialized_dashboard: str,
- warehouse_id: str,
-) -> Dict[str, Any]:
- """Create a new AI/BI dashboard.
-
- Args:
- display_name: Dashboard display name
- parent_path: Workspace folder path (e.g., /Workspace/Users/me/dashboards)
- serialized_dashboard: Dashboard JSON content as string
- warehouse_id: SQL warehouse ID
-
- Returns:
- Dictionary with:
- - dashboard_id: Created dashboard ID
- - display_name: Dashboard name
- - path: Full workspace path
- - url: Dashboard URL
- """
- w = get_workspace_client()
-
- dashboard = Dashboard(
- display_name=display_name,
- warehouse_id=warehouse_id,
- parent_path=parent_path,
- serialized_dashboard=serialized_dashboard,
- )
-
- created = w.lakeview.create(dashboard=dashboard)
- dashboard_url = f"{w.config.host}/sql/dashboardsv3/{created.dashboard_id}"
-
- return {
- "dashboard_id": created.dashboard_id,
- "display_name": created.display_name,
- "path": created.path,
- "url": dashboard_url,
- }
-
-
-def update_dashboard(
- dashboard_id: str,
- display_name: Optional[str] = None,
- serialized_dashboard: Optional[str] = None,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """Update an existing AI/BI dashboard.
-
- Args:
- dashboard_id: Dashboard ID to update
- display_name: New display name (optional)
- serialized_dashboard: New dashboard JSON content (optional)
- warehouse_id: New warehouse ID (optional)
-
- Returns:
- Dictionary with updated dashboard details
- """
- w = get_workspace_client()
-
- # Get current dashboard to preserve existing values
- current = w.lakeview.get(dashboard_id=dashboard_id)
-
- dashboard = Dashboard(
- display_name=display_name or current.display_name,
- warehouse_id=warehouse_id or current.warehouse_id,
- parent_path=current.parent_path,
- serialized_dashboard=serialized_dashboard or current.serialized_dashboard,
- )
-
- updated = w.lakeview.update(dashboard_id=dashboard_id, dashboard=dashboard)
- dashboard_url = f"{w.config.host}/sql/dashboardsv3/{updated.dashboard_id}"
-
- return {
- "dashboard_id": updated.dashboard_id,
- "display_name": updated.display_name,
- "path": updated.path,
- "url": dashboard_url,
- }
-
-
-def trash_dashboard(dashboard_id: str) -> Dict[str, str]:
- """Move a dashboard to trash.
-
- Args:
- dashboard_id: Dashboard ID to trash
-
- Returns:
- Dictionary with status message
- """
- w = get_workspace_client()
- w.lakeview.trash(dashboard_id=dashboard_id)
-
- return {
- "status": "success",
- "message": f"Dashboard {dashboard_id} moved to trash",
- "dashboard_id": dashboard_id,
- }
-
-
-def publish_dashboard(
- dashboard_id: str,
- warehouse_id: str,
- embed_credentials: bool = True,
-) -> Dict[str, Any]:
- """Publish a dashboard to make it accessible to viewers.
-
- Publishing with embed_credentials=True allows users without direct
- data access to view the dashboard (queries execute using the
- service principal's permissions).
-
- Args:
- dashboard_id: Dashboard ID to publish
- warehouse_id: SQL warehouse ID for query execution
- embed_credentials: Whether to embed credentials (default: True)
-
- Returns:
- Dictionary with publish status
- """
- w = get_workspace_client()
-
- w.lakeview.publish(
- dashboard_id=dashboard_id,
- warehouse_id=warehouse_id,
- embed_credentials=embed_credentials,
- )
-
- dashboard_url = f"{w.config.host}/sql/dashboardsv3/{dashboard_id}"
-
- return {
- "status": "published",
- "dashboard_id": dashboard_id,
- "url": dashboard_url,
- "embed_credentials": embed_credentials,
- }
-
-
-def unpublish_dashboard(dashboard_id: str) -> Dict[str, str]:
- """Unpublish a dashboard.
-
- Args:
- dashboard_id: Dashboard ID to unpublish
-
- Returns:
- Dictionary with status message
- """
- w = get_workspace_client()
- w.lakeview.unpublish(dashboard_id=dashboard_id)
-
- return {
- "status": "unpublished",
- "message": f"Dashboard {dashboard_id} unpublished",
- "dashboard_id": dashboard_id,
- }
-
-
-def _inject_genie_space(
- dashboard_content: Union[str, dict],
- genie_space_id: Optional[str],
-) -> str:
- """Inject Genie space configuration into dashboard JSON.
-
- Args:
- dashboard_content: Dashboard JSON content as string or dict
- genie_space_id: Optional Genie space ID to link
-
- Returns:
- Dashboard JSON string with Genie space configuration
- """
- if isinstance(dashboard_content, str):
- dashboard_dict = json.loads(dashboard_content)
- else:
- dashboard_dict = dashboard_content
-
- if genie_space_id:
- # Ensure uiSettings exists
- if "uiSettings" not in dashboard_dict:
- dashboard_dict["uiSettings"] = {}
-
- # Add Genie space configuration
- dashboard_dict["uiSettings"]["genieSpace"] = {
- "isEnabled": True,
- "overrideId": genie_space_id,
- "enablementMode": "ENABLED",
- }
-
- return json.dumps(dashboard_dict)
-
-
-def deploy_dashboard(
- dashboard_content: Union[str, dict],
- install_path: str,
- dashboard_name: str,
- warehouse_id: str,
- genie_space_id: Optional[str] = None,
- dataset_catalog: Optional[str] = None,
- dataset_schema: Optional[str] = None,
-) -> DashboardDeploymentResult:
- """Deploy a dashboard to Databricks workspace.
-
- This is a high-level function that handles create-or-update logic:
- - Checks if a dashboard exists at the path
- - Creates new or updates existing dashboard
- - Publishes the dashboard
-
- Args:
- dashboard_content: Dashboard JSON content as string or dict
- install_path: Workspace folder path (e.g., /Workspace/Users/me/dashboards)
- dashboard_name: Display name for the dashboard
- warehouse_id: SQL warehouse ID
- genie_space_id: Optional Genie space ID to link to dashboard
- dataset_catalog: Default catalog for datasets (doesn't affect fully qualified names)
- dataset_schema: Default schema for datasets (doesn't affect fully qualified names)
-
- Returns:
- DashboardDeploymentResult with deployment status and details
- """
- from databricks.sdk.errors.platform import ResourceDoesNotExist
-
- # Inject Genie space if provided, and ensure content is JSON string
- dashboard_content = _inject_genie_space(dashboard_content, genie_space_id)
-
- w = get_workspace_client()
- dashboard_path = f"{install_path}/{dashboard_name}.lvdash.json"
-
- try:
- # Ensure the parent directory exists
- try:
- w.workspace.mkdirs(install_path)
- except Exception as e:
- logger.debug(f"Directory creation check: {install_path} - {e}")
-
- # Check if dashboard already exists at path
- existing_dashboard_id = None
- try:
- existing = w.workspace.get_status(path=dashboard_path)
- existing_dashboard_id = existing.resource_id
- except ResourceDoesNotExist:
- pass
-
- dashboard = Dashboard(
- display_name=dashboard_name,
- warehouse_id=warehouse_id,
- parent_path=install_path,
- serialized_dashboard=dashboard_content,
- )
-
- # Update or create
- if existing_dashboard_id:
- try:
- logger.info(f"Updating existing dashboard: {dashboard_name}")
- updated = w.lakeview.update(
- dashboard_id=existing_dashboard_id,
- dashboard=dashboard,
- dataset_catalog=dataset_catalog,
- dataset_schema=dataset_schema,
- )
- dashboard_id = updated.dashboard_id
- status = "updated"
- except Exception as e:
- logger.warning(f"Failed to update dashboard {existing_dashboard_id}: {e}. Creating new.")
- created = w.lakeview.create(
- dashboard=dashboard,
- dataset_catalog=dataset_catalog,
- dataset_schema=dataset_schema,
- )
- dashboard_id = created.dashboard_id
- status = "created"
- else:
- logger.info(f"Creating new dashboard: {dashboard_name}")
- created = w.lakeview.create(
- dashboard=dashboard,
- dataset_catalog=dataset_catalog,
- dataset_schema=dataset_schema,
- )
- dashboard_id = created.dashboard_id
- status = "created"
-
- dashboard_url = f"{w.config.host}/sql/dashboardsv3/{dashboard_id}"
-
- # Publish (best-effort)
- try:
- w.lakeview.publish(
- dashboard_id=dashboard_id,
- warehouse_id=warehouse_id,
- embed_credentials=True,
- )
- logger.info(f"Dashboard {dashboard_id} published successfully")
- except Exception as e:
- logger.warning(f"Failed to publish dashboard {dashboard_id}: {e}")
-
- return DashboardDeploymentResult(
- success=True,
- status=status,
- dashboard_id=dashboard_id,
- path=dashboard_path,
- url=dashboard_url,
- )
-
- except Exception as e:
- logger.error(f"Dashboard deployment failed: {e}", exc_info=True)
- return DashboardDeploymentResult(
- success=False,
- error=str(e),
- path=dashboard_path,
- )
-
-
-def create_or_update_dashboard(
- display_name: str,
- parent_path: str,
- serialized_dashboard: Union[str, dict],
- warehouse_id: str,
- publish: bool = True,
- genie_space_id: Optional[str] = None,
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
-) -> Dict[str, Any]:
- """Create or update a dashboard.
-
- This is a convenience function that:
- 1. Checks if a dashboard exists at the path
- 2. Creates new or updates existing
- 3. Optionally publishes
-
- Args:
- display_name: Dashboard display name
- parent_path: Workspace folder path
- serialized_dashboard: Dashboard JSON content
- warehouse_id: SQL warehouse ID
- publish: Whether to publish after create/update (default: True)
- genie_space_id: Optional Genie space ID to link to dashboard.
- When provided, enables the "Ask Genie" button on the dashboard.
- catalog: Default catalog for datasets. Doesn't affect fully qualified
- table references (e.g., catalog.schema.table).
- schema: Default schema for datasets. Doesn't affect fully qualified
- table references (e.g., schema.table).
-
- Returns:
- Dictionary with:
- - success: Whether operation succeeded
- - status: 'created' or 'updated'
- - dashboard_id: Dashboard ID
- - url: Dashboard URL
- - published: Whether dashboard was published
- """
- result = deploy_dashboard(
- dashboard_content=serialized_dashboard,
- install_path=parent_path,
- dashboard_name=display_name,
- warehouse_id=warehouse_id,
- genie_space_id=genie_space_id,
- dataset_catalog=catalog,
- dataset_schema=schema,
- )
-
- return {
- "success": result.success,
- "status": result.status,
- "dashboard_id": result.dashboard_id,
- "path": result.path,
- "url": result.url,
- "published": result.success and publish,
- "error": result.error,
- }
diff --git a/databricks-tools-core/databricks_tools_core/aibi_dashboards/models.py b/databricks-tools-core/databricks_tools_core/aibi_dashboards/models.py
deleted file mode 100644
index 1f1b06a6..00000000
--- a/databricks-tools-core/databricks_tools_core/aibi_dashboards/models.py
+++ /dev/null
@@ -1,30 +0,0 @@
-"""Models for AI/BI Dashboard operations.
-
-Defines dataclasses for dashboard deployment results.
-"""
-
-from dataclasses import dataclass
-from typing import Any, Dict, Optional
-
-
-@dataclass
-class DashboardDeploymentResult:
- """Result from deploying a dashboard to Databricks."""
-
- success: bool = False
- status: str = "" # 'created' or 'updated'
- dashboard_id: Optional[str] = None
- path: Optional[str] = None
- url: Optional[str] = None
- error: Optional[str] = None
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert to dictionary for JSON serialization."""
- return {
- "success": self.success,
- "status": self.status,
- "dashboard_id": self.dashboard_id,
- "path": self.path,
- "url": self.url,
- "error": self.error,
- }
diff --git a/databricks-tools-core/databricks_tools_core/apps/__init__.py b/databricks-tools-core/databricks_tools_core/apps/__init__.py
deleted file mode 100644
index 5d18e5f1..00000000
--- a/databricks-tools-core/databricks_tools_core/apps/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Databricks Apps operations."""
-
-from .apps import (
- create_app,
- deploy_app,
- delete_app,
- get_app,
- get_app_logs,
- list_apps,
-)
-
-__all__ = [
- "create_app",
- "deploy_app",
- "delete_app",
- "get_app",
- "get_app_logs",
- "list_apps",
-]
diff --git a/databricks-tools-core/databricks_tools_core/apps/apps.py b/databricks-tools-core/databricks_tools_core/apps/apps.py
deleted file mode 100644
index 21574da1..00000000
--- a/databricks-tools-core/databricks_tools_core/apps/apps.py
+++ /dev/null
@@ -1,208 +0,0 @@
-"""
-Databricks Apps - App Lifecycle Management
-
-Functions for managing Databricks Apps lifecycle using the Databricks SDK.
-"""
-
-from typing import Any, Dict, List, Optional
-
-from databricks.sdk.service.apps import App, AppDeployment
-
-from ..auth import get_workspace_client
-
-
-def create_app(
- name: str,
- description: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Create a new Databricks App.
-
- Args:
- name: App name (must be unique within the workspace).
- description: Optional human-readable description.
-
- Returns:
- Dictionary with app details including name, url, and status.
- """
- w = get_workspace_client()
- app_spec = App(name=name, description=description)
- app = w.apps.create(app=app_spec).result()
- return _app_to_dict(app)
-
-
-def get_app(name: str) -> Dict[str, Any]:
- """
- Get details for a Databricks App.
-
- Args:
- name: App name.
-
- Returns:
- Dictionary with app details including name, url, status, and active deployment.
- """
- w = get_workspace_client()
- app = w.apps.get(name=name)
- return _app_to_dict(app)
-
-
-def list_apps(
- name_contains: Optional[str] = None,
- limit: int = 20,
-) -> List[Dict[str, Any]]:
- """
- List Databricks Apps in the workspace.
-
- Returns a limited number of apps, optionally filtered by name substring.
- Apps are returned in API order (most recently created first).
-
- Args:
- name_contains: Optional substring filter applied to app names
- (case-insensitive). Only apps whose name contains this string
- are returned.
- limit: Maximum number of apps to return (default: 20).
- Use 0 for no limit (returns all apps).
-
- Returns:
- List of dictionaries with app details.
- """
- w = get_workspace_client()
- results: List[Dict[str, Any]] = []
-
- for app in w.apps.list():
- if name_contains and name_contains.lower() not in (getattr(app, "name", "") or "").lower():
- continue
- results.append(_app_to_dict(app))
- if limit and len(results) >= limit:
- break
-
- return results
-
-
-def deploy_app(
- app_name: str,
- source_code_path: str,
- mode: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Deploy a Databricks App from a workspace source path.
-
- Args:
- app_name: Name of the app to deploy.
- source_code_path: Workspace path to the app source code
- (e.g., /Workspace/Users/user@example.com/my_app).
- mode: Optional deployment mode (e.g., "snapshot").
-
- Returns:
- Dictionary with deployment details including deployment_id and status.
- """
- w = get_workspace_client()
- # w.apps.deploy returns a Wait[AppDeployment], use .response to get the AppDeployment
- wait_obj = w.apps.deploy(
- app_name=app_name,
- app_deployment=AppDeployment(
- source_code_path=source_code_path,
- mode=mode,
- ),
- )
- return _deployment_to_dict(wait_obj.response)
-
-
-def delete_app(name: str) -> Dict[str, str]:
- """
- Delete a Databricks App.
-
- Args:
- name: App name to delete.
-
- Returns:
- Dictionary confirming deletion.
- """
- w = get_workspace_client()
- w.apps.delete(name=name)
- return {"name": name, "status": "deleted"}
-
-
-def get_app_logs(
- app_name: str,
- deployment_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Get logs for a Databricks App deployment.
-
- If deployment_id is not provided, gets logs for the active deployment.
-
- Args:
- app_name: App name.
- deployment_id: Optional specific deployment ID. If None, uses the
- active deployment.
-
- Returns:
- Dictionary with deployment logs.
- """
- w = get_workspace_client()
-
- # If no deployment_id, get the active one
- if not deployment_id:
- app = w.apps.get(name=app_name)
- if app.active_deployment:
- deployment_id = app.active_deployment.deployment_id
- else:
- return {"app_name": app_name, "error": "No active deployment found"}
-
- # Use the REST client to fetch logs since SDK may not have direct method
- from ..client import get_api_client
-
- client = get_api_client()
- response = client.do(
- "GET",
- f"/api/2.0/apps/{app_name}/deployments/{deployment_id}/logs",
- )
- return {
- "app_name": app_name,
- "deployment_id": deployment_id,
- "logs": response.get("logs", ""),
- }
-
-
-def _app_to_dict(app: Any) -> Dict[str, Any]:
- """Convert an App SDK object to a dictionary."""
- result = {
- "name": getattr(app, "name", None),
- "description": getattr(app, "description", None),
- "url": getattr(app, "url", None),
- "status": None,
- "create_time": str(getattr(app, "create_time", None)),
- "update_time": str(getattr(app, "update_time", None)),
- }
-
- # Extract status from compute_status or status
- compute_status = getattr(app, "compute_status", None)
- if compute_status:
- result["status"] = getattr(compute_status, "state", None)
- if result["status"]:
- result["status"] = str(result["status"])
-
- # Extract active deployment info
- active_deployment = getattr(app, "active_deployment", None)
- if active_deployment:
- result["active_deployment"] = _deployment_to_dict(active_deployment)
-
- return result
-
-
-def _deployment_to_dict(deployment: Any) -> Dict[str, Any]:
- """Convert an AppDeployment SDK object to a dictionary."""
- result = {
- "deployment_id": getattr(deployment, "deployment_id", None),
- "source_code_path": getattr(deployment, "source_code_path", None),
- "mode": str(getattr(deployment, "mode", None)),
- "create_time": str(getattr(deployment, "create_time", None)),
- }
-
- status = getattr(deployment, "status", None)
- if status:
- result["state"] = str(getattr(status, "state", None))
- result["message"] = getattr(status, "message", None)
-
- return result
diff --git a/databricks-tools-core/databricks_tools_core/auth.py b/databricks-tools-core/databricks_tools_core/auth.py
deleted file mode 100644
index c3db9fb4..00000000
--- a/databricks-tools-core/databricks_tools_core/auth.py
+++ /dev/null
@@ -1,216 +0,0 @@
-"""Authentication context for Databricks WorkspaceClient.
-
-Uses Python contextvars to pass authentication through the async call stack
-without threading parameters through every function.
-
-All clients are tagged with a custom product identifier and auto-detected
-project name so that API calls are attributable in ``system.access.audit``.
-
-Usage in FastAPI:
- # In request handler or middleware
- set_databricks_auth(host, token)
- try:
- # Any code here can call get_workspace_client()
- result = some_databricks_function()
- finally:
- clear_databricks_auth()
-
-Cross-workspace (force explicit token over env OAuth):
- set_databricks_auth(target_host, target_token, force_token=True)
-
-Usage in functions:
- from databricks_tools_core.auth import get_workspace_client
-
- def my_function():
- client = get_workspace_client() # Uses context auth or env vars
- # ...
-"""
-
-import logging
-import os
-from contextvars import ContextVar
-from typing import Optional
-
-from databricks.sdk import WorkspaceClient
-
-from .identity import PRODUCT_NAME, PRODUCT_VERSION, tag_client
-
-logger = logging.getLogger(__name__)
-
-# Cached current username — only fetched once per process
-_current_username: Optional[str] = None
-_current_username_fetched: bool = False
-
-# Server-level active workspace override (set by manage_workspace tool).
-# Module-level globals are appropriate here: the standalone MCP server is
-# single-user over stdio, so there is no per-request isolation needed.
-_active_profile: Optional[str] = None
-_active_host: Optional[str] = None
-
-
-def set_active_workspace(profile: Optional[str] = None, host: Optional[str] = None) -> None:
- """Set the active workspace for all subsequent tool calls.
-
- Adds a step 0 to get_workspace_client() that overrides the default SDK
- auth chain. Used by the manage_workspace MCP tool to switch workspaces
- at runtime without restarting the MCP server.
-
- Args:
- profile: Profile name from ~/.databrickscfg to activate.
- host: Workspace URL to activate (used when no profile is available).
- """
- global _active_profile, _active_host, _current_username, _current_username_fetched
- _active_profile = profile
- _active_host = host
- # Reset cached username — it belongs to the previous workspace
- _current_username = None
- _current_username_fetched = False
-
-
-def clear_active_workspace() -> None:
- """Reset to the default workspace from environment / config file."""
- set_active_workspace(None, None)
-
-
-def get_active_workspace() -> dict:
- """Return the current server-level workspace override state.
-
- Returns:
- Dict with 'profile' and 'host' keys (either or both may be None).
- """
- return {"profile": _active_profile, "host": _active_host}
-
-
-def _has_oauth_credentials() -> bool:
- """Check if OAuth credentials (SP) are configured in environment."""
- return bool(os.environ.get("DATABRICKS_CLIENT_ID") and os.environ.get("DATABRICKS_CLIENT_SECRET"))
-
-
-# Context variables for per-request authentication
-_host_ctx: ContextVar[Optional[str]] = ContextVar("databricks_host", default=None)
-_token_ctx: ContextVar[Optional[str]] = ContextVar("databricks_token", default=None)
-_force_token_ctx: ContextVar[bool] = ContextVar("force_token", default=False)
-
-
-def set_databricks_auth(
- host: Optional[str],
- token: Optional[str],
- *,
- force_token: bool = False,
-) -> None:
- """Set Databricks authentication for the current async context.
-
- Call this at the start of a request to set per-user credentials.
- The credentials will be used by all get_workspace_client() calls
- within this async context.
-
- Args:
- host: Databricks workspace URL (e.g., https://xxx.cloud.databricks.com)
- token: Databricks access token
- force_token: When True, the explicit token takes priority over
- environment OAuth credentials. Use for cross-workspace requests
- where the token belongs to a different workspace's SP.
- """
- _host_ctx.set(host)
- _token_ctx.set(token)
- _force_token_ctx.set(force_token)
-
-
-def clear_databricks_auth() -> None:
- """Clear Databricks authentication from the current context.
-
- Call this at the end of a request to clean up.
- """
- _host_ctx.set(None)
- _token_ctx.set(None)
- _force_token_ctx.set(False)
-
-
-def get_workspace_client() -> WorkspaceClient:
- """Get a WorkspaceClient using context auth or environment variables.
-
- Authentication priority:
- 0. Server-level active workspace override (set by manage_workspace tool)
- 1. If force_token is set (cross-workspace), use the explicit token from context
- 2. If OAuth credentials exist in env, use explicit OAuth M2M auth (Databricks Apps)
- - This explicitly sets auth_type to prevent conflicts with other auth methods
- 3. Context variables with explicit token (PAT auth for development)
- 4. Fall back to default authentication (env vars, config file)
-
- Returns:
- Configured WorkspaceClient instance
- """
- host = _host_ctx.get()
- token = _token_ctx.get()
- force = _force_token_ctx.get()
-
- # Common kwargs for product identification in user-agent
- product_kwargs = dict(product=PRODUCT_NAME, product_version=PRODUCT_VERSION)
-
- # Server-level workspace override set by the manage_workspace MCP tool.
- # Profile takes precedence over host when both are set.
- # Skipped when force_token is active (Builder App cross-workspace path wins)
- # or when OAuth M2M credentials are present (Databricks Apps runtime).
- if not force and not _has_oauth_credentials():
- if _active_profile:
- return tag_client(WorkspaceClient(profile=_active_profile, **product_kwargs))
- if _active_host:
- return tag_client(WorkspaceClient(host=_active_host, **product_kwargs))
-
- # Cross-workspace: explicit token overrides env OAuth so tool operations
- # target the caller-specified workspace instead of the app's own workspace
- if force and host and token:
- return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs))
-
- # In Databricks Apps (OAuth credentials in env), explicitly use OAuth M2M.
- # Setting auth_type="oauth-m2m" prevents the SDK from also reading
- # DATABRICKS_TOKEN from os.environ and raising a "more than one
- # authorization method configured" validation error.
- if _has_oauth_credentials():
- oauth_host = host or os.environ.get("DATABRICKS_HOST", "")
- client_id = os.environ.get("DATABRICKS_CLIENT_ID", "")
- client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET", "")
-
- return tag_client(
- WorkspaceClient(
- host=oauth_host,
- client_id=client_id,
- client_secret=client_secret,
- auth_type="oauth-m2m",
- **product_kwargs,
- )
- )
-
- # Development mode: use explicit token if provided
- if host and token:
- return tag_client(WorkspaceClient(host=host, token=token, auth_type="pat", **product_kwargs))
-
- if host:
- return tag_client(WorkspaceClient(host=host, **product_kwargs))
-
- # Fall back to default authentication (env vars, config file)
- return tag_client(WorkspaceClient(**product_kwargs))
-
-
-def get_current_username() -> Optional[str]:
- """Get the current authenticated user's username (email).
-
- Cached after first successful call — the authenticated user doesn't
- change mid-session. Returns None if the API call fails, allowing
- callers to degrade gracefully (e.g., skip user-based filtering).
-
- Returns:
- Username string (typically an email), or None on failure.
- """
- global _current_username, _current_username_fetched
- if _current_username_fetched:
- return _current_username
- try:
- w = get_workspace_client()
- _current_username = w.current_user.me().user_name
- _current_username_fetched = True
- return _current_username
- except Exception as e:
- logger.debug(f"Failed to fetch current username: {e}")
- _current_username_fetched = True
- return None
diff --git a/databricks-tools-core/databricks_tools_core/client.py b/databricks-tools-core/databricks_tools_core/client.py
deleted file mode 100644
index 91d10fda..00000000
--- a/databricks-tools-core/databricks_tools_core/client.py
+++ /dev/null
@@ -1,277 +0,0 @@
-"""
-Databricks REST API Client
-
-Shared HTTP client for all Databricks API operations.
-Uses Databricks SDK for authentication to support both PAT and OAuth.
-
-All clients are tagged with a custom product identifier and auto-detected
-project name so that API calls are attributable in ``system.access.audit``.
-"""
-
-import os
-from typing import Dict, Any, Optional, Callable
-import requests
-
-from databricks.sdk import WorkspaceClient
-
-from .identity import PRODUCT_NAME, PRODUCT_VERSION, tag_client
-
-
-def _has_oauth_credentials() -> bool:
- """Check if OAuth credentials (SP) are configured in environment."""
- return bool(os.environ.get("DATABRICKS_CLIENT_ID") and os.environ.get("DATABRICKS_CLIENT_SECRET"))
-
-
-class FilesAPI:
- """Databricks Files API for Unity Catalog Volumes."""
-
- def __init__(self, client: "DatabricksClient"):
- self.client = client
-
- def create_directory(self, path: str) -> None:
- """
- Create directory in Volume (idempotent).
-
- Args:
- path: Volume path (e.g., "/Volumes/catalog/schema/volume/dir")
-
- Raises:
- requests.HTTPError: If request fails
- """
- self.client.put("/api/2.0/fs/directories", json={"path": path})
-
- def delete_directory(self, path: str, ignore_missing: bool = False) -> None:
- """
- Delete directory recursively.
-
- Args:
- path: Volume path to delete
- ignore_missing: If True, ignore 404 errors
-
- Raises:
- requests.HTTPError: If request fails (unless ignore_missing=True for 404)
- """
- try:
- self.client.delete("/api/2.0/fs/directories", params={"path": path, "recursive": "true"})
- except requests.HTTPError as e:
- if not ignore_missing or e.response.status_code != 404:
- raise
-
- def upload(self, path: str, data: bytes, overwrite: bool = False) -> None:
- """
- Upload file to Volume.
-
- Args:
- path: Volume file path (e.g., "/Volumes/catalog/schema/volume/file.parquet")
- data: File content as bytes
- overwrite: If True, overwrite existing file
-
- Raises:
- requests.HTTPError: If request fails
- """
- self.client.put(f"/api/2.0/fs/files{path}", data=data, params={"overwrite": str(overwrite).lower()})
-
-
-class DatabricksClient:
- """Client for making requests to Databricks REST APIs.
-
- Uses Databricks SDK for authentication to support both PAT and OAuth (SP credentials).
- """
-
- def __init__(self, host: Optional[str] = None, token: Optional[str] = None, profile: Optional[str] = None):
- """
- Initialize Databricks client.
-
- Authentication priority (via Databricks SDK):
- 1. If OAuth credentials exist in env, use explicit OAuth M2M auth (Databricks Apps)
- 2. Explicit host/token parameters (PAT auth for development)
- 3. DATABRICKS_HOST and DATABRICKS_TOKEN env vars
- 4. Profile from ~/.databrickscfg
-
- Args:
- host: Databricks workspace URL
- token: Databricks personal access token (optional - uses SDK auth if not provided)
- profile: Profile name from ~/.databrickscfg (e.g., "ai-strat")
- """
- # Common kwargs for product identification in user-agent
- product_kwargs = dict(product=PRODUCT_NAME, product_version=PRODUCT_VERSION)
-
- # In Databricks Apps (OAuth credentials in env), explicitly use OAuth M2M
- # This prevents the SDK from detecting other auth methods like PAT or config file
- if _has_oauth_credentials():
- oauth_host = host or os.environ.get("DATABRICKS_HOST", "")
- client_id = os.environ.get("DATABRICKS_CLIENT_ID", "")
- client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET", "")
-
- # Explicitly configure OAuth M2M to prevent auth conflicts
- self._sdk_client = tag_client(
- WorkspaceClient(
- host=oauth_host,
- client_id=client_id,
- client_secret=client_secret,
- **product_kwargs,
- )
- )
- elif host and token:
- # Development mode: explicit PAT auth
- self._sdk_client = tag_client(WorkspaceClient(host=host, token=token, **product_kwargs))
- elif host:
- # Host provided, use SDK default auth
- self._sdk_client = tag_client(WorkspaceClient(host=host, **product_kwargs))
- elif profile:
- # Use config profile
- self._sdk_client = tag_client(WorkspaceClient(profile=profile, **product_kwargs))
- else:
- # Use default SDK auth (env vars, config file)
- self._sdk_client = tag_client(WorkspaceClient(**product_kwargs))
-
- # Get host from SDK config
- self.host = self._sdk_client.config.host.rstrip("/") if self._sdk_client.config.host else ""
-
- if not self.host:
- raise ValueError(
- "Databricks host must be provided via:\n"
- " 1. Constructor parameters (host)\n"
- " 2. Environment variables (DATABRICKS_HOST)\n"
- " 3. Config profile (profile parameter or DATABRICKS_CONFIG_PROFILE env var)"
- )
-
- # Store the authenticate function for getting fresh headers
- self._authenticate: Callable[[], dict] = self._sdk_client.config.authenticate
-
- # Initialize Files API
- self.files = FilesAPI(self)
-
- @property
- def headers(self) -> Dict[str, str]:
- """Get authentication and user-agent headers.
-
- Includes SDK authentication (fresh OAuth tokens when needed) and the
- product/project User-Agent so raw ``requests`` calls are also tracked
- in ``system.access.audit``.
- """
- headers = self._authenticate()
- headers["User-Agent"] = self._sdk_client.config.user_agent
- return headers
-
- def get(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
- """
- Make GET request to Databricks API.
-
- Args:
- endpoint: API endpoint path (e.g., "/api/2.1/unity-catalog/catalogs")
- params: Query parameters
-
- Returns:
- JSON response as dictionary
-
- Raises:
- requests.HTTPError: If request fails
- """
- url = f"{self.host}{endpoint}"
- response = requests.get(url, headers=self.headers, params=params)
- response.raise_for_status()
- return response.json()
-
- def post(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
- """
- Make POST request to Databricks API.
-
- Args:
- endpoint: API endpoint path
- json: JSON request body
-
- Returns:
- JSON response as dictionary
-
- Raises:
- requests.HTTPError: If request fails
- """
- url = f"{self.host}{endpoint}"
- response = requests.post(url, headers=self.headers, json=json)
- response.raise_for_status()
- return response.json()
-
- def patch(self, endpoint: str, json: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
- """
- Make PATCH request to Databricks API.
-
- Args:
- endpoint: API endpoint path
- json: JSON request body
-
- Returns:
- JSON response as dictionary
-
- Raises:
- requests.HTTPError: If request fails
- """
- url = f"{self.host}{endpoint}"
- response = requests.patch(url, headers=self.headers, json=json)
- response.raise_for_status()
- return response.json()
-
- def put(
- self,
- endpoint: str,
- json: Optional[Dict[str, Any]] = None,
- data: Optional[bytes] = None,
- params: Optional[Dict[str, Any]] = None,
- ) -> Dict[str, Any]:
- """
- Make PUT request to Databricks API.
-
- Args:
- endpoint: API endpoint path
- json: JSON request body (mutually exclusive with data)
- data: Binary data for file uploads (mutually exclusive with json)
- params: Query parameters
-
- Returns:
- JSON response as dictionary (or empty dict for 204 responses)
-
- Raises:
- requests.HTTPError: If request fails
- """
- url = f"{self.host}{endpoint}"
-
- if data is not None:
- headers = {**self.headers, "Content-Type": "application/octet-stream"}
- response = requests.put(url, data=data, params=params, headers=headers)
- elif json is not None:
- headers = {**self.headers, "Content-Type": "application/json"}
- response = requests.put(url, json=json, params=params, headers=headers)
- else:
- response = requests.put(url, params=params, headers=self.headers)
-
- response.raise_for_status()
-
- # Handle 204 No Content responses
- if response.status_code == 204:
- return {}
-
- return response.json()
-
- def delete(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
- """
- Make DELETE request to Databricks API.
-
- Args:
- endpoint: API endpoint path
- params: Query parameters
-
- Returns:
- JSON response as dictionary (or empty dict for 204 responses)
-
- Raises:
- requests.HTTPError: If request fails
- """
- url = f"{self.host}{endpoint}"
- response = requests.delete(url, headers=self.headers, params=params)
- response.raise_for_status()
-
- # Handle 204 No Content responses
- if response.status_code == 204 or not response.content:
- return {}
-
- return response.json()
diff --git a/databricks-tools-core/databricks_tools_core/common/__init__.py b/databricks-tools-core/databricks_tools_core/common/__init__.py
deleted file mode 100644
index e0bf43fb..00000000
--- a/databricks-tools-core/databricks_tools_core/common/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Common utilities shared across product lines."""
-
-__all__ = []
diff --git a/databricks-tools-core/databricks_tools_core/compute/__init__.py b/databricks-tools-core/databricks_tools_core/compute/__init__.py
deleted file mode 100644
index 81205e0b..00000000
--- a/databricks-tools-core/databricks_tools_core/compute/__init__.py
+++ /dev/null
@@ -1,60 +0,0 @@
-"""
-Compute - Code Execution and Compute Management Operations
-
-Functions for executing code on Databricks clusters and serverless compute,
-and for creating, modifying, and deleting compute resources.
-"""
-
-from .execution import (
- ExecutionResult,
- NoRunningClusterError,
- list_clusters,
- get_best_cluster,
- start_cluster,
- get_cluster_status,
- create_context,
- destroy_context,
- execute_databricks_command,
- run_file_on_databricks,
-)
-
-from .serverless import (
- ServerlessRunResult,
- run_code_on_serverless,
-)
-
-from .manage import (
- create_cluster,
- modify_cluster,
- terminate_cluster,
- delete_cluster,
- list_node_types,
- list_spark_versions,
- create_sql_warehouse,
- modify_sql_warehouse,
- delete_sql_warehouse,
-)
-
-__all__ = [
- "ExecutionResult",
- "NoRunningClusterError",
- "list_clusters",
- "get_best_cluster",
- "start_cluster",
- "get_cluster_status",
- "create_context",
- "destroy_context",
- "execute_databricks_command",
- "run_file_on_databricks",
- "ServerlessRunResult",
- "run_code_on_serverless",
- "create_cluster",
- "modify_cluster",
- "terminate_cluster",
- "delete_cluster",
- "list_node_types",
- "list_spark_versions",
- "create_sql_warehouse",
- "modify_sql_warehouse",
- "delete_sql_warehouse",
-]
diff --git a/databricks-tools-core/databricks_tools_core/compute/execution.py b/databricks-tools-core/databricks_tools_core/compute/execution.py
deleted file mode 100644
index 9a1c10ac..00000000
--- a/databricks-tools-core/databricks_tools_core/compute/execution.py
+++ /dev/null
@@ -1,785 +0,0 @@
-"""
-Compute - Execution Context Operations
-
-Functions for executing code on Databricks clusters using execution contexts.
-Uses Databricks Command Execution API via SDK.
-"""
-
-import datetime
-from typing import Optional, List, Dict, Any
-from databricks.sdk.service.compute import (
- CommandStatus,
- ClusterSource,
- DataSecurityMode,
- Language,
- ListClustersFilterBy,
- State,
-)
-
-from ..auth import get_workspace_client, get_current_username
-
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-class ExecutionResult:
- """Result from code execution on a Databricks cluster.
-
- Attributes:
- success: Whether the execution completed successfully.
- output: The output from the execution (if successful).
- error: The error message (if failed).
- cluster_id: The cluster ID used for execution.
- context_id: The execution context ID. Can be reused for follow-up
- commands to maintain state and speed up execution.
- context_destroyed: Whether the context was destroyed after execution.
- If False, the context_id can be reused.
- message: A helpful message about reusing the context.
- """
-
- def __init__(
- self,
- success: bool,
- output: Optional[str] = None,
- error: Optional[str] = None,
- cluster_id: Optional[str] = None,
- context_id: Optional[str] = None,
- context_destroyed: bool = True,
- ):
- self.success = success
- self.output = output
- self.error = error
- self.cluster_id = cluster_id
- self.context_id = context_id
- self.context_destroyed = context_destroyed
-
- # Generate helpful message
- if success and context_id and not context_destroyed:
- self.message = (
- f"Execution successful. To speed up follow-up commands and maintain "
- f"state (variables, imports), reuse context_id='{context_id}' with "
- f"cluster_id='{cluster_id}'."
- )
- elif success and context_destroyed:
- self.message = "Execution successful. Context was destroyed."
- elif not success:
- self.message = None
- else:
- self.message = None
-
- def __repr__(self):
- if self.success:
- return (
- f"ExecutionResult(success=True, output={repr(self.output)}, "
- f"cluster_id={repr(self.cluster_id)}, context_id={repr(self.context_id)})"
- )
- return f"ExecutionResult(success=False, error={repr(self.error)})"
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert to dictionary for JSON serialization."""
- return {
- "success": self.success,
- "output": self.output,
- "error": self.error,
- "cluster_id": self.cluster_id,
- "context_id": self.context_id,
- "context_destroyed": self.context_destroyed,
- "message": self.message,
- }
-
-
-# Only list user-created clusters (UI or API), not job/pipeline clusters
-_USER_CLUSTER_SOURCES = [ClusterSource.UI, ClusterSource.API]
-
-# Language string to enum mapping
-_LANGUAGE_MAP = {
- "python": Language.PYTHON,
- "scala": Language.SCALA,
- "sql": Language.SQL,
- "r": Language.R,
-}
-
-
-def list_clusters(
- include_terminated: bool = True,
- limit: Optional[int] = None,
-) -> List[Dict[str, Any]]:
- """
- List user-created clusters in the workspace.
-
- Only returns clusters created by users (via UI or API), not job or pipeline clusters.
- Searches running clusters first, then terminated if include_terminated=True.
-
- Args:
- include_terminated: If True, includes terminated clusters in results.
- If False, only returns running/pending clusters.
- limit: Maximum number of clusters to return. None means no limit.
-
- Returns:
- List of cluster info dicts with cluster_id, cluster_name, state, etc.
- """
- w = get_workspace_client()
- clusters = []
-
- def _add_cluster(cluster) -> bool:
- """Add cluster to list, return False if limit reached."""
- clusters.append(
- {
- "cluster_id": cluster.cluster_id,
- "cluster_name": cluster.cluster_name,
- "state": cluster.state.value if cluster.state else None,
- "creator_user_name": cluster.creator_user_name,
- "cluster_source": cluster.cluster_source.value if cluster.cluster_source else None,
- }
- )
- return limit is None or len(clusters) < limit
-
- # First, get running clusters (faster query)
- running_filter = ListClustersFilterBy(
- cluster_sources=_USER_CLUSTER_SOURCES,
- cluster_states=[State.RUNNING, State.PENDING, State.RESIZING, State.RESTARTING],
- )
- for cluster in w.clusters.list(filter_by=running_filter):
- if not _add_cluster(cluster):
- return clusters
-
- # If requested and not at limit, also get terminated clusters
- if include_terminated and (limit is None or len(clusters) < limit):
- terminated_filter = ListClustersFilterBy(
- cluster_sources=_USER_CLUSTER_SOURCES,
- cluster_states=[State.TERMINATED, State.TERMINATING, State.ERROR],
- )
- for cluster in w.clusters.list(filter_by=terminated_filter):
- if not _add_cluster(cluster):
- return clusters
-
- return clusters
-
-
-def _is_cluster_accessible(cluster, current_user: Optional[str]) -> bool:
- """Check whether the current user can use this cluster.
-
- A cluster is inaccessible when its data_security_mode is SINGLE_USER
- and the single_user_name doesn't match the current user.
-
- Args:
- cluster: SDK ClusterDetails object.
- current_user: Current user's username/email, or None if unknown.
-
- Returns:
- True if the cluster is accessible (or we can't determine either way).
- """
- if current_user is None:
- # Can't determine access — assume accessible (graceful degradation)
- return True
-
- dsm = getattr(cluster, "data_security_mode", None)
- single_user = getattr(cluster, "single_user_name", None)
-
- # If it's a single-user cluster assigned to someone else, skip it
- if dsm == DataSecurityMode.SINGLE_USER and single_user:
- if single_user.lower() != current_user.lower():
- logger.debug(
- f"Skipping cluster '{cluster.cluster_name}' ({cluster.cluster_id}): "
- f"single-user cluster owned by {single_user}, current user is {current_user}"
- )
- return False
-
- return True
-
-
-class ClusterSelectionResult:
- """Result from get_best_cluster with details about skipped clusters."""
-
- def __init__(
- self,
- cluster_id: Optional[str],
- skipped_clusters: Optional[List[Dict[str, str]]] = None,
- ):
- self.cluster_id = cluster_id
- self.skipped_clusters = skipped_clusters or []
-
-
-def get_best_cluster() -> Optional[str]:
- """
- Get the ID of the best available cluster for code execution.
-
- Only considers user-created clusters (UI or API), not job or pipeline clusters.
- Filters out single-user clusters that belong to a different user.
-
- Selection logic:
- 1. Only considers RUNNING clusters accessible to the current user
- 2. Prefers clusters with "shared" in the name (case-insensitive)
- 3. Then prefers clusters with "demo" in the name
- 4. Otherwise returns the first running cluster
-
- Returns:
- Cluster ID string, or None if no running clusters available.
- """
- return _select_best_cluster().cluster_id
-
-
-def _select_best_cluster() -> ClusterSelectionResult:
- """Internal cluster selection that also tracks skipped clusters.
-
- Returns:
- ClusterSelectionResult with the selected cluster_id and any skipped clusters.
- """
- w = get_workspace_client()
- current_user = get_current_username()
-
- # Only get running user-created clusters
- running_filter = ListClustersFilterBy(
- cluster_sources=_USER_CLUSTER_SOURCES,
- cluster_states=[State.RUNNING],
- )
-
- running_clusters = []
- skipped_clusters = []
- for cluster in w.clusters.list(filter_by=running_filter):
- if not _is_cluster_accessible(cluster, current_user):
- skipped_clusters.append(
- {
- "cluster_id": cluster.cluster_id,
- "cluster_name": cluster.cluster_name or "",
- "single_user_name": getattr(cluster, "single_user_name", None) or "unknown",
- }
- )
- continue
- running_clusters.append(
- {
- "cluster_id": cluster.cluster_id,
- "cluster_name": cluster.cluster_name or "",
- }
- )
-
- if not running_clusters:
- return ClusterSelectionResult(cluster_id=None, skipped_clusters=skipped_clusters)
-
- # Priority 1: clusters with "shared" in name
- for c in running_clusters:
- if "shared" in c["cluster_name"].lower():
- return ClusterSelectionResult(cluster_id=c["cluster_id"], skipped_clusters=skipped_clusters)
-
- # Priority 2: clusters with "demo" in name
- for c in running_clusters:
- if "demo" in c["cluster_name"].lower():
- return ClusterSelectionResult(cluster_id=c["cluster_id"], skipped_clusters=skipped_clusters)
-
- # Fallback: first running cluster
- return ClusterSelectionResult(cluster_id=running_clusters[0]["cluster_id"], skipped_clusters=skipped_clusters)
-
-
-class NoRunningClusterError(Exception):
- """Raised when no running cluster is available and none was specified.
-
- Provides structured data so agents can present actionable options to the user.
-
- Attributes:
- available_clusters: All clusters visible to the user (any state).
- skipped_clusters: Running clusters filtered out (e.g., single-user owned by others).
- startable_clusters: Terminated clusters the user can start.
- suggestions: List of actionable suggestion strings for the agent/user.
- """
-
- def __init__(
- self,
- available_clusters: List[Dict[str, str]],
- skipped_clusters: Optional[List[Dict[str, str]]] = None,
- startable_clusters: Optional[List[Dict[str, str]]] = None,
- ):
- self.available_clusters = available_clusters
- self.skipped_clusters = skipped_clusters or []
- self.startable_clusters = startable_clusters or []
- self.suggestions = self._build_suggestions()
-
- message = self._build_message()
- super().__init__(message)
-
- def _build_suggestions(self) -> List[str]:
- """Build a list of actionable suggestions based on available resources."""
- suggestions = []
-
- # Suggestion 1: offer to start a terminated cluster (agent should ask user first)
- if self.startable_clusters:
- best = self.startable_clusters[0]
- suggestions.append(
- f"ASK THE USER: \"I found your terminated cluster '{best['cluster_name']}'. "
- f'Would you like me to start it? (It typically takes 3-8 minutes to start.)". '
- f"If they approve, call start_cluster(cluster_id='{best['cluster_id']}'), "
- f"then poll with get_cluster_status() until it's RUNNING, then retry."
- )
- # Show additional options if there are more
- for c in self.startable_clusters[1:3]:
- suggestions.append(
- f"Alternative cluster: '{c['cluster_name']}' (cluster_id='{c['cluster_id']}', state={c['state']})"
- )
-
- # Suggestion 2: use execute_sql for SQL workloads
- suggestions.append(
- "For SQL-only workloads, use execute_sql() instead — it routes through "
- "SQL warehouses and doesn't require a cluster."
- )
-
- # Suggestion 3: ask for shared access
- suggestions.append("Ask a workspace admin for access to a shared cluster.")
-
- return suggestions
-
- def _build_message(self) -> str:
- """Build a human-readable error message."""
- message = "No running cluster available for the current user."
-
- if self.startable_clusters:
- cluster_list = "\n".join(
- f" - {c['cluster_name']} ({c['cluster_id']}) - {c['state']}" for c in self.startable_clusters[:10]
- )
- message += (
- f"\n\nYou have {len(self.startable_clusters)} terminated cluster(s) you could start:\n{cluster_list}"
- )
-
- if self.skipped_clusters:
- skipped_list = "\n".join(
- f" - {c['cluster_name']} ({c['cluster_id']}) - owned by {c.get('single_user_name', 'unknown')}"
- for c in self.skipped_clusters
- )
- message += (
- f"\n\n{len(self.skipped_clusters)} running cluster(s) were skipped because they are "
- f"single-user clusters assigned to a different user:\n{skipped_list}"
- )
-
- message += "\n\nSuggestions:\n"
- for i, suggestion in enumerate(self.suggestions, 1):
- message += f" {i}. {suggestion}\n"
-
- return message
-
-
-def start_cluster(cluster_id: str) -> Dict[str, Any]:
- """
- Start a terminated Databricks cluster.
-
- This initiates the cluster start process and returns immediately — it does
- NOT wait for the cluster to reach RUNNING state (that typically takes 3-8
- minutes). Use ``get_cluster_status()`` to poll until the cluster is ready.
-
- Args:
- cluster_id: ID of the cluster to start.
-
- Returns:
- Dictionary with cluster_id, cluster_name, state, and a message.
-
- Raises:
- Exception: If the cluster cannot be started (e.g., permissions, not found).
- """
- w = get_workspace_client()
-
- # Get cluster info first for a better response message
- cluster = w.clusters.get(cluster_id)
- cluster_name = cluster.cluster_name or cluster_id
- current_state = cluster.state.value if cluster.state else "UNKNOWN"
-
- if current_state == "RUNNING":
- return {
- "cluster_id": cluster_id,
- "cluster_name": cluster_name,
- "state": "RUNNING",
- "message": f"Cluster '{cluster_name}' is already running.",
- }
-
- if current_state not in ("TERMINATED", "ERROR"):
- return {
- "cluster_id": cluster_id,
- "cluster_name": cluster_name,
- "state": current_state,
- "message": (
- f"Cluster '{cluster_name}' is in state {current_state}. It may already be starting or resizing."
- ),
- }
-
- # Kick off start (non-blocking)
- w.clusters.start(cluster_id)
-
- return {
- "cluster_id": cluster_id,
- "cluster_name": cluster_name,
- "previous_state": current_state,
- "state": "PENDING",
- "message": (
- f"Cluster '{cluster_name}' is now starting. "
- f"This typically takes 3-8 minutes. "
- f"Use get_cluster_status(cluster_id='{cluster_id}') to check progress."
- ),
- }
-
-
-def get_cluster_status(cluster_id: str) -> Dict[str, Any]:
- """
- Get the current status of a Databricks cluster.
-
- Useful for polling a cluster after calling ``start_cluster()`` to check
- whether it has reached RUNNING state.
-
- Args:
- cluster_id: ID of the cluster.
-
- Returns:
- Dictionary with cluster_id, cluster_name, state, and a message.
- """
- w = get_workspace_client()
- cluster = w.clusters.get(cluster_id)
-
- cluster_name = cluster.cluster_name or cluster_id
- state = cluster.state.value if cluster.state else "UNKNOWN"
-
- if state == "RUNNING":
- message = f"Cluster '{cluster_name}' is running and ready for use."
- elif state in ("PENDING", "RESTARTING", "RESIZING"):
- message = f"Cluster '{cluster_name}' is {state.lower()}. Please wait and check again in 30-60 seconds."
- elif state == "TERMINATED":
- message = f"Cluster '{cluster_name}' is terminated."
- elif state == "TERMINATING":
- message = f"Cluster '{cluster_name}' is shutting down."
- else:
- message = f"Cluster '{cluster_name}' is in state: {state}."
-
- return {
- "cluster_id": cluster_id,
- "cluster_name": cluster_name,
- "state": state,
- "message": message,
- }
-
-
-def create_context(cluster_id: str, language: str = "python") -> str:
- """
- Create a new execution context on a Databricks cluster.
-
- Args:
- cluster_id: ID of the cluster to create context on
- language: Programming language ("python", "scala", "sql", "r")
-
- Returns:
- Context ID string
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
-
- lang_enum = _LANGUAGE_MAP.get(language.lower(), Language.PYTHON)
-
- # SDK returns Wait object, need to wait for result
- result = w.command_execution.create(
- cluster_id=cluster_id, language=lang_enum
- ).result() # Blocks until context is created
-
- return result.id
-
-
-def destroy_context(cluster_id: str, context_id: str) -> None:
- """
- Destroy an execution context.
-
- Args:
- cluster_id: ID of the cluster
- context_id: ID of the context to destroy
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.command_execution.destroy(cluster_id=cluster_id, context_id=context_id)
-
-
-def _execute_on_context(cluster_id: str, context_id: str, code: str, language: str, timeout: int) -> ExecutionResult:
- """
- Internal function to execute code on an existing context.
-
- Args:
- cluster_id: ID of the cluster
- context_id: ID of the execution context
- code: Code to execute
- language: Programming language
- timeout: Maximum time to wait for execution (seconds)
-
- Returns:
- ExecutionResult with output or error (context_id filled in but context_destroyed=False)
- """
- w = get_workspace_client()
- lang_enum = _LANGUAGE_MAP.get(language.lower(), Language.PYTHON)
-
- try:
- # Execute and wait for result with timeout
- result = w.command_execution.execute(
- cluster_id=cluster_id, context_id=context_id, language=lang_enum, command=code
- ).result(timeout=datetime.timedelta(seconds=timeout))
-
- # Check result status (compare with enum values)
- if result.status == CommandStatus.FINISHED:
- # Check if there was an error in the results
- if result.results and result.results.result_type and result.results.result_type.value == "error":
- error_msg = result.results.cause if result.results.cause else "Unknown error"
- return ExecutionResult(
- success=False,
- error=error_msg,
- cluster_id=cluster_id,
- context_id=context_id,
- context_destroyed=False,
- )
- output = result.results.data if result.results and result.results.data else "Success (no output)"
- return ExecutionResult(
- success=True,
- output=str(output),
- cluster_id=cluster_id,
- context_id=context_id,
- context_destroyed=False,
- )
- elif result.status in [CommandStatus.ERROR, CommandStatus.CANCELLED]:
- error_msg = result.results.cause if result.results and result.results.cause else "Unknown error"
- return ExecutionResult(
- success=False,
- error=error_msg,
- cluster_id=cluster_id,
- context_id=context_id,
- context_destroyed=False,
- )
- else:
- return ExecutionResult(
- success=False,
- error=f"Unexpected status: {result.status}",
- cluster_id=cluster_id,
- context_id=context_id,
- context_destroyed=False,
- )
-
- except TimeoutError:
- return ExecutionResult(
- success=False,
- error="Command timed out",
- cluster_id=cluster_id,
- context_id=context_id,
- context_destroyed=False,
- )
-
-
-def execute_databricks_command(
- code: str,
- cluster_id: Optional[str] = None,
- context_id: Optional[str] = None,
- language: str = "python",
- timeout: int = 120,
- destroy_context_on_completion: bool = False,
-) -> ExecutionResult:
- """
- Execute code on a Databricks cluster.
-
- If context_id is provided, reuses the existing context (faster, maintains state).
- If not provided, creates a new context.
-
- By default, the context is kept alive for reuse. Set destroy_context_on_completion=True
- to destroy it after execution.
-
- Args:
- code: Code to execute
- cluster_id: ID of the cluster. If not provided, auto-selects a running cluster
- (prefers clusters with "shared" or "demo" in name).
- context_id: Optional existing execution context ID. If provided, reuses it
- for faster execution and state preservation (variables, imports).
- language: Programming language ("python", "scala", "sql", "r")
- timeout: Maximum time to wait for execution (seconds)
- destroy_context_on_completion: If True, destroys the context after execution.
- Default is False to allow reuse.
-
- Returns:
- ExecutionResult with output, error, and context info for reuse.
-
- Raises:
- NoRunningClusterError: If no cluster_id provided and no running cluster found
- DatabricksError: If API request fails
- """
- # Auto-select cluster if not provided
- if cluster_id is None:
- selection = _select_best_cluster()
- cluster_id = selection.cluster_id
- if cluster_id is None:
- # No accessible running cluster — build an actionable error
- available_clusters = list_clusters(limit=20)
-
- # Deduplicate clusters by cluster_id (API sometimes returns duplicates)
- seen_ids = set()
- deduped = []
- for c in available_clusters:
- if c["cluster_id"] not in seen_ids:
- seen_ids.add(c["cluster_id"])
- deduped.append(c)
- available_clusters = deduped
-
- # Find terminated clusters the user could start, preferring user-owned
- current_user = get_current_username()
- terminated = [c for c in available_clusters if c.get("state") in ("TERMINATED", "ERROR")]
- if current_user:
- user_lower = current_user.lower()
- user_owned = [c for c in terminated if (c.get("creator_user_name") or "").lower() == user_lower]
- others = [c for c in terminated if (c.get("creator_user_name") or "").lower() != user_lower]
- startable_clusters = user_owned + others
- else:
- startable_clusters = terminated
-
- raise NoRunningClusterError(
- available_clusters,
- skipped_clusters=selection.skipped_clusters,
- startable_clusters=startable_clusters,
- )
-
- # Create context if not provided
- context_created = False
- if context_id is None:
- context_id = create_context(cluster_id, language)
- context_created = True
-
- try:
- # Execute command
- result = _execute_on_context(
- cluster_id=cluster_id,
- context_id=context_id,
- code=code,
- language=language,
- timeout=timeout,
- )
-
- # Destroy context if requested
- if destroy_context_on_completion:
- try:
- destroy_context(cluster_id, context_id)
- result.context_destroyed = True
- result.message = "Execution successful. Context was destroyed."
- except Exception:
- pass # Ignore cleanup errors
-
- return result
-
- except Exception:
- # If we created the context and there's an error, clean up
- if context_created and destroy_context_on_completion:
- try:
- destroy_context(cluster_id, context_id)
- except Exception:
- pass
- raise
-
-
-_FILE_EXT_LANGUAGE = {
- ".py": "python",
- ".scala": "scala",
- ".sql": "sql",
- ".r": "r",
-}
-
-
-def run_file_on_databricks(
- file_path: str,
- cluster_id: Optional[str] = None,
- context_id: Optional[str] = None,
- language: Optional[str] = None,
- timeout: int = 600,
- destroy_context_on_completion: bool = False,
- workspace_path: Optional[str] = None,
-) -> ExecutionResult:
- """
- Read a local file and execute it on a Databricks cluster.
-
- Supports Python, Scala, SQL, and R files. If ``language`` is not specified,
- it is auto-detected from the file extension (.py, .scala, .sql, .r).
-
- Two modes:
- - **Ephemeral** (default): Sends code directly via Command Execution API.
- No artifact is saved in the workspace.
- - **Persistent**: If ``workspace_path`` is provided, also uploads the file
- as a notebook to that workspace path so it is visible in the Databricks UI.
-
- Args:
- file_path: Local path to the file to execute.
- cluster_id: ID of the cluster to run on. If not provided, auto-selects
- a running cluster (prefers "shared" or "demo").
- context_id: Optional existing execution context ID for reuse.
- language: Programming language ("python", "scala", "sql", "r").
- If omitted, auto-detected from file extension.
- timeout: Maximum time to wait for execution (seconds, default 600).
- destroy_context_on_completion: If True, destroys the context after execution.
- workspace_path: Optional workspace path to persist the file as a notebook
- (e.g. "/Workspace/Users/user@company.com/my-project/train").
- If omitted, no workspace artifact is created.
-
- Returns:
- ExecutionResult with output, error, and context info for reuse.
- """
- import os
-
- # Read the file contents
- try:
- with open(file_path, "r", encoding="utf-8") as f:
- code = f.read()
- except FileNotFoundError:
- return ExecutionResult(success=False, error=f"File not found: {file_path}")
- except Exception as e:
- return ExecutionResult(success=False, error=f"Failed to read file {file_path}: {str(e)}")
-
- if not code.strip():
- return ExecutionResult(success=False, error=f"File is empty: {file_path}")
-
- # Auto-detect language from file extension if not specified
- if language is None:
- ext = os.path.splitext(file_path)[1].lower()
- language = _FILE_EXT_LANGUAGE.get(ext, "python")
-
- # Persist to workspace if requested
- if workspace_path:
- try:
- _upload_to_workspace(code, language, workspace_path)
- except Exception as e:
- return ExecutionResult(success=False, error=f"Failed to upload to workspace: {e}")
-
- # Execute the code on Databricks
- return execute_databricks_command(
- code=code,
- cluster_id=cluster_id,
- context_id=context_id,
- language=language,
- timeout=timeout,
- destroy_context_on_completion=destroy_context_on_completion,
- )
-
-
-def _upload_to_workspace(code: str, language: str, workspace_path: str) -> None:
- """Upload code as a notebook to the Databricks workspace for persistence."""
- import base64
-
- from databricks.sdk.service.workspace import ImportFormat, Language
-
- lang_map = {
- "python": Language.PYTHON,
- "scala": Language.SCALA,
- "sql": Language.SQL,
- "r": Language.R,
- }
-
- w = get_workspace_client()
- lang_enum = lang_map.get(language.lower(), Language.PYTHON)
- content_b64 = base64.b64encode(code.encode("utf-8")).decode("utf-8")
-
- # Ensure parent directory exists
- parent = workspace_path.rsplit("/", 1)[0]
- try:
- w.workspace.mkdirs(parent)
- except Exception:
- pass # Directory may already exist
-
- w.workspace.import_(
- path=workspace_path,
- content=content_b64,
- language=lang_enum,
- format=ImportFormat.SOURCE,
- overwrite=True,
- )
diff --git a/databricks-tools-core/databricks_tools_core/compute/manage.py b/databricks-tools-core/databricks_tools_core/compute/manage.py
deleted file mode 100644
index a216dc20..00000000
--- a/databricks-tools-core/databricks_tools_core/compute/manage.py
+++ /dev/null
@@ -1,560 +0,0 @@
-"""
-Compute - Manage Compute Resources
-
-Functions for creating, modifying, and deleting Databricks clusters and SQL warehouses.
-Uses Databricks SDK for all operations.
-"""
-
-import logging
-from typing import Optional, List, Dict, Any
-
-from databricks.sdk.service.compute import (
- AutoScale,
- DataSecurityMode,
-)
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-# --- Clusters ---
-
-
-def _get_latest_lts_spark_version(w) -> str:
- """Pick the latest LTS Databricks Runtime version.
-
- Falls back to the latest non-ML, non-GPU, non-Photon LTS version,
- or the first available version if no LTS is found.
- """
- versions = w.clusters.spark_versions()
- lts_versions = []
- for v in versions.versions:
- key = v.key or ""
- name = (v.name or "").lower()
- # Skip ML, GPU, Photon, and aarch64 runtimes
- if any(tag in key for tag in ("-ml-", "-gpu-", "-photon-", "-aarch64-")):
- continue
- if "lts" in name:
- lts_versions.append(v)
-
- if lts_versions:
- # Sort by key descending to get latest
- lts_versions.sort(key=lambda v: v.key, reverse=True)
- return lts_versions[0].key
-
- # Fallback: first available version
- if versions.versions:
- return versions.versions[0].key
-
- raise RuntimeError("No Spark versions available in this workspace")
-
-
-def _get_default_node_type(w) -> str:
- """Pick a reasonable default node type for the current cloud.
-
- Prefers memory-optimized, mid-size instances. Falls back to the
- smallest available node type.
- """
- node_types = w.clusters.list_node_types()
-
- # Common sensible defaults by cloud
- preferred = [
- "i3.xlarge", # AWS
- "Standard_DS3_v2", # Azure
- "n1-highmem-4", # GCP
- "Standard_D4ds_v5", # Azure newer
- "m5d.xlarge", # AWS newer
- ]
-
- available_ids = {nt.node_type_id for nt in node_types.node_types}
-
- for pref in preferred:
- if pref in available_ids:
- return pref
-
- # Fallback: pick smallest available node type by memory
- if node_types.node_types:
- sorted_types = sorted(
- node_types.node_types,
- key=lambda nt: getattr(nt, "memory_mb", 0) or 0,
- )
- # Skip types with 0 memory (metadata-only entries)
- for nt in sorted_types:
- if (getattr(nt, "memory_mb", 0) or 0) > 0:
- return nt.node_type_id
- return sorted_types[0].node_type_id
-
- raise RuntimeError("No node types available in this workspace")
-
-
-def create_cluster(
- name: str,
- num_workers: int = 1,
- spark_version: Optional[str] = None,
- node_type_id: Optional[str] = None,
- driver_node_type_id: Optional[str] = None,
- autotermination_minutes: int = 120,
- data_security_mode: Optional[str] = None,
- single_user_name: Optional[str] = None,
- spark_conf: Optional[Dict[str, str]] = None,
- autoscale_min_workers: Optional[int] = None,
- autoscale_max_workers: Optional[int] = None,
- **kwargs,
-) -> Dict[str, Any]:
- """Create a new Databricks cluster with sensible defaults.
-
- Provides opinionated defaults so ``create_cluster(name="my-cluster", num_workers=1)``
- just works — auto-picks the latest LTS DBR, a reasonable node type, single-user
- security mode, and 120-minute auto-termination.
-
- Power users can override any parameter or pass additional SDK parameters via kwargs.
-
- Args:
- name: Human-readable cluster name.
- num_workers: Fixed number of workers (ignored if autoscale is set). Default 1.
- spark_version: DBR version key (e.g. "15.4.x-scala2.12"). Auto-picks latest LTS if omitted.
- node_type_id: Worker node type (e.g. "i3.xlarge"). Auto-picked if omitted.
- driver_node_type_id: Driver node type. Defaults to same as worker.
- autotermination_minutes: Minutes of inactivity before auto-termination. Default 120.
- data_security_mode: Security mode string ("SINGLE_USER", "USER_ISOLATION", etc.).
- Defaults to SINGLE_USER.
- single_user_name: User for SINGLE_USER mode. Auto-detected if omitted.
- spark_conf: Spark configuration overrides.
- autoscale_min_workers: If set (with autoscale_max_workers), enables autoscaling
- instead of fixed num_workers.
- autoscale_max_workers: Maximum workers for autoscaling.
- **kwargs: Additional parameters passed directly to the SDK clusters.create() call.
-
- Returns:
- Dict with cluster_id, cluster_name, state, and message.
- """
- w = get_workspace_client()
-
- # Auto-pick defaults
- if spark_version is None:
- spark_version = _get_latest_lts_spark_version(w)
- if node_type_id is None:
- node_type_id = _get_default_node_type(w)
- if driver_node_type_id is None:
- driver_node_type_id = node_type_id
-
- # Security mode defaults
- if data_security_mode is None:
- dsm = DataSecurityMode.SINGLE_USER
- else:
- dsm = DataSecurityMode(data_security_mode)
-
- if dsm == DataSecurityMode.SINGLE_USER and single_user_name is None:
- from ..auth import get_current_username
- single_user_name = get_current_username()
-
- # Build create kwargs
- create_kwargs = {
- "cluster_name": name,
- "spark_version": spark_version,
- "node_type_id": node_type_id,
- "driver_node_type_id": driver_node_type_id,
- "autotermination_minutes": autotermination_minutes,
- "data_security_mode": dsm,
- }
-
- if single_user_name:
- create_kwargs["single_user_name"] = single_user_name
- if spark_conf:
- create_kwargs["spark_conf"] = spark_conf
-
- # Autoscale vs fixed workers
- if autoscale_min_workers is not None and autoscale_max_workers is not None:
- create_kwargs["autoscale"] = AutoScale(
- min_workers=autoscale_min_workers,
- max_workers=autoscale_max_workers,
- )
- else:
- create_kwargs["num_workers"] = num_workers
-
- # Merge any extra SDK parameters
- create_kwargs.update(kwargs)
-
- # Create the cluster (non-blocking — returns immediately)
- wait = w.clusters.create(**create_kwargs)
- cluster_id = wait.cluster_id
-
- return {
- "cluster_id": cluster_id,
- "cluster_name": name,
- "state": "PENDING",
- "spark_version": spark_version,
- "node_type_id": node_type_id,
- "message": (
- f"Cluster '{name}' is being created (cluster_id='{cluster_id}'). "
- f"It typically takes 3-8 minutes to start. "
- f"Use get_cluster_status(cluster_id='{cluster_id}') to check progress."
- ),
- }
-
-
-def modify_cluster(
- cluster_id: str,
- name: Optional[str] = None,
- num_workers: Optional[int] = None,
- spark_version: Optional[str] = None,
- node_type_id: Optional[str] = None,
- driver_node_type_id: Optional[str] = None,
- autotermination_minutes: Optional[int] = None,
- spark_conf: Optional[Dict[str, str]] = None,
- autoscale_min_workers: Optional[int] = None,
- autoscale_max_workers: Optional[int] = None,
- **kwargs,
-) -> Dict[str, Any]:
- """Modify an existing Databricks cluster configuration.
-
- Fetches the current config, applies the requested changes, and calls the
- edit API. The cluster will restart if it is running.
-
- Args:
- cluster_id: ID of the cluster to modify.
- name: New cluster name (optional).
- num_workers: New fixed worker count (optional).
- spark_version: New DBR version (optional).
- node_type_id: New worker node type (optional).
- driver_node_type_id: New driver node type (optional).
- autotermination_minutes: New auto-termination timeout (optional).
- spark_conf: Spark configuration overrides (optional).
- autoscale_min_workers: Set to enable/modify autoscaling (optional).
- autoscale_max_workers: Set to enable/modify autoscaling (optional).
- **kwargs: Additional SDK parameters.
-
- Returns:
- Dict with cluster_id, cluster_name, state, and message.
- """
- w = get_workspace_client()
-
- # Get current cluster config
- cluster = w.clusters.get(cluster_id)
-
- # Build edit kwargs from current config
- edit_kwargs = {
- "cluster_id": cluster_id,
- "cluster_name": name or cluster.cluster_name,
- "spark_version": spark_version or cluster.spark_version,
- "node_type_id": node_type_id or cluster.node_type_id,
- "driver_node_type_id": driver_node_type_id or cluster.driver_node_type_id or cluster.node_type_id,
- }
-
- if autotermination_minutes is not None:
- edit_kwargs["autotermination_minutes"] = autotermination_minutes
- elif cluster.autotermination_minutes:
- edit_kwargs["autotermination_minutes"] = cluster.autotermination_minutes
-
- if spark_conf is not None:
- edit_kwargs["spark_conf"] = spark_conf
- elif cluster.spark_conf:
- edit_kwargs["spark_conf"] = cluster.spark_conf
-
- # Handle data_security_mode and single_user_name from existing config
- if cluster.data_security_mode:
- edit_kwargs["data_security_mode"] = cluster.data_security_mode
- if cluster.single_user_name:
- edit_kwargs["single_user_name"] = cluster.single_user_name
-
- # Autoscale vs fixed workers
- if autoscale_min_workers is not None and autoscale_max_workers is not None:
- edit_kwargs["autoscale"] = AutoScale(
- min_workers=autoscale_min_workers,
- max_workers=autoscale_max_workers,
- )
- elif num_workers is not None:
- edit_kwargs["num_workers"] = num_workers
- elif cluster.autoscale:
- edit_kwargs["autoscale"] = cluster.autoscale
- else:
- edit_kwargs["num_workers"] = cluster.num_workers or 0
-
- # Merge extra SDK params
- edit_kwargs.update(kwargs)
-
- w.clusters.edit(**edit_kwargs)
-
- current_state = cluster.state.value if cluster.state else "UNKNOWN"
- cluster_name = edit_kwargs["cluster_name"]
-
- return {
- "cluster_id": cluster_id,
- "cluster_name": cluster_name,
- "state": current_state,
- "message": (
- f"Cluster '{cluster_name}' configuration updated. "
- + (
- "The cluster will restart to apply changes."
- if current_state == "RUNNING"
- else "Changes will take effect when the cluster starts."
- )
- ),
- }
-
-
-def terminate_cluster(cluster_id: str) -> Dict[str, Any]:
- """Stop a running Databricks cluster (reversible).
-
- The cluster is terminated but not deleted. It can be restarted later
- with start_cluster(). This is a safe, reversible operation.
-
- Args:
- cluster_id: ID of the cluster to terminate.
-
- Returns:
- Dict with cluster_id, cluster_name, state, and message.
- """
- w = get_workspace_client()
- cluster = w.clusters.get(cluster_id)
- cluster_name = cluster.cluster_name or cluster_id
- current_state = cluster.state.value if cluster.state else "UNKNOWN"
-
- if current_state == "TERMINATED":
- return {
- "cluster_id": cluster_id,
- "cluster_name": cluster_name,
- "state": "TERMINATED",
- "message": f"Cluster '{cluster_name}' is already terminated.",
- }
-
- w.clusters.delete(cluster_id) # SDK's delete = terminate (confusing but correct)
-
- return {
- "cluster_id": cluster_id,
- "cluster_name": cluster_name,
- "previous_state": current_state,
- "state": "TERMINATING",
- "message": f"Cluster '{cluster_name}' is being terminated. This is reversible — use start_cluster() to restart.",
- }
-
-
-def delete_cluster(cluster_id: str) -> Dict[str, Any]:
- """Permanently delete a Databricks cluster.
-
- WARNING: This action is PERMANENT and cannot be undone. The cluster
- and its configuration will be permanently removed.
-
- Args:
- cluster_id: ID of the cluster to permanently delete.
-
- Returns:
- Dict with cluster_id, cluster_name, and warning message.
- """
- w = get_workspace_client()
- cluster = w.clusters.get(cluster_id)
- cluster_name = cluster.cluster_name or cluster_id
-
- w.clusters.permanent_delete(cluster_id)
-
- return {
- "cluster_id": cluster_id,
- "cluster_name": cluster_name,
- "state": "DELETED",
- "message": (
- f"WARNING: Cluster '{cluster_name}' has been PERMANENTLY deleted. "
- f"This action cannot be undone. The cluster configuration is gone."
- ),
- }
-
-
-def list_node_types() -> List[Dict[str, Any]]:
- """List available VM/node types for cluster creation.
-
- Returns a summary of each node type including ID, memory, cores,
- and GPU info. Useful for choosing node_type_id when creating clusters.
-
- Returns:
- List of node type info dicts.
- """
- w = get_workspace_client()
- result = w.clusters.list_node_types()
-
- node_types = []
- for nt in result.node_types:
- node_types.append({
- "node_type_id": nt.node_type_id,
- "memory_mb": nt.memory_mb,
- "num_cores": getattr(nt, "num_cores", None),
- "num_gpus": getattr(nt, "num_gpus", None) or 0,
- "description": getattr(nt, "description", None) or nt.node_type_id,
- "is_deprecated": getattr(nt, "is_deprecated", False),
- })
- return node_types
-
-
-def list_spark_versions() -> List[Dict[str, Any]]:
- """List available Databricks Runtime (Spark) versions.
-
- Returns version key and name. Filter for "LTS" in the name to find
- long-term support versions.
-
- Returns:
- List of dicts with key and name for each version.
- """
- w = get_workspace_client()
- result = w.clusters.spark_versions()
-
- versions = []
- for v in result.versions:
- versions.append({
- "key": v.key,
- "name": v.name,
- })
- return versions
-
-
-# --- SQL Warehouses ---
-
-
-def create_sql_warehouse(
- name: str,
- size: str = "Small",
- min_num_clusters: int = 1,
- max_num_clusters: int = 1,
- auto_stop_mins: int = 120,
- warehouse_type: str = "PRO",
- enable_serverless: bool = True,
- **kwargs,
-) -> Dict[str, Any]:
- """Create a new SQL warehouse with sensible defaults.
-
- By default creates a serverless Pro warehouse with auto-stop at 120 minutes.
-
- Args:
- name: Human-readable warehouse name.
- size: T-shirt size ("2X-Small", "X-Small", "Small", "Medium", "Large",
- "X-Large", "2X-Large", "3X-Large", "4X-Large"). Default "Small".
- min_num_clusters: Minimum number of clusters. Default 1.
- max_num_clusters: Maximum number of clusters for scaling. Default 1.
- auto_stop_mins: Minutes of inactivity before auto-stop. Default 120.
- warehouse_type: "PRO", "CLASSIC", or "TYPE_UNSPECIFIED". Default "PRO".
- enable_serverless: Enable serverless compute. Default True.
- **kwargs: Additional SDK parameters.
-
- Returns:
- Dict with warehouse_id, name, state, and message.
- """
- w = get_workspace_client()
-
- from databricks.sdk.service.sql import (
- CreateWarehouseRequestWarehouseType,
- )
-
- # Map warehouse type string to enum
- type_map = {
- "PRO": CreateWarehouseRequestWarehouseType.PRO,
- "CLASSIC": CreateWarehouseRequestWarehouseType.CLASSIC,
- "TYPE_UNSPECIFIED": CreateWarehouseRequestWarehouseType.TYPE_UNSPECIFIED,
- }
- wh_type = type_map.get(warehouse_type.upper(), CreateWarehouseRequestWarehouseType.PRO)
-
- create_kwargs = {
- "name": name,
- "cluster_size": size,
- "min_num_clusters": min_num_clusters,
- "max_num_clusters": max_num_clusters,
- "auto_stop_mins": auto_stop_mins,
- "warehouse_type": wh_type,
- "enable_serverless_compute": enable_serverless,
- }
- create_kwargs.update(kwargs)
-
- wait = w.warehouses.create(**create_kwargs)
- warehouse_id = wait.id
-
- return {
- "warehouse_id": warehouse_id,
- "name": name,
- "size": size,
- "state": "STARTING",
- "message": (
- f"SQL warehouse '{name}' is being created (warehouse_id='{warehouse_id}'). "
- f"It typically takes 1-3 minutes to start."
- ),
- }
-
-
-def modify_sql_warehouse(
- warehouse_id: str,
- name: Optional[str] = None,
- size: Optional[str] = None,
- min_num_clusters: Optional[int] = None,
- max_num_clusters: Optional[int] = None,
- auto_stop_mins: Optional[int] = None,
- **kwargs,
-) -> Dict[str, Any]:
- """Modify an existing SQL warehouse configuration.
-
- Only the specified parameters are changed; others remain as-is.
-
- Args:
- warehouse_id: ID of the warehouse to modify.
- name: New warehouse name (optional).
- size: New T-shirt size (optional).
- min_num_clusters: New minimum clusters (optional).
- max_num_clusters: New maximum clusters (optional).
- auto_stop_mins: New auto-stop timeout in minutes (optional).
- **kwargs: Additional SDK parameters.
-
- Returns:
- Dict with warehouse_id, name, state, and message.
- """
- w = get_workspace_client()
-
- # Get current config
- wh = w.warehouses.get(warehouse_id)
-
- edit_kwargs = {
- "id": warehouse_id,
- "name": name or wh.name,
- "cluster_size": size or wh.cluster_size,
- "min_num_clusters": min_num_clusters if min_num_clusters is not None else wh.min_num_clusters,
- "max_num_clusters": max_num_clusters if max_num_clusters is not None else wh.max_num_clusters,
- "auto_stop_mins": auto_stop_mins if auto_stop_mins is not None else wh.auto_stop_mins,
- }
- edit_kwargs.update(kwargs)
-
- w.warehouses.edit(**edit_kwargs)
-
- current_state = wh.state.value if wh.state else "UNKNOWN"
- wh_name = edit_kwargs["name"]
-
- return {
- "warehouse_id": warehouse_id,
- "name": wh_name,
- "state": current_state,
- "message": f"SQL warehouse '{wh_name}' configuration updated.",
- }
-
-
-def delete_sql_warehouse(warehouse_id: str) -> Dict[str, Any]:
- """Permanently delete a SQL warehouse.
-
- WARNING: This action is PERMANENT and cannot be undone. The warehouse
- and its configuration will be permanently removed.
-
- Args:
- warehouse_id: ID of the warehouse to permanently delete.
-
- Returns:
- Dict with warehouse_id, name, and warning message.
- """
- w = get_workspace_client()
-
- # Get warehouse info before deleting
- wh = w.warehouses.get(warehouse_id)
- wh_name = wh.name or warehouse_id
-
- w.warehouses.delete(warehouse_id)
-
- return {
- "warehouse_id": warehouse_id,
- "name": wh_name,
- "state": "DELETED",
- "message": (
- f"WARNING: SQL warehouse '{wh_name}' has been PERMANENTLY deleted. "
- f"This action cannot be undone."
- ),
- }
diff --git a/databricks-tools-core/databricks_tools_core/compute/serverless.py b/databricks-tools-core/databricks_tools_core/compute/serverless.py
deleted file mode 100644
index 65fdff2b..00000000
--- a/databricks-tools-core/databricks_tools_core/compute/serverless.py
+++ /dev/null
@@ -1,448 +0,0 @@
-"""
-Compute - Serverless Code Execution
-
-Execute Python or SQL code on Databricks serverless compute via the Jobs API
-(runs/submit). No interactive cluster required.
-
-Usage:
- from databricks_tools_core.compute.serverless import run_code_on_serverless
-
- result = run_code_on_serverless("print('hello')", language="python")
- result = run_code_on_serverless("SELECT 1", language="sql")
-"""
-
-import base64
-import datetime
-import json
-import logging
-import time
-import uuid
-from dataclasses import dataclass
-from typing import Dict, Any, Optional
-
-from databricks.sdk.service.compute import Environment
-from databricks.sdk.service.jobs import (
- JobEnvironment,
- NotebookTask,
- RunResultState,
- SubmitTask,
-)
-from databricks.sdk.service.workspace import ImportFormat, Language
-
-from ..auth import get_workspace_client, get_current_username
-
-logger = logging.getLogger(__name__)
-
-# Language string to workspace Language enum
-_LANGUAGE_MAP = {
- "python": Language.PYTHON,
- "sql": Language.SQL,
-}
-
-
-@dataclass
-class ServerlessRunResult:
- """Result from serverless code execution via Jobs API.
-
- Attributes:
- success: Whether the execution completed successfully.
- output: The output from the execution (notebook result or logs).
- error: Error message if execution failed.
- run_id: Databricks Jobs run ID.
- run_url: URL to the run in the Databricks UI.
- duration_seconds: Wall-clock duration of the execution.
- state: Final state string (SUCCESS, FAILED, TIMEDOUT, CANCELED, etc.).
- message: Human-readable summary of the result.
- """
-
- success: bool
- output: Optional[str] = None
- error: Optional[str] = None
- run_id: Optional[int] = None
- run_url: Optional[str] = None
- duration_seconds: Optional[float] = None
- state: Optional[str] = None
- message: Optional[str] = None
- workspace_path: Optional[str] = None
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert to dictionary for JSON serialization."""
- d = {
- "success": self.success,
- "output": self.output,
- "error": self.error,
- "run_id": self.run_id,
- "run_url": self.run_url,
- "duration_seconds": self.duration_seconds,
- "state": self.state,
- "message": self.message,
- }
- if self.workspace_path:
- d["workspace_path"] = self.workspace_path
- return d
-
-
-def _get_temp_notebook_path(run_label: str) -> str:
- """Build a workspace path for a temporary serverless notebook.
-
- Args:
- run_label: Unique label for this run.
-
- Returns:
- Workspace path string under the current user's home directory.
- """
- username = get_current_username()
- base = f"/Workspace/Users/{username}" if username else "/Workspace"
- return f"{base}/.ai_dev_kit_tmp/{run_label}"
-
-
-def _is_ipynb(content: str) -> bool:
- """Check if content is a Jupyter notebook (.ipynb) JSON structure."""
- try:
- data = json.loads(content)
- return isinstance(data, dict) and "cells" in data
- except (json.JSONDecodeError, ValueError):
- return False
-
-
-def _upload_temp_notebook(
- code: str, language: str, workspace_path: str, is_jupyter: bool = False
-) -> None:
- """Upload code as a temporary notebook to the Databricks workspace.
-
- Args:
- code: Source code or .ipynb JSON content to upload.
- language: Language string ("python" or "sql"). Ignored for Jupyter uploads.
- workspace_path: Target workspace path (without file extension).
- is_jupyter: If True, upload as Jupyter format (ImportFormat.JUPYTER).
-
- Raises:
- Exception: If the upload fails.
- """
- w = get_workspace_client()
- content_b64 = base64.b64encode(code.encode("utf-8")).decode("utf-8")
-
- # Ensure parent directory exists
- parent = workspace_path.rsplit("/", 1)[0]
- try:
- w.workspace.mkdirs(parent)
- except Exception:
- pass # Directory may already exist
-
- if is_jupyter:
- w.workspace.import_(
- path=workspace_path,
- content=content_b64,
- format=ImportFormat.JUPYTER,
- overwrite=True,
- )
- else:
- lang_enum = _LANGUAGE_MAP[language]
- w.workspace.import_(
- path=workspace_path,
- content=content_b64,
- language=lang_enum,
- format=ImportFormat.SOURCE,
- overwrite=True,
- )
-
-
-def _cleanup_temp_notebook(workspace_path: str) -> None:
- """Delete a temporary notebook from the workspace (best-effort)."""
- try:
- w = get_workspace_client()
- w.workspace.delete(path=workspace_path, recursive=False)
- except Exception as e:
- logger.debug(f"Cleanup of {workspace_path} failed (non-fatal): {e}")
-
-
-def _get_run_output(task_run_id: int) -> Dict[str, Optional[str]]:
- """Retrieve output and error text from a completed task run.
-
- Args:
- task_run_id: The run ID of the specific task (not the parent run).
-
- Returns:
- Dict with ``output`` and ``error`` keys (both may be None).
- """
- w = get_workspace_client()
- result: Dict[str, Optional[str]] = {"output": None, "error": None}
-
- try:
- run_output = w.jobs.get_run_output(run_id=task_run_id)
-
- # Notebook output (from dbutils.notebook.exit() or last cell)
- if run_output.notebook_output and run_output.notebook_output.result:
- result["output"] = run_output.notebook_output.result
-
- # Logs (stdout/stderr, typically for spark_python_task)
- if run_output.logs:
- if result["output"]:
- result["output"] += f"\n\n--- Logs ---\n{run_output.logs}"
- else:
- result["output"] = run_output.logs
-
- # Error details
- if run_output.error:
- error_parts = [run_output.error]
- if run_output.error_trace:
- error_parts.append(run_output.error_trace)
- result["error"] = "\n\n".join(error_parts)
-
- except Exception as e:
- logger.debug(f"Failed to get output for task run {task_run_id}: {e}")
- result["error"] = str(e)
-
- return result
-
-
-def run_code_on_serverless(
- code: str,
- language: str = "python",
- timeout: int = 1800,
- run_name: Optional[str] = None,
- cleanup: bool = True,
- workspace_path: Optional[str] = None,
- job_extra_params: Optional[Dict[str, Any]] = None,
-) -> ServerlessRunResult:
- """Execute code on serverless compute via Jobs API runs/submit.
-
- Uploads the code as a notebook, submits it as a one-time run on serverless
- compute (no cluster required), waits for completion, and retrieves output.
-
- Two modes:
- - **Ephemeral** (default): Uploads to a temp path and cleans up after.
- - **Persistent**: If ``workspace_path`` is provided, uploads to that path
- and keeps it after execution. Useful for project notebooks (model training,
- ETL) the user wants saved in their workspace.
-
- Jupyter notebooks (.ipynb) are also supported. If the code content is
- detected as .ipynb JSON (contains "cells" key), it is uploaded using
- Databricks' native Jupyter import (ImportFormat.JUPYTER). The language
- parameter is ignored in this case since the notebook carries its own
- kernel metadata.
-
- SQL is supported but SELECT query results are NOT captured in the output.
- SQL via this tool is only useful for DDL/DML (CREATE TABLE, INSERT, MERGE).
- For SQL that needs result rows, use execute_sql() instead.
-
- Args:
- code: Code to execute, or raw .ipynb JSON content (auto-detected).
- language: Programming language ("python" or "sql"). Ignored for .ipynb.
- timeout: Maximum wait time in seconds (default: 1800 = 30 minutes).
- run_name: Optional human-readable run name. Auto-generated if omitted.
- cleanup: Delete the notebook after execution (default: True).
- Ignored when ``workspace_path`` is provided (persistent mode never cleans up).
- workspace_path: Optional workspace path to save the notebook to
- (e.g. "/Workspace/Users/user@company.com/my-project/train").
- If provided, the notebook is persisted at this path. If omitted,
- a temporary path is used and cleaned up after execution.
- job_extra_params: Optional dict of extra parameters to pass to jobs.submit().
- Use for custom environments with dependencies, e.g.:
- {"environments": [{"environment_key": "my_env", "spec": {"client": "4", "dependencies": ["pandas"]}}]}
-
- Returns:
- ServerlessRunResult with output, error, run_id, run_url, and timing info.
- """
- if not code or not code.strip():
- return ServerlessRunResult(
- success=False,
- error="Code cannot be empty.",
- state="INVALID_INPUT",
- message="No code provided to execute.",
- )
-
- # Auto-detect .ipynb content
- is_jupyter = _is_ipynb(code)
-
- language = language.lower()
- if not is_jupyter and language not in _LANGUAGE_MAP:
- return ServerlessRunResult(
- success=False,
- error=f"Unsupported language: {language!r}. Must be 'python' or 'sql'.",
- state="INVALID_INPUT",
- message=f"Unsupported language {language!r}. Use 'python' or 'sql'.",
- )
-
- unique_id = uuid.uuid4().hex[:12]
- if not run_name:
- run_name = f"ai_dev_kit_serverless_{unique_id}"
-
- # Persistent mode: user-specified path, never cleanup
- if workspace_path:
- notebook_path = workspace_path
- cleanup = False
- else:
- notebook_path = _get_temp_notebook_path(f"serverless_{unique_id}")
-
- start_time = time.time()
- w = get_workspace_client()
-
- # --- Step 1: Upload code as a notebook ---
- try:
- _upload_temp_notebook(code, language, notebook_path, is_jupyter=is_jupyter)
- except Exception as e:
- return ServerlessRunResult(
- success=False,
- error=f"Failed to upload code to workspace: {e}",
- state="UPLOAD_FAILED",
- message=f"Could not upload temporary notebook: {e}",
- )
-
- run_id = None
- run_url = None
-
- try:
- # --- Step 2: Submit serverless run ---
- try:
- # Build submit kwargs, allowing job_extra_params to override defaults
- extra = job_extra_params or {}
-
- # Determine environment_key for the task
- env_key = "Default"
- if "environments" in extra and extra["environments"]:
- # Use the first environment's key from extra params
- env_key = extra["environments"][0].get("environment_key", "Default")
-
- submit_kwargs = {
- "run_name": run_name,
- "tasks": [
- SubmitTask(
- task_key="main",
- notebook_task=NotebookTask(notebook_path=notebook_path),
- environment_key=env_key,
- )
- ],
- }
-
- # Use custom environments if provided, otherwise use default
- if "environments" not in extra:
- submit_kwargs["environments"] = [
- JobEnvironment(
- environment_key="Default",
- spec=Environment(client="1"),
- )
- ]
-
- # Merge any extra params (environments, timeout_seconds, etc.)
- submit_kwargs.update(extra)
-
- wait = w.jobs.submit(**submit_kwargs)
- # Extract run_id from the Wait object
- run_id = getattr(wait, "run_id", None)
- if run_id is None and hasattr(wait, "response"):
- run_id = getattr(wait.response, "run_id", None)
-
- # Get the canonical run URL immediately via get_run so the user
- # can monitor progress even before the run completes.
- if run_id:
- try:
- initial_run = w.jobs.get_run(run_id=run_id)
- run_url = initial_run.run_page_url
- except Exception:
- pass # Fall back to no URL rather than a guessed one
-
- except Exception as e:
- return ServerlessRunResult(
- success=False,
- error=f"Failed to submit serverless run: {e}",
- state="SUBMIT_FAILED",
- message=f"Jobs API runs/submit call failed: {e}",
- )
-
- # --- Step 3: Wait for completion ---
- try:
- run = wait.result(timeout=datetime.timedelta(seconds=timeout))
- except TimeoutError:
- elapsed = time.time() - start_time
- return ServerlessRunResult(
- success=False,
- error=f"Run timed out after {timeout}s.",
- run_id=run_id,
- run_url=run_url,
- duration_seconds=round(elapsed, 2),
- state="TIMEDOUT",
- message=(f"Serverless run {run_id} did not complete within {timeout}s. Check status at {run_url}"),
- )
- except Exception as e:
- elapsed = time.time() - start_time
- error_text = str(e)
-
- # Best-effort: retrieve the actual error traceback from run output
- if run_id:
- try:
- failed_run = w.jobs.get_run(run_id=run_id)
- if failed_run.tasks:
- task_run_id = failed_run.tasks[0].run_id
- output_data = _get_run_output(task_run_id)
- if output_data.get("error"):
- error_text = output_data["error"]
- except Exception:
- pass # Fall back to the original exception message
-
- return ServerlessRunResult(
- success=False,
- error=error_text,
- run_id=run_id,
- run_url=run_url,
- duration_seconds=round(elapsed, 2),
- state="FAILED",
- message=f"Run {run_id} failed: {e}",
- )
-
- elapsed = time.time() - start_time
-
- # --- Step 4: Determine result state ---
- result_state = None
- state_message = None
- if run.state:
- result_state = run.state.result_state
- state_message = run.state.state_message
-
- # Prefer the canonical URL from the Run object
- if run.run_page_url:
- run_url = run.run_page_url
-
- is_success = result_state == RunResultState.SUCCESS
- state_str = result_state.value if result_state else "UNKNOWN"
-
- # --- Step 5: Retrieve output ---
- task_run_id = None
- if run.tasks:
- task_run_id = run.tasks[0].run_id
-
- output_text = None
- error_text = None
-
- if task_run_id:
- output_data = _get_run_output(task_run_id)
- output_text = output_data["output"]
- error_text = output_data["error"]
-
- # Fallback error from state message
- if not is_success and not error_text:
- error_text = state_message or f"Run ended with state: {state_str}"
-
- if is_success:
- if not output_text:
- output_text = "Success (no output)"
- message = f"Code executed successfully on serverless compute in {round(elapsed, 1)}s."
- else:
- message = f"Serverless run failed with state {state_str}. Check {run_url} for details."
-
- return ServerlessRunResult(
- success=is_success,
- output=output_text if is_success else None,
- error=error_text if not is_success else None,
- run_id=run_id,
- run_url=run_url,
- duration_seconds=round(elapsed, 2),
- state=state_str,
- message=message,
- workspace_path=notebook_path if workspace_path else None,
- )
-
- finally:
- # --- Step 6: Cleanup temporary notebook ---
- if cleanup:
- _cleanup_temp_notebook(notebook_path)
diff --git a/databricks-tools-core/databricks_tools_core/dabs/__init__.py b/databricks-tools-core/databricks_tools_core/dabs/__init__.py
deleted file mode 100644
index 68b9a176..00000000
--- a/databricks-tools-core/databricks_tools_core/dabs/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Databricks Asset Bundle (DAB) generation - Coming Soon"""
diff --git a/databricks-tools-core/databricks_tools_core/file/__init__.py b/databricks-tools-core/databricks_tools_core/file/__init__.py
deleted file mode 100644
index 77e2ec08..00000000
--- a/databricks-tools-core/databricks_tools_core/file/__init__.py
+++ /dev/null
@@ -1,24 +0,0 @@
-"""
-File - Workspace File Operations
-
-Functions for uploading and deleting files and folders in Databricks Workspace.
-
-Note: For Unity Catalog Volume file operations, use the unity_catalog module.
-"""
-
-from .workspace import (
- UploadResult,
- FolderUploadResult,
- DeleteResult,
- upload_to_workspace,
- delete_from_workspace,
-)
-
-__all__ = [
- # Workspace file operations
- "UploadResult",
- "FolderUploadResult",
- "DeleteResult",
- "upload_to_workspace",
- "delete_from_workspace",
-]
diff --git a/databricks-tools-core/databricks_tools_core/file/workspace.py b/databricks-tools-core/databricks_tools_core/file/workspace.py
deleted file mode 100644
index dc22d875..00000000
--- a/databricks-tools-core/databricks_tools_core/file/workspace.py
+++ /dev/null
@@ -1,715 +0,0 @@
-"""
-File - Workspace File Operations
-
-Functions for uploading files and folders to Databricks Workspace.
-Uses Databricks Workspace API via SDK.
-"""
-
-import glob
-import io
-import os
-import re
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from dataclasses import dataclass, field
-from pathlib import Path
-from typing import List, Optional
-
-from databricks.sdk import WorkspaceClient
-import base64
-
-from databricks.sdk.service.workspace import ImportFormat, Language
-
-from ..auth import get_workspace_client
-
-
-@dataclass
-class UploadResult:
- """Result from a single file upload"""
-
- local_path: str
- remote_path: str
- success: bool
- error: Optional[str] = None
-
-
-@dataclass
-class FolderUploadResult:
- """Result from uploading a folder or multiple files"""
-
- local_folder: str
- remote_folder: str
- total_files: int
- successful: int
- failed: int
- results: List[UploadResult] = field(default_factory=list)
-
- @property
- def success(self) -> bool:
- """Returns True if all files were uploaded successfully"""
- return self.failed == 0 and self.total_files > 0
-
- def get_failed_uploads(self) -> List[UploadResult]:
- """Returns list of failed uploads"""
- return [r for r in self.results if not r.success]
-
-
-@dataclass
-class DeleteResult:
- """Result from a workspace delete operation"""
-
- workspace_path: str
- success: bool
- error: Optional[str] = None
-
-
-# Notebook markers for each language
-_NOTEBOOK_MARKERS = {
- Language.PYTHON: b"# Databricks notebook source",
- Language.SQL: b"-- Databricks notebook source",
- Language.SCALA: b"// Databricks notebook source",
- Language.R: b"# Databricks notebook source",
-}
-
-
-def _detect_notebook_language(local_path: str, content: bytes) -> Optional[Language]:
- """
- Detect if a file is a Databricks notebook and return its language.
-
- Notebooks are identified by their marker comment at the start of the file.
- This is required because workspace.upload() creates FILE objects, but
- jobs/pipelines require NOTEBOOK objects.
-
- Args:
- local_path: Path to the file (used for extension-based language hint)
- content: File content as bytes
-
- Returns:
- Language enum if file is a notebook, None otherwise
- """
- # Check for notebook markers in content
- for lang, marker in _NOTEBOOK_MARKERS.items():
- if content.startswith(marker):
- return lang
-
- return None
-
-
-def _upload_single_file(w: WorkspaceClient, local_path: str, remote_path: str, overwrite: bool = True) -> UploadResult:
- """
- Upload a single file to Databricks workspace.
-
- Notebooks (files with Databricks notebook markers) are imported using
- workspace.import_() with SOURCE format to create NOTEBOOK objects.
- Regular files use workspace.upload() with AUTO format.
-
- Args:
- w: WorkspaceClient instance
- local_path: Path to local file
- remote_path: Target path in workspace
- overwrite: Whether to overwrite existing files
-
- Returns:
- UploadResult with success status
- """
- try:
- with open(local_path, "rb") as f:
- content = f.read()
-
- # Check if this is a Databricks notebook
- notebook_language = _detect_notebook_language(local_path, content)
-
- if notebook_language:
- # Use import_() with SOURCE format for notebooks
- # This creates NOTEBOOK objects that jobs/pipelines can run
- w.workspace.import_(
- path=remote_path,
- content=base64.b64encode(content).decode("utf-8"),
- format=ImportFormat.SOURCE,
- language=notebook_language,
- overwrite=overwrite,
- )
- else:
- # Use upload() with AUTO format for regular files
- w.workspace.upload(
- path=remote_path,
- content=io.BytesIO(content),
- format=ImportFormat.AUTO,
- overwrite=overwrite,
- )
-
- return UploadResult(local_path=local_path, remote_path=remote_path, success=True)
-
- except Exception as e:
- error_msg = str(e).lower()
- # Handle type mismatch errors (e.g., overwriting notebook with file or vice versa)
- # When overwrite=True, delete the existing item and retry
- if overwrite and "type mismatch" in error_msg:
- try:
- w.workspace.delete(remote_path)
- # Retry with same logic
- notebook_language = _detect_notebook_language(local_path, content)
- if notebook_language:
- w.workspace.import_(
- path=remote_path,
- content=base64.b64encode(content).decode("utf-8"),
- format=ImportFormat.SOURCE,
- language=notebook_language,
- overwrite=False,
- )
- else:
- w.workspace.upload(
- path=remote_path,
- content=io.BytesIO(content),
- format=ImportFormat.AUTO,
- overwrite=False,
- )
- return UploadResult(local_path=local_path, remote_path=remote_path, success=True)
- except Exception as retry_error:
- return UploadResult(
- local_path=local_path, remote_path=remote_path, success=False, error=str(retry_error)
- )
- return UploadResult(local_path=local_path, remote_path=remote_path, success=False, error=str(e))
-
-
-def _collect_files(local_folder: str) -> List[tuple]:
- """
- Collect all files in a folder recursively.
-
- Args:
- local_folder: Path to local folder
-
- Returns:
- List of (local_path, relative_path) tuples
- """
- files = []
- local_folder = os.path.abspath(local_folder)
-
- for dirpath, _, filenames in os.walk(local_folder):
- for filename in filenames:
- # Skip hidden files and __pycache__
- if filename.startswith(".") or "__pycache__" in dirpath:
- continue
-
- local_path = os.path.join(dirpath, filename)
- rel_path = os.path.relpath(local_path, local_folder)
- files.append((local_path, rel_path))
-
- return files
-
-
-def _collect_directories(local_folder: str) -> List[str]:
- """
- Collect all directories in a folder recursively.
-
- Args:
- local_folder: Path to local folder
-
- Returns:
- List of relative directory paths
- """
- directories = set()
- local_folder = os.path.abspath(local_folder)
-
- for dirpath, dirnames, _ in os.walk(local_folder):
- # Skip hidden directories and __pycache__
- dirnames[:] = [d for d in dirnames if not d.startswith(".") and d != "__pycache__"]
-
- for dirname in dirnames:
- full_path = os.path.join(dirpath, dirname)
- rel_path = os.path.relpath(full_path, local_folder)
- directories.add(rel_path)
- # Also add parent directories
- parent = Path(rel_path).parent
- while str(parent) != ".":
- directories.add(str(parent))
- parent = parent.parent
-
- return sorted(directories)
-
-
-def upload_folder(
- local_folder: str, workspace_folder: str, max_workers: int = 10, overwrite: bool = True
-) -> FolderUploadResult:
- """
- Upload an entire local folder to Databricks workspace.
-
- Uses parallel uploads with ThreadPoolExecutor for performance.
- Automatically handles all file types using ImportFormat.AUTO.
-
- Follows `cp -r` semantics:
- - With trailing slash or /* (e.g., "pipeline/" or "pipeline/*"): copies contents into workspace_folder
- - Without trailing slash (e.g., "pipeline"): creates workspace_folder/pipeline/
-
- Args:
- local_folder: Path to local folder to upload. Add trailing slash to copy
- contents only, omit to preserve folder name.
- workspace_folder: Target path in Databricks workspace
- (e.g., "/Workspace/Users/user@example.com/my-project")
- max_workers: Maximum number of parallel upload threads (default: 10)
- overwrite: Whether to overwrite existing files (default: True)
-
- Returns:
- FolderUploadResult with upload statistics and individual results
-
- Raises:
- FileNotFoundError: If local folder doesn't exist
- ValueError: If local folder is not a directory
-
- Example:
- >>> # Copy folder preserving name: creates /Workspace/.../dest/my-project/
- >>> result = upload_folder(
- ... local_folder="/path/to/my-project",
- ... workspace_folder="/Workspace/Users/me@example.com/dest"
- ... )
- >>> # Copy contents only: files go directly into /Workspace/.../dest/
- >>> result = upload_folder(
- ... local_folder="/path/to/my-project/",
- ... workspace_folder="/Workspace/Users/me@example.com/dest"
- ... )
- >>> print(f"Uploaded {result.successful}/{result.total_files} files")
- >>> if not result.success:
- ... for failed in result.get_failed_uploads():
- ... print(f"Failed: {failed.local_path} - {failed.error}")
- """
- # Check if user wants to copy contents only (trailing slash or /*) or preserve folder name
- # Supports: "folder/", "folder/*", "folder\\*" (Windows)
- copy_contents_suffixes = ("/", os.sep, "/*", os.sep + "*")
- copy_contents_only = local_folder.endswith(copy_contents_suffixes)
-
- # Strip /* or * suffix before validation
- clean_local_folder = local_folder.rstrip("*").rstrip("/").rstrip(os.sep)
-
- # Validate local folder
- local_folder_abs = os.path.abspath(clean_local_folder)
- if not os.path.exists(local_folder_abs):
- raise FileNotFoundError(f"Local folder not found: {local_folder_abs}")
- if not os.path.isdir(local_folder_abs):
- raise ValueError(f"Path is not a directory: {local_folder_abs}")
-
- # Normalize workspace path (remove trailing slash)
- workspace_folder = workspace_folder.rstrip("/")
-
- # If not copying contents only, append the source folder name to destination
- if not copy_contents_only:
- folder_name = os.path.basename(local_folder_abs)
- workspace_folder = f"{workspace_folder}/{folder_name}"
-
- # Use absolute path for file collection
- local_folder = local_folder_abs
-
- # Initialize client
- w = get_workspace_client()
-
- # Create all directories first
- directories = _collect_directories(local_folder)
- for dir_path in directories:
- remote_dir = f"{workspace_folder}/{dir_path}"
- try:
- w.workspace.mkdirs(remote_dir)
- except Exception:
- # Directory might already exist, ignore
- pass
-
- # Create the root directory too
- try:
- w.workspace.mkdirs(workspace_folder)
- except Exception:
- pass
-
- # Collect all files
- files = _collect_files(local_folder)
-
- if not files:
- return FolderUploadResult(
- local_folder=local_folder,
- remote_folder=workspace_folder,
- total_files=0,
- successful=0,
- failed=0,
- results=[],
- )
-
- # Upload files in parallel
- results = []
- successful = 0
- failed = 0
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- # Submit all upload tasks
- future_to_file = {}
- for local_path, rel_path in files:
- # Convert Windows paths to forward slashes for workspace
- remote_path = f"{workspace_folder}/{rel_path.replace(os.sep, '/')}"
- future = executor.submit(_upload_single_file, w, local_path, remote_path, overwrite)
- future_to_file[future] = (local_path, remote_path)
-
- # Collect results as they complete
- for future in as_completed(future_to_file):
- result = future.result()
- results.append(result)
- if result.success:
- successful += 1
- else:
- failed += 1
-
- return FolderUploadResult(
- local_folder=local_folder,
- remote_folder=workspace_folder,
- total_files=len(files),
- successful=successful,
- failed=failed,
- results=results,
- )
-
-
-def upload_file(local_path: str, workspace_path: str, overwrite: bool = True) -> UploadResult:
- """
- Upload a single file to Databricks workspace.
-
- Args:
- local_path: Path to local file
- workspace_path: Target path in Databricks workspace
- overwrite: Whether to overwrite existing file (default: True)
-
- Returns:
- UploadResult with success status
-
- Example:
- >>> result = upload_file(
- ... local_path="/path/to/script.py",
- ... workspace_path="/Users/me@example.com/scripts/script.py"
- ... )
- >>> if result.success:
- ... print("Upload complete")
- ... else:
- ... print(f"Error: {result.error}")
- """
- if not os.path.exists(local_path):
- return UploadResult(
- local_path=local_path,
- remote_path=workspace_path,
- success=False,
- error=f"Local file not found: {local_path}",
- )
-
- if not os.path.isfile(local_path):
- return UploadResult(
- local_path=local_path,
- remote_path=workspace_path,
- success=False,
- error=f"Path is not a file: {local_path}",
- )
-
- w = get_workspace_client()
-
- # Create parent directory if needed
- parent_dir = str(Path(workspace_path).parent)
- if parent_dir != "/":
- try:
- w.workspace.mkdirs(parent_dir)
- except Exception:
- pass
-
- return _upload_single_file(w, local_path, workspace_path, overwrite)
-
-
-def _is_protected_path(workspace_path: str) -> bool:
- """
- Check if a workspace path is protected from deletion.
-
- Protected paths include:
- - Root paths (/, /Workspace, /Users, /Repos)
- - User home folders (/Workspace/Users/user@example.com, /Users/user@example.com)
- - Repos user roots (/Workspace/Repos/user@example.com, /Repos/user@example.com)
- - Shared folder root (/Workspace/Shared)
-
- Args:
- workspace_path: Path to check
-
- Returns:
- True if the path is protected, False otherwise
- """
- # Normalize path: remove trailing slashes
- path = workspace_path.rstrip("/")
-
- # Root paths are always protected
- protected_roots = {
- "",
- "/",
- "/Workspace",
- "/Workspace/Users",
- "/Workspace/Repos",
- "/Workspace/Shared",
- "/Users",
- "/Repos",
- }
- if path in protected_roots:
- return True
-
- # User home folders: /Workspace/Users/user@example.com or /Users/user@example.com
- # Pattern: exactly one level below Users (the email)
- user_home_pattern = r"^(/Workspace)?/Users/[^/]+$"
- if re.match(user_home_pattern, path):
- return True
-
- # Repos user roots: /Workspace/Repos/user@example.com or /Repos/user@example.com
- repos_pattern = r"^(/Workspace)?/Repos/[^/]+$"
- if re.match(repos_pattern, path):
- return True
-
- return False
-
-
-def upload_to_workspace(
- local_path: str,
- workspace_path: str,
- max_workers: int = 10,
- overwrite: bool = True,
-) -> FolderUploadResult:
- """
- Upload files or folders to Databricks workspace.
-
- Handles single files, folders, and glob patterns. This is the unified upload
- function that replaces both upload_file and upload_folder.
-
- Args:
- local_path: Path to local file, folder, or glob pattern.
- - Single file: "/path/to/file.py"
- - Folder: "/path/to/folder" (preserves folder name)
- - Folder contents: "/path/to/folder/" or "/path/to/folder/*"
- - Glob pattern: "/path/to/*.py"
- - Tilde expansion: "~/projects/file.py"
- workspace_path: Target path in Databricks workspace
- max_workers: Maximum parallel upload threads (default: 10)
- overwrite: Whether to overwrite existing files (default: True)
-
- Returns:
- FolderUploadResult with upload statistics
-
- Example:
- >>> # Upload single file
- >>> result = upload_to_workspace(
- ... local_path="/path/to/script.py",
- ... workspace_path="/Workspace/Users/me@example.com/script.py",
- ... )
- >>> # Upload folder preserving name
- >>> result = upload_to_workspace(
- ... local_path="/path/to/project",
- ... workspace_path="/Workspace/Users/me@example.com/dest",
- ... )
- >>> # Upload folder contents only
- >>> result = upload_to_workspace(
- ... local_path="/path/to/project/",
- ... workspace_path="/Workspace/Users/me@example.com/dest",
- ... )
- >>> # Upload with glob pattern
- >>> result = upload_to_workspace(
- ... local_path="/path/to/*.py",
- ... workspace_path="/Workspace/Users/me@example.com/scripts",
- ... )
- """
- # Expand ~ in path
- local_path = os.path.expanduser(local_path)
-
- # Normalize workspace path (remove trailing slash)
- workspace_path = workspace_path.rstrip("/")
-
- # Check if this is a glob pattern (contains * or ?)
- has_glob = "*" in local_path or "?" in local_path
-
- if has_glob:
- return _upload_glob_pattern(local_path, workspace_path, max_workers, overwrite)
-
- # Check if path exists
- if not os.path.exists(local_path.rstrip("/")):
- error_result = UploadResult(
- local_path=local_path,
- remote_path=workspace_path,
- success=False,
- error=f"Path not found: {local_path}",
- )
- return FolderUploadResult(
- local_folder=local_path,
- remote_folder=workspace_path,
- total_files=1,
- successful=0,
- failed=1,
- results=[error_result],
- )
-
- # Single file
- if os.path.isfile(local_path):
- result = upload_file(local_path, workspace_path, overwrite)
- return FolderUploadResult(
- local_folder=local_path,
- remote_folder=workspace_path,
- total_files=1,
- successful=1 if result.success else 0,
- failed=0 if result.success else 1,
- results=[result],
- )
-
- # Directory - use existing upload_folder logic
- return upload_folder(local_path, workspace_path, max_workers, overwrite)
-
-
-def _upload_glob_pattern(
- pattern: str,
- workspace_path: str,
- max_workers: int = 10,
- overwrite: bool = True,
-) -> FolderUploadResult:
- """
- Upload files matching a glob pattern.
-
- Args:
- pattern: Glob pattern (e.g., "*.py", "**/*.sql")
- workspace_path: Target workspace folder
- max_workers: Maximum parallel upload threads
- overwrite: Whether to overwrite existing files
-
- Returns:
- FolderUploadResult with upload statistics
- """
- # Expand the glob pattern
- matches = glob.glob(pattern, recursive=True)
-
- if not matches:
- error_result = UploadResult(
- local_path=pattern,
- remote_path=workspace_path,
- success=False,
- error=f"No files match pattern: {pattern}",
- )
- return FolderUploadResult(
- local_folder=pattern,
- remote_folder=workspace_path,
- total_files=1,
- successful=0,
- failed=1,
- results=[error_result],
- )
-
- # Separate files and directories
- files = [m for m in matches if os.path.isfile(m)]
- dirs = [m for m in matches if os.path.isdir(m)]
-
- # Get the base directory from the pattern for relative path calculation
- pattern_base = os.path.dirname(pattern.split("*")[0].rstrip("/")) or "."
- pattern_base = os.path.abspath(pattern_base)
-
- w = get_workspace_client()
-
- # Create workspace directory
- try:
- w.workspace.mkdirs(workspace_path)
- except Exception:
- pass
-
- results = []
- successful = 0
- failed = 0
-
- # Upload files from matched directories
- for dir_path in dirs:
- dir_files = _collect_files(dir_path)
- for local_file, rel_path in dir_files:
- # Calculate relative path from pattern base
- dir_name = os.path.basename(dir_path)
- remote_path = f"{workspace_path}/{dir_name}/{rel_path.replace(os.sep, '/')}"
-
- # Create parent directory
- parent_dir = str(Path(remote_path).parent)
- try:
- w.workspace.mkdirs(parent_dir)
- except Exception:
- pass
-
- result = _upload_single_file(w, local_file, remote_path, overwrite)
- results.append(result)
- if result.success:
- successful += 1
- else:
- failed += 1
-
- # Upload individual files
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_file = {}
- for local_file in files:
- # Use just the filename for the remote path
- filename = os.path.basename(local_file)
- remote_path = f"{workspace_path}/{filename}"
- future = executor.submit(_upload_single_file, w, local_file, remote_path, overwrite)
- future_to_file[future] = (local_file, remote_path)
-
- for future in as_completed(future_to_file):
- result = future.result()
- results.append(result)
- if result.success:
- successful += 1
- else:
- failed += 1
-
- return FolderUploadResult(
- local_folder=pattern,
- remote_folder=workspace_path,
- total_files=len(results),
- successful=successful,
- failed=failed,
- results=results,
- )
-
-
-def delete_from_workspace(
- workspace_path: str,
- recursive: bool = False,
-) -> DeleteResult:
- """
- Delete a file or folder from Databricks workspace.
-
- Includes safety checks to prevent accidental deletion of protected paths
- like user home folders, repos roots, and shared folder roots.
-
- Args:
- workspace_path: Path to delete in Databricks workspace
- recursive: If True, delete folder and all contents (default: False)
-
- Returns:
- DeleteResult with success status
-
- Example:
- >>> # Delete a single file
- >>> result = delete_from_workspace(
- ... workspace_path="/Workspace/Users/me@example.com/old_file.py",
- ... )
- >>> # Delete a folder recursively
- >>> result = delete_from_workspace(
- ... workspace_path="/Workspace/Users/me@example.com/old_project",
- ... recursive=True,
- ... )
- """
- # Normalize path
- workspace_path = workspace_path.rstrip("/")
-
- # Safety check: prevent deletion of protected paths
- if _is_protected_path(workspace_path):
- return DeleteResult(
- workspace_path=workspace_path,
- success=False,
- error=f"Cannot delete protected path: {workspace_path}. "
- "User home folders, repos roots, and system folders are protected.",
- )
-
- try:
- w = get_workspace_client()
- w.workspace.delete(workspace_path, recursive=recursive)
- return DeleteResult(
- workspace_path=workspace_path,
- success=True,
- )
- except Exception as e:
- return DeleteResult(
- workspace_path=workspace_path,
- success=False,
- error=str(e),
- )
diff --git a/databricks-tools-core/databricks_tools_core/identity.py b/databricks-tools-core/databricks_tools_core/identity.py
deleted file mode 100644
index db9b4493..00000000
--- a/databricks-tools-core/databricks_tools_core/identity.py
+++ /dev/null
@@ -1,240 +0,0 @@
-"""Product identity, project detection, and resource tagging.
-
-Every Databricks API call made through the SDK includes a User-Agent header.
-This module sets a custom product name and auto-detects the project name from
-a config file or git, so all calls are attributable in the
-``system.access.audit`` system table.
-
-Resources created by the MCP server are also tagged with project metadata
-and any freeform tags defined in ``.databricks-ai-dev-kit.yaml``.
-
-Example user-agent string::
-
- databricks-ai-dev-kit/0.1.0 databricks-sdk-py/0.73.0 python/3.11.13 os/darwin auth/pat project/my-repo
-
-Example config file (``.databricks-ai-dev-kit.yaml`` at repo root)::
-
- project: my-sales-dashboard
- tags:
- team: data-eng
- env: dev
-"""
-
-import logging
-import os
-import re
-import subprocess
-from pathlib import Path
-from typing import Any, Dict, Optional
-
-import yaml
-from databricks.sdk import WorkspaceClient
-
-logger = logging.getLogger(__name__)
-
-PRODUCT_NAME = "databricks-ai-dev-kit"
-
-DESCRIPTION_FOOTER = "Built with Databricks AI Dev Kit"
-
-
-def _load_version() -> str:
- """Load version from the repository root ``VERSION`` file.
-
- Searches upward from this module's directory for a ``VERSION`` file.
- Falls back to ``"0.0.0-unknown"`` if not found.
- """
- fallback = "0.0.0-unknown"
- try:
- d = Path(__file__).resolve().parent
- for _ in range(6): # walk up at most 6 levels
- candidate = d / "VERSION"
- if candidate.is_file():
- version = candidate.read_text().strip()
- logger.debug("Loaded version %s from %s", version, candidate)
- return version
- if d.parent == d:
- break
- d = d.parent
- except Exception:
- logger.debug("Failed to read VERSION file", exc_info=True)
- logger.warning("VERSION file not found; falling back to %s", fallback)
- return fallback
-
-
-PRODUCT_VERSION = _load_version()
-
-_CONFIG_FILENAME = ".databricks-ai-dev-kit.yaml"
-
-_cached_project: Optional[str] = None
-_cached_config: Optional[Dict[str, Any]] = None
-
-
-def _sanitize_project_name(name: str) -> str:
- """Sanitize a project name for use in a user-agent string.
-
- Keeps only alphanumeric characters, hyphens, underscores, and dots.
- """
- sanitized = re.sub(r"[^a-zA-Z0-9._-]", "-", name)
- # Collapse multiple hyphens and strip leading/trailing hyphens
- sanitized = re.sub(r"-+", "-", sanitized).strip("-")
- return sanitized or "unknown"
-
-
-def _git_toplevel() -> Optional[str]:
- """Return the git repo root directory, or None."""
- try:
- result = subprocess.run(
- ["git", "rev-parse", "--show-toplevel"],
- stdin=subprocess.DEVNULL,
- capture_output=True,
- text=True,
- timeout=5,
- )
- if result.returncode == 0 and result.stdout.strip():
- return result.stdout.strip()
- except Exception:
- pass
- return None
-
-
-def _load_config() -> Dict[str, Any]:
- """Load ``.databricks-ai-dev-kit.yaml`` from the git repo root or cwd.
-
- The result is cached for the lifetime of the process.
-
- Returns:
- Parsed config dict, or empty dict if the file doesn't exist.
- """
- global _cached_config
- if _cached_config is not None:
- return _cached_config
-
- search_dirs = [_git_toplevel(), os.getcwd()]
- for directory in search_dirs:
- if directory:
- config_path = os.path.join(directory, _CONFIG_FILENAME)
- if os.path.isfile(config_path):
- try:
- with open(config_path) as f:
- _cached_config = yaml.safe_load(f) or {}
- logger.debug("Loaded config from %s", config_path)
- return _cached_config
- except Exception:
- logger.warning("Failed to parse %s", config_path, exc_info=True)
-
- _cached_config = {}
- return _cached_config
-
-
-def detect_project_name() -> str:
- """Detect the project name. Cached after first call.
-
- Detection priority:
- 1. ``project`` field in ``.databricks-ai-dev-kit.yaml``
- 2. Git remote origin URL → extract repository name
- 3. Git repo root directory → use basename
- 4. Current working directory → use basename
-
- Returns:
- A sanitized project name string.
- """
- global _cached_project
- if _cached_project is not None:
- return _cached_project
-
- name: Optional[str] = None
-
- # Priority 1: config file
- config = _load_config()
- config_project = config.get("project")
- if config_project and isinstance(config_project, str):
- name = config_project
-
- # Priority 2: git remote origin URL
- if not name:
- try:
- result = subprocess.run(
- ["git", "remote", "get-url", "origin"],
- stdin=subprocess.DEVNULL,
- capture_output=True,
- text=True,
- timeout=5,
- )
- if result.returncode == 0 and result.stdout.strip():
- url = result.stdout.strip()
- # Handle both HTTPS and SSH URLs
- # https://github.com/org/repo.git -> repo
- # git@github.com:org/repo.git -> repo
- name = url.rstrip("/").split("/")[-1].removesuffix(".git")
- except Exception:
- pass
-
- # Priority 3: git repo root directory name
- if not name:
- toplevel = _git_toplevel()
- if toplevel:
- name = os.path.basename(toplevel)
-
- # Priority 4: current working directory name
- if not name:
- name = os.path.basename(os.getcwd())
-
- _cached_project = _sanitize_project_name(name or "unknown")
- logger.debug("Detected project name: %s", _cached_project)
- return _cached_project
-
-
-def get_default_tags() -> Dict[str, str]:
- """Get tags to apply to all created resources.
-
- Merges:
- 1. ``created_by: databricks-ai-dev-kit`` (always)
- 2. ``project: `` (always)
- 3. Freeform tags from ``.databricks-ai-dev-kit.yaml`` ``tags:`` section
-
- Config file tags do not override ``created_by`` or ``project``.
- User-provided tags at resource creation time take precedence over
- all default tags (handled by callers).
-
- Returns:
- Dictionary of tag key-value pairs.
- """
- tags: Dict[str, str] = {
- "created_by": PRODUCT_NAME,
- "aidevkit_project": detect_project_name(),
- }
- # Merge freeform tags from config file
- config = _load_config()
- config_tags = config.get("tags", {})
- if isinstance(config_tags, dict):
- for k, v in config_tags.items():
- # Config tags don't override the hardcoded created_by / project
- tags.setdefault(str(k), str(v))
- return tags
-
-
-def tag_client(client: WorkspaceClient) -> WorkspaceClient:
- """Add project identifier to a WorkspaceClient's user-agent.
-
- Call this after creating a ``WorkspaceClient`` to append the
- ``project/`` tag to the user-agent header.
-
- Args:
- client: A ``WorkspaceClient`` instance (already created with
- ``product=PRODUCT_NAME, product_version=PRODUCT_VERSION``).
-
- Returns:
- The same client, for chaining.
- """
- client.config.with_user_agent_extra("project", detect_project_name())
- return client
-
-
-def with_description_footer(description: Optional[str]) -> str:
- """Append the AI Dev Kit footer to a resource description.
-
- If *description* is empty or ``None``, returns just the footer.
- """
- if not description:
- return DESCRIPTION_FOOTER
- return f"{description}\n\n{DESCRIPTION_FOOTER}"
diff --git a/databricks-tools-core/databricks_tools_core/jobs/__init__.py b/databricks-tools-core/databricks_tools_core/jobs/__init__.py
deleted file mode 100644
index a0ff1c40..00000000
--- a/databricks-tools-core/databricks_tools_core/jobs/__init__.py
+++ /dev/null
@@ -1,87 +0,0 @@
-"""
-Jobs Module - Databricks Jobs API
-
-This module provides functions for managing Databricks jobs and job runs.
-Uses serverless compute by default for optimal performance and cost.
-
-Core Operations:
-- Job CRUD: create_job, update_job, delete_job, get_job, list_jobs, find_job_by_name
-- Run Management: run_job_now, get_run, get_run_output, cancel_run, list_runs
-- Run Monitoring: wait_for_run (blocks until completion)
-
-Data Models:
-- JobRunResult: Detailed run result with status, timing, and error info
-- JobStatus, RunLifecycleState, RunResultState: Status enums
-- JobError: Exception class for job-related errors
-
-Example Usage:
- >>> from databricks_tools_core.jobs import (
- ... create_job, run_job_now, wait_for_run
- ... )
- >>>
- >>> # Create a job
- >>> tasks = [{
- ... "task_key": "main",
- ... "notebook_task": {
- ... "notebook_path": "/Workspace/ETL/process",
- ... "source": "WORKSPACE"
- ... }
- ... }]
- >>> job = create_job(name="my_etl_job", tasks=tasks)
- >>>
- >>> # Run the job and wait for completion
- >>> run_id = run_job_now(job_id=job["job_id"])
- >>> result = wait_for_run(run_id=run_id)
- >>> if result.success:
- ... print(f"Job completed in {result.duration_seconds}s")
-"""
-
-# Import all public functions and classes
-from .models import (
- JobStatus,
- RunLifecycleState,
- RunResultState,
- JobRunResult,
- JobError,
-)
-
-from .jobs import (
- list_jobs,
- get_job,
- find_job_by_name,
- create_job,
- update_job,
- delete_job,
-)
-
-from .runs import (
- run_job_now,
- get_run,
- get_run_output,
- cancel_run,
- list_runs,
- wait_for_run,
-)
-
-__all__ = [
- # Models and Enums
- "JobStatus",
- "RunLifecycleState",
- "RunResultState",
- "JobRunResult",
- "JobError",
- # Job CRUD Operations
- "list_jobs",
- "get_job",
- "find_job_by_name",
- "create_job",
- "update_job",
- "delete_job",
- # Run Operations
- "run_job_now",
- "get_run",
- "get_run_output",
- "cancel_run",
- "list_runs",
- "wait_for_run",
-]
diff --git a/databricks-tools-core/databricks_tools_core/jobs/jobs.py b/databricks-tools-core/databricks_tools_core/jobs/jobs.py
deleted file mode 100644
index da5fdae8..00000000
--- a/databricks-tools-core/databricks_tools_core/jobs/jobs.py
+++ /dev/null
@@ -1,393 +0,0 @@
-"""
-Jobs - Core Job CRUD Operations
-
-Functions for managing Databricks jobs using the Jobs API.
-Uses serverless compute by default for optimal performance and cost.
-"""
-
-from typing import Optional, List, Dict, Any
-
-from databricks.sdk.service.jobs import JobSettings
-
-from ..auth import get_workspace_client
-from .models import JobError
-
-
-def list_jobs(
- name: Optional[str] = None,
- limit: int = 25,
- expand_tasks: bool = False,
-) -> List[Dict[str, Any]]:
- """
- List jobs in the workspace.
-
- Args:
- name: Optional name filter (partial match, case-insensitive)
- limit: Maximum number of jobs to return (default: 25)
- expand_tasks: If True, include full task definitions in results
-
- Returns:
- List of job info dicts with job_id, name, creator, created_time, etc.
- """
- w = get_workspace_client()
- jobs = []
-
- # SDK list() returns an iterator - we need to consume it
- for job in w.jobs.list(name=name, expand_tasks=expand_tasks, limit=limit):
- job_dict = {
- "job_id": job.job_id,
- "name": job.settings.name if job.settings else None,
- "creator_user_name": job.creator_user_name,
- "created_time": job.created_time,
- }
-
- # Add additional info if available
- if job.settings:
- job_dict["tags"] = job.settings.tags if hasattr(job.settings, "tags") else None
- job_dict["timeout_seconds"] = (
- job.settings.timeout_seconds if hasattr(job.settings, "timeout_seconds") else None
- )
- job_dict["max_concurrent_runs"] = (
- job.settings.max_concurrent_runs if hasattr(job.settings, "max_concurrent_runs") else None
- )
-
- # Include tasks if expanded
- if expand_tasks and job.settings.tasks:
- job_dict["tasks"] = [task.as_dict() for task in job.settings.tasks]
-
- jobs.append(job_dict)
-
- if len(jobs) >= limit:
- break
-
- return jobs
-
-
-def get_job(job_id: int) -> Dict[str, Any]:
- """
- Get detailed job configuration.
-
- Args:
- job_id: Job ID
-
- Returns:
- Dictionary with full job configuration including tasks, clusters, schedule, etc.
-
- Raises:
- JobError: If job not found or API request fails
- """
- w = get_workspace_client()
-
- try:
- job = w.jobs.get(job_id=job_id)
-
- # Convert SDK object to dict for JSON serialization
- return job.as_dict()
-
- except Exception as e:
- raise JobError(f"Failed to get job {job_id}: {str(e)}", job_id=job_id)
-
-
-def find_job_by_name(name: str) -> Optional[int]:
- """
- Find a job by exact name and return its ID.
-
- Args:
- name: Job name to search for (exact match)
-
- Returns:
- Job ID if found, None otherwise
- """
- w = get_workspace_client()
-
- # List jobs with name filter and find exact match
- for job in w.jobs.list(name=name, limit=100):
- if job.settings and job.settings.name == name:
- return job.job_id
-
- return None
-
-
-def create_job(
- name: str,
- tasks: List[Dict[str, Any]],
- job_clusters: Optional[List[Dict[str, Any]]] = None,
- environments: Optional[List[Dict[str, Any]]] = None,
- tags: Optional[Dict[str, str]] = None,
- timeout_seconds: Optional[int] = None,
- max_concurrent_runs: int = 1,
- email_notifications: Optional[Dict[str, Any]] = None,
- webhook_notifications: Optional[Dict[str, Any]] = None,
- notification_settings: Optional[Dict[str, Any]] = None,
- schedule: Optional[Dict[str, Any]] = None,
- queue: Optional[Dict[str, Any]] = None,
- run_as: Optional[Dict[str, Any]] = None,
- git_source: Optional[Dict[str, Any]] = None,
- parameters: Optional[List[Dict[str, Any]]] = None,
- health: Optional[Dict[str, Any]] = None,
- deployment: Optional[Dict[str, Any]] = None,
- **extra_settings,
-) -> Dict[str, Any]:
- """
- Create a new Databricks job with serverless compute by default.
-
- Args:
- name: Job name
- tasks: List of task definitions (dicts). Each task should have:
- - task_key: Unique identifier
- - description: Optional task description
- - depends_on: Optional list of task dependencies
- - [task_type]: One of spark_python_task, notebook_task, python_wheel_task,
- spark_jar_task, spark_submit_task, pipeline_task, sql_task, dbt_task, run_job_task
- - [compute]: One of new_cluster, existing_cluster_id, job_cluster_key, compute_key
- job_clusters: Optional list of job cluster definitions (for non-serverless tasks)
- environments: Optional list of environment definitions for serverless tasks.
- Each dict should have:
- - environment_key: Unique identifier referenced by tasks via environment_key
- - spec: Dict with dependencies (list of pip packages) and optionally client ("4")
- tags: Optional tags dict for organization
- timeout_seconds: Job-level timeout (0 means no timeout)
- max_concurrent_runs: Maximum number of concurrent runs (default: 1)
- email_notifications: Email notification settings
- webhook_notifications: Webhook notification settings
- notification_settings: Notification settings for run lifecycle events
- schedule: Optional schedule configuration
- queue: Optional queue settings for job queueing
- run_as: Optional run-as user/service principal
- git_source: Optional Git source configuration
- parameters: Optional job parameters
- health: Optional health monitoring rules
- deployment: Optional deployment configuration
- **extra_settings: Additional job settings passed directly to SDK
-
- Returns:
- Dictionary with job_id and other creation metadata
-
- Raises:
- JobError: If job creation fails
-
- Example:
- >>> tasks = [
- ... {
- ... "task_key": "data_ingestion",
- ... "notebook_task": {
- ... "notebook_path": "/Workspace/ETL/ingest",
- ... "source": "WORKSPACE"
- ... }
- ... }
- ... ]
- >>> job = create_job(name="my_etl_job", tasks=tasks)
- >>> print(job["job_id"])
- """
- w = get_workspace_client()
-
- try:
- # Build settings dict - JobSettings.from_dict() handles all nested conversions
- settings_dict: Dict[str, Any] = {
- "name": name,
- "max_concurrent_runs": max_concurrent_runs,
- }
-
- # Add tasks
- if tasks:
- settings_dict["tasks"] = tasks
-
- # Add job_clusters if provided
- if job_clusters:
- settings_dict["job_clusters"] = job_clusters
-
- # Add environments if provided (for serverless tasks with dependencies)
- # Auto-inject "client": "4" into spec if missing to avoid API error:
- # "Either base environment or version must be provided for environment"
- if environments:
- for env in environments:
- if "spec" in env and "client" not in env["spec"]:
- env["spec"]["client"] = "4"
- settings_dict["environments"] = environments
-
- # Add optional parameters
- if tags:
- settings_dict["tags"] = tags
- if timeout_seconds is not None:
- settings_dict["timeout_seconds"] = timeout_seconds
- if email_notifications:
- settings_dict["email_notifications"] = email_notifications
- if webhook_notifications:
- settings_dict["webhook_notifications"] = webhook_notifications
- if notification_settings:
- settings_dict["notification_settings"] = notification_settings
- if schedule:
- settings_dict["schedule"] = schedule
- if queue:
- settings_dict["queue"] = queue
- if run_as:
- settings_dict["run_as"] = run_as
- if git_source:
- settings_dict["git_source"] = git_source
- if parameters:
- settings_dict["parameters"] = parameters
- if health:
- settings_dict["health"] = health
- if deployment:
- settings_dict["deployment"] = deployment
-
- # Add any extra settings
- settings_dict.update(extra_settings)
-
- # Convert entire dict to JobSettings - handles all nested type conversions
- settings = JobSettings.from_dict(settings_dict)
-
- # Create job using the converted SDK objects
- response = w.jobs.create(
- name=settings.name,
- tasks=settings.tasks,
- job_clusters=settings.job_clusters,
- environments=settings.environments,
- tags=settings.tags,
- timeout_seconds=settings.timeout_seconds,
- max_concurrent_runs=settings.max_concurrent_runs,
- email_notifications=settings.email_notifications,
- webhook_notifications=settings.webhook_notifications,
- notification_settings=settings.notification_settings,
- schedule=settings.schedule,
- queue=settings.queue,
- run_as=settings.run_as,
- git_source=settings.git_source,
- parameters=settings.parameters,
- health=settings.health,
- deployment=settings.deployment,
- )
-
- # Convert response to dict
- return response.as_dict()
-
- except Exception as e:
- raise JobError(f"Failed to create job '{name}': {str(e)}")
-
-
-def update_job(
- job_id: int,
- name: Optional[str] = None,
- tasks: Optional[List[Dict[str, Any]]] = None,
- job_clusters: Optional[List[Dict[str, Any]]] = None,
- environments: Optional[List[Dict[str, Any]]] = None,
- tags: Optional[Dict[str, str]] = None,
- timeout_seconds: Optional[int] = None,
- max_concurrent_runs: Optional[int] = None,
- email_notifications: Optional[Dict[str, Any]] = None,
- webhook_notifications: Optional[Dict[str, Any]] = None,
- notification_settings: Optional[Dict[str, Any]] = None,
- schedule: Optional[Dict[str, Any]] = None,
- queue: Optional[Dict[str, Any]] = None,
- run_as: Optional[Dict[str, Any]] = None,
- git_source: Optional[Dict[str, Any]] = None,
- parameters: Optional[List[Dict[str, Any]]] = None,
- health: Optional[Dict[str, Any]] = None,
- deployment: Optional[Dict[str, Any]] = None,
- **extra_settings,
-) -> None:
- """
- Update an existing job's configuration.
-
- Only provided parameters will be updated. To remove a field, explicitly set it to None
- or an empty value.
-
- Args:
- job_id: Job ID to update
- name: New job name
- tasks: New task definitions
- job_clusters: New job cluster definitions
- environments: New environment definitions for serverless tasks with dependencies
- tags: New tags (replaces existing)
- timeout_seconds: New timeout
- max_concurrent_runs: New max concurrent runs
- email_notifications: New email notifications
- webhook_notifications: New webhook notifications
- notification_settings: New notification settings
- schedule: New schedule configuration
- queue: New queue settings
- run_as: New run-as configuration
- git_source: New Git source configuration
- parameters: New job parameters
- health: New health monitoring rules
- deployment: New deployment configuration
- **extra_settings: Additional job settings
-
- Raises:
- JobError: If job update fails
- """
- w = get_workspace_client()
-
- try:
- # Build kwargs for SDK call - must include full new_settings
- # Get current job config first
- current_job = w.jobs.get(job_id=job_id)
-
- # Start with current settings as dict
- new_settings_dict = current_job.settings.as_dict() if current_job.settings else {}
-
- # Update with provided parameters
- if name is not None:
- new_settings_dict["name"] = name
- if tasks is not None:
- new_settings_dict["tasks"] = tasks
- if job_clusters is not None:
- new_settings_dict["job_clusters"] = job_clusters
- if environments is not None:
- new_settings_dict["environments"] = environments
- if tags is not None:
- new_settings_dict["tags"] = tags
- if timeout_seconds is not None:
- new_settings_dict["timeout_seconds"] = timeout_seconds
- if max_concurrent_runs is not None:
- new_settings_dict["max_concurrent_runs"] = max_concurrent_runs
- if email_notifications is not None:
- new_settings_dict["email_notifications"] = email_notifications
- if webhook_notifications is not None:
- new_settings_dict["webhook_notifications"] = webhook_notifications
- if notification_settings is not None:
- new_settings_dict["notification_settings"] = notification_settings
- if schedule is not None:
- new_settings_dict["schedule"] = schedule
- if queue is not None:
- new_settings_dict["queue"] = queue
- if run_as is not None:
- new_settings_dict["run_as"] = run_as
- if git_source is not None:
- new_settings_dict["git_source"] = git_source
- if parameters is not None:
- new_settings_dict["parameters"] = parameters
- if health is not None:
- new_settings_dict["health"] = health
- if deployment is not None:
- new_settings_dict["deployment"] = deployment
-
- # Apply extra settings
- new_settings_dict.update(extra_settings)
-
- # Convert to JobSettings object
- new_settings = JobSettings.from_dict(new_settings_dict)
-
- # Update job
- w.jobs.update(job_id=job_id, new_settings=new_settings)
-
- except Exception as e:
- raise JobError(f"Failed to update job {job_id}: {str(e)}", job_id=job_id)
-
-
-def delete_job(job_id: int) -> None:
- """
- Delete a job.
-
- Args:
- job_id: Job ID to delete
-
- Raises:
- JobError: If job deletion fails
- """
- w = get_workspace_client()
-
- try:
- w.jobs.delete(job_id=job_id)
- except Exception as e:
- raise JobError(f"Failed to delete job {job_id}: {str(e)}", job_id=job_id)
diff --git a/databricks-tools-core/databricks_tools_core/jobs/models.py b/databricks-tools-core/databricks_tools_core/jobs/models.py
deleted file mode 100644
index ee5f0493..00000000
--- a/databricks-tools-core/databricks_tools_core/jobs/models.py
+++ /dev/null
@@ -1,112 +0,0 @@
-"""
-Jobs - Data Models and Enums
-
-Data classes and enums for job operations.
-"""
-
-from dataclasses import dataclass, field
-from enum import Enum
-from typing import Optional, List, Dict, Any
-
-
-class JobStatus(Enum):
- """Job lifecycle status enum."""
-
- RUNNING = "RUNNING"
- QUEUED = "QUEUED"
- TERMINATED = "TERMINATED"
- TERMINATING = "TERMINATING"
- PENDING = "PENDING"
- SKIPPED = "SKIPPED"
- INTERNAL_ERROR = "INTERNAL_ERROR"
-
-
-class RunLifecycleState(Enum):
- """Run lifecycle state enum."""
-
- PENDING = "PENDING"
- RUNNING = "RUNNING"
- TERMINATING = "TERMINATING"
- TERMINATED = "TERMINATED"
- SKIPPED = "SKIPPED"
- INTERNAL_ERROR = "INTERNAL_ERROR"
- QUEUED = "QUEUED"
- WAITING_FOR_RETRY = "WAITING_FOR_RETRY"
- BLOCKED = "BLOCKED"
-
-
-class RunResultState(Enum):
- """Run result state enum."""
-
- SUCCESS = "SUCCESS"
- FAILED = "FAILED"
- TIMEDOUT = "TIMEDOUT"
- CANCELED = "CANCELED"
- EXCLUDED = "EXCLUDED"
- SUCCESS_WITH_FAILURES = "SUCCESS_WITH_FAILURES"
- UPSTREAM_FAILED = "UPSTREAM_FAILED"
- UPSTREAM_CANCELED = "UPSTREAM_CANCELED"
-
-
-@dataclass
-class JobRunResult:
- """
- Result from a job run operation with detailed status for LLM consumption.
-
- This dataclass provides comprehensive information about job runs
- to help LLMs understand what happened and take appropriate action.
- """
-
- # Job identification
- job_id: int
- run_id: int
- job_name: Optional[str] = None
-
- # Run status
- lifecycle_state: Optional[str] = None
- result_state: Optional[str] = None
- success: bool = False
-
- # Timing
- duration_seconds: Optional[float] = None
- start_time: Optional[int] = None # epoch millis
- end_time: Optional[int] = None # epoch millis
-
- # Run details
- run_page_url: Optional[str] = None
- state_message: Optional[str] = None
-
- # Error details (if failed)
- error_message: Optional[str] = None
- errors: List[Dict[str, Any]] = field(default_factory=list)
-
- # Human-readable status
- message: str = ""
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert to dictionary for JSON serialization."""
- return {
- "job_id": self.job_id,
- "run_id": self.run_id,
- "job_name": self.job_name,
- "lifecycle_state": self.lifecycle_state,
- "result_state": self.result_state,
- "success": self.success,
- "duration_seconds": self.duration_seconds,
- "start_time": self.start_time,
- "end_time": self.end_time,
- "run_page_url": self.run_page_url,
- "state_message": self.state_message,
- "error_message": self.error_message,
- "errors": self.errors,
- "message": self.message,
- }
-
-
-class JobError(Exception):
- """Exception raised for job-related errors."""
-
- def __init__(self, message: str, job_id: Optional[int] = None, run_id: Optional[int] = None):
- self.job_id = job_id
- self.run_id = run_id
- super().__init__(message)
diff --git a/databricks-tools-core/databricks_tools_core/jobs/runs.py b/databricks-tools-core/databricks_tools_core/jobs/runs.py
deleted file mode 100644
index 6b74c055..00000000
--- a/databricks-tools-core/databricks_tools_core/jobs/runs.py
+++ /dev/null
@@ -1,376 +0,0 @@
-"""
-Jobs - Run Operations
-
-Functions for triggering and monitoring job runs.
-"""
-
-import time
-from typing import Optional, List, Dict, Any
-
-from databricks.sdk.service.jobs import (
- RunLifeCycleState,
- RunResultState,
-)
-
-from ..auth import get_workspace_client
-from .models import JobRunResult, JobError
-
-
-# Terminal states - run has finished (success or failure)
-TERMINAL_STATES = {
- RunLifeCycleState.TERMINATED,
- RunLifeCycleState.SKIPPED,
- RunLifeCycleState.INTERNAL_ERROR,
-}
-
-# Success states - run completed successfully
-SUCCESS_STATES = {
- RunResultState.SUCCESS,
-}
-
-
-def run_job_now(
- job_id: int,
- idempotency_token: Optional[str] = None,
- jar_params: Optional[List[str]] = None,
- notebook_params: Optional[Dict[str, str]] = None,
- python_params: Optional[List[str]] = None,
- spark_submit_params: Optional[List[str]] = None,
- python_named_params: Optional[Dict[str, str]] = None,
- pipeline_params: Optional[Dict[str, Any]] = None,
- sql_params: Optional[Dict[str, str]] = None,
- dbt_commands: Optional[List[str]] = None,
- queue: Optional[Dict[str, Any]] = None,
- **extra_params,
-) -> int:
- """
- Trigger a job run immediately and return the run ID.
-
- Args:
- job_id: Job ID to run
- idempotency_token: Optional token to ensure idempotent job runs
- jar_params: Parameters for JAR tasks
- notebook_params: Parameters for notebook tasks
- python_params: Parameters for Python tasks
- spark_submit_params: Parameters for spark-submit tasks
- python_named_params: Named parameters for Python tasks
- pipeline_params: Parameters for pipeline tasks
- sql_params: Parameters for SQL tasks
- dbt_commands: Commands for dbt tasks
- queue: Queue settings for this run
- **extra_params: Additional run parameters
-
- Returns:
- Run ID (integer) for tracking the run
-
- Raises:
- JobError: If job run fails to start
-
- Example:
- >>> run_id = run_job_now(job_id=123, notebook_params={"env": "prod"})
- >>> print(f"Started run {run_id}")
- """
- w = get_workspace_client()
-
- try:
- # Build kwargs for SDK call
- kwargs: Dict[str, Any] = {"job_id": job_id}
-
- # Add optional parameters
- if idempotency_token:
- kwargs["idempotency_token"] = idempotency_token
- if jar_params:
- kwargs["jar_params"] = jar_params
- if notebook_params:
- kwargs["notebook_params"] = notebook_params
- if python_params:
- kwargs["python_params"] = python_params
- if spark_submit_params:
- kwargs["spark_submit_params"] = spark_submit_params
- if python_named_params:
- kwargs["python_named_params"] = python_named_params
- if pipeline_params:
- kwargs["pipeline_params"] = pipeline_params
- if sql_params:
- kwargs["sql_params"] = sql_params
- if dbt_commands:
- kwargs["dbt_commands"] = dbt_commands
- if queue:
- kwargs["queue"] = queue
-
- # Add extra params
- kwargs.update(extra_params)
-
- # Trigger run - SDK returns Wait[Run] object
- response = w.jobs.run_now(**kwargs)
-
- # Extract run_id from response
- # The Wait object has a response attribute that contains the Run
- if hasattr(response, "response") and hasattr(response.response, "run_id"):
- return response.response.run_id
- elif hasattr(response, "run_id"):
- return response.run_id
- else:
- # Fallback: try to get it from as_dict()
- response_dict = response.as_dict() if hasattr(response, "as_dict") else {}
- if "run_id" in response_dict:
- return response_dict["run_id"]
- raise JobError(f"Failed to extract run_id from response for job {job_id}", job_id=job_id)
-
- except Exception as e:
- raise JobError(f"Failed to start run for job {job_id}: {str(e)}", job_id=job_id)
-
-
-def get_run(run_id: int) -> Dict[str, Any]:
- """
- Get detailed run status and information.
-
- Args:
- run_id: Run ID
-
- Returns:
- Dictionary with run details including state, start_time, end_time, tasks, etc.
-
- Raises:
- JobError: If run not found or API request fails
- """
- w = get_workspace_client()
-
- try:
- run = w.jobs.get_run(run_id=run_id)
-
- # Convert SDK object to dict for JSON serialization
- return run.as_dict()
-
- except Exception as e:
- raise JobError(f"Failed to get run {run_id}: {str(e)}", run_id=run_id)
-
-
-def get_run_output(run_id: int) -> Dict[str, Any]:
- """
- Get run output including logs and results.
-
- Args:
- run_id: Run ID
-
- Returns:
- Dictionary with run output including logs, error messages, and task outputs
-
- Raises:
- JobError: If run not found or API request fails
- """
- w = get_workspace_client()
-
- try:
- output = w.jobs.get_run_output(run_id=run_id)
-
- # Convert SDK object to dict for JSON serialization
- return output.as_dict()
-
- except Exception as e:
- raise JobError(f"Failed to get output for run {run_id}: {str(e)}", run_id=run_id)
-
-
-def cancel_run(run_id: int) -> None:
- """
- Cancel a running job.
-
- Args:
- run_id: Run ID to cancel
-
- Raises:
- JobError: If cancel request fails
- """
- w = get_workspace_client()
-
- try:
- w.jobs.cancel_run(run_id=run_id)
- except Exception as e:
- raise JobError(f"Failed to cancel run {run_id}: {str(e)}", run_id=run_id)
-
-
-def list_runs(
- job_id: Optional[int] = None,
- active_only: bool = False,
- completed_only: bool = False,
- limit: int = 25,
- offset: int = 0,
- start_time_from: Optional[int] = None,
- start_time_to: Optional[int] = None,
-) -> List[Dict[str, Any]]:
- """
- List job runs with optional filters.
-
- Args:
- job_id: Optional filter by specific job ID
- active_only: If True, only return active runs (RUNNING, PENDING, etc.)
- completed_only: If True, only return completed runs
- limit: Maximum number of runs to return (default: 25, max: 1000)
- offset: Offset for pagination
- start_time_from: Filter by start time (epoch milliseconds)
- start_time_to: Filter by start time (epoch milliseconds)
-
- Returns:
- List of run info dicts with run_id, state, start_time, job_id, etc.
-
- Example:
- >>> # Get last 10 runs for a specific job
- >>> runs = list_runs(job_id=123, limit=10)
- >>>
- >>> # Get all active runs
- >>> active_runs = list_runs(active_only=True)
- """
- w = get_workspace_client()
- runs = []
-
- try:
- # SDK list_runs returns an iterator
- for run in w.jobs.list_runs(
- job_id=job_id,
- active_only=active_only,
- completed_only=completed_only,
- limit=limit,
- offset=offset,
- start_time_from=start_time_from,
- start_time_to=start_time_to,
- ):
- run_dict = run.as_dict()
- runs.append(run_dict)
-
- if len(runs) >= limit:
- break
-
- return runs
-
- except Exception as e:
- raise JobError(f"Failed to list runs: {str(e)}", job_id=job_id)
-
-
-def wait_for_run(
- run_id: int,
- timeout: int = 3600,
- poll_interval: int = 10,
-) -> JobRunResult:
- """
- Wait for a job run to complete and return detailed results.
-
- Args:
- run_id: Run ID to wait for
- timeout: Maximum wait time in seconds (default: 3600 = 1 hour)
- poll_interval: Time between status checks in seconds (default: 10)
-
- Returns:
- JobRunResult with detailed run status including:
- - success: True if run completed successfully
- - lifecycle_state: Final lifecycle state (TERMINATED, SKIPPED, etc.)
- - result_state: Final result state (SUCCESS, FAILED, etc.)
- - duration_seconds: Total time taken
- - error_message: Error message if failed
- - run_page_url: Link to run in Databricks UI
-
- Raises:
- TimeoutError: If run doesn't complete within timeout
- JobError: If API request fails
-
- Example:
- >>> run_id = run_job_now(job_id=123)
- >>> result = wait_for_run(run_id=run_id, timeout=1800)
- >>> if result.success:
- ... print(f"Job completed in {result.duration_seconds}s")
- ... else:
- ... print(f"Job failed: {result.error_message}")
- """
- w = get_workspace_client()
- start_time = time.time()
-
- job_id = None
- job_name = None
-
- while True:
- elapsed = time.time() - start_time
-
- if elapsed > timeout:
- raise TimeoutError(
- f"Job run {run_id} did not complete within {timeout} seconds. "
- f"Check run status in Databricks UI or call get_run(run_id={run_id})."
- )
-
- try:
- run = w.jobs.get_run(run_id=run_id)
-
- # Extract job info on first iteration
- if job_id is None:
- job_id = run.job_id
- # Get job name if available
- if run.job_id:
- try:
- job = w.jobs.get(job_id=run.job_id)
- job_name = job.settings.name if job.settings else None
- except Exception:
- pass # Ignore errors getting job name
-
- # Check if run is in terminal state
- lifecycle_state = run.state.life_cycle_state if run.state else None
- result_state = run.state.result_state if run.state else None
- state_message = run.state.state_message if run.state else None
-
- if lifecycle_state in TERMINAL_STATES:
- # Calculate duration
- duration = round(elapsed, 2)
- if run.start_time and run.end_time:
- # Use actual run times if available (more accurate)
- duration = round((run.end_time - run.start_time) / 1000.0, 2)
-
- # Determine success
- success = result_state in SUCCESS_STATES
-
- # Build result
- result = JobRunResult(
- job_id=job_id or 0,
- run_id=run_id,
- job_name=job_name,
- lifecycle_state=lifecycle_state.value if lifecycle_state else None,
- result_state=result_state.value if result_state else None,
- success=success,
- duration_seconds=duration,
- start_time=run.start_time,
- end_time=run.end_time,
- run_page_url=run.run_page_url,
- state_message=state_message,
- )
-
- # Build message
- if success:
- result.message = f"Job run {run_id} completed successfully in {duration}s. View: {run.run_page_url}"
- else:
- # Extract error details
- error_message = (
- state_message or f"Run failed with state: {result_state.value if result_state else 'UNKNOWN'}"
- )
- result.error_message = error_message
-
- # Try to get output for more details
- try:
- output = w.jobs.get_run_output(run_id=run_id)
- if output.error:
- result.error_message = output.error
- if output.error_trace:
- result.errors = [{"trace": output.error_trace}]
- except Exception:
- pass # Ignore errors getting output
-
- result.message = (
- f"Job run {run_id} failed. "
- f"State: {lifecycle_state.value if lifecycle_state else 'UNKNOWN'}, "
- f"Result: {result_state.value if result_state else 'UNKNOWN'}. "
- f"Error: {error_message}. "
- f"View: {run.run_page_url}"
- )
-
- return result
-
- except Exception as e:
- # If we can't get run status, raise error
- raise JobError(f"Failed to get run status for {run_id}: {str(e)}", run_id=run_id)
-
- time.sleep(poll_interval)
diff --git a/databricks-tools-core/databricks_tools_core/lakebase/__init__.py b/databricks-tools-core/databricks_tools_core/lakebase/__init__.py
deleted file mode 100644
index 2f6e458c..00000000
--- a/databricks-tools-core/databricks_tools_core/lakebase/__init__.py
+++ /dev/null
@@ -1,43 +0,0 @@
-"""
-Lakebase Provisioned Operations
-
-Functions for managing Databricks Lakebase Provisioned (managed PostgreSQL)
-instances, Unity Catalog registration, and reverse ETL synced tables.
-"""
-
-from .instances import (
- create_lakebase_instance,
- get_lakebase_instance,
- list_lakebase_instances,
- update_lakebase_instance,
- delete_lakebase_instance,
- generate_lakebase_credential,
-)
-from .catalogs import (
- create_lakebase_catalog,
- get_lakebase_catalog,
- delete_lakebase_catalog,
-)
-from .synced_tables import (
- create_synced_table,
- get_synced_table,
- delete_synced_table,
-)
-
-__all__ = [
- # Instances
- "create_lakebase_instance",
- "get_lakebase_instance",
- "list_lakebase_instances",
- "update_lakebase_instance",
- "delete_lakebase_instance",
- "generate_lakebase_credential",
- # Catalogs
- "create_lakebase_catalog",
- "get_lakebase_catalog",
- "delete_lakebase_catalog",
- # Synced Tables
- "create_synced_table",
- "get_synced_table",
- "delete_synced_table",
-]
diff --git a/databricks-tools-core/databricks_tools_core/lakebase/catalogs.py b/databricks-tools-core/databricks_tools_core/lakebase/catalogs.py
deleted file mode 100644
index 83f6c985..00000000
--- a/databricks-tools-core/databricks_tools_core/lakebase/catalogs.py
+++ /dev/null
@@ -1,152 +0,0 @@
-"""
-Lakebase Catalog Operations
-
-Functions for registering Lakebase database instances with Unity Catalog.
-"""
-
-import logging
-from typing import Any, Dict
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def create_lakebase_catalog(
- name: str,
- instance_name: str,
- database_name: str = "databricks_postgres",
- create_database_if_not_exists: bool = False,
-) -> Dict[str, Any]:
- """
- Register a Lakebase database instance as a Unity Catalog catalog.
-
- This makes the Lakebase PostgreSQL database discoverable and
- governable through Unity Catalog. The catalog is read-only.
-
- Args:
- name: Catalog name in Unity Catalog
- instance_name: Lakebase instance name to register
- database_name: PostgreSQL database name to register
- (default: "databricks_postgres")
- create_database_if_not_exists: If True, create the Postgres database
- if it does not exist (default: False)
-
- Returns:
- Dictionary with:
- - name: Catalog name
- - instance_name: Associated instance
- - database_name: PostgreSQL database name
- - status: Registration result
-
- Raises:
- Exception: If registration fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.database import DatabaseCatalog
-
- client.database.create_database_catalog(
- DatabaseCatalog(
- name=name,
- database_instance_name=instance_name,
- database_name=database_name,
- create_database_if_not_exists=create_database_if_not_exists,
- )
- )
-
- return {
- "name": name,
- "instance_name": instance_name,
- "database_name": database_name,
- "status": "created",
- "message": f"Catalog '{name}' registered for instance '{instance_name}', database '{database_name}'.",
- }
- except Exception as e:
- error_msg = str(e)
- if "ALREADY_EXISTS" in error_msg or "already exists" in error_msg.lower():
- return {
- "name": name,
- "instance_name": instance_name,
- "status": "ALREADY_EXISTS",
- "error": f"Catalog '{name}' already exists",
- }
- raise Exception(f"Failed to create Lakebase catalog '{name}': {error_msg}")
-
-
-def get_lakebase_catalog(name: str) -> Dict[str, Any]:
- """
- Get details of a Lakebase catalog registered in Unity Catalog.
-
- Args:
- name: Catalog name
-
- Returns:
- Dictionary with catalog details
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- catalog = client.database.get_database_catalog(name=name)
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "status": "NOT_FOUND",
- "error": f"Catalog '{name}' not found",
- }
- raise Exception(f"Failed to get Lakebase catalog '{name}': {error_msg}")
-
- result: Dict[str, Any] = {"name": name}
-
- if hasattr(catalog, "database_instance_name") and catalog.database_instance_name:
- result["instance_name"] = catalog.database_instance_name
-
- if hasattr(catalog, "database_name") and catalog.database_name:
- result["database_name"] = catalog.database_name
-
- if hasattr(catalog, "state") and catalog.state:
- result["state"] = str(catalog.state)
-
- return result
-
-
-def delete_lakebase_catalog(name: str) -> Dict[str, Any]:
- """
- Remove a Lakebase catalog registration from Unity Catalog.
-
- This does not delete the underlying database instance.
-
- Args:
- name: Catalog name to remove
-
- Returns:
- Dictionary with:
- - name: Catalog name
- - status: "deleted" or error info
-
- Raises:
- Exception: If deletion fails
- """
- client = get_workspace_client()
-
- try:
- client.database.delete_database_catalog(name=name)
- return {
- "name": name,
- "status": "deleted",
- }
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "status": "NOT_FOUND",
- "error": f"Catalog '{name}' not found",
- }
- raise Exception(f"Failed to delete Lakebase catalog '{name}': {error_msg}")
diff --git a/databricks-tools-core/databricks_tools_core/lakebase/instances.py b/databricks-tools-core/databricks_tools_core/lakebase/instances.py
deleted file mode 100644
index e9c80853..00000000
--- a/databricks-tools-core/databricks_tools_core/lakebase/instances.py
+++ /dev/null
@@ -1,338 +0,0 @@
-"""
-Lakebase Provisioned Instance Operations
-
-Functions for creating, managing, and connecting to Lakebase Provisioned
-(managed PostgreSQL) database instances.
-"""
-
-import logging
-import uuid
-from typing import Any, Dict, List, Optional
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def create_lakebase_instance(
- name: str,
- capacity: str = "CU_1",
- stopped: bool = False,
-) -> Dict[str, Any]:
- """
- Create a Lakebase Provisioned database instance.
-
- Args:
- name: Instance name (1-63 characters, letters and hyphens only)
- capacity: Compute capacity: "CU_1", "CU_2", "CU_4", or "CU_8"
- stopped: If True, create in stopped state (default: False)
-
- Returns:
- Dictionary with:
- - name: Instance name
- - capacity: Compute capacity
- - state: Instance state
- - read_write_dns: DNS endpoint for connections (if available)
-
- Raises:
- Exception: If creation fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.database import DatabaseInstance
-
- instance = client.database.create_database_instance(
- DatabaseInstance(
- name=name,
- capacity=capacity,
- stopped=stopped,
- )
- )
-
- result: Dict[str, Any] = {
- "name": name,
- "capacity": capacity,
- "status": "CREATING",
- "message": f"Instance '{name}' creation initiated. Use get_lakebase_instance('{name}') to check status.",
- }
-
- if instance:
- try:
- if instance.state:
- result["state"] = str(instance.state)
- except (KeyError, AttributeError):
- pass
- try:
- if instance.read_write_dns:
- result["read_write_dns"] = instance.read_write_dns
- except (KeyError, AttributeError):
- pass
-
- return result
- except Exception as e:
- error_msg = str(e)
- if "ALREADY_EXISTS" in error_msg or "already exists" in error_msg.lower() or "not unique" in error_msg.lower():
- return {
- "name": name,
- "status": "ALREADY_EXISTS",
- "error": f"Instance '{name}' already exists",
- }
- raise Exception(f"Failed to create Lakebase instance '{name}': {error_msg}")
-
-
-def get_lakebase_instance(name: str) -> Dict[str, Any]:
- """
- Get Lakebase Provisioned instance details.
-
- Args:
- name: Instance name
-
- Returns:
- Dictionary with:
- - name: Instance name
- - state: Current state (e.g., RUNNING, STOPPED, CREATING)
- - capacity: Compute capacity (CU_1, CU_2, CU_4, CU_8)
- - read_write_dns: DNS endpoint for read-write connections
- - read_only_dns: DNS endpoint for read-only connections (if available)
- - creator: Who created the instance
- - creation_time: When instance was created
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- instance = client.database.get_database_instance(name=name)
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "state": "NOT_FOUND",
- "error": f"Instance '{name}' not found",
- }
- raise Exception(f"Failed to get Lakebase instance '{name}': {error_msg}")
-
- result: Dict[str, Any] = {"name": instance.name}
-
- # Use try/except for SDK response objects where hasattr is unreliable
- for attr, key, transform in [
- ("state", "state", str),
- ("capacity", "capacity", str),
- ("read_write_dns", "read_write_dns", None),
- ("read_only_dns", "read_only_dns", None),
- ("creator", "creator", None),
- ("creation_time", "creation_time", str),
- ("uid", "uid", None),
- ]:
- try:
- val = getattr(instance, attr)
- if val is not None:
- result[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- # stopped can be False, so check explicitly
- try:
- if instance.stopped is not None:
- result["stopped"] = instance.stopped
- except (KeyError, AttributeError):
- pass
-
- return result
-
-
-def list_lakebase_instances() -> List[Dict[str, Any]]:
- """
- List all Lakebase Provisioned instances in the workspace.
-
- Returns:
- List of instance dictionaries with:
- - name: Instance name
- - state: Current state
- - capacity: Compute capacity
- - stopped: Whether instance is stopped
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- response = client.database.list_database_instances()
- except Exception as e:
- raise Exception(f"Failed to list Lakebase instances: {str(e)}")
-
- result = []
- instances = list(response) if response else []
- for inst in instances:
- entry: Dict[str, Any] = {"name": inst.name}
-
- for attr, key, transform in [
- ("state", "state", str),
- ("capacity", "capacity", str),
- ("read_write_dns", "read_write_dns", None),
- ]:
- try:
- val = getattr(inst, attr)
- if val is not None:
- entry[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- try:
- if inst.stopped is not None:
- entry["stopped"] = inst.stopped
- except (KeyError, AttributeError):
- pass
-
- result.append(entry)
-
- return result
-
-
-def update_lakebase_instance(
- name: str,
- capacity: Optional[str] = None,
- stopped: Optional[bool] = None,
-) -> Dict[str, Any]:
- """
- Update a Lakebase Provisioned instance (resize or start/stop).
-
- Args:
- name: Instance name
- capacity: New compute capacity: "CU_1", "CU_2", "CU_4", or "CU_8"
- stopped: True to stop instance (saves cost), False to start it
-
- Returns:
- Dictionary with updated instance details
-
- Raises:
- Exception: If update fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.database import DatabaseInstance
-
- update_fields: Dict[str, Any] = {"name": name}
- if capacity is not None:
- update_fields["capacity"] = capacity
- if stopped is not None:
- update_fields["stopped"] = stopped
-
- instance = client.database.update_database_instance(
- name=name,
- database_instance=DatabaseInstance(**update_fields),
- update_mask="*",
- )
-
- result: Dict[str, Any] = {
- "name": name,
- "status": "UPDATED",
- }
-
- if capacity is not None:
- result["capacity"] = capacity
- if stopped is not None:
- result["stopped"] = stopped
-
- if instance:
- if hasattr(instance, "state") and instance.state:
- result["state"] = str(instance.state)
-
- return result
- except Exception as e:
- raise Exception(f"Failed to update Lakebase instance '{name}': {str(e)}")
-
-
-def delete_lakebase_instance(
- name: str,
- force: bool = False,
- purge: bool = True,
-) -> Dict[str, Any]:
- """
- Delete a Lakebase Provisioned instance.
-
- Args:
- name: Instance name to delete
- force: If True, delete child instances as well (default: False)
- purge: Required to be True to confirm deletion (default: True)
-
- Returns:
- Dictionary with:
- - name: Instance name
- - status: "deleted" or error info
-
- Raises:
- Exception: If deletion fails
- """
- client = get_workspace_client()
-
- try:
- client.database.delete_database_instance(
- name=name,
- force=force,
- purge=purge,
- )
- return {
- "name": name,
- "status": "deleted",
- }
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "status": "NOT_FOUND",
- "error": f"Instance '{name}' not found",
- }
- raise Exception(f"Failed to delete Lakebase instance '{name}': {error_msg}")
-
-
-def generate_lakebase_credential(
- instance_names: List[str],
-) -> Dict[str, Any]:
- """
- Generate an OAuth token for connecting to Lakebase instances.
-
- The token is valid for 1 hour. Use it as the password in PostgreSQL
- connection strings with sslmode=require.
-
- Args:
- instance_names: List of instance names to generate credentials for
-
- Returns:
- Dictionary with:
- - token: OAuth token (use as password in connection string)
- - expiration: Token expiration info
- - instance_names: Instances the token is valid for
-
- Raises:
- Exception: If credential generation fails
- """
- client = get_workspace_client()
-
- try:
- cred = client.database.generate_database_credential(
- request_id=str(uuid.uuid4()),
- instance_names=instance_names,
- )
-
- result: Dict[str, Any] = {
- "instance_names": instance_names,
- }
-
- if hasattr(cred, "token") and cred.token:
- result["token"] = cred.token
-
- if hasattr(cred, "expiration_time") and cred.expiration_time:
- result["expiration_time"] = str(cred.expiration_time)
-
- result["message"] = "Token generated. Valid for ~1 hour. Use as password with sslmode=require."
-
- return result
- except Exception as e:
- raise Exception(f"Failed to generate Lakebase credentials: {str(e)}")
diff --git a/databricks-tools-core/databricks_tools_core/lakebase/synced_tables.py b/databricks-tools-core/databricks_tools_core/lakebase/synced_tables.py
deleted file mode 100644
index e88a7f6e..00000000
--- a/databricks-tools-core/databricks_tools_core/lakebase/synced_tables.py
+++ /dev/null
@@ -1,182 +0,0 @@
-"""
-Lakebase Synced Table Operations
-
-Functions for creating and managing reverse ETL synced tables that
-sync data from Unity Catalog Delta tables to Lakebase PostgreSQL.
-"""
-
-import logging
-from typing import Any, Dict, List, Optional
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def create_synced_table(
- instance_name: str,
- source_table_name: str,
- target_table_name: str,
- primary_key_columns: Optional[List[str]] = None,
- scheduling_policy: str = "TRIGGERED",
-) -> Dict[str, Any]:
- """
- Create a synced table to replicate data from a Delta table to Lakebase.
-
- This enables reverse ETL: syncing processed data from Unity Catalog
- into a PostgreSQL table for low-latency OLTP access.
-
- Args:
- instance_name: Lakebase instance name
- source_table_name: Fully qualified source Delta table
- (catalog.schema.table_name)
- target_table_name: Fully qualified target Lakebase catalog table
- (catalog.schema.table_name)
- primary_key_columns: List of primary key column names from the source
- table. If not provided, the source table's primary key is used.
- scheduling_policy: Sync mode: "TRIGGERED" (manual sync), "SNAPSHOT"
- (full refresh), or "CONTINUOUS" (real-time). Default: "TRIGGERED"
-
- Returns:
- Dictionary with:
- - instance_name: Lakebase instance
- - source_table_name: Source Delta table
- - target_table_name: Target PostgreSQL table
- - status: Creation result
-
- Raises:
- Exception: If creation fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.database import (
- SyncedDatabaseTable,
- SyncedTableSpec,
- SyncedTableSchedulingPolicy,
- )
-
- spec_kwargs: Dict[str, Any] = {
- "source_table_full_name": source_table_name,
- "scheduling_policy": SyncedTableSchedulingPolicy(scheduling_policy),
- }
-
- if primary_key_columns:
- spec_kwargs["primary_key_columns"] = primary_key_columns
-
- synced_table = SyncedDatabaseTable(
- name=target_table_name,
- database_instance_name=instance_name,
- spec=SyncedTableSpec(**spec_kwargs),
- )
-
- client.database.create_synced_database_table(synced_table)
-
- return {
- "instance_name": instance_name,
- "source_table_name": source_table_name,
- "target_table_name": target_table_name,
- "scheduling_policy": scheduling_policy,
- "status": "CREATING",
- "message": (
- f"Synced table creation initiated. Source: '{source_table_name}' -> Target: '{target_table_name}'."
- ),
- }
- except Exception as e:
- error_msg = str(e)
- if "ALREADY_EXISTS" in error_msg or "already exists" in error_msg.lower():
- return {
- "target_table_name": target_table_name,
- "status": "ALREADY_EXISTS",
- "error": f"Synced table '{target_table_name}' already exists",
- }
- raise Exception(f"Failed to create synced table '{target_table_name}': {error_msg}")
-
-
-def get_synced_table(
- table_name: str,
-) -> Dict[str, Any]:
- """
- Get status and details of a synced table.
-
- Args:
- table_name: Fully qualified synced table name
- (catalog.schema.table_name)
-
- Returns:
- Dictionary with synced table details including sync status
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- table = client.database.get_synced_database_table(name=table_name)
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "table_name": table_name,
- "status": "NOT_FOUND",
- "error": f"Synced table '{table_name}' not found",
- }
- raise Exception(f"Failed to get synced table '{table_name}': {error_msg}")
-
- result: Dict[str, Any] = {"table_name": table_name}
-
- if hasattr(table, "database_instance_name") and table.database_instance_name:
- result["instance_name"] = table.database_instance_name
-
- if hasattr(table, "spec") and table.spec:
- if hasattr(table.spec, "source_table_full_name") and table.spec.source_table_full_name:
- result["source_table_name"] = table.spec.source_table_full_name
- if hasattr(table.spec, "scheduling_policy") and table.spec.scheduling_policy:
- result["scheduling_policy"] = str(table.spec.scheduling_policy.value)
- if hasattr(table.spec, "primary_key_columns") and table.spec.primary_key_columns:
- result["primary_key_columns"] = table.spec.primary_key_columns
-
- if hasattr(table, "status") and table.status:
- result["status"] = str(table.status)
-
- return result
-
-
-def delete_synced_table(
- table_name: str,
-) -> Dict[str, Any]:
- """
- Delete a synced table.
-
- This stops syncing and removes the target table from Lakebase.
- The source Delta table is not affected.
-
- Args:
- table_name: Fully qualified synced table name
- (catalog.schema.table_name)
-
- Returns:
- Dictionary with:
- - table_name: Synced table name
- - status: "deleted" or error info
-
- Raises:
- Exception: If deletion fails
- """
- client = get_workspace_client()
-
- try:
- client.database.delete_synced_database_table(name=table_name)
- return {
- "table_name": table_name,
- "status": "deleted",
- }
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "table_name": table_name,
- "status": "NOT_FOUND",
- "error": f"Synced table '{table_name}' not found",
- }
- raise Exception(f"Failed to delete synced table '{table_name}': {error_msg}")
diff --git a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/__init__.py b/databricks-tools-core/databricks_tools_core/lakebase_autoscale/__init__.py
deleted file mode 100644
index 3da187a2..00000000
--- a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/__init__.py
+++ /dev/null
@@ -1,54 +0,0 @@
-"""
-Lakebase Autoscaling Operations
-
-Functions for managing Databricks Lakebase Autoscaling projects, branches,
-computes (endpoints), and database credentials.
-"""
-
-from .projects import (
- create_project,
- get_project,
- list_projects,
- update_project,
- delete_project,
-)
-from .branches import (
- create_branch,
- get_branch,
- list_branches,
- update_branch,
- delete_branch,
-)
-from .computes import (
- create_endpoint,
- get_endpoint,
- list_endpoints,
- update_endpoint,
- delete_endpoint,
-)
-from .credentials import (
- generate_credential,
-)
-
-__all__ = [
- # Projects
- "create_project",
- "get_project",
- "list_projects",
- "update_project",
- "delete_project",
- # Branches
- "create_branch",
- "get_branch",
- "list_branches",
- "update_branch",
- "delete_branch",
- # Computes (Endpoints)
- "create_endpoint",
- "get_endpoint",
- "list_endpoints",
- "update_endpoint",
- "delete_endpoint",
- # Credentials
- "generate_credential",
-]
diff --git a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/branches.py b/databricks-tools-core/databricks_tools_core/lakebase_autoscale/branches.py
deleted file mode 100644
index cdda3696..00000000
--- a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/branches.py
+++ /dev/null
@@ -1,335 +0,0 @@
-"""
-Lakebase Autoscaling Branch Operations
-
-Functions for creating, managing, and deleting branches within
-Lakebase Autoscaling projects.
-"""
-
-import logging
-from typing import Any, Dict, List, Optional
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def create_branch(
- project_name: str,
- branch_id: str,
- source_branch: Optional[str] = None,
- ttl_seconds: Optional[int] = None,
- no_expiry: bool = False,
-) -> Dict[str, Any]:
- """
- Create a branch in a Lakebase Autoscaling project.
-
- Args:
- project_name: Project resource name (e.g., "projects/my-app")
- branch_id: Branch identifier (1-63 chars, lowercase letters, digits, hyphens)
- source_branch: Source branch to fork from. If not specified,
- automatically uses the project's default branch.
- ttl_seconds: Time-to-live in seconds (max 30 days = 2592000s).
- Set to create an expiring branch.
- no_expiry: If True, branch never expires. One of ttl_seconds
- or no_expiry must be specified.
-
- Returns:
- Dictionary with:
- - name: Branch resource name
- - status: Creation status
- - expire_time: Expiration time (if TTL set)
-
- Raises:
- Exception: If creation fails
- """
- client = get_workspace_client()
-
- if not project_name.startswith("projects/"):
- project_name = f"projects/{project_name}"
-
- # Resolve the source branch: use the provided one, or find the default branch
- if source_branch is None:
- branches = list_branches(project_name)
- default_branches = [b for b in branches if b.get("is_default") is True]
- if default_branches:
- source_branch = default_branches[0]["name"]
- elif branches:
- source_branch = branches[0]["name"]
- else:
- raise Exception(f"No branches found in project '{project_name}' to fork from")
-
- try:
- from databricks.sdk.service.postgres import Branch, BranchSpec, Duration
-
- spec_kwargs: Dict[str, Any] = {
- "source_branch": source_branch,
- }
-
- if ttl_seconds is not None:
- spec_kwargs["ttl"] = Duration(seconds=ttl_seconds)
- elif no_expiry:
- spec_kwargs["no_expiry"] = True
- else:
- # Default to no expiry if neither is specified
- spec_kwargs["no_expiry"] = True
-
- operation = client.postgres.create_branch(
- parent=project_name,
- branch=Branch(spec=BranchSpec(**spec_kwargs)),
- branch_id=branch_id,
- )
- result_branch = operation.wait()
-
- result: Dict[str, Any] = {
- "name": result_branch.name,
- "status": "CREATED",
- }
-
- if result_branch.status:
- for attr, key, transform in [
- ("current_state", "state", str),
- ("default", "is_default", None),
- ("is_protected", "is_protected", None),
- ("expire_time", "expire_time", str),
- ("logical_size_bytes", "logical_size_bytes", None),
- ]:
- try:
- val = getattr(result_branch.status, attr)
- if val is not None:
- result[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- return result
- except Exception as e:
- error_msg = str(e)
- if "ALREADY_EXISTS" in error_msg or "already exists" in error_msg.lower():
- return {
- "name": f"{project_name}/branches/{branch_id}",
- "status": "ALREADY_EXISTS",
- "error": f"Branch '{branch_id}' already exists",
- }
- raise Exception(f"Failed to create branch '{branch_id}': {error_msg}")
-
-
-def get_branch(name: str) -> Dict[str, Any]:
- """
- Get Lakebase Autoscaling branch details.
-
- Args:
- name: Branch resource name
- (e.g., "projects/my-app/branches/production")
-
- Returns:
- Dictionary with:
- - name: Branch resource name
- - state: Current state
- - is_default: Whether this is the default branch
- - is_protected: Whether the branch is protected
- - expire_time: Expiration time (if set)
- - logical_size_bytes: Logical data size
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- branch = client.postgres.get_branch(name=name)
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "state": "NOT_FOUND",
- "error": f"Branch '{name}' not found",
- }
- raise Exception(f"Failed to get branch '{name}': {error_msg}")
-
- result: Dict[str, Any] = {"name": branch.name}
-
- if branch.status:
- for attr, key, transform in [
- ("current_state", "state", str),
- ("default", "is_default", None),
- ("is_protected", "is_protected", None),
- ("expire_time", "expire_time", str),
- ("logical_size_bytes", "logical_size_bytes", None),
- ("parent_name", "parent_name", None),
- ]:
- try:
- val = getattr(branch.status, attr)
- if val is not None:
- result[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- return result
-
-
-def list_branches(project_name: str) -> List[Dict[str, Any]]:
- """
- List all branches in a Lakebase Autoscaling project.
-
- Args:
- project_name: Project resource name (e.g., "projects/my-app")
-
- Returns:
- List of branch dictionaries with name, state, is_default, is_protected.
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- if not project_name.startswith("projects/"):
- project_name = f"projects/{project_name}"
-
- try:
- response = client.postgres.list_branches(parent=project_name)
- except Exception as e:
- raise Exception(f"Failed to list branches for '{project_name}': {str(e)}")
-
- result = []
- branches = list(response) if response else []
- for br in branches:
- entry: Dict[str, Any] = {"name": br.name}
-
- if br.status:
- for attr, key, transform in [
- ("current_state", "state", str),
- ("default", "is_default", None),
- ("is_protected", "is_protected", None),
- ("expire_time", "expire_time", str),
- ]:
- try:
- val = getattr(br.status, attr)
- if val is not None:
- entry[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- result.append(entry)
-
- return result
-
-
-def update_branch(
- name: str,
- is_protected: Optional[bool] = None,
- ttl_seconds: Optional[int] = None,
- no_expiry: Optional[bool] = None,
-) -> Dict[str, Any]:
- """
- Update a Lakebase Autoscaling branch (protect or set expiration).
-
- Args:
- name: Branch resource name
- (e.g., "projects/my-app/branches/production")
- is_protected: Set branch protection status
- ttl_seconds: New TTL in seconds (max 30 days)
- no_expiry: If True, remove expiration
-
- Returns:
- Dictionary with updated branch details
-
- Raises:
- Exception: If update fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.postgres import Branch, BranchSpec, Duration, FieldMask
-
- spec_kwargs: Dict[str, Any] = {}
- update_fields: list[str] = []
-
- if is_protected is not None:
- spec_kwargs["is_protected"] = is_protected
- update_fields.append("spec.is_protected")
-
- if ttl_seconds is not None:
- spec_kwargs["ttl"] = Duration(seconds=ttl_seconds)
- update_fields.append("spec.expiration")
- elif no_expiry is True:
- spec_kwargs["no_expiry"] = True
- update_fields.append("spec.expiration")
-
- if not update_fields:
- return {
- "name": name,
- "status": "NO_CHANGES",
- "error": "No fields specified for update",
- }
-
- operation = client.postgres.update_branch(
- name=name,
- branch=Branch(
- name=name,
- spec=BranchSpec(**spec_kwargs),
- ),
- update_mask=FieldMask(field_mask=update_fields),
- )
- result_branch = operation.wait()
-
- result: Dict[str, Any] = {
- "name": name,
- "status": "UPDATED",
- }
-
- if is_protected is not None:
- result["is_protected"] = is_protected
-
- if result_branch and result_branch.status:
- try:
- if result_branch.status.expire_time:
- result["expire_time"] = str(result_branch.status.expire_time)
- except (KeyError, AttributeError):
- pass
-
- return result
- except Exception as e:
- raise Exception(f"Failed to update branch '{name}': {str(e)}")
-
-
-def delete_branch(name: str) -> Dict[str, Any]:
- """
- Delete a Lakebase Autoscaling branch.
-
- This permanently deletes all databases, roles, computes, and data
- specific to this branch.
-
- Args:
- name: Branch resource name
- (e.g., "projects/my-app/branches/development")
-
- Returns:
- Dictionary with:
- - name: Branch resource name
- - status: "deleted" or error info
-
- Raises:
- Exception: If deletion fails
- """
- client = get_workspace_client()
-
- try:
- operation = client.postgres.delete_branch(name=name)
- operation.wait()
- return {
- "name": name,
- "status": "deleted",
- }
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "status": "NOT_FOUND",
- "error": f"Branch '{name}' not found",
- }
- raise Exception(f"Failed to delete branch '{name}': {error_msg}")
-
-
-# NOTE: reset_branch is not yet available in the Databricks SDK.
-# It may be added in a future SDK release.
diff --git a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/computes.py b/databricks-tools-core/databricks_tools_core/lakebase_autoscale/computes.py
deleted file mode 100644
index ce323c8f..00000000
--- a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/computes.py
+++ /dev/null
@@ -1,361 +0,0 @@
-"""
-Lakebase Autoscaling Compute (Endpoint) Operations
-
-Functions for creating, managing, and deleting compute endpoints
-within Lakebase Autoscaling branches.
-"""
-
-import logging
-from typing import Any, Dict, List, Optional
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def create_endpoint(
- branch_name: str,
- endpoint_id: str,
- endpoint_type: str = "ENDPOINT_TYPE_READ_WRITE",
- autoscaling_limit_min_cu: Optional[float] = None,
- autoscaling_limit_max_cu: Optional[float] = None,
- scale_to_zero_seconds: Optional[int] = None,
-) -> Dict[str, Any]:
- """
- Create a compute endpoint on a branch.
-
- Args:
- branch_name: Branch resource name
- (e.g., "projects/my-app/branches/production")
- endpoint_id: Endpoint identifier (1-63 chars, lowercase letters,
- digits, hyphens)
- endpoint_type: Endpoint type: "ENDPOINT_TYPE_READ_WRITE" or
- "ENDPOINT_TYPE_READ_ONLY". Default: "ENDPOINT_TYPE_READ_WRITE"
- autoscaling_limit_min_cu: Minimum compute units (0.5-32)
- autoscaling_limit_max_cu: Maximum compute units (0.5-112)
- scale_to_zero_seconds: Inactivity timeout before suspending.
- Set to 0 to disable scale-to-zero.
-
- Returns:
- Dictionary with:
- - name: Endpoint resource name
- - host: Connection hostname
- - status: Creation status
-
- Raises:
- Exception: If creation fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.postgres import Endpoint, EndpointSpec, EndpointType
-
- ep_type = EndpointType(endpoint_type)
-
- spec_kwargs: Dict[str, Any] = {
- "endpoint_type": ep_type,
- }
-
- if autoscaling_limit_min_cu is not None:
- spec_kwargs["autoscaling_limit_min_cu"] = autoscaling_limit_min_cu
- if autoscaling_limit_max_cu is not None:
- spec_kwargs["autoscaling_limit_max_cu"] = autoscaling_limit_max_cu
- if scale_to_zero_seconds is not None:
- from databricks.sdk.service.postgres import Duration
-
- spec_kwargs["suspend_timeout_duration"] = Duration(seconds=scale_to_zero_seconds)
-
- operation = client.postgres.create_endpoint(
- parent=branch_name,
- endpoint=Endpoint(spec=EndpointSpec(**spec_kwargs)),
- endpoint_id=endpoint_id,
- )
- result_ep = operation.wait()
-
- result: Dict[str, Any] = {
- "name": result_ep.name,
- "status": "CREATED",
- }
-
- if result_ep.status:
- for attr, key, transform in [
- ("current_state", "state", str),
- ("endpoint_type", "endpoint_type", str),
- ("autoscaling_limit_min_cu", "min_cu", None),
- ("autoscaling_limit_max_cu", "max_cu", None),
- ]:
- try:
- val = getattr(result_ep.status, attr)
- if val is not None:
- result[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- try:
- if result_ep.status.hosts and result_ep.status.hosts.host:
- result["host"] = result_ep.status.hosts.host
- except (KeyError, AttributeError):
- pass
-
- return result
- except Exception as e:
- error_msg = str(e)
- if "ALREADY_EXISTS" in error_msg or "already exists" in error_msg.lower():
- return {
- "name": f"{branch_name}/endpoints/{endpoint_id}",
- "status": "ALREADY_EXISTS",
- "error": f"Endpoint '{endpoint_id}' already exists on branch",
- }
- raise Exception(f"Failed to create endpoint '{endpoint_id}': {error_msg}")
-
-
-def get_endpoint(name: str) -> Dict[str, Any]:
- """
- Get Lakebase Autoscaling endpoint details.
-
- Args:
- name: Endpoint resource name
- (e.g., "projects/my-app/branches/production/endpoints/ep-primary")
-
- Returns:
- Dictionary with:
- - name: Endpoint resource name
- - state: Current state (ACTIVE, SUSPENDED, etc.)
- - endpoint_type: READ_WRITE or READ_ONLY
- - host: Connection hostname
- - min_cu: Minimum compute units
- - max_cu: Maximum compute units
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- endpoint = client.postgres.get_endpoint(name=name)
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "state": "NOT_FOUND",
- "error": f"Endpoint '{name}' not found",
- }
- raise Exception(f"Failed to get endpoint '{name}': {error_msg}")
-
- result: Dict[str, Any] = {"name": endpoint.name}
-
- if endpoint.status:
- for attr, key, transform in [
- ("current_state", "state", str),
- ("endpoint_type", "endpoint_type", str),
- ("autoscaling_limit_min_cu", "min_cu", None),
- ("autoscaling_limit_max_cu", "max_cu", None),
- ]:
- try:
- val = getattr(endpoint.status, attr)
- if val is not None:
- result[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- try:
- if endpoint.status.hosts and endpoint.status.hosts.host:
- result["host"] = endpoint.status.hosts.host
- except (KeyError, AttributeError):
- pass
-
- return result
-
-
-def list_endpoints(branch_name: str) -> List[Dict[str, Any]]:
- """
- List all endpoints on a branch.
-
- Args:
- branch_name: Branch resource name
- (e.g., "projects/my-app/branches/production")
-
- Returns:
- List of endpoint dictionaries with name, state, type, CU settings.
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- response = client.postgres.list_endpoints(parent=branch_name)
- except Exception as e:
- raise Exception(f"Failed to list endpoints for '{branch_name}': {str(e)}")
-
- result = []
- endpoints = list(response) if response else []
- for ep in endpoints:
- entry: Dict[str, Any] = {"name": ep.name}
-
- if ep.status:
- for attr, key, transform in [
- ("current_state", "state", str),
- ("endpoint_type", "endpoint_type", str),
- ("autoscaling_limit_min_cu", "min_cu", None),
- ("autoscaling_limit_max_cu", "max_cu", None),
- ]:
- try:
- val = getattr(ep.status, attr)
- if val is not None:
- entry[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- try:
- if ep.status.hosts and ep.status.hosts.host:
- entry["host"] = ep.status.hosts.host
- except (KeyError, AttributeError):
- pass
-
- result.append(entry)
-
- return result
-
-
-def update_endpoint(
- name: str,
- autoscaling_limit_min_cu: Optional[float] = None,
- autoscaling_limit_max_cu: Optional[float] = None,
- scale_to_zero_seconds: Optional[int] = None,
-) -> Dict[str, Any]:
- """
- Update a Lakebase Autoscaling endpoint (resize or configure scale-to-zero).
-
- Args:
- name: Endpoint resource name
- (e.g., "projects/my-app/branches/production/endpoints/ep-primary")
- autoscaling_limit_min_cu: New minimum compute units (0.5-32)
- autoscaling_limit_max_cu: New maximum compute units (0.5-112)
- scale_to_zero_seconds: New inactivity timeout. 0 to disable.
-
- Returns:
- Dictionary with updated endpoint details
-
- Raises:
- Exception: If update fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.postgres import Endpoint, EndpointSpec, EndpointType, FieldMask
-
- spec_kwargs: Dict[str, Any] = {}
- update_fields: list[str] = []
-
- if autoscaling_limit_min_cu is not None:
- spec_kwargs["autoscaling_limit_min_cu"] = autoscaling_limit_min_cu
- update_fields.append("spec.autoscaling_limit_min_cu")
-
- if autoscaling_limit_max_cu is not None:
- spec_kwargs["autoscaling_limit_max_cu"] = autoscaling_limit_max_cu
- update_fields.append("spec.autoscaling_limit_max_cu")
-
- if scale_to_zero_seconds is not None:
- from databricks.sdk.service.postgres import Duration
-
- spec_kwargs["suspend_timeout_duration"] = Duration(seconds=scale_to_zero_seconds)
- update_fields.append("spec.suspend_timeout_duration")
-
- if not update_fields:
- return {
- "name": name,
- "status": "NO_CHANGES",
- "error": "No fields specified for update",
- }
-
- # EndpointSpec requires endpoint_type -- fetch it from the current endpoint
- existing_ep = client.postgres.get_endpoint(name=name)
- ep_type = (
- existing_ep.spec.endpoint_type
- if existing_ep.spec and existing_ep.spec.endpoint_type
- else EndpointType.ENDPOINT_TYPE_READ_WRITE
- )
- spec_kwargs["endpoint_type"] = ep_type
-
- operation = client.postgres.update_endpoint(
- name=name,
- endpoint=Endpoint(
- name=name,
- spec=EndpointSpec(**spec_kwargs),
- ),
- update_mask=FieldMask(field_mask=update_fields),
- )
- result_ep = operation.wait()
-
- result: Dict[str, Any] = {
- "name": name,
- "status": "UPDATED",
- }
-
- if autoscaling_limit_min_cu is not None:
- result["min_cu"] = autoscaling_limit_min_cu
- if autoscaling_limit_max_cu is not None:
- result["max_cu"] = autoscaling_limit_max_cu
-
- if result_ep and result_ep.status:
- try:
- if result_ep.status.current_state:
- result["state"] = str(result_ep.status.current_state)
- except (KeyError, AttributeError):
- pass
-
- return result
- except Exception as e:
- raise Exception(f"Failed to update endpoint '{name}': {str(e)}")
-
-
-def delete_endpoint(name: str, max_retries: int = 6, retry_delay: int = 10) -> Dict[str, Any]:
- """
- Delete a Lakebase Autoscaling endpoint.
-
- Retries on ``Aborted`` errors (reconciliation in progress).
-
- Args:
- name: Endpoint resource name
- (e.g., "projects/my-app/branches/production/endpoints/ep-primary")
- max_retries: Maximum number of retries for transient errors.
- retry_delay: Seconds to wait between retries.
-
- Returns:
- Dictionary with:
- - name: Endpoint resource name
- - status: "deleted" or error info
-
- Raises:
- Exception: If deletion fails after retries
- """
- import time
-
- client = get_workspace_client()
-
- for attempt in range(max_retries + 1):
- try:
- operation = client.postgres.delete_endpoint(name=name)
- operation.wait()
- return {
- "name": name,
- "status": "deleted",
- }
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "status": "NOT_FOUND",
- "error": f"Endpoint '{name}' not found",
- }
- if ("reconciliation" in error_msg.lower() or "aborted" in error_msg.lower()) and attempt < max_retries:
- logger.info(
- f"Endpoint reconciliation in progress, retrying in {retry_delay}s "
- f"(attempt {attempt + 1}/{max_retries})"
- )
- time.sleep(retry_delay)
- continue
- raise Exception(f"Failed to delete endpoint '{name}': {error_msg}")
diff --git a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/credentials.py b/databricks-tools-core/databricks_tools_core/lakebase_autoscale/credentials.py
deleted file mode 100644
index 3eaf5955..00000000
--- a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/credentials.py
+++ /dev/null
@@ -1,53 +0,0 @@
-"""
-Lakebase Autoscaling Credential Operations
-
-Functions for generating OAuth tokens for connecting to
-Lakebase Autoscaling PostgreSQL databases.
-"""
-
-import logging
-from typing import Any, Dict
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def generate_credential(endpoint: str) -> Dict[str, Any]:
- """
- Generate an OAuth token for connecting to Lakebase Autoscaling databases.
-
- The token is valid for ~1 hour. Use it as the password in PostgreSQL
- connection strings with sslmode=require.
-
- Args:
- endpoint: Endpoint resource name to scope the credential to
- (e.g., "projects/my-app/branches/production/endpoints/ep-primary").
-
- Returns:
- Dictionary with:
- - token: OAuth token (use as password in connection string)
- - expiration_time: Token expiration time
- - message: Usage instructions
-
- Raises:
- Exception: If credential generation fails
- """
- client = get_workspace_client()
-
- try:
- cred = client.postgres.generate_database_credential(endpoint=endpoint)
-
- result: Dict[str, Any] = {}
-
- if hasattr(cred, "token") and cred.token:
- result["token"] = cred.token
-
- if hasattr(cred, "expiration_time") and cred.expiration_time:
- result["expiration_time"] = str(cred.expiration_time)
-
- result["message"] = "Token generated. Valid for ~1 hour. Use as password with sslmode=require."
-
- return result
- except Exception as e:
- raise Exception(f"Failed to generate Lakebase Autoscaling credentials: {str(e)}")
diff --git a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/projects.py b/databricks-tools-core/databricks_tools_core/lakebase_autoscale/projects.py
deleted file mode 100644
index cd27596e..00000000
--- a/databricks-tools-core/databricks_tools_core/lakebase_autoscale/projects.py
+++ /dev/null
@@ -1,291 +0,0 @@
-"""
-Lakebase Autoscaling Project Operations
-
-Functions for creating, managing, and deleting Lakebase Autoscaling projects.
-Projects are the top-level container for branches, computes, databases, and roles.
-"""
-
-import logging
-from typing import Any, Dict, List, Optional
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def _normalize_project_name(name: str) -> str:
- """Ensure project name has the 'projects/' prefix."""
- if not name.startswith("projects/"):
- return f"projects/{name}"
- return name
-
-
-def create_project(
- project_id: str,
- display_name: Optional[str] = None,
- pg_version: str = "17",
-) -> Dict[str, Any]:
- """
- Create a Lakebase Autoscaling project.
-
- Args:
- project_id: Project identifier (1-63 chars, lowercase letters, digits, hyphens).
- display_name: Human-readable display name. Defaults to project_id.
- pg_version: Postgres version ("16" or "17"). Default: "17".
-
- Returns:
- Dictionary with:
- - name: Project resource name (projects/{project_id})
- - display_name: Display name
- - pg_version: Postgres version
- - status: Creation status
-
- Raises:
- Exception: If creation fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.postgres import Project, ProjectSpec
-
- spec = ProjectSpec(
- display_name=display_name or project_id,
- pg_version=int(pg_version),
- )
-
- operation = client.postgres.create_project(
- project=Project(spec=spec),
- project_id=project_id,
- )
- result_project = operation.wait()
-
- result: Dict[str, Any] = {
- "name": result_project.name,
- "display_name": display_name or project_id,
- "pg_version": pg_version,
- "status": "CREATED",
- }
-
- if result_project.status:
- try:
- if result_project.status.display_name:
- result["display_name"] = result_project.status.display_name
- except (KeyError, AttributeError):
- pass
- try:
- if result_project.status.pg_version:
- result["pg_version"] = str(result_project.status.pg_version)
- except (KeyError, AttributeError):
- pass
- try:
- if result_project.status.state:
- result["state"] = str(result_project.status.state)
- except (KeyError, AttributeError):
- pass
-
- return result
- except Exception as e:
- error_msg = str(e)
- if "ALREADY_EXISTS" in error_msg or "already exists" in error_msg.lower():
- return {
- "name": f"projects/{project_id}",
- "status": "ALREADY_EXISTS",
- "error": f"Project '{project_id}' already exists",
- }
- raise Exception(f"Failed to create Lakebase Autoscaling project '{project_id}': {error_msg}")
-
-
-def get_project(name: str) -> Dict[str, Any]:
- """
- Get Lakebase Autoscaling project details.
-
- Args:
- name: Project resource name (e.g., "projects/my-app" or "my-app")
-
- Returns:
- Dictionary with:
- - name: Project resource name
- - display_name: Display name
- - pg_version: Postgres version
- - state: Current state (READY, CREATING, etc.)
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
- full_name = _normalize_project_name(name)
-
- try:
- project = client.postgres.get_project(name=full_name)
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": full_name,
- "state": "NOT_FOUND",
- "error": f"Project '{full_name}' not found",
- }
- raise Exception(f"Failed to get Lakebase Autoscaling project '{full_name}': {error_msg}")
-
- result: Dict[str, Any] = {"name": project.name}
-
- if project.status:
- for attr, key, transform in [
- ("display_name", "display_name", None),
- ("pg_version", "pg_version", str),
- ("owner", "owner", None),
- ]:
- try:
- val = getattr(project.status, attr)
- if val is not None:
- result[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- return result
-
-
-def list_projects() -> List[Dict[str, Any]]:
- """
- List all Lakebase Autoscaling projects in the workspace.
-
- Returns:
- List of project dictionaries with name, display_name, pg_version, state.
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- response = client.postgres.list_projects()
- except Exception as e:
- raise Exception(f"Failed to list Lakebase Autoscaling projects: {str(e)}")
-
- result = []
- projects = list(response) if response else []
- for proj in projects:
- entry: Dict[str, Any] = {"name": proj.name}
-
- if proj.status:
- for attr, key, transform in [
- ("display_name", "display_name", None),
- ("pg_version", "pg_version", str),
- ("owner", "owner", None),
- ]:
- try:
- val = getattr(proj.status, attr)
- if val is not None:
- entry[key] = transform(val) if transform else val
- except (KeyError, AttributeError):
- pass
-
- result.append(entry)
-
- return result
-
-
-def update_project(
- name: str,
- display_name: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Update a Lakebase Autoscaling project.
-
- Args:
- name: Project resource name (e.g., "projects/my-app" or "my-app")
- display_name: New display name for the project
-
- Returns:
- Dictionary with updated project details
-
- Raises:
- Exception: If update fails
- """
- client = get_workspace_client()
- full_name = _normalize_project_name(name)
-
- try:
- from databricks.sdk.service.postgres import Project, ProjectSpec, FieldMask
-
- update_fields = []
- spec_kwargs: Dict[str, Any] = {}
-
- if display_name is not None:
- spec_kwargs["display_name"] = display_name
- update_fields.append("spec.display_name")
-
- if not update_fields:
- return {
- "name": full_name,
- "status": "NO_CHANGES",
- "error": "No fields specified for update",
- }
-
- operation = client.postgres.update_project(
- name=full_name,
- project=Project(
- name=full_name,
- spec=ProjectSpec(**spec_kwargs),
- ),
- update_mask=FieldMask(field_mask=update_fields),
- )
- result_project = operation.wait()
-
- result: Dict[str, Any] = {
- "name": full_name,
- "status": "UPDATED",
- }
-
- if display_name is not None:
- result["display_name"] = display_name
-
- if result_project and result_project.status:
- try:
- if result_project.status.state:
- result["state"] = str(result_project.status.state)
- except (KeyError, AttributeError):
- pass
-
- return result
- except Exception as e:
- raise Exception(f"Failed to update Lakebase Autoscaling project '{full_name}': {str(e)}")
-
-
-def delete_project(name: str) -> Dict[str, Any]:
- """
- Delete a Lakebase Autoscaling project and all its resources.
-
- WARNING: This permanently deletes all branches, computes, databases,
- roles, and data in the project.
-
- Args:
- name: Project resource name (e.g., "projects/my-app" or "my-app")
-
- Returns:
- Dictionary with:
- - name: Project resource name
- - status: "deleted" or error info
-
- Raises:
- Exception: If deletion fails
- """
- client = get_workspace_client()
- full_name = _normalize_project_name(name)
-
- try:
- operation = client.postgres.delete_project(name=full_name)
- operation.wait()
- return {
- "name": full_name,
- "status": "deleted",
- }
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": full_name,
- "status": "NOT_FOUND",
- "error": f"Project '{full_name}' not found",
- }
- raise Exception(f"Failed to delete Lakebase Autoscaling project '{full_name}': {error_msg}")
diff --git a/databricks-tools-core/databricks_tools_core/pdf/__init__.py b/databricks-tools-core/databricks_tools_core/pdf/__init__.py
deleted file mode 100644
index 823ae93c..00000000
--- a/databricks-tools-core/databricks_tools_core/pdf/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-"""PDF - Convert HTML to PDF and upload to Unity Catalog volumes."""
-
-from .generator import PDFResult, generate_and_upload_pdf
-
-__all__ = [
- "generate_and_upload_pdf",
- "PDFResult",
-]
diff --git a/databricks-tools-core/databricks_tools_core/pdf/generator.py b/databricks-tools-core/databricks_tools_core/pdf/generator.py
deleted file mode 100644
index d0cea223..00000000
--- a/databricks-tools-core/databricks_tools_core/pdf/generator.py
+++ /dev/null
@@ -1,160 +0,0 @@
-"""PDF document generation - convert HTML to PDF and upload to Unity Catalog volumes."""
-
-import logging
-import tempfile
-from pathlib import Path
-from typing import Optional
-
-from pydantic import BaseModel
-
-from ..auth import get_workspace_client
-from ..unity_catalog.volume_files import upload_to_volume
-
-logger = logging.getLogger(__name__)
-
-
-class PDFResult(BaseModel):
- """Result from generating a PDF."""
-
- success: bool
- volume_path: Optional[str] = None
- error: Optional[str] = None
-
-
-def _convert_html_to_pdf(html_content: str, output_path: str) -> bool:
- """Convert HTML content to PDF using PlutoPrint.
-
- Args:
- html_content: HTML string to convert
- output_path: Path where PDF should be saved
-
- Returns:
- True if successful, False otherwise
- """
- output_dir = Path(output_path).parent
- output_dir.mkdir(parents=True, exist_ok=True)
-
- try:
- import plutoprint
-
- logger.debug(f"Converting HTML to PDF using PlutoPrint: {output_path}")
-
- book = plutoprint.Book(plutoprint.PAGE_SIZE_A4)
- book.load_html(html_content)
- book.write_to_pdf(output_path)
-
- if Path(output_path).exists():
- file_size = Path(output_path).stat().st_size
- logger.info(f"PDF saved: {output_path} (size: {file_size:,} bytes)")
- return True
- else:
- logger.error("PlutoPrint conversion failed - file not created")
- return False
-
- except ImportError:
- logger.error("PlutoPrint is not installed. Install with: pip install plutoprint")
- return False
- except Exception as e:
- logger.error(f"Failed to convert HTML to PDF: {str(e)}", exc_info=True)
- return False
-
-
-def _validate_volume_path(catalog: str, schema: str, volume: str) -> None:
- """Validate that the catalog, schema, and volume exist."""
- w = get_workspace_client()
-
- try:
- w.schemas.get(full_name=f"{catalog}.{schema}")
- except Exception as e:
- raise ValueError(f"Schema '{catalog}.{schema}' does not exist: {e}") from e
-
- try:
- w.volumes.read(name=f"{catalog}.{schema}.{volume}")
- except Exception as e:
- raise ValueError(f"Volume '{catalog}.{schema}.{volume}' does not exist: {e}") from e
-
-
-def generate_and_upload_pdf(
- html_content: str,
- filename: str,
- catalog: str,
- schema: str,
- volume: str = "raw_data",
- folder: Optional[str] = None,
-) -> PDFResult:
- """Convert HTML to PDF and upload to a Unity Catalog volume.
-
- Args:
- html_content: Complete HTML document (including , , ,
- ... Hello World
- ...
- ... '''
- >>> result = generate_and_upload_pdf(
- ... html_content=html,
- ... filename="hello.pdf",
- ... catalog="my_catalog",
- ... schema="my_schema",
- ... )
- >>> print(result.volume_path)
- /Volumes/my_catalog/my_schema/raw_data/hello.pdf
- """
- # Ensure filename ends with .pdf
- if not filename.lower().endswith(".pdf"):
- filename = f"{filename}.pdf"
-
- # Validate volume exists
- try:
- _validate_volume_path(catalog, schema, volume)
- except ValueError as e:
- return PDFResult(success=False, error=str(e))
-
- # Build volume path
- if folder:
- volume_path = f"/Volumes/{catalog}/{schema}/{volume}/{folder}/{filename}"
- else:
- volume_path = f"/Volumes/{catalog}/{schema}/{volume}/{filename}"
-
- try:
- with tempfile.TemporaryDirectory() as temp_dir:
- local_pdf_path = str(Path(temp_dir) / filename)
-
- # Convert HTML to PDF
- if not _convert_html_to_pdf(html_content, local_pdf_path):
- return PDFResult(success=False, error="Failed to convert HTML to PDF")
-
- # Create folder if needed
- if folder:
- from ..unity_catalog.volume_files import create_volume_directory
-
- folder_path = f"/Volumes/{catalog}/{schema}/{volume}/{folder}"
- try:
- create_volume_directory(folder_path)
- except Exception:
- pass # Folder may already exist
-
- # Upload to volume
- result = upload_to_volume(local_pdf_path, volume_path, overwrite=True)
- if not result.success:
- return PDFResult(success=False, error=f"Failed to upload PDF: {result.error}")
-
- logger.info(f"PDF uploaded to {volume_path}")
- return PDFResult(success=True, volume_path=volume_path)
-
- except Exception as e:
- error_msg = f"Error generating PDF: {str(e)}"
- logger.error(error_msg, exc_info=True)
- return PDFResult(success=False, error=error_msg)
diff --git a/databricks-tools-core/databricks_tools_core/serving/__init__.py b/databricks-tools-core/databricks_tools_core/serving/__init__.py
deleted file mode 100644
index eea824ef..00000000
--- a/databricks-tools-core/databricks_tools_core/serving/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-"""
-Model Serving Operations
-
-Functions for managing and querying Databricks Model Serving endpoints.
-"""
-
-from .endpoints import (
- get_serving_endpoint_status,
- query_serving_endpoint,
- list_serving_endpoints,
-)
-
-__all__ = [
- "get_serving_endpoint_status",
- "query_serving_endpoint",
- "list_serving_endpoints",
-]
diff --git a/databricks-tools-core/databricks_tools_core/serving/endpoints.py b/databricks-tools-core/databricks_tools_core/serving/endpoints.py
deleted file mode 100644
index 5645e0c6..00000000
--- a/databricks-tools-core/databricks_tools_core/serving/endpoints.py
+++ /dev/null
@@ -1,250 +0,0 @@
-"""
-Model Serving Endpoints Operations
-
-Functions for checking status and querying Databricks Model Serving endpoints.
-"""
-
-import logging
-from typing import Any, Dict, List, Optional
-
-from databricks.sdk.service.serving import ChatMessage
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def get_serving_endpoint_status(name: str) -> Dict[str, Any]:
- """
- Get the status of a Model Serving endpoint.
-
- Args:
- name: The name of the serving endpoint
-
- Returns:
- Dictionary with endpoint status:
- - name: Endpoint name
- - state: Current state (READY, NOT_READY, etc.)
- - config_update: Config update state if updating
- - creation_timestamp: When endpoint was created
- - last_updated_timestamp: When endpoint was last updated
- - pending_config: Details of pending config update if any
- - served_entities: List of served models/entities with their states
- - error: Error message if endpoint is in error state
-
- Raises:
- Exception: If endpoint not found or API request fails
- """
- client = get_workspace_client()
-
- try:
- endpoint = client.serving_endpoints.get(name=name)
- except Exception as e:
- error_msg = str(e)
- if "RESOURCE_DOES_NOT_EXIST" in error_msg or "404" in error_msg:
- return {
- "name": name,
- "state": "NOT_FOUND",
- "error": f"Endpoint '{name}' not found",
- }
- raise Exception(f"Failed to get serving endpoint '{name}': {error_msg}")
-
- # Extract state information
- state_info = {}
- if endpoint.state:
- state_info["state"] = endpoint.state.ready.value if endpoint.state.ready else None
- state_info["config_update"] = endpoint.state.config_update.value if endpoint.state.config_update else None
-
- # Extract served entities status
- served_entities = []
- if endpoint.config and endpoint.config.served_entities:
- for entity in endpoint.config.served_entities:
- entity_info = {
- "name": entity.name,
- "entity_name": entity.entity_name,
- "entity_version": entity.entity_version,
- }
- if entity.state:
- entity_info["deployment_state"] = entity.state.deployment.value if entity.state.deployment else None
- entity_info["deployment_state_message"] = entity.state.deployment_state_message
- served_entities.append(entity_info)
-
- # Check for pending config
- pending_config = None
- if endpoint.pending_config:
- pending_config = {
- "served_entities": [
- {
- "name": e.name,
- "entity_name": e.entity_name,
- "entity_version": e.entity_version,
- }
- for e in (endpoint.pending_config.served_entities or [])
- ]
- }
-
- return {
- "name": endpoint.name,
- "state": state_info.get("state"),
- "config_update": state_info.get("config_update"),
- "creation_timestamp": endpoint.creation_timestamp,
- "last_updated_timestamp": endpoint.last_updated_timestamp,
- "served_entities": served_entities,
- "pending_config": pending_config,
- "error": None,
- }
-
-
-def query_serving_endpoint(
- name: str,
- messages: Optional[List[Dict[str, str]]] = None,
- inputs: Optional[Dict[str, Any]] = None,
- dataframe_records: Optional[List[Dict[str, Any]]] = None,
- max_tokens: Optional[int] = None,
- temperature: Optional[float] = None,
-) -> Dict[str, Any]:
- """
- Query a Model Serving endpoint.
-
- Supports multiple input formats:
- - messages: For chat/agent endpoints (OpenAI-compatible format)
- - inputs: For custom pyfunc models
- - dataframe_records: For traditional ML models (pandas DataFrame format)
-
- Args:
- name: The name of the serving endpoint
- messages: List of chat messages [{"role": "user", "content": "..."}]
- inputs: Dictionary of inputs for custom models
- dataframe_records: List of records for DataFrame input
- max_tokens: Maximum tokens for chat/completion endpoints
- temperature: Temperature for chat/completion endpoints
-
- Returns:
- Dictionary with query response:
- - For chat endpoints: Contains 'choices' with assistant response
- - For ML endpoints: Contains 'predictions'
- - Always includes 'usage' if available
-
- Raises:
- Exception: If query fails or endpoint not ready
- """
- client = get_workspace_client()
-
- # Build query kwargs
- query_kwargs: Dict[str, Any] = {"name": name}
-
- if messages is not None:
- # Chat/Agent endpoint - convert dicts to ChatMessage objects
- query_kwargs["messages"] = [ChatMessage.from_dict(m) for m in messages]
- if max_tokens is not None:
- query_kwargs["max_tokens"] = max_tokens
- if temperature is not None:
- query_kwargs["temperature"] = temperature
- elif inputs is not None:
- # Custom pyfunc model - use instances format
- query_kwargs["instances"] = [inputs]
- elif dataframe_records is not None:
- # Traditional ML model - DataFrame format
- query_kwargs["dataframe_records"] = dataframe_records
- else:
- raise ValueError(
- "Must provide one of: messages (for chat/agents), "
- "inputs (for custom models), or dataframe_records (for ML models)"
- )
-
- try:
- response = client.serving_endpoints.query(**query_kwargs)
- except Exception as e:
- error_msg = str(e)
- if "RESOURCE_DOES_NOT_EXIST" in error_msg:
- raise Exception(f"Endpoint '{name}' not found")
- if "NOT_READY" in error_msg or "PENDING" in error_msg:
- raise Exception(f"Endpoint '{name}' is not ready. Check status with get_serving_endpoint_status('{name}')")
- raise Exception(f"Failed to query endpoint '{name}': {error_msg}")
-
- # Convert response to dict
- result: Dict[str, Any] = {}
-
- # Handle chat response format
- if hasattr(response, "choices") and response.choices:
- result["choices"] = [
- {
- "index": c.index,
- "message": {
- "role": c.message.role if c.message else None,
- "content": c.message.content if c.message else None,
- },
- "finish_reason": c.finish_reason,
- }
- for c in response.choices
- ]
-
- # Handle predictions format (ML models)
- if hasattr(response, "predictions") and response.predictions:
- result["predictions"] = response.predictions
-
- # Handle generic output
- if hasattr(response, "output") and response.output:
- result["output"] = response.output
-
- # Include usage if available
- if hasattr(response, "usage") and response.usage:
- result["usage"] = {
- "prompt_tokens": response.usage.prompt_tokens,
- "completion_tokens": response.usage.completion_tokens,
- "total_tokens": response.usage.total_tokens,
- }
-
- # If empty, return raw response as dict
- if not result:
- result = response.as_dict() if hasattr(response, "as_dict") else {"raw": str(response)}
-
- return result
-
-
-def list_serving_endpoints(limit: Optional[int] = 50) -> List[Dict[str, Any]]:
- """
- List Model Serving endpoints in the workspace.
-
- Args:
- limit: Maximum number of endpoints to return (default: 50). Pass None for all.
-
- Returns:
- List of endpoint dictionaries with keys:
- - name: Endpoint name
- - state: Current state (READY, NOT_READY, etc.)
- - creation_timestamp: When endpoint was created
- - creator: Who created the endpoint
- - served_entities_count: Number of served models
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- endpoints = list(client.serving_endpoints.list())
- except Exception as e:
- raise Exception(f"Failed to list serving endpoints: {str(e)}")
-
- result = []
- for ep in endpoints[:limit]:
- state = None
- if ep.state:
- state = ep.state.ready.value if ep.state.ready else None
-
- served_count = 0
- if ep.config and ep.config.served_entities:
- served_count = len(ep.config.served_entities)
-
- result.append(
- {
- "name": ep.name,
- "state": state,
- "creation_timestamp": ep.creation_timestamp,
- "creator": ep.creator,
- "served_entities_count": served_count,
- }
- )
-
- return result
diff --git a/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/__init__.py b/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/__init__.py
deleted file mode 100644
index 44e53763..00000000
--- a/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Spark Declarative Pipelines (SDP) operations"""
-
-from . import pipelines as pipelines, workspace_files as workspace_files
diff --git a/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/pipelines.py b/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/pipelines.py
deleted file mode 100644
index dd0e4053..00000000
--- a/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/pipelines.py
+++ /dev/null
@@ -1,866 +0,0 @@
-"""
-Spark Declarative Pipelines - Pipeline Management
-
-Functions for managing SDP pipeline lifecycle using Databricks Pipelines API.
-All pipelines use Unity Catalog and serverless compute by default.
-"""
-
-import time
-from dataclasses import dataclass, field
-from typing import List, Optional, Dict, Any
-
-from databricks.sdk.service.pipelines import (
- CreatePipelineResponse,
- GetPipelineResponse,
- PipelineLibrary,
- FileLibrary,
- PipelineEvent,
- UpdateInfoState,
- PipelineCluster,
- EventLogSpec,
- Notifications,
- RestartWindow,
- PipelineDeployment,
- Filters,
- PipelinesEnvironment,
- IngestionGatewayPipelineDefinition,
- IngestionPipelineDefinition,
- PipelineTrigger,
- RunAs,
-)
-
-from ..auth import get_workspace_client
-
-
-# Fields that are not valid SDK parameters and should be filtered out
-_INVALID_SDK_FIELDS = {"pipeline_type"}
-
-# Fields that need conversion from dict to SDK objects
-_COMPLEX_FIELD_CONVERTERS = {
- "libraries": lambda items: [PipelineLibrary.from_dict(item) for item in items] if items else None,
- "clusters": lambda items: [PipelineCluster.from_dict(item) for item in items] if items else None,
- "event_log": lambda item: EventLogSpec.from_dict(item) if item else None,
- "notifications": lambda items: [Notifications.from_dict(item) for item in items] if items else None,
- "restart_window": lambda item: RestartWindow.from_dict(item) if item else None,
- "deployment": lambda item: PipelineDeployment.from_dict(item) if item else None,
- "filters": lambda item: Filters.from_dict(item) if item else None,
- "environment": lambda item: PipelinesEnvironment.from_dict(item) if item else None,
- "gateway_definition": lambda item: IngestionGatewayPipelineDefinition.from_dict(item) if item else None,
- "ingestion_definition": lambda item: IngestionPipelineDefinition.from_dict(item) if item else None,
- "trigger": lambda item: PipelineTrigger.from_dict(item) if item else None,
- "run_as": lambda item: RunAs.from_dict(item) if item else None,
-}
-
-
-def _convert_extra_settings(extra_settings: Dict[str, Any]) -> Dict[str, Any]:
- """
- Convert extra_settings dict to SDK-compatible kwargs.
-
- - Filters out invalid fields (e.g., pipeline_type)
- - Converts nested dicts to SDK objects (e.g., clusters, event_log)
- - Passes simple types directly
-
- Args:
- extra_settings: Raw dict from user (e.g., from Databricks UI JSON export)
-
- Returns:
- Dict with SDK-compatible values
- """
- result = {}
-
- for key, value in extra_settings.items():
- # Skip invalid fields
- if key in _INVALID_SDK_FIELDS:
- continue
-
- # Skip None values
- if value is None:
- continue
-
- # Convert complex fields
- if key in _COMPLEX_FIELD_CONVERTERS:
- converted = _COMPLEX_FIELD_CONVERTERS[key](value)
- if converted is not None:
- result[key] = converted
- else:
- # Pass simple types directly (strings, bools, dicts like configuration/tags)
- result[key] = value
-
- return result
-
-
-# Terminal states - pipeline update has finished (success or failure)
-TERMINAL_STATES = {
- UpdateInfoState.COMPLETED,
- UpdateInfoState.FAILED,
- UpdateInfoState.CANCELED,
-}
-
-# Running states - pipeline update is in progress
-RUNNING_STATES = {
- UpdateInfoState.RUNNING,
- UpdateInfoState.INITIALIZING,
- UpdateInfoState.SETTING_UP_TABLES,
- UpdateInfoState.WAITING_FOR_RESOURCES,
- UpdateInfoState.QUEUED,
- UpdateInfoState.RESETTING,
- UpdateInfoState.STOPPING,
- UpdateInfoState.CREATED,
-}
-
-
-def _build_libraries(workspace_file_paths: List[str]) -> List[PipelineLibrary]:
- """Build PipelineLibrary list from file paths."""
- return [PipelineLibrary(file=FileLibrary(path=path)) for path in workspace_file_paths]
-
-
-def _extract_error_summary(events: List[PipelineEvent]) -> List[str]:
- """
- Extract concise error messages from pipeline events.
-
- Returns a deduplicated list of error messages, trying multiple fallbacks:
- 1. First exception message from error.exceptions (most detailed)
- 2. Event message (always present, e.g., "Update X is FAILED")
-
- This is the default for MCP tools since full stack traces are too verbose.
- """
- summaries = []
- for event in events:
- message = None
-
- # Try to get detailed exception message first
- if event.error and event.error.exceptions:
- for exc in event.error.exceptions:
- if exc.message:
- short_class = (exc.class_name or "Error").split(".")[-1]
- message = f"{short_class}: {exc.message}"
- break # Take the first exception with a message
-
- # Fall back to event message if no exception message found
- if not message and event.message:
- message = str(event.message)
-
- if message:
- summaries.append(message)
-
- # Deduplicate while preserving order
- seen = set()
- return [s for s in summaries if not (s in seen or seen.add(s))]
-
-
-def _extract_error_details(events: List[PipelineEvent]) -> List[Dict[str, Any]]:
- """Extract full error details from pipeline events (includes stack traces)."""
- errors = []
- for event in events:
- if event.error:
- error_info = {
- "message": str(event.message) if event.message else None,
- "level": event.level.value if event.level else None,
- "timestamp": event.timestamp if event.timestamp else None,
- }
- # Extract exception details
- if event.error.exceptions:
- exceptions = []
- for exc in event.error.exceptions:
- exc_detail = {
- "class_name": exc.class_name if hasattr(exc, "class_name") else None,
- "message": exc.message if hasattr(exc, "message") else str(exc),
- }
- exceptions.append(exc_detail)
- error_info["exceptions"] = exceptions
- errors.append(error_info)
- return errors
-
-
-@dataclass
-class PipelineRunResult:
- """
- Result from a pipeline operation with detailed status for LLM consumption.
-
- This dataclass provides comprehensive information about pipeline operations
- to help LLMs understand what happened and take appropriate action.
- """
-
- # Pipeline identification
- pipeline_id: str
- pipeline_name: str
-
- # Operation details
- update_id: Optional[str] = None
- state: Optional[str] = None
- success: bool = False
- created: bool = False # True if pipeline was created, False if updated
-
- # Configuration (for context)
- catalog: Optional[str] = None
- schema: Optional[str] = None
- root_path: Optional[str] = None
-
- # Timing
- duration_seconds: Optional[float] = None
-
- # Error details (if failed)
- error_message: Optional[str] = None
- errors: List[Dict[str, Any]] = field(default_factory=list)
-
- # Human-readable status
- message: str = ""
-
- def to_dict(self) -> Dict[str, Any]:
- """Convert to dictionary for JSON serialization."""
- return {
- "pipeline_id": self.pipeline_id,
- "pipeline_name": self.pipeline_name,
- "update_id": self.update_id,
- "state": self.state,
- "success": self.success,
- "created": self.created,
- "catalog": self.catalog,
- "schema": self.schema,
- "root_path": self.root_path,
- "duration_seconds": self.duration_seconds,
- "error_message": self.error_message,
- "errors": self.errors,
- "message": self.message,
- }
-
-
-def find_pipeline_by_name(name: str) -> Optional[str]:
- """
- Find a pipeline by name and return its ID.
-
- Args:
- name: Pipeline name to search for (exact match)
-
- Returns:
- Pipeline ID if found, None otherwise
- """
- w = get_workspace_client()
-
- # List pipelines with name filter and find exact match
- for pipeline in w.pipelines.list_pipelines(filter=f"name LIKE '{name}'"):
- if pipeline.name == name:
- return pipeline.pipeline_id
-
- return None
-
-
-def create_pipeline(
- name: str,
- root_path: str,
- catalog: str,
- schema: str,
- workspace_file_paths: List[str],
- extra_settings: Optional[Dict[str, Any]] = None,
-) -> CreatePipelineResponse:
- """
- Create a new Spark Declarative Pipeline (Unity Catalog, serverless by default).
-
- Args:
- name: Pipeline name
- root_path: Root folder for source code (added to Python sys.path for imports)
- catalog: Unity Catalog name
- schema: Schema name for output tables
- workspace_file_paths: List of workspace file paths (raw .sql or .py files)
- extra_settings: Optional dict with additional pipeline settings. These are passed
- directly to the Databricks SDK pipelines.create() call. Explicit parameters
- (name, root_path, catalog, schema, workspace_file_paths) take precedence.
- Supports all SDK options: clusters, continuous, development, photon, edition,
- channel, event_log, configuration, notifications, tags, etc.
- Note: If 'id' is provided in extra_settings, use update_pipeline instead.
-
- Returns:
- CreatePipelineResponse with pipeline_id
-
- Raises:
- DatabricksError: If pipeline already exists or API request fails
- """
- w = get_workspace_client()
- libraries = _build_libraries(workspace_file_paths)
-
- # Start with converted extra_settings as base
- kwargs: Dict[str, Any] = {}
- if extra_settings:
- kwargs = _convert_extra_settings(extra_settings)
-
- # Explicit parameters always take precedence
- kwargs["name"] = name
- kwargs["root_path"] = root_path
- kwargs["catalog"] = catalog
- kwargs["schema"] = schema
- kwargs["libraries"] = libraries
-
- # Set defaults only if not provided in extra_settings
- if "continuous" not in kwargs:
- kwargs["continuous"] = False
- if "serverless" not in kwargs:
- kwargs["serverless"] = True
-
- # Remove 'id' if present - create should not have an id
- kwargs.pop("id", None)
-
- return w.pipelines.create(**kwargs)
-
-
-def get_pipeline(pipeline_id: str) -> GetPipelineResponse:
- """
- Get pipeline details and configuration.
-
- Args:
- pipeline_id: Pipeline ID
-
- Returns:
- GetPipelineResponse with full pipeline configuration and state
- """
- w = get_workspace_client()
- return w.pipelines.get(pipeline_id=pipeline_id)
-
-
-def update_pipeline(
- pipeline_id: str,
- name: Optional[str] = None,
- root_path: Optional[str] = None,
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- workspace_file_paths: Optional[List[str]] = None,
- extra_settings: Optional[Dict[str, Any]] = None,
-) -> None:
- """
- Update pipeline configuration.
-
- Args:
- pipeline_id: Pipeline ID
- name: New pipeline name
- root_path: New root folder for source code
- catalog: New catalog name
- schema: New schema name
- workspace_file_paths: New list of file paths (raw .sql or .py files)
- extra_settings: Optional dict with additional pipeline settings. These are passed
- directly to the Databricks SDK pipelines.update() call. Explicit parameters
- take precedence over values in extra_settings.
- Supports all SDK options: clusters, continuous, development, photon, edition,
- channel, event_log, configuration, notifications, tags, etc.
- """
- w = get_workspace_client()
-
- # Start with converted extra_settings as base
- kwargs: Dict[str, Any] = {}
- if extra_settings:
- kwargs = _convert_extra_settings(extra_settings)
-
- # pipeline_id is required and always set
- kwargs["pipeline_id"] = pipeline_id
-
- # Explicit parameters take precedence (only if provided)
- if name:
- kwargs["name"] = name
- if root_path:
- kwargs["root_path"] = root_path
- if catalog:
- kwargs["catalog"] = catalog
- if schema:
- kwargs["schema"] = schema
- if workspace_file_paths:
- kwargs["libraries"] = _build_libraries(workspace_file_paths)
-
- # Ensure id in kwargs matches pipeline_id (SDK uses both)
- if "id" in kwargs and kwargs["id"] != pipeline_id:
- kwargs["id"] = pipeline_id
-
- w.pipelines.update(**kwargs)
-
-
-def delete_pipeline(pipeline_id: str) -> None:
- """
- Delete a pipeline.
-
- Args:
- pipeline_id: Pipeline ID
- """
- w = get_workspace_client()
- w.pipelines.delete(pipeline_id=pipeline_id)
-
-
-def start_update(
- pipeline_id: str,
- refresh_selection: Optional[List[str]] = None,
- full_refresh: bool = False,
- full_refresh_selection: Optional[List[str]] = None,
- validate_only: bool = False,
- wait: bool = True,
- timeout: int = 300,
- poll_interval: int = 5,
- full_error_details: bool = False,
-) -> Dict[str, Any]:
- """
- Start a pipeline update or dry-run validation.
-
- Args:
- pipeline_id: Pipeline ID
- refresh_selection: List of table names to refresh
- full_refresh: If True, performs full refresh of all tables
- full_refresh_selection: List of table names for full refresh
- validate_only: If True, performs dry-run validation without updating data
- wait: If True (default), wait for the update to complete and return results.
- If False, return immediately with just the update_id.
- timeout: Maximum wait time in seconds (default: 300 = 5 minutes)
- poll_interval: Time between status checks in seconds (default: 5)
- full_error_details: If True, return full error events with stack traces.
- If False (default), return only concise error messages.
-
- Returns:
- Dictionary with:
- - update_id: The update ID
- - If wait=True, also includes:
- - state: Final state (COMPLETED, FAILED, CANCELED)
- - success: True if completed successfully
- - duration_seconds: Total time taken
- - error_summary: List of concise error messages (default)
- - errors: Full error events with stack traces (only if full_error_details=True)
- """
- w = get_workspace_client()
-
- response = w.pipelines.start_update(
- pipeline_id=pipeline_id,
- refresh_selection=refresh_selection,
- full_refresh=full_refresh,
- full_refresh_selection=full_refresh_selection,
- validate_only=validate_only,
- )
-
- update_id = response.update_id
-
- if not wait:
- return {"update_id": update_id}
-
- # Wait for completion
- start_time = time.time()
-
- while True:
- elapsed = time.time() - start_time
-
- if elapsed > timeout:
- return {
- "update_id": update_id,
- "state": "TIMEOUT",
- "success": False,
- "duration_seconds": round(elapsed, 2),
- "error_summary": [
- f"Pipeline update did not complete within {timeout} seconds. "
- f"Check status with get_update(pipeline_id='{pipeline_id}', update_id='{update_id}')."
- ],
- }
-
- update_response = w.pipelines.get_update(pipeline_id=pipeline_id, update_id=update_id)
- update_info = update_response.update
-
- if not update_info:
- time.sleep(poll_interval)
- continue
-
- state = update_info.state
-
- if state in TERMINAL_STATES:
- result = {
- "update_id": update_id,
- "state": state.value if state else None,
- "success": state == UpdateInfoState.COMPLETED,
- "duration_seconds": round(elapsed, 2),
- }
-
- # If failed, get error/warning events for this specific update
- if state == UpdateInfoState.FAILED:
- events = get_pipeline_events(
- pipeline_id=pipeline_id,
- max_results=10,
- filter="level in ('ERROR', 'WARN')",
- update_id=update_id,
- )
- result["error_summary"] = _extract_error_summary(events)
- if full_error_details:
- result["errors"] = [e.as_dict() if hasattr(e, "as_dict") else vars(e) for e in events]
-
- return result
-
- time.sleep(poll_interval)
-
-
-def get_update(
- pipeline_id: str,
- update_id: str,
- include_config: bool = False,
- full_error_details: bool = False,
-) -> Dict[str, Any]:
- """
- Get pipeline update status and results.
-
- If the update failed, automatically fetches ERROR/WARN events for that update.
-
- Args:
- pipeline_id: Pipeline ID
- update_id: Update ID from start_update
- include_config: If True, include the full pipeline configuration in the response.
- Default is False since the config is very large and verbose.
- full_error_details: If True, return full error events with stack traces.
- If False (default), return only concise error messages.
-
- Returns:
- Dictionary with:
- - update_id: The update ID
- - state: Current state (QUEUED, RUNNING, COMPLETED, FAILED, CANCELED)
- - success: True if completed successfully, False if failed, None if still running
- - cause: What triggered the update (USER_ACTION, RETRY_ON_FAILURE, etc.)
- - creation_time: When the update was created
- - error_summary: List of concise error messages (default)
- - errors: Full error events with stack traces (only if full_error_details=True)
- - config: Pipeline configuration (only if include_config=True)
- """
- w = get_workspace_client()
- response = w.pipelines.get_update(pipeline_id=pipeline_id, update_id=update_id)
-
- update_info = response.update
- if not update_info:
- return {"update_id": update_id, "state": None, "success": None}
-
- state = update_info.state
-
- # Determine success status
- success = None
- if state == UpdateInfoState.COMPLETED:
- success = True
- elif state in (UpdateInfoState.FAILED, UpdateInfoState.CANCELED):
- success = False
-
- result = {
- "update_id": update_id,
- "state": state.value if state else None,
- "success": success,
- "cause": update_info.cause.value if update_info.cause else None,
- "creation_time": update_info.creation_time,
- }
-
- # If failed, get error/warning events for this specific update
- if state == UpdateInfoState.FAILED:
- events = get_pipeline_events(
- pipeline_id=pipeline_id,
- max_results=10,
- filter="level in ('ERROR', 'WARN')",
- update_id=update_id,
- )
- result["error_summary"] = _extract_error_summary(events)
- if full_error_details:
- result["errors"] = [e.as_dict() if hasattr(e, "as_dict") else vars(e) for e in events]
-
- # Optionally include config
- if include_config and update_info.config:
- config = update_info.config
- result["config"] = config.as_dict() if hasattr(config, "as_dict") else vars(config)
-
- return result
-
-
-def stop_pipeline(pipeline_id: str) -> None:
- """
- Stop a running pipeline.
-
- Args:
- pipeline_id: Pipeline ID
- """
- w = get_workspace_client()
- w.pipelines.stop(pipeline_id=pipeline_id)
-
-
-def get_pipeline_events(
- pipeline_id: str,
- max_results: int = 5,
- filter: str = "level in ('ERROR', 'WARN')",
- update_id: str = None,
-) -> List[PipelineEvent]:
- """
- Get pipeline events, issues, and error messages.
-
- Use this to debug pipeline failures. By default returns ERROR and WARN events
- since those contain the failure details. Each event can include full stack
- traces, so output can be verbose.
-
- Args:
- pipeline_id: Pipeline ID
- max_results: Maximum number of events to return (default: 5)
- filter: SQL-like filter expression (default: "level in ('ERROR', 'WARN')").
- Examples:
- - "level in ('ERROR', 'WARN')" - errors and warnings (default)
- - "level='ERROR'" - only errors
- - "level='INFO'" - info events (state transitions)
- - None or "" - all events (no filter)
- update_id: Optional update ID to filter events. If provided, only
- events from this specific update are returned. Get update IDs
- from get_pipeline().latest_updates or start_update().
-
- Returns:
- List of PipelineEvent objects with error details
- """
- w = get_workspace_client()
-
- effective_filter = filter if filter else None
-
- # If filtering by update_id, we need to fetch more events and filter client-side
- # since the API doesn't support origin.update_id in filter expressions
- api_max_results = max_results * 10 if update_id else max_results
-
- events = w.pipelines.list_pipeline_events(
- pipeline_id=pipeline_id,
- max_results=api_max_results,
- filter=effective_filter,
- )
-
- result = []
- for event in events:
- # Filter by update_id client-side if specified
- if update_id:
- event_update_id = event.origin.update_id if event.origin else None
- if event_update_id != update_id:
- continue
-
- result.append(event)
- if len(result) >= max_results:
- break
-
- return result
-
-
-def wait_for_pipeline_update(
- pipeline_id: str, update_id: str, timeout: int = 1800, poll_interval: int = 5
-) -> Dict[str, Any]:
- """
- Wait for a pipeline update to complete and return detailed results.
-
- Args:
- pipeline_id: Pipeline ID
- update_id: Update ID from start_update
- timeout: Maximum wait time in seconds (default: 30 minutes)
- poll_interval: Time between status checks in seconds
-
- Returns:
- Dictionary with detailed update results:
- - state: Final state (COMPLETED, FAILED, CANCELED)
- - success: True if completed successfully
- - duration_seconds: Total time taken
- - errors: List of error details if failed
-
- Raises:
- TimeoutError: If pipeline doesn't complete within timeout
- """
- w = get_workspace_client()
- start_time = time.time()
-
- while True:
- elapsed = time.time() - start_time
-
- if elapsed > timeout:
- raise TimeoutError(
- f"Pipeline update {update_id} did not complete within {timeout} seconds. "
- f"Check status in UI or call get_update(pipeline_id='{pipeline_id}', update_id='{update_id}')."
- )
-
- response = w.pipelines.get_update(pipeline_id=pipeline_id, update_id=update_id)
-
- update_info = response.update
- if not update_info:
- time.sleep(poll_interval)
- continue
-
- state = update_info.state
-
- if state in TERMINAL_STATES:
- result = {
- "state": state.value if state else None,
- "success": state == UpdateInfoState.COMPLETED,
- "duration_seconds": round(elapsed, 2),
- "update_id": update_id,
- "errors": [],
- }
-
- # If failed, get detailed error information
- if state == UpdateInfoState.FAILED:
- events = get_pipeline_events(pipeline_id, max_results=50)
- result["errors"] = _extract_error_details(events)
-
- return result
-
- time.sleep(poll_interval)
-
-
-def create_or_update_pipeline(
- name: str,
- root_path: str,
- catalog: str,
- schema: str,
- workspace_file_paths: List[str],
- start_run: bool = False,
- wait_for_completion: bool = False,
- full_refresh: bool = True,
- timeout: int = 1800,
- extra_settings: Optional[Dict[str, Any]] = None,
-) -> PipelineRunResult:
- """
- Create a new pipeline or update an existing one with the same name.
-
- This is the main entry point for pipeline management. It:
- 1. Searches for an existing pipeline with the same name (or uses 'id' from extra_settings)
- 2. Creates a new pipeline or updates the existing one
- 3. Optionally starts a pipeline run
- 4. Optionally waits for the run to complete
-
- Uses Unity Catalog and serverless compute by default.
-
- Args:
- name: Pipeline name (used for lookup and creation)
- root_path: Root folder for source code (added to Python sys.path for imports)
- catalog: Unity Catalog name for output tables
- schema: Schema name for output tables
- workspace_file_paths: List of workspace file paths (raw .sql or .py files)
- start_run: If True, start a pipeline run after create/update
- wait_for_completion: If True, wait for the run to complete (requires start_run=True)
- full_refresh: If True, perform full refresh when starting
- timeout: Maximum wait time in seconds (default: 30 minutes)
- extra_settings: Optional dict with additional pipeline settings. Supports all SDK
- options: clusters, continuous, development, photon, edition, channel, event_log,
- configuration, notifications, tags, serverless, etc.
- If 'id' is provided, the pipeline will be updated instead of created.
- Explicit parameters (name, root_path, catalog, schema) take precedence.
-
- Returns:
- PipelineRunResult with detailed status including:
- - pipeline_id, pipeline_name, catalog, schema, root_path
- - created: True if newly created, False if updated
- - success: True if all operations succeeded
- - state: Final state if run was started (COMPLETED, FAILED, etc.)
- - duration_seconds: Time taken if waited
- - error_message: Summary error message if failed
- - errors: List of detailed errors if failed
- - message: Human-readable status message
- """
- # Step 1: Check if pipeline exists (by name or by id in extra_settings)
- existing_pipeline_id = None
-
- # If extra_settings contains an 'id', use it for update
- if extra_settings and extra_settings.get("id"):
- existing_pipeline_id = extra_settings["id"]
- else:
- existing_pipeline_id = find_pipeline_by_name(name)
-
- created = existing_pipeline_id is None
-
- # Step 2: Create or update
- try:
- if created:
- response = create_pipeline(
- name=name,
- root_path=root_path,
- catalog=catalog,
- schema=schema,
- workspace_file_paths=workspace_file_paths,
- extra_settings=extra_settings,
- )
- pipeline_id = response.pipeline_id
- else:
- pipeline_id = existing_pipeline_id
- update_pipeline(
- pipeline_id=pipeline_id,
- name=name,
- root_path=root_path,
- catalog=catalog,
- schema=schema,
- workspace_file_paths=workspace_file_paths,
- extra_settings=extra_settings,
- )
- except Exception as e:
- # Return detailed error for LLM consumption
- return PipelineRunResult(
- pipeline_id=existing_pipeline_id or "unknown",
- pipeline_name=name,
- catalog=catalog,
- schema=schema,
- root_path=root_path,
- success=False,
- created=False,
- error_message=str(e),
- message=f"Failed to {'create' if created else 'update'} pipeline: {e}",
- )
-
- # Build result with context
- result = PipelineRunResult(
- pipeline_id=pipeline_id,
- pipeline_name=name,
- catalog=catalog,
- schema=schema,
- root_path=root_path,
- created=created,
- success=True,
- message=f"Pipeline {'created' if created else 'updated'} successfully. Target: {catalog}.{schema}",
- )
-
- # Step 3: Start run if requested
- if start_run:
- try:
- update_id = start_update(
- pipeline_id=pipeline_id,
- full_refresh=full_refresh,
- )
- result.update_id = update_id
- result.message = f"Pipeline {'created' if created else 'updated'} and run started. Update ID: {update_id}"
- except Exception as e:
- result.success = False
- result.error_message = f"Pipeline created but failed to start run: {e}"
- result.message = result.error_message
- return result
-
- # Step 4: Wait for completion if requested
- if wait_for_completion:
- try:
- wait_result = wait_for_pipeline_update(
- pipeline_id=pipeline_id,
- update_id=update_id,
- timeout=timeout,
- )
- result.state = wait_result["state"]
- result.success = wait_result["success"]
- result.duration_seconds = wait_result["duration_seconds"]
-
- if result.success:
- result.message = (
- f"Pipeline {'created' if created else 'updated'} and "
- f"completed successfully in {result.duration_seconds}s. "
- f"Tables written to {catalog}.{schema}"
- )
- else:
- result.errors = wait_result.get("errors", [])
- # Build informative error message for LLM
- if result.errors:
- first_error = result.errors[0]
- error_msg = first_error.get("message", "")
- if first_error.get("exceptions"):
- exc = first_error["exceptions"][0]
- error_msg = exc.get("message", error_msg)
- result.error_message = error_msg
- else:
- result.error_message = f"Pipeline failed with state: {result.state}"
-
- result.message = (
- f"Pipeline {'created' if created else 'updated'} but run failed. "
- f"State: {result.state}. "
- f"Error: {result.error_message}. "
- f"Use get_pipeline_events(pipeline_id='{pipeline_id}') for full details."
- )
-
- except TimeoutError as e:
- result.success = False
- result.state = "TIMEOUT"
- result.error_message = str(e)
- result.message = (
- f"Pipeline run timed out after {timeout}s. "
- f"The pipeline may still be running. "
- f"Check status with get_update(pipeline_id='{pipeline_id}', update_id='{update_id}')"
- )
-
- return result
diff --git a/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/workspace_files.py b/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/workspace_files.py
deleted file mode 100644
index 4e7e9f26..00000000
--- a/databricks-tools-core/databricks_tools_core/spark_declarative_pipelines/workspace_files.py
+++ /dev/null
@@ -1,141 +0,0 @@
-"""
-Spark Declarative Pipelines - Workspace File Operations
-
-Functions for managing workspace files and directories for SDP pipelines.
-"""
-
-import base64
-from typing import List
-from databricks.sdk.service.workspace import ObjectInfo, Language, ImportFormat, ExportFormat
-
-from ..auth import get_workspace_client
-
-
-def list_files(path: str) -> List[ObjectInfo]:
- """
- List files and directories in a workspace path.
-
- Args:
- path: Workspace path to list
-
- Returns:
- List of ObjectInfo objects with file/directory metadata:
- - path: Full workspace path
- - object_type: DIRECTORY, NOTEBOOK, FILE, LIBRARY, or REPO
- - language: For notebooks (PYTHON, SQL, SCALA, R)
- - object_id: Unique identifier
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return list(w.workspace.list(path=path))
-
-
-def get_file_status(path: str) -> ObjectInfo:
- """
- Get file or directory metadata.
-
- Args:
- path: Workspace path
-
- Returns:
- ObjectInfo object with metadata:
- - path: Full workspace path
- - object_type: DIRECTORY, NOTEBOOK, FILE, LIBRARY, or REPO
- - language: For notebooks (PYTHON, SQL, SCALA, R)
- - object_id: Unique identifier
- - size: File size in bytes (for files)
- - created_at: Creation timestamp
- - modified_at: Last modification timestamp
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.workspace.get_status(path=path)
-
-
-def read_file(path: str) -> str:
- """
- Read workspace file contents.
-
- Args:
- path: Workspace file path
-
- Returns:
- Decoded file contents as string
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- response = w.workspace.export(path=path, format=ExportFormat.SOURCE)
-
- # SDK returns ExportResponse with .content field (base64 encoded)
- return base64.b64decode(response.content).decode("utf-8")
-
-
-def write_file(path: str, content: str, language: str = "PYTHON", overwrite: bool = True) -> None:
- """
- Write or update workspace file.
-
- Args:
- path: Workspace file path
- content: File content as string
- language: PYTHON, SQL, SCALA, or R
- overwrite: If True, replaces existing file
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
-
- # Convert language string to enum
- lang_map = {
- "PYTHON": Language.PYTHON,
- "SQL": Language.SQL,
- "SCALA": Language.SCALA,
- "R": Language.R,
- }
- lang_enum = lang_map.get(language.upper(), Language.PYTHON)
-
- # Base64 encode content
- content_b64 = base64.b64encode(content.encode("utf-8")).decode("utf-8")
-
- w.workspace.import_(
- path=path,
- content=content_b64,
- language=lang_enum,
- format=ImportFormat.SOURCE,
- overwrite=overwrite,
- )
-
-
-def create_directory(path: str) -> None:
- """
- Create workspace directory.
-
- Args:
- path: Workspace directory path
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.workspace.mkdirs(path=path)
-
-
-def delete_path(path: str, recursive: bool = False) -> None:
- """
- Delete workspace file or directory.
-
- Args:
- path: Workspace path to delete
- recursive: If True, recursively deletes directories
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.workspace.delete(path=path, recursive=recursive)
diff --git a/databricks-tools-core/databricks_tools_core/sql/__init__.py b/databricks-tools-core/databricks_tools_core/sql/__init__.py
deleted file mode 100644
index befb8d53..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/__init__.py
+++ /dev/null
@@ -1,41 +0,0 @@
-"""
-SQL - SQL Warehouse Operations
-
-Functions for executing SQL queries, managing SQL warehouses, and getting table statistics.
-"""
-
-from .sql import execute_sql, execute_sql_multi
-from .warehouse import list_warehouses, get_best_warehouse
-from .table_stats import get_table_stats_and_schema, get_volume_folder_details
-from .sql_utils import (
- SQLExecutionError,
- TableStatLevel,
- TableSchemaResult,
- DataSourceInfo,
- TableInfo, # Alias for DataSourceInfo (backwards compatibility)
- ColumnDetail,
- VolumeFileInfo,
- VolumeFolderResult, # Alias for DataSourceInfo (backwards compatibility)
-)
-
-__all__ = [
- # SQL execution
- "execute_sql",
- "execute_sql_multi",
- # Warehouse management
- "list_warehouses",
- "get_best_warehouse",
- # Table statistics
- "get_table_stats_and_schema",
- "get_volume_folder_details",
- "TableStatLevel",
- "TableSchemaResult",
- "DataSourceInfo",
- "TableInfo", # Alias for DataSourceInfo
- "ColumnDetail",
- # Volume folder statistics
- "VolumeFileInfo",
- "VolumeFolderResult", # Alias for DataSourceInfo
- # Errors
- "SQLExecutionError",
-]
diff --git a/databricks-tools-core/databricks_tools_core/sql/sql.py b/databricks-tools-core/databricks_tools_core/sql/sql.py
deleted file mode 100644
index 2701314e..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/sql.py
+++ /dev/null
@@ -1,159 +0,0 @@
-"""
-SQL Execution
-
-High-level functions for executing SQL queries on Databricks.
-"""
-
-import logging
-from typing import Any, Dict, List, Optional
-
-from .sql_utils import SQLExecutor, SQLExecutionError, SQLParallelExecutor
-from .warehouse import get_best_warehouse
-
-logger = logging.getLogger(__name__)
-
-
-def execute_sql(
- sql_query: str,
- warehouse_id: Optional[str] = None,
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- timeout: int = 180,
- query_tags: Optional[str] = None,
-) -> List[Dict[str, Any]]:
- """
- Execute a SQL query on a Databricks SQL Warehouse.
-
- If no warehouse_id is provided, automatically selects the best available
- warehouse using get_best_warehouse().
-
- Args:
- sql_query: SQL query to execute
- warehouse_id: Optional warehouse ID. If not provided, auto-selects one.
- catalog: Optional catalog context. If not provided, use fully qualified names.
- schema: Optional schema context. If not provided, use fully qualified names.
- timeout: Timeout in seconds (default: 180)
- query_tags: Optional query tags for cost attribution and filtering.
- Format: "key:value,key2:value2" (e.g., "team:eng,cost_center:701").
- Appears in system.query.history and Query History UI.
-
- Returns:
- List of dictionaries, each representing a row with column names as keys.
- Example: [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
-
- Raises:
- SQLExecutionError: If query execution fails, with detailed error message:
- - No warehouse available
- - Warehouse not accessible
- - Query syntax error
- - Query timeout
- - Permission denied
- """
- # Auto-select warehouse if not provided
- if not warehouse_id:
- logger.debug("No warehouse_id provided, selecting best available warehouse")
- warehouse_id = get_best_warehouse()
- if not warehouse_id:
- raise SQLExecutionError(
- "No SQL warehouse available in the workspace. "
- "Please create a SQL warehouse or start an existing one, "
- "or provide a specific warehouse_id."
- )
- logger.debug(f"Auto-selected warehouse: {warehouse_id}")
-
- # Execute the query
- executor = SQLExecutor(warehouse_id=warehouse_id)
- return executor.execute(
- sql_query=sql_query,
- catalog=catalog,
- schema=schema,
- timeout=timeout,
- query_tags=query_tags,
- )
-
-
-def execute_sql_multi(
- sql_content: str,
- warehouse_id: Optional[str] = None,
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- timeout: int = 180,
- max_workers: int = 4,
- query_tags: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Execute multiple SQL statements with dependency-aware parallelism.
-
- Parses the SQL content into individual statements, analyzes dependencies
- between them (based on table creation and references), and executes them
- in optimal order. Queries that don't depend on each other run in parallel.
-
- If no warehouse_id is provided, automatically selects the best available
- warehouse using get_best_warehouse().
-
- Args:
- sql_content: SQL content with multiple statements separated by ;
- warehouse_id: Optional warehouse ID. If not provided, auto-selects one.
- catalog: Optional catalog context. If not provided, use fully qualified names.
- schema: Optional schema context. If not provided, use fully qualified names.
- timeout: Timeout per query in seconds (default: 180)
- max_workers: Maximum parallel queries per group (default: 4)
- query_tags: Optional query tags for cost attribution (e.g., "team:eng,cost_center:701").
-
- Returns:
- Dictionary with:
- - results: Dict mapping query index to result dict, each containing:
- - query_index: 0-based index of the query
- - status: "success" or "error"
- - execution_time: Time taken in seconds
- - query_preview: First 100 chars of the query
- - result_rows: Number of rows returned (for success)
- - sample_results: First 5 rows (for success)
- - error: Error message (for error)
- - error_category: Error type like SYNTAX_ERROR, MISSING_TABLE (for error)
- - suggestion: Hint on how to fix (for error)
- - group_number: Which execution group this query was in
- - is_parallel: Whether it ran in parallel with other queries
- - execution_summary: Overall statistics including:
- - total_queries: Number of queries parsed
- - total_groups: Number of execution groups
- - total_time: Total execution time
- - stopped_after_group: Group number where execution stopped (if error)
- - groups: List of group details
-
- Raises:
- SQLExecutionError: If parsing fails or no warehouse available
-
- Example:
- >>> result = execute_sql_multi('''
- ... CREATE TABLE t1 AS SELECT 1 as id;
- ... CREATE TABLE t2 AS SELECT 2 as id;
- ... CREATE TABLE t3 AS SELECT * FROM t1 JOIN t2;
- ... ''')
- >>> # t1 and t2 run in parallel (no dependencies)
- >>> # t3 runs after both complete (depends on t1 and t2)
- """
- # Auto-select warehouse if not provided
- if not warehouse_id:
- logger.debug("No warehouse_id provided, selecting best available warehouse")
- warehouse_id = get_best_warehouse()
- if not warehouse_id:
- raise SQLExecutionError(
- "No SQL warehouse available in the workspace. "
- "Please create a SQL warehouse or start an existing one, "
- "or provide a specific warehouse_id."
- )
- logger.debug(f"Auto-selected warehouse: {warehouse_id}")
-
- # Execute with parallel executor
- executor = SQLParallelExecutor(
- warehouse_id=warehouse_id,
- max_workers=max_workers,
- )
- return executor.execute(
- sql_content=sql_content,
- catalog=catalog,
- schema=schema,
- timeout=timeout,
- query_tags=query_tags,
- )
diff --git a/databricks-tools-core/databricks_tools_core/sql/sql_utils/__init__.py b/databricks-tools-core/databricks_tools_core/sql/sql_utils/__init__.py
deleted file mode 100644
index 18a125a2..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/sql_utils/__init__.py
+++ /dev/null
@@ -1,34 +0,0 @@
-"""
-SQL Utilities - Internal helpers for SQL operations.
-"""
-
-from .executor import SQLExecutor, SQLExecutionError
-from .dependency_analyzer import SQLDependencyAnalyzer
-from .parallel_executor import SQLParallelExecutor
-from .models import (
- TableStatLevel,
- HistogramBin,
- ColumnDetail,
- DataSourceInfo,
- TableInfo, # Alias for DataSourceInfo
- TableSchemaResult,
- VolumeFileInfo,
- VolumeFolderResult, # Alias for DataSourceInfo
-)
-from .table_stats_collector import TableStatsCollector
-
-__all__ = [
- "SQLExecutor",
- "SQLExecutionError",
- "SQLDependencyAnalyzer",
- "SQLParallelExecutor",
- "TableStatLevel",
- "HistogramBin",
- "ColumnDetail",
- "DataSourceInfo",
- "TableInfo", # Alias for DataSourceInfo
- "TableSchemaResult",
- "VolumeFileInfo",
- "VolumeFolderResult", # Alias for DataSourceInfo
- "TableStatsCollector",
-]
diff --git a/databricks-tools-core/databricks_tools_core/sql/sql_utils/dependency_analyzer.py b/databricks-tools-core/databricks_tools_core/sql/sql_utils/dependency_analyzer.py
deleted file mode 100644
index 42e7dac8..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/sql_utils/dependency_analyzer.py
+++ /dev/null
@@ -1,306 +0,0 @@
-"""
-SQL Dependency Analyzer
-
-Analyzes SQL queries to determine dependencies and optimal execution order.
-Uses sqlglot for parsing and sqlfluff for comment stripping.
-"""
-
-import logging
-from typing import Dict, List, Optional, Set
-
-import sqlglot
-from sqlglot import exp
-from sqlfluff.core import Linter
-
-logger = logging.getLogger(__name__)
-
-
-class SQLDependencyAnalyzer:
- """Analyzes SQL queries to determine dependencies and execution order.
-
- Parses SQL statements to identify:
- - Which tables are created by each query
- - Which tables are referenced by each query
- - Dependencies between queries based on table usage
-
- Uses this information to group queries into execution levels where
- queries within a level can run in parallel.
- """
-
- def __init__(self, dialect: str = "databricks"):
- """
- Initialize the analyzer.
-
- Args:
- dialect: SQL dialect for parsing (default: "databricks")
- """
- self.dialect = dialect
- self.created_tables: Dict[str, int] = {} # table_name -> query_index
- self.query_dependencies: Dict[int, Set[str]] = {} # query_index -> referenced tables
- self._linter = Linter(dialect=self.dialect)
-
- def parse_sql_content(self, sql_content: str) -> List[str]:
- """
- Parse SQL content into individual queries.
-
- Handles:
- - Multiple statements separated by semicolons
- - Comments (both -- and /* */)
- - Complex statements with subqueries
-
- Args:
- sql_content: Raw SQL content with potentially multiple statements
-
- Returns:
- List of individual SQL queries (each ending with ;)
- """
- cleaned = self._strip_comments(sql_content)
-
- queries: List[str] = []
- # Split by semicolons first to preserve original SQL
- raw_statements = [s.strip() for s in cleaned.split(";") if s.strip()]
-
- for raw_sql in raw_statements:
- # Parse to validate and analyze dependencies, but keep original SQL
- # to preserve formatting like backticks that sqlglot might modify
- parsed = sqlglot.parse(raw_sql, read=self.dialect)
- if parsed and parsed[0] is not None:
- # Use original SQL to preserve backticks and other formatting
- # that sqlglot might change during regeneration
- queries.append(raw_sql + ";")
- elif raw_sql:
- # If parsing fails, still try to include the statement
- queries.append(raw_sql + ";")
-
- logger.debug(f"Parsed {len(queries)} queries from SQL content")
- return queries
-
- def analyze_dependencies(self, queries: List[str]) -> List[List[int]]:
- """
- Analyze query dependencies and return execution groups.
-
- Queries within a group have no dependencies on each other and
- can be executed in parallel. Groups must be executed sequentially.
-
- Args:
- queries: List of SQL query strings
-
- Returns:
- List of groups, where each group is a list of query indices
- Example: [[0, 1], [2], [3, 4]] means:
- - Queries 0 and 1 can run in parallel (group 1)
- - Query 2 runs after group 1 (group 2)
- - Queries 3 and 4 can run in parallel after query 2 (group 3)
- """
- self.created_tables.clear()
- self.query_dependencies.clear()
-
- # Pass 1: Discover created objects and references
- for idx, query in enumerate(queries):
- exprs = [e for e in (sqlglot.parse(query, read=self.dialect) or []) if e is not None]
-
- created_here: Set[str] = set()
- referenced_here: Set[str] = set()
-
- for root in exprs:
- # Track CREATE statements
- if isinstance(root, exp.Create):
- table = self._bare(self._table_from_create(root))
- if table:
- created_here.add(table)
-
- # Track ALTER statements (reference existing table)
- if isinstance(root, exp.Alter):
- table = self._bare(self._table_from_alter(root))
- if table:
- referenced_here.add(table)
-
- # Track DROP statements (reference existing table)
- if isinstance(root, exp.Drop):
- table = self._bare(self._table_from_drop(root))
- if table:
- referenced_here.add(table)
-
- # Track general references (FROM, JOIN, etc.)
- refs = self._extract_referenced_tables(root)
- if refs:
- referenced_here |= refs
-
- # Record created tables
- for table in created_here:
- self.created_tables[table] = idx
- logger.debug(f"Query {idx} creates: {table}")
-
- # Record dependencies
- if referenced_here:
- self.query_dependencies.setdefault(idx, set()).update(referenced_here)
- logger.debug(f"Query {idx} references: {sorted(referenced_here)}")
-
- # Pass 2: Build query-to-query dependency edges
- edges: Dict[int, Set[int]] = {}
- for query_idx, tables in self.query_dependencies.items():
- deps = set()
- for table in tables:
- creator = self.created_tables.get(table)
- if creator is not None and creator != query_idx:
- deps.add(creator)
- if deps:
- edges[query_idx] = deps
- logger.debug(f"Query {query_idx} depends on queries: {sorted(deps)}")
-
- # Topological sort into execution groups
- groups = self._topological_sort(len(queries), edges)
- logger.info(f"Organized {len(queries)} queries into {len(groups)} execution groups")
- return groups
-
- def _strip_comments(self, sql: str) -> str:
- """Strip comments using sqlfluff, preserving line structure."""
- try:
- parsed = self._linter.parse_string(sql)
- if not parsed or not parsed.tree:
- return sql
-
- out_parts: List[str] = []
- for seg in parsed.tree.raw_segments:
- if seg.is_type("comment"):
- # Preserve newlines to maintain line boundaries
- if "\n" in seg.raw:
- out_parts.append("\n")
- continue
- out_parts.append(seg.raw)
- return "".join(out_parts)
- except Exception as e:
- logger.warning(f"Failed to strip comments with sqlfluff: {e}")
- return sql
-
- def _topological_sort(self, num_queries: int, dependencies: Dict[int, Set[int]]) -> List[List[int]]:
- """
- Kahn's algorithm for levelized topological ordering.
-
- Returns groups where each group can be executed in parallel.
- """
- in_degree = [0] * num_queries
- reverse_deps: Dict[int, Set[int]] = {i: set() for i in range(num_queries)}
-
- for query, deps in dependencies.items():
- in_degree[query] = len(deps)
- for dep in deps:
- reverse_deps[dep].add(query)
-
- # Start with queries that have no dependencies
- queue = [i for i in range(num_queries) if in_degree[i] == 0]
- groups: List[List[int]] = []
- visited: Set[int] = set()
-
- while queue:
- # All queries in current queue can run in parallel
- current = sorted(queue)
- groups.append(current)
- queue = []
-
- for query in current:
- visited.add(query)
- for next_query in reverse_deps[query]:
- if next_query in visited:
- continue
- in_degree[next_query] -= 1
- if in_degree[next_query] == 0:
- queue.append(next_query)
-
- # Handle circular dependencies (shouldn't happen in well-formed SQL)
- remaining = [i for i in range(num_queries) if i not in visited]
- if remaining:
- logger.warning(f"Circular dependencies detected for queries: {remaining}")
- groups.append(remaining)
-
- return groups
-
- def _extract_referenced_tables(self, root: exp.Expression) -> Set[str]:
- """
- Extract referenced tables from an AST node.
-
- Excludes:
- - CTE names (WITH clause aliases)
- - The target table in CREATE/INSERT statements
- """
- referenced: Set[str] = set()
- if root is None:
- return referenced
-
- cte_names = self._collect_cte_names(root)
- created_target = None
-
- if isinstance(root, exp.Create):
- created_target = self._bare(self._table_from_create(root))
-
- for table in root.find_all(exp.Table):
- bare = self._bare(table)
- if not bare:
- continue
- if bare == created_target:
- continue
- if bare in cte_names:
- continue
- referenced.add(bare)
-
- # Exclude INSERT target from dependencies
- if isinstance(root, exp.Insert):
- target = getattr(root, "this", None)
- if isinstance(target, exp.Table):
- target_bare = self._bare(target)
- if target_bare:
- referenced.discard(target_bare)
-
- return referenced
-
- def _collect_cte_names(self, root: exp.Expression) -> Set[str]:
- """Collect CTE names from WITH clause for exclusion."""
- names: Set[str] = set()
- with_clause = root.args.get("with")
- if isinstance(with_clause, exp.With):
- for cte in with_clause.expressions or []:
- alias = getattr(cte, "alias", None)
- if alias:
- ident = getattr(alias, "this", None)
- if isinstance(ident, exp.Identifier):
- names.add(ident.this.lower())
- return names
-
- def _bare(self, table_exp: Optional[exp.Expression]) -> Optional[str]:
- """Extract bare table name (lowercase) from expression."""
- if table_exp is None:
- return None
- if isinstance(table_exp, exp.Table):
- name = table_exp.name or ""
- return name.strip('`"').lower() or None
- if hasattr(table_exp, "name"):
- name = table_exp.name
- if isinstance(name, str):
- return name.strip('`"').lower() or None
- if hasattr(table_exp, "this") and isinstance(table_exp.this, exp.Table):
- name = table_exp.this.name or ""
- return name.strip('`"').lower() or None
- return None
-
- def _table_from_create(self, node: exp.Create) -> Optional[exp.Table]:
- """Extract table from CREATE statement."""
- target = node.this
- if isinstance(target, exp.Table):
- return target
- if isinstance(target, exp.Schema) and isinstance(target.this, exp.Table):
- return target.this
- return None
-
- def _table_from_alter(self, node: exp.Alter) -> Optional[exp.Table]:
- """Extract table from ALTER statement."""
- target = node.this
- if isinstance(target, exp.Table):
- return target
- return None
-
- def _table_from_drop(self, node: exp.Drop) -> Optional[exp.Table]:
- """Extract table from DROP statement."""
- target = node.this
- if isinstance(target, exp.Table):
- return target
- return None
diff --git a/databricks-tools-core/databricks_tools_core/sql/sql_utils/executor.py b/databricks-tools-core/databricks_tools_core/sql/sql_utils/executor.py
deleted file mode 100644
index 73d4b420..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/sql_utils/executor.py
+++ /dev/null
@@ -1,187 +0,0 @@
-"""
-SQL Executor - Internal class for executing SQL queries on Databricks.
-"""
-
-import time
-import logging
-from typing import Any, Dict, List, Optional
-
-from databricks.sdk import WorkspaceClient
-from databricks.sdk.service.sql import StatementState
-
-from ...auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-class SQLExecutionError(Exception):
- """Exception raised when SQL execution fails.
-
- Provides detailed error messages for LLM consumption.
- """
-
-
-class SQLExecutor:
- """Execute SQL queries on Databricks SQL Warehouses."""
-
- def __init__(self, warehouse_id: str, client: Optional[WorkspaceClient] = None):
- """
- Initialize the SQL executor.
-
- Args:
- warehouse_id: SQL warehouse ID to use for queries
- client: Optional WorkspaceClient (creates new one if not provided)
-
- Raises:
- SQLExecutionError: If no warehouse ID is provided
- """
- if not warehouse_id:
- raise SQLExecutionError(
- "No SQL warehouse ID provided. "
- "Either specify a warehouse_id or let the system select one automatically."
- )
- self.warehouse_id = warehouse_id
- self.client = client or get_workspace_client()
-
- def execute(
- self,
- sql_query: str,
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- row_limit: Optional[int] = None,
- timeout: int = 180,
- query_tags: Optional[str] = None,
- ) -> List[Dict[str, Any]]:
- """
- Execute a SQL query and return results as a list of dictionaries.
-
- Args:
- sql_query: SQL query to execute
- catalog: Optional catalog context for the query
- schema: Optional schema context for the query
- row_limit: Optional maximum number of rows to return
- timeout: Timeout in seconds (default: 180)
- query_tags: Optional query tags for cost attribution and filtering.
- Format: "key:value,key2:value2" (e.g., "team:eng,cost_center:701").
- Appears in system.query.history and Query History UI.
-
- Returns:
- List of dictionaries, each representing a row with column names as keys
-
- Raises:
- SQLExecutionError: If query execution fails with detailed error message
- """
- logger.debug(f"Executing SQL query: {sql_query[:100]}...")
-
- # Build execution parameters
- exec_params = {
- "warehouse_id": self.warehouse_id,
- "statement": sql_query,
- "wait_timeout": "0s", # Immediate return, we poll manually
- }
- if catalog:
- exec_params["catalog"] = catalog
- if schema:
- exec_params["schema"] = schema
- if row_limit is not None:
- exec_params["row_limit"] = row_limit
- if query_tags:
- from databricks.sdk.service.sql import QueryTag
-
- exec_params["query_tags"] = [
- QueryTag(key=k.strip(), value=v.strip())
- for pair in query_tags.split(",")
- for k, v in [pair.split(":", 1)]
- if ":" in pair
- ]
-
- # Submit the statement
- try:
- response = self.client.statement_execution.execute_statement(**exec_params)
- except Exception as e:
- raise SQLExecutionError(
- f"Failed to submit SQL query to warehouse '{self.warehouse_id}': {str(e)}. "
- f"Check that the warehouse exists and is accessible."
- )
-
- statement_id = response.statement_id
- logger.debug(f"Statement submitted with ID: {statement_id}")
-
- # Poll for completion
- poll_interval = 2
- elapsed = 0
-
- while elapsed < timeout:
- try:
- status = self.client.statement_execution.get_statement(statement_id=statement_id)
- except Exception as e:
- raise SQLExecutionError(f"Failed to check status of statement '{statement_id}': {str(e)}")
-
- state = status.status.state
-
- if state == StatementState.SUCCEEDED:
- return self._extract_results(status)
-
- if state == StatementState.FAILED:
- error_msg = self._get_error_message(status)
- raise SQLExecutionError(
- f"SQL query failed: {error_msg}\nQuery: {sql_query[:500]}{'...' if len(sql_query) > 500 else ''}"
- )
-
- if state == StatementState.CANCELED:
- raise SQLExecutionError(f"SQL query was canceled before completion. Statement ID: {statement_id}")
-
- if state == StatementState.CLOSED:
- raise SQLExecutionError(f"SQL statement was closed unexpectedly. Statement ID: {statement_id}")
-
- # Still running, wait and poll again
- time.sleep(poll_interval)
- elapsed += poll_interval
-
- # Timeout reached - cancel the statement
- self._cancel_statement(statement_id)
- raise SQLExecutionError(
- f"SQL query timed out after {timeout} seconds and was canceled. "
- f"Consider increasing the timeout or optimizing the query. "
- f"Statement ID: {statement_id}"
- )
-
- def _extract_results(self, response) -> List[Dict[str, Any]]:
- """Extract results from a successful statement response."""
- results: List[Dict[str, Any]] = []
-
- if not response.result or not response.result.data_array:
- return results
-
- # Get column names from manifest
- columns = None
- if response.manifest and response.manifest.schema and response.manifest.schema.columns:
- columns = [col.name for col in response.manifest.schema.columns]
-
- # Convert rows to dicts
- for row in response.result.data_array:
- if columns:
- results.append(dict(zip(columns, row, strict=False)))
- else:
- # Fallback if no schema available
- results.append({"values": list(row)})
-
- return results
-
- def _get_error_message(self, response) -> str:
- """Extract error message from a failed statement response."""
- if response.status and response.status.error:
- error = response.status.error
- msg = error.message if error.message else "Unknown error"
- if error.error_code:
- msg = f"[{error.error_code}] {msg}"
- return msg
- return "Unknown error (no error details available)"
-
- def _cancel_statement(self, statement_id: str) -> None:
- """Attempt to cancel a running statement."""
- try:
- self.client.statement_execution.cancel_execution(statement_id=statement_id)
- logger.debug(f"Canceled statement {statement_id}")
- except Exception as e:
- logger.warning(f"Failed to cancel statement {statement_id}: {e}")
diff --git a/databricks-tools-core/databricks_tools_core/sql/sql_utils/models.py b/databricks-tools-core/databricks_tools_core/sql/sql_utils/models.py
deleted file mode 100644
index 2353e546..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/sql_utils/models.py
+++ /dev/null
@@ -1,247 +0,0 @@
-"""
-SQL Models - Pydantic models for table statistics and schema information.
-"""
-
-from enum import Enum
-from typing import Any, Dict, List, Optional, Union
-
-from pydantic import BaseModel
-
-
-class TableStatLevel(str, Enum):
- """Level of statistics to collect for tables."""
-
- NONE = "none" # Just describe table structure, no stats
- SIMPLE = "simple" # Basic stats: samples, cardinality, min/max, null counts
- DETAILED = "detailed" # Full stats: histograms, percentiles, stddev, value counts
-
-
-# Constants for column statistics
-NUMERIC_TYPES = ["int", "bigint", "float", "double", "decimal", "numeric"]
-TIMESTAMP_TYPES = ["timestamp", "date"]
-ID_PATTERNS = ["_id", "id_", "_uuid", "uuid_", "_key", "key_"]
-MAX_CATEGORICAL_VALUES = 30
-SAMPLE_ROW_COUNT = 10
-HISTOGRAM_BINS = 10
-
-
-class HistogramBin(BaseModel):
- """Histogram bin data."""
-
- bin_center: float
- count: int
- date_label: Optional[str] = None # For timestamp histograms
-
-
-class ColumnDetail(BaseModel):
- """Detailed information about a table column including statistics."""
-
- name: str
- data_type: str
- samples: Optional[List[Any]] = None # Up to 3 distinct sample values
- cardinality: Optional[int] = None # count distinct for string columns
- min: Optional[Union[str, float, int]] = None # for numeric and timestamp
- max: Optional[Union[str, float, int]] = None # for numeric and timestamp
- avg: Optional[float] = None # for numeric columns only
- null_count: Optional[int] = None
- total_count: Optional[int] = None
- # Enhanced statistics (DETAILED level)
- unique_count: Optional[int] = None
- mean: Optional[float] = None # alias for avg
- stddev: Optional[float] = None
- q1: Optional[float] = None # 25th percentile
- median: Optional[float] = None # 50th percentile
- q3: Optional[float] = None # 75th percentile
- min_date: Optional[str] = None # for timestamp columns
- max_date: Optional[str] = None # for timestamp columns
- histogram: Optional[List[HistogramBin]] = None # histogram data
- value_counts: Optional[Dict[str, int]] = None # value counts for categorical
-
-
-class VolumeFileInfo(BaseModel):
- """Information about a file in a volume."""
-
- name: str
- path: str
- size_bytes: Optional[int] = None
- is_directory: bool = False
- modification_time: Optional[str] = None
-
-
-def _get_basic_column_details(
- column_details: Optional[Dict[str, ColumnDetail]],
-) -> Optional[Dict[str, ColumnDetail]]:
- """Return simplified column details with basic stats only.
-
- Removes heavy stats like histograms, stddev, percentiles.
- For categorical columns with value_counts, replaces samples with value_counts.
- """
- if not column_details:
- return None
-
- basic_columns = {}
- for col_name, col_detail in column_details.items():
- basic_col = ColumnDetail(
- name=col_detail.name,
- data_type=col_detail.data_type,
- samples=col_detail.samples,
- cardinality=col_detail.cardinality,
- min=col_detail.min,
- max=col_detail.max,
- avg=col_detail.avg,
- null_count=col_detail.null_count if col_detail.null_count and col_detail.null_count > 0 else None,
- total_count=col_detail.total_count,
- unique_count=col_detail.unique_count,
- # Exclude heavy stats
- mean=None,
- stddev=None,
- q1=None,
- median=None,
- q3=None,
- min_date=None,
- max_date=None,
- histogram=None,
- value_counts=col_detail.value_counts,
- )
-
- # For categorical columns with value_counts, use those instead of samples
- if col_detail.value_counts:
- basic_col.samples = None
-
- basic_columns[col_name] = basic_col
-
- return basic_columns
-
-
-class DataSourceInfo(BaseModel):
- """Unified information about a data source (UC table or Volume folder).
-
- For UC tables: name is the table name, ddl is populated
- For Volume folders: name is the volume path, format and file info are populated
-
- When serializing, use model_dump(exclude_none=True) to omit irrelevant fields.
- """
-
- name: str # Table name or volume path
-
- # Common fields
- column_details: Optional[Dict[str, ColumnDetail]] = None
- total_rows: Optional[int] = None
- sample_data: Optional[List[Dict[str, Any]]] = None
- error: Optional[str] = None
-
- # UC Table specific fields
- comment: Optional[str] = None
- ddl: Optional[str] = None
- updated_at: Optional[int] = None # Timestamp in epoch ms from Databricks
-
- # Volume folder specific fields
- format: Optional[str] = None # "parquet", "csv", "json", "delta", "file"
- total_files: Optional[int] = None
- total_size_bytes: Optional[int] = None
- files: Optional[List["VolumeFileInfo"]] = None # Only for format="file"
-
- def get_basic_column_details(self) -> Optional[Dict[str, ColumnDetail]]:
- """Return simplified column details with basic stats only."""
- return _get_basic_column_details(self.column_details)
-
-
-# Backwards compatibility alias
-TableInfo = DataSourceInfo
-
-
-class TableSchemaResult(BaseModel):
- """Result model for data source schema information.
-
- Works for both UC tables and Volume folders.
- - For UC tables: catalog and schema_name identify the location
- - For Volume folders: catalog/schema_name can be extracted from volume path,
- or left as descriptive values
- """
-
- catalog: str
- schema_name: str
- tables: List[DataSourceInfo] # List of tables or volume folders
-
- @property
- def table_count(self) -> int:
- """Get the number of data sources in this result."""
- return len(self.tables)
-
- def keep_basic_stats(self) -> "TableSchemaResult":
- """Return a new TableSchemaResult with only basic stats preserved.
-
- Creates a lightweight version suitable for SIMPLE stat level.
- Does not mutate the original cached object.
- """
- tables_with_basic = []
- for table in self.tables:
- basic_columns = table.get_basic_column_details()
-
- table_basic = DataSourceInfo(
- name=table.name,
- comment=table.comment,
- ddl=table.ddl,
- column_details=basic_columns,
- updated_at=None, # Don't expose cache timestamp
- error=table.error,
- total_rows=table.total_rows,
- sample_data=None, # Exclude sample data for lighter payload
- # Volume fields
- format=table.format,
- total_files=table.total_files,
- total_size_bytes=table.total_size_bytes,
- files=table.files,
- )
- tables_with_basic.append(table_basic)
-
- return TableSchemaResult(
- catalog=self.catalog,
- schema_name=self.schema_name,
- tables=tables_with_basic,
- )
-
- def remove_stats(self) -> "TableSchemaResult":
- """Return a new TableSchemaResult with column statistics removed.
-
- Keeps column names and types but removes all numeric/histogram stats.
- """
- tables_no_stats = []
- for table in self.tables:
- # Strip stats from column details if they exist
- basic_columns = None
- if table.column_details:
- basic_columns = {}
- for col_name, col_detail in table.column_details.items():
- basic_columns[col_name] = ColumnDetail(
- name=col_detail.name,
- data_type=col_detail.data_type,
- )
-
- table_no_stats = DataSourceInfo(
- name=table.name,
- comment=table.comment,
- ddl=table.ddl,
- column_details=basic_columns,
- updated_at=None,
- error=table.error,
- total_rows=None,
- sample_data=None,
- # Volume fields
- format=table.format,
- total_files=table.total_files,
- total_size_bytes=table.total_size_bytes,
- files=table.files,
- )
- tables_no_stats.append(table_no_stats)
-
- return TableSchemaResult(
- catalog=self.catalog,
- schema_name=self.schema_name,
- tables=tables_no_stats,
- )
-
-
-# VolumeFolderResult is deprecated - use DataSourceInfo within TableSchemaResult instead
-# Kept as alias for backwards compatibility
-VolumeFolderResult = DataSourceInfo
diff --git a/databricks-tools-core/databricks_tools_core/sql/sql_utils/parallel_executor.py b/databricks-tools-core/databricks_tools_core/sql/sql_utils/parallel_executor.py
deleted file mode 100644
index 1c376ba3..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/sql_utils/parallel_executor.py
+++ /dev/null
@@ -1,308 +0,0 @@
-"""
-SQL Parallel Executor
-
-Executes multiple SQL queries with dependency-aware parallelism.
-Queries within a group run in parallel; groups run sequentially.
-"""
-
-import logging
-import time
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from typing import Any, Dict, List, Optional
-
-from databricks.sdk import WorkspaceClient
-
-from ...auth import get_workspace_client
-from .dependency_analyzer import SQLDependencyAnalyzer
-from .executor import SQLExecutor, SQLExecutionError
-
-logger = logging.getLogger(__name__)
-
-
-class SQLParallelExecutor:
- """Execute multiple SQL queries with dependency-aware parallelism.
-
- Analyzes query dependencies and executes them in optimal order:
- - Queries with no dependencies on each other run in parallel
- - Queries that depend on others wait for their dependencies
- - Execution stops on first error (fail-fast)
- """
-
- def __init__(
- self,
- warehouse_id: str,
- max_workers: int = 4,
- client: Optional[WorkspaceClient] = None,
- ):
- """
- Initialize the parallel executor.
-
- Args:
- warehouse_id: SQL warehouse ID to use for queries
- max_workers: Maximum parallel queries per group (default: 4)
- client: Optional WorkspaceClient (creates new one if not provided)
- """
- self.warehouse_id = warehouse_id
- self.max_workers = max_workers
- self.client = client or get_workspace_client()
- self.analyzer = SQLDependencyAnalyzer()
- self.sql_executor = SQLExecutor(warehouse_id=warehouse_id, client=self.client)
-
- def execute(
- self,
- sql_content: str,
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- timeout: int = 180,
- query_tags: Optional[str] = None,
- ) -> Dict[str, Any]:
- """
- Execute multiple SQL statements with dependency-aware parallelism.
-
- Args:
- sql_content: SQL content with multiple statements separated by ;
- catalog: Optional catalog context for queries
- schema: Optional schema context for queries
- timeout: Timeout per query in seconds (default: 180)
- query_tags: Optional query tags for cost attribution (e.g., "team:eng,cost_center:701")
-
- Returns:
- Dictionary with:
- - results: Dict mapping query index to result dict
- - execution_summary: Overall execution statistics
-
- Raises:
- SQLExecutionError: If parsing fails or no queries found
- """
- logger.info("Starting SQL parallel execution")
- start_time = time.time()
-
- # Parse and analyze
- queries = self.analyzer.parse_sql_content(sql_content)
- if not queries:
- raise SQLExecutionError(
- "No valid SQL statements found in the provided content. "
- "Check that statements are properly terminated with semicolons."
- )
-
- logger.info(f"Parsed {len(queries)} queries")
-
- execution_groups = self.analyzer.analyze_dependencies(queries)
- logger.info(f"Found {len(execution_groups)} execution groups")
-
- results: Dict[int, Dict[str, Any]] = {}
- stopped_after_group: Optional[int] = None
-
- # Execute groups sequentially
- for group_idx, group in enumerate(execution_groups):
- group_num = group_idx + 1
- logger.info(f"Executing group {group_num}/{len(execution_groups)} with {len(group)} queries")
-
- # Log query previews
- for query_idx in group:
- preview = queries[query_idx][:80].replace("\n", " ").replace(" ", " ")
- logger.debug(f" [Q{query_idx + 1}] {preview}...")
-
- # Execute queries in this group (parallel if multiple)
- group_results = self._execute_group(
- queries=queries,
- query_indices=group,
- group_num=group_num,
- catalog=catalog,
- schema=schema,
- timeout=timeout,
- query_tags=query_tags,
- )
-
- # Store results and check for errors
- error_in_group = False
- for query_idx, result in group_results.items():
- results[query_idx] = result
- if result.get("status") == "error":
- error_in_group = True
-
- # Stop on first error
- if error_in_group:
- stopped_after_group = group_num
- logger.warning(
- f"Stopping execution after group {group_num} due to errors; subsequent groups will be skipped."
- )
- break
-
- total_time = round(time.time() - start_time, 2)
- logger.info(f"SQL parallel execution completed in {total_time:.2f}s")
-
- # Build summary
- execution_summary = self._build_summary(
- execution_groups=execution_groups,
- stopped_after_group=stopped_after_group,
- total_time=total_time,
- )
-
- return {
- "results": results,
- "execution_summary": execution_summary,
- }
-
- def _execute_group(
- self,
- queries: List[str],
- query_indices: List[int],
- group_num: int,
- catalog: Optional[str],
- schema: Optional[str],
- timeout: int,
- query_tags: Optional[str] = None,
- ) -> Dict[int, Dict[str, Any]]:
- """Execute a group of queries in parallel using ThreadPoolExecutor."""
- results: Dict[int, Dict[str, Any]] = {}
- is_parallel = len(query_indices) > 1
-
- def execute_single(query_idx: int) -> Dict[str, Any]:
- query_text = queries[query_idx]
- query_preview = query_text[:100] + "..." if len(query_text) > 100 else query_text
- query_preview = query_preview.replace("\n", " ").replace(" ", " ")
-
- try:
- logger.info(f"[SQL] Executing query {query_idx + 1}: {query_preview}")
- t0 = time.time()
-
- result_data = self.sql_executor.execute(
- sql_query=query_text,
- catalog=catalog,
- schema=schema,
- timeout=timeout,
- query_tags=query_tags,
- )
-
- dt = round(time.time() - t0, 2)
- row_count = len(result_data) if result_data else 0
- logger.info(f"Query {query_idx + 1} completed ({dt}s, {row_count} rows)")
-
- return {
- "query_index": query_idx,
- "status": "success",
- "execution_time": dt,
- "query_preview": query_preview,
- "result_rows": row_count,
- "sample_results": result_data[:5] if result_data else [],
- "group_number": group_num,
- "group_size": len(query_indices),
- "is_parallel": is_parallel,
- }
-
- except Exception as e:
- error_str = str(e)
- error_category, suggestion = self._categorize_error(error_str)
-
- logger.error(f"Query {query_idx + 1} failed: {error_category} - {error_str}")
-
- return {
- "query_index": query_idx,
- "status": "error",
- "error": error_str,
- "error_category": error_category,
- "suggestion": suggestion,
- "execution_time": 0,
- "query_preview": query_preview,
- "group_number": group_num,
- "group_size": len(query_indices),
- "is_parallel": is_parallel,
- }
-
- # Use ThreadPoolExecutor for parallel execution within group
- max_workers = min(self.max_workers, len(query_indices))
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_idx = {executor.submit(execute_single, idx): idx for idx in query_indices}
-
- for future in as_completed(future_to_idx):
- query_idx = future_to_idx[future]
- try:
- results[query_idx] = future.result()
- except Exception as e:
- # Unexpected error in thread
- results[query_idx] = {
- "query_index": query_idx,
- "status": "error",
- "error": f"Unexpected execution error: {str(e)}",
- "error_category": "UNEXPECTED_ERROR",
- "suggestion": "Check the query syntax and try again",
- "execution_time": 0,
- "group_number": group_num,
- "group_size": len(query_indices),
- "is_parallel": is_parallel,
- }
-
- return results
-
- def _categorize_error(self, error_str: str) -> tuple:
- """Categorize error for better LLM understanding."""
- error_lower = error_str.lower()
-
- if "table or view not found" in error_lower or "table not found" in error_lower:
- return (
- "MISSING_TABLE",
- "Check if the table exists in the specified catalog and schema",
- )
- elif "column not found" in error_lower or "cannot resolve" in error_lower:
- return (
- "MISSING_COLUMN",
- "Verify column names and check for typos",
- )
- elif "syntax error" in error_lower or "parse error" in error_lower:
- return (
- "SYNTAX_ERROR",
- "Review SQL syntax for errors",
- )
- elif "permission denied" in error_lower or "access denied" in error_lower:
- return (
- "PERMISSION_ERROR",
- "Check user permissions on the table or catalog",
- )
- elif "timeout" in error_lower:
- return (
- "TIMEOUT_ERROR",
- "Query took too long; consider optimizing or increasing timeout",
- )
- elif "warehouse" in error_lower:
- return (
- "WAREHOUSE_ERROR",
- "Check that the SQL warehouse is running and accessible",
- )
- else:
- return (
- "GENERAL_ERROR",
- "Review the error message for details",
- )
-
- def _build_summary(
- self,
- execution_groups: List[List[int]],
- stopped_after_group: Optional[int],
- total_time: float,
- ) -> Dict[str, Any]:
- """Build execution summary with group details."""
- groups_summary: List[Dict[str, Any]] = []
-
- for group_idx, group in enumerate(execution_groups):
- group_num = group_idx + 1
- ran = stopped_after_group is None or group_num <= stopped_after_group
-
- groups_summary.append(
- {
- "group_number": group_num,
- "group_size": len(group),
- "query_indices": [i + 1 for i in group], # 1-based for display
- "is_parallel": len(group) > 1,
- "status": "executed" if ran else "skipped",
- }
- )
-
- return {
- "total_queries": sum(len(g) for g in execution_groups),
- "total_groups": len(execution_groups),
- "total_time": total_time,
- "stopped_after_group": stopped_after_group,
- "groups": groups_summary,
- }
diff --git a/databricks-tools-core/databricks_tools_core/sql/sql_utils/table_stats_collector.py b/databricks-tools-core/databricks_tools_core/sql/sql_utils/table_stats_collector.py
deleted file mode 100644
index 7ed5de26..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/sql_utils/table_stats_collector.py
+++ /dev/null
@@ -1,787 +0,0 @@
-"""
-Table Stats Collector - Collects column statistics for tables with caching.
-"""
-
-import fnmatch
-import json
-import logging
-import threading
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from typing import Any, Dict, List, Optional, Tuple
-
-from databricks.sdk import WorkspaceClient
-
-from ...auth import get_workspace_client
-from .executor import SQLExecutor
-from .models import (
- ColumnDetail,
- HistogramBin,
- TableInfo,
- HISTOGRAM_BINS,
- ID_PATTERNS,
- MAX_CATEGORICAL_VALUES,
- NUMERIC_TYPES,
- SAMPLE_ROW_COUNT,
-)
-
-logger = logging.getLogger(__name__)
-
-# Module-level cache for table information
-# Key: "catalog.schema", Value: Dict[table_name, (updated_at_timestamp, TableInfo)]
-_table_cache: Dict[str, Dict[str, Tuple[Optional[int], TableInfo]]] = {}
-_table_cache_lock = threading.Lock()
-
-# Locks per table to prevent duplicate fetches
-# Key: "catalog.schema.table", Value: threading.Lock
-_table_locks: Dict[str, threading.Lock] = {}
-_table_locks_creation_lock = threading.Lock()
-
-
-def _get_table_lock(full_table_name: str) -> threading.Lock:
- """Get or create a lock for a specific table."""
- if full_table_name not in _table_locks:
- with _table_locks_creation_lock:
- if full_table_name not in _table_locks:
- _table_locks[full_table_name] = threading.Lock()
- return _table_locks[full_table_name]
-
-
-def _get_schema_cache(catalog: str, schema: str) -> Dict[str, Tuple[Optional[int], TableInfo]]:
- """Get or create cache for a schema."""
- schema_key = f"{catalog}.{schema}"
- if schema_key not in _table_cache:
- with _table_cache_lock:
- if schema_key not in _table_cache:
- _table_cache[schema_key] = {}
- return _table_cache[schema_key]
-
-
-def _check_cache(catalog: str, schema: str, table_name: str, updated_at_ms: Optional[int]) -> Optional[TableInfo]:
- """Check if table info is in cache and still valid."""
- schema_cache = _get_schema_cache(catalog, schema)
- if table_name in schema_cache:
- cached_updated_at, cached_info = schema_cache[table_name]
- if updated_at_ms == cached_updated_at:
- return cached_info
- return None
-
-
-def _update_cache(
- catalog: str, schema: str, table_name: str, updated_at_ms: Optional[int], table_info: TableInfo
-) -> None:
- """Update cache with table info."""
- schema_cache = _get_schema_cache(catalog, schema)
- with _table_cache_lock:
- schema_cache[table_name] = (updated_at_ms, table_info)
-
-
-class TableStatsCollector:
- """Collects table statistics with caching support."""
-
- def __init__(
- self,
- warehouse_id: str,
- max_workers: int = 10,
- client: Optional[WorkspaceClient] = None,
- ):
- """
- Initialize the stats collector.
-
- Args:
- warehouse_id: SQL warehouse ID to use for queries
- max_workers: Maximum parallel table fetches (default: 10)
- client: Optional WorkspaceClient (creates new one if not provided)
- """
- self.warehouse_id = warehouse_id
- self.max_workers = max_workers
- self.client = client or get_workspace_client()
- self.executor = SQLExecutor(warehouse_id=warehouse_id, client=self.client)
-
- def list_tables(self, catalog: str, schema: str) -> List[Dict[str, Any]]:
- """
- List all tables in a schema.
-
- Returns:
- List of table info dicts with 'name' and 'updated_at' keys
- """
- try:
- tables = list(
- self.client.tables.list(
- catalog_name=catalog,
- schema_name=schema,
- )
- )
- return [
- {
- "name": t.name,
- "updated_at": getattr(t, "updated_at", None),
- "comment": getattr(t, "comment", None),
- }
- for t in tables
- if t.name
- ]
- except Exception as e:
- raise Exception(
- f"Failed to list tables in {catalog}.{schema}: {str(e)}. "
- f"Check that the catalog and schema exist and you have access."
- )
-
- def filter_tables_by_patterns(self, tables: List[Dict[str, Any]], patterns: List[str]) -> List[Dict[str, Any]]:
- """
- Filter tables by glob patterns.
-
- Args:
- tables: List of table info dicts
- patterns: List of glob patterns (e.g., ["raw_*", "silver_customers"])
-
- Returns:
- Filtered list of tables matching any pattern
- """
- if not patterns:
- return tables
-
- filtered = []
- for table in tables:
- table_name = table["name"]
- for pattern in patterns:
- if fnmatch.fnmatch(table_name, pattern):
- filtered.append(table)
- break
- return filtered
-
- def get_table_ddl(self, catalog: str, schema: str, table_name: str) -> str:
- """Get the DDL (CREATE TABLE statement) for a table."""
- full_table_name = f"`{catalog}`.`{schema}`.`{table_name}`"
- query = f"SHOW CREATE TABLE {full_table_name}"
-
- try:
- result = self.executor.execute(
- sql_query=query,
- catalog=catalog,
- schema=schema,
- timeout=45,
- )
-
- if not result:
- return ""
-
- row = result[0]
- create_statement = ""
-
- # Try different possible column names
- for col_name in [
- "createtab_stmt",
- "create_statement",
- "Create Table",
- "create_table_statement",
- ]:
- if col_name in row:
- create_statement = row[col_name]
- break
-
- if not create_statement:
- # Use first string value containing CREATE
- for value in row.values():
- if isinstance(value, str) and "CREATE" in value.upper():
- create_statement = value
- break
-
- # Clean up the CREATE statement
- if create_statement:
- # Remove TBLPROPERTIES section
- tblprops_pos = create_statement.upper().find("TBLPROPERTIES")
- if tblprops_pos != -1:
- create_statement = create_statement[:tblprops_pos].strip()
- if create_statement.endswith(","):
- create_statement = create_statement[:-1]
- if not create_statement.endswith(")"):
- create_statement += ")"
-
- # Remove USING delta section
- using_pos = create_statement.upper().find("USING ")
- if using_pos != -1:
- create_statement = create_statement[:using_pos].strip()
-
- # Clean formatting
- create_statement = create_statement.replace("\n", " ").replace(" ", " ").strip()
-
- return create_statement
-
- except Exception as e:
- logger.warning(f"Failed to get DDL for {full_table_name}: {e}")
- return ""
-
- def collect_column_stats(
- self, catalog: str, schema: str, table_name: str
- ) -> Tuple[Dict[str, ColumnDetail], Optional[int], List[Dict[str, Any]]]:
- """
- Collect enhanced column statistics for a UC table.
-
- Args:
- catalog: Catalog name
- schema: Schema name
- table_name: Table name
-
- Returns:
- Tuple of (column_details dict, total_rows, sample_data)
- """
- full_table_name = f"`{catalog}`.`{schema}`.`{table_name}`"
- return self._collect_stats_for_ref(
- table_ref=full_table_name,
- catalog=catalog,
- schema=schema,
- use_describe_table=True,
- fetch_value_counts_table=f"{catalog}.{schema}.{table_name}",
- )
-
- def collect_volume_stats(
- self, volume_path: str, format: str
- ) -> Tuple[Dict[str, ColumnDetail], Optional[int], List[Dict[str, Any]]]:
- """
- Collect enhanced column statistics for volume folder data.
-
- Args:
- volume_path: Full volume path (e.g., /Volumes/catalog/schema/volume/path)
- format: Data format (parquet, csv, json, delta)
-
- Returns:
- Tuple of (column_details dict, total_rows, sample_data)
- """
- table_ref = f"read_files('{volume_path}', format => '{format}')"
- return self._collect_stats_for_ref(
- table_ref=table_ref,
- catalog=None,
- schema=None,
- use_describe_table=False,
- fetch_value_counts_table=None,
- )
-
- def _describe_columns(self, catalog: str, schema: str, table_name: str) -> Dict[str, ColumnDetail]:
- """
- Return column names and types for a UC table without collecting statistics.
-
- Used by get_table_info when stat level is NONE.
- """
- full_table_name = f"`{catalog}`.`{schema}`.`{table_name}`"
- try:
- describe_result = self.executor.execute(
- sql_query=f"DESCRIBE TABLE {full_table_name}",
- catalog=catalog,
- schema=schema,
- timeout=45,
- )
- column_details: Dict[str, ColumnDetail] = {}
- for col in describe_result or []:
- col_name = col.get("col_name")
- data_type = col.get("data_type", "string").lower()
- if not col_name or col_name.startswith("#") or col_name == "":
- continue
- column_details[col_name] = ColumnDetail(name=col_name, data_type=data_type)
- return column_details
- except Exception as e:
- logger.warning(f"Failed to describe columns for {full_table_name}: {e}")
- return {}
-
- def _collect_stats_for_ref(
- self,
- table_ref: str,
- catalog: Optional[str],
- schema: Optional[str],
- use_describe_table: bool,
- fetch_value_counts_table: Optional[str],
- ) -> Tuple[Dict[str, ColumnDetail], Optional[int], List[Dict[str, Any]]]:
- """
- Internal method to collect column statistics for any table reference.
-
- Args:
- table_ref: SQL table reference (UC table name or read_files(...))
- catalog: Catalog for query context (optional)
- schema: Schema for query context (optional)
- use_describe_table: If True, use DESCRIBE TABLE; otherwise infer from SELECT
- fetch_value_counts_table: Table name for value counts query (None to skip)
-
- Returns:
- Tuple of (column_details dict, total_rows, sample_data)
- """
- try:
- # Step 1: Get column information
- if use_describe_table:
- describe_result = self.executor.execute(
- sql_query=f"DESCRIBE TABLE {table_ref}",
- catalog=catalog,
- schema=schema,
- timeout=45,
- )
- if not describe_result:
- return {}, 0, []
- else:
- # For read_files, infer schema from a sample query
- sample_query = f"SELECT * FROM {table_ref} LIMIT 1"
- sample_row = self.executor.execute(sql_query=sample_query, timeout=60)
- if not sample_row:
- return {}, 0, []
- # Build describe_result-like structure from first row
- describe_result = []
- for col_name, value in sample_row[0].items():
- if col_name == "_rescued_data":
- continue
- # Infer type from Python value
- if isinstance(value, bool):
- data_type = "boolean"
- elif isinstance(value, int):
- data_type = "bigint"
- elif isinstance(value, float):
- data_type = "double"
- else:
- data_type = "string"
- describe_result.append({"col_name": col_name, "data_type": data_type})
-
- # Map describe columns to ColumnDetail
- column_details = {}
- for col in describe_result:
- col_name = col.get("col_name")
- data_type = col.get("data_type", "string").lower()
- if not col_name:
- continue
- # Handle empty rows/comments in DESCRIBE
- if col_name.startswith("#") or col_name == "":
- continue
- column_details[col_name] = ColumnDetail(name=col_name, data_type=data_type)
-
- # Step 2: Get row count
- count_result = self.executor.execute(
- sql_query=f"SELECT COUNT(*) as total_rows FROM {table_ref}",
- catalog=catalog,
- schema=schema,
- timeout=45,
- )
- total_rows = count_result[0]["total_rows"] if count_result else 0
-
- # Step 3: Build unified stats query
- union_queries = []
- column_types: Dict[str, str] = {}
- columns_needing_value_counts: List[Tuple[str, str]] = []
-
- base_ref = "base"
- base_cte = f"WITH {base_ref} AS (SELECT * FROM {table_ref})\n"
-
- for col_info in describe_result:
- col_name = col_info.get("col_name", "")
- if not col_name or col_name.startswith(("#", "_")):
- continue
-
- data_type = col_info.get("data_type", "").lower()
- escaped_col = f"`{col_name}`"
-
- # Determine column type for building query
- is_numeric = any(t in data_type for t in NUMERIC_TYPES)
- is_timestamp = "timestamp" in data_type
- is_array = "array" in data_type
- is_struct_or_map = "struct" in data_type or "map" in data_type or "variant" in data_type
- is_boolean = "boolean" in data_type
- is_id = any(p in col_name.lower() for p in ID_PATTERNS) and (
- "bigint" in data_type or "string" in data_type
- )
-
- if is_array or is_struct_or_map:
- col_type = "complex"
- elif is_timestamp:
- col_type = "timestamp"
- elif is_boolean:
- col_type = "boolean"
- elif is_id:
- col_type = "id"
- elif is_numeric:
- col_type = "numeric"
- else:
- col_type = "categorical"
-
- column_types[col_name] = col_type
-
- # Build query based on type
- query = self._build_column_stats_query(col_name, escaped_col, data_type, col_type, base_ref)
- union_queries.append(query)
-
- if col_type == "boolean":
- columns_needing_value_counts.append((col_name, "boolean"))
-
- if not union_queries:
- return column_details, total_rows, []
-
- # Execute combined stats query
- combined_query = base_cte + "\nUNION ALL\n".join(union_queries)
- stats_result = self.executor.execute(
- sql_query=combined_query,
- catalog=catalog,
- schema=schema,
- timeout=60,
- )
- # Step 6: Parse stats results (updates column_details in-place)
- self._parse_stats_results(stats_result, column_types, column_details)
-
- # Step 4: Get sample data
- sample_result = []
- try:
- sample_result = self.executor.execute(
- sql_query=f"SELECT * FROM {table_ref} LIMIT {SAMPLE_ROW_COUNT}",
- catalog=catalog,
- schema=schema,
- timeout=45,
- )
- # Filter out _rescued_data column from samples
- if sample_result:
- sample_result = [{k: v for k, v in row.items() if k != "_rescued_data"} for row in sample_result]
- except Exception as e:
- logger.warning(f"Failed to get sample data for {table_ref}: {e}")
-
- # Step 5: Build column samples from sample data
- column_samples = self._extract_column_samples(describe_result, sample_result)
-
- # Update column details with samples
- for col_name, samples in column_samples.items():
- if col_name in column_details:
- column_details[col_name].samples = samples
-
- # Step 7: Get value counts for categorical columns (only for UC tables)
- if fetch_value_counts_table:
- columns_needing_value_counts = []
- for col_name, detail in column_details.items():
- if column_types.get(col_name) == "categorical":
- approx_unique = detail.unique_count or 0
- if 0 < approx_unique < MAX_CATEGORICAL_VALUES:
- columns_needing_value_counts.append((col_name, "categorical"))
-
- # Parse catalog.schema.table from fetch_value_counts_table
- parts = fetch_value_counts_table.split(".")
- if len(parts) == 3:
- self._fetch_value_counts(parts[0], parts[1], parts[2], columns_needing_value_counts, column_details)
-
- return column_details, total_rows, sample_result or []
-
- except Exception as e:
- logger.warning(f"Failed to collect stats for {table_ref}: {e}")
- return {}, 0, []
-
- def _build_column_stats_query(
- self, col_name: str, escaped_col: str, data_type: str, col_type: str, base_ref: str
- ) -> str:
- """Build stats query for a column based on its type."""
- if col_type == "complex":
- return f"""
- SELECT
- '{col_name}' AS column_name,
- '{data_type}' AS data_type,
- COUNT(*) AS total_count,
- SUM(CASE WHEN {escaped_col} IS NULL THEN 1 ELSE 0 END) AS null_count,
- NULL AS unique_count,
- NULL AS min_val, NULL AS max_val,
- NULL AS mean_val, NULL AS stddev_val,
- NULL AS q1_val, NULL AS median_val, NULL AS q3_val,
- NULL AS histogram_data
- FROM {base_ref}
- """
- elif col_type == "numeric":
- return f"""
- SELECT
- '{col_name}' AS column_name,
- '{data_type}' AS data_type,
- COUNT(*) AS total_count,
- SUM(CASE WHEN {escaped_col} IS NULL THEN 1 ELSE 0 END) AS null_count,
- approx_count_distinct({escaped_col}) AS unique_count,
- CAST(MIN({escaped_col}) AS STRING) AS min_val,
- CAST(MAX({escaped_col}) AS STRING) AS max_val,
- CAST(AVG({escaped_col}) AS STRING) AS mean_val,
- CAST(STDDEV({escaped_col}) AS STRING) AS stddev_val,
- CAST(approx_percentile({escaped_col}, 0.25) AS STRING) AS q1_val,
- CAST(approx_percentile({escaped_col}, 0.5) AS STRING) AS median_val,
- CAST(approx_percentile({escaped_col}, 0.75) AS STRING) AS q3_val,
- to_json(histogram_numeric({escaped_col}, {HISTOGRAM_BINS})) AS histogram_data
- FROM {base_ref}
- """
- elif col_type == "timestamp":
- return f"""
- SELECT
- '{col_name}' AS column_name,
- '{data_type}' AS data_type,
- COUNT(*) AS total_count,
- SUM(CASE WHEN {escaped_col} IS NULL THEN 1 ELSE 0 END) AS null_count,
- approx_count_distinct({escaped_col}) AS unique_count,
- CAST(MIN({escaped_col}) AS STRING) AS min_val,
- CAST(MAX({escaped_col}) AS STRING) AS max_val,
- NULL AS mean_val, NULL AS stddev_val,
- NULL AS q1_val, NULL AS median_val, NULL AS q3_val,
- to_json(histogram_numeric(unix_timestamp({escaped_col}), {HISTOGRAM_BINS})) AS histogram_data
- FROM {base_ref}
- """
- else:
- # boolean, id, categorical, date
- return f"""
- SELECT
- '{col_name}' AS column_name,
- '{data_type}' AS data_type,
- COUNT(*) AS total_count,
- SUM(CASE WHEN {escaped_col} IS NULL THEN 1 ELSE 0 END) AS null_count,
- approx_count_distinct({escaped_col}) AS unique_count,
- {"CAST(MIN(" + escaped_col + ") AS STRING)" if col_type == "date" else "NULL"} AS min_val,
- {"CAST(MAX(" + escaped_col + ") AS STRING)" if col_type == "date" else "NULL"} AS max_val,
- NULL AS mean_val, NULL AS stddev_val,
- NULL AS q1_val, NULL AS median_val, NULL AS q3_val,
- NULL AS histogram_data
- FROM {base_ref}
- """
-
- def _extract_column_samples(
- self, columns_info: List[Dict], sample_data: Optional[List[Dict]]
- ) -> Dict[str, List[str]]:
- """Extract sample values for each column."""
- column_samples: Dict[str, List[str]] = {}
- if not sample_data:
- return column_samples
-
- for col_info in columns_info:
- col_name = col_info.get("col_name", "")
- if not col_name or col_name.startswith(("#", "_")):
- continue
-
- seen = set()
- samples = []
- for row in sample_data:
- if col_name in row and row[col_name] is not None:
- val_str = str(row[col_name])
- if len(val_str) > 15:
- val_str = val_str[:15] + "..."
- if val_str not in seen:
- seen.add(val_str)
- samples.append(val_str)
- if len(samples) >= 3:
- break
- column_samples[col_name] = samples
-
- return column_samples
-
- def _parse_stats_results(
- self,
- stats_result: List[Dict],
- column_types: Dict[str, str],
- column_details: Dict[str, ColumnDetail],
- ) -> None:
- """Parse stats query results into existing ColumnDetail objects."""
- for row in stats_result:
- col_name = row.get("column_name")
- if not col_name or col_name not in column_details:
- continue
-
- detail = column_details[col_name]
- col_type = column_types.get(col_name, "categorical")
- approx_unique = int(row.get("unique_count") or 0) if row.get("unique_count") is not None else None
-
- # Update base stats
- detail.null_count = int(row.get("null_count") or 0) if row.get("null_count") is not None else 0
- detail.unique_count = approx_unique
-
- # Parse histogram if present
- histogram_bins = None
- if row.get("histogram_data"):
- try:
- hist_str = row.get("histogram_data")
- if hist_str and hist_str != "null":
- hist_data = json.loads(hist_str)
- if isinstance(hist_data, list) and hist_data:
- histogram_bins = [
- HistogramBin(
- bin_center=float(item.get("x") or 0),
- count=int(item.get("y") or 0),
- )
- for item in hist_data
- ]
- except Exception as e:
- logger.debug(f"Failed to parse histogram for {col_name}: {e}")
-
- detail.histogram = histogram_bins
-
- # Update numeric/timestamp/etc based on type and stats row
- if col_type == "numeric":
- detail.total_count = int(row.get("total_count") or 0)
- detail.min = float(row["min_val"]) if row.get("min_val") else None
- detail.max = float(row["max_val"]) if row.get("max_val") else None
- detail.avg = float(row["mean_val"]) if row.get("mean_val") else None
- detail.mean = float(row["mean_val"]) if row.get("mean_val") else None
- detail.stddev = float(row["stddev_val"]) if row.get("stddev_val") else None
- detail.q1 = float(row["q1_val"]) if row.get("q1_val") else None
- detail.median = float(row["median_val"]) if row.get("median_val") else None
- detail.q3 = float(row["q3_val"]) if row.get("q3_val") else None
- elif col_type == "timestamp":
- detail.total_count = int(row.get("total_count") or 0)
- detail.min_date = str(row["min_val"]) if row.get("min_val") else None
- detail.max_date = str(row["max_val"]) if row.get("max_val") else None
- else:
- detail.total_count = int(row.get("total_count") or 0)
- detail.min = str(row["min_val"]) if row.get("min_val") else None
- detail.max = str(row["max_val"]) if row.get("max_val") else None
-
- def _fetch_value_counts(
- self,
- catalog: str,
- schema: str,
- table_name: str,
- columns: List[Tuple[str, str]],
- column_details: Dict[str, ColumnDetail],
- ) -> None:
- """Fetch exact value counts for small-cardinality columns."""
- # Ensure table name is properly quoted for SQL
- quoted_table = table_name if table_name.startswith("`") else f"`{table_name}`"
- full_table_name = f"`{catalog}`.`{schema}`.{quoted_table}"
-
- for col_name, _col_type in columns:
- if col_name not in column_details:
- continue
-
- escaped_col = f"`{col_name}`"
- query = f"""
- SELECT {escaped_col} AS value, COUNT(*) AS count
- FROM {full_table_name}
- WHERE {escaped_col} IS NOT NULL
- GROUP BY {escaped_col}
- ORDER BY COUNT(*) DESC
- """
-
- try:
- result = self.executor.execute(
- sql_query=query,
- catalog=catalog,
- schema=schema,
- timeout=45,
- )
- if result:
- actual_count = len(result)
- if actual_count <= MAX_CATEGORICAL_VALUES:
- value_counts = {str(row.get("value", "")): int(row.get("count") or 0) for row in result}
- column_details[col_name].value_counts = value_counts
- column_details[col_name].unique_count = actual_count
- except Exception as e:
- logger.debug(f"Failed to get value counts for {col_name}: {e}")
-
- def get_table_info(
- self,
- catalog: str,
- schema: str,
- table_name: str,
- updated_at_ms: Optional[int],
- comment: Optional[str],
- collect_stats: bool = True,
- ) -> TableInfo:
- """
- Get complete info for a single table.
-
- Args:
- catalog: Catalog name
- schema: Schema name
- table_name: Table name
- updated_at_ms: Table's updated_at timestamp (for cache validation)
- comment: Table comment
- collect_stats: Whether to collect column statistics
-
- Returns:
- TableInfo with DDL and optionally column stats
- """
- full_table_name = f"{catalog}.{schema}.{table_name}"
- table_lock = _get_table_lock(full_table_name)
-
- with table_lock:
- # Check cache (only if collecting stats)
- if collect_stats:
- cached = _check_cache(catalog, schema, table_name, updated_at_ms)
- if cached:
- logger.debug(f"Using cached info for {full_table_name}")
- return cached
-
- # Get DDL
- ddl = self.get_table_ddl(catalog, schema, table_name)
-
- # Collect schema and stats (or schema-only if collect_stats=False)
- column_details = None
- total_rows = None
- sample_data = None
-
- try:
- if collect_stats:
- column_details, total_rows, sample_data = self.collect_column_stats(catalog, schema, table_name)
- else:
- column_details = self._describe_columns(catalog, schema, table_name)
- except Exception as e:
- logger.warning(f"Failed to collect stats for {full_table_name}: {e}")
-
- table_info = TableInfo(
- name=full_table_name,
- comment=comment,
- ddl=ddl,
- column_details=column_details,
- updated_at=updated_at_ms,
- total_rows=total_rows,
- sample_data=sample_data,
- )
-
- # Update cache (only if we collected stats)
- if collect_stats and not table_info.error:
- _update_cache(catalog, schema, table_name, updated_at_ms, table_info)
-
- return table_info
-
- def get_tables_info_parallel(
- self,
- catalog: str,
- schema: str,
- tables: List[Dict[str, Any]],
- collect_stats: bool = True,
- ) -> List[TableInfo]:
- """
- Get info for multiple tables in parallel.
-
- Args:
- catalog: Catalog name
- schema: Schema name
- tables: List of table info dicts with 'name', 'updated_at', 'comment'
- collect_stats: Whether to collect column statistics
-
- Returns:
- List of TableInfo objects
- """
- if not tables:
- return []
-
- def process_table(table_info: Dict) -> TableInfo:
- table_name = table_info["name"]
- updated_at = table_info.get("updated_at")
- updated_at_ms = int(updated_at) if updated_at else None
- comment = table_info.get("comment")
-
- try:
- return self.get_table_info(catalog, schema, table_name, updated_at_ms, comment, collect_stats)
- except Exception as e:
- logger.error(f"Failed to get info for {catalog}.{schema}.{table_name}: {e}")
- return TableInfo(
- name=f"{catalog}.{schema}.{table_name}",
- ddl="",
- error=str(e),
- )
-
- results = []
- max_workers = min(self.max_workers, len(tables))
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_table = {executor.submit(process_table, t): t for t in tables}
- for future in as_completed(future_to_table):
- try:
- results.append(future.result())
- except Exception as e:
- table = future_to_table[future]
- logger.error(f"Unexpected error for {table['name']}: {e}")
- results.append(
- TableInfo(
- name=f"{catalog}.{schema}.{table['name']}",
- ddl="",
- error=str(e),
- )
- )
-
- return results
diff --git a/databricks-tools-core/databricks_tools_core/sql/table_stats.py b/databricks-tools-core/databricks_tools_core/sql/table_stats.py
deleted file mode 100644
index 68526767..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/table_stats.py
+++ /dev/null
@@ -1,474 +0,0 @@
-"""
-Table Statistics
-
-High-level functions for getting table details and statistics.
-Supports both Unity Catalog tables and Volume folder data.
-"""
-
-import logging
-from typing import List, Literal, Optional
-
-from .sql_utils.models import (
- ColumnDetail,
- DataSourceInfo,
- TableSchemaResult,
- TableStatLevel,
- VolumeFileInfo,
-)
-from .sql_utils.table_stats_collector import TableStatsCollector
-from .warehouse import get_best_warehouse
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def _has_glob_pattern(name: str) -> bool:
- """Check if a name contains glob pattern characters."""
- return any(c in name for c in ["*", "?", "[", "]"])
-
-
-def get_table_stats_and_schema(
- catalog: str,
- schema: str,
- table_names: Optional[List[str]] = None,
- table_stat_level: TableStatLevel = TableStatLevel.SIMPLE,
- warehouse_id: Optional[str] = None,
-) -> TableSchemaResult:
- """
- Get schema and statistics for tables in a Unity Catalog schema.
-
- Returns column names, data types, row counts, and optionally detailed
- column-level statistics (cardinality, min/max, histograms, percentiles).
-
- Supports three modes based on table_names:
- 1. Empty list or None: List all tables in the schema
- 2. Names with glob patterns (*, ?, []): List tables and filter by pattern
- 3. Exact names: Get tables directly without listing (faster)
-
- Args:
- catalog: Catalog name
- schema: Schema name
- table_names: Optional list of table names or glob patterns.
- Examples:
- - None or []: Get all tables
- - ["customers", "orders"]: Get specific tables
- - ["raw_*"]: Get all tables starting with "raw_"
- - ["*_customers", "orders"]: Mix of patterns and exact names
- table_stat_level: Level of statistics to collect:
- - NONE: Schema only (column names, types) - fast, no stats
- - SIMPLE: Schema + row count, basic column info (default)
- - DETAILED: Full stats including cardinality, min/max,
- null counts, histograms, and percentiles per column
- warehouse_id: Optional warehouse ID. If not provided, auto-selects one.
-
- Returns:
- TableSchemaResult containing table information with requested stat level
-
- Raises:
- Exception: If warehouse not available or catalog/schema doesn't exist
-
- Examples:
- >>> # Get all tables with basic stats
- >>> result = get_table_stats_and_schema("my_catalog", "my_schema")
-
- >>> # Get specific tables
- >>> result = get_table_stats_and_schema("my_catalog", "my_schema", ["customers", "orders"])
-
- >>> # Get tables matching pattern with full stats (histograms, percentiles)
- >>> result = get_table_stats_and_schema(
- ... "my_catalog", "my_schema",
- ... ["gold_*"],
- ... table_stat_level=TableStatLevel.DETAILED
- ... )
-
- >>> # Quick schema-only lookup (no stats)
- >>> result = get_table_stats_and_schema(
- ... "my_catalog", "my_schema",
- ... ["my_table"],
- ... table_stat_level=TableStatLevel.NONE
- ... )
- """
- # Auto-select warehouse if not provided
- if not warehouse_id:
- logger.debug("No warehouse_id provided, selecting best available warehouse")
- warehouse_id = get_best_warehouse()
- if not warehouse_id:
- raise Exception(
- "No SQL warehouse available in the workspace. "
- "Please create a SQL warehouse or start an existing one, "
- "or provide a specific warehouse_id."
- )
- logger.debug(f"Auto-selected warehouse: {warehouse_id}")
-
- collector = TableStatsCollector(warehouse_id=warehouse_id)
-
- # Determine if we need to list tables
- table_names = table_names or []
- has_patterns = any(_has_glob_pattern(name) for name in table_names)
- needs_listing = len(table_names) == 0 or has_patterns
- failed_tables: List[DataSourceInfo] = []
-
- if needs_listing:
- # List all tables first
- logger.debug(f"Listing tables in {catalog}.{schema}")
- all_tables = collector.list_tables(catalog, schema)
-
- if table_names:
- # Filter by patterns
- tables_to_fetch = collector.filter_tables_by_patterns(all_tables, table_names)
- logger.debug(
- f"Filtered {len(all_tables)} tables to {len(tables_to_fetch)} matching patterns: {table_names}"
- )
- else:
- tables_to_fetch = all_tables
- logger.debug(f"Found {len(tables_to_fetch)} tables")
- else:
- # Direct lookup - build table info without listing
- logger.debug(f"Direct lookup for tables: {table_names}")
- tables_to_fetch = []
- for name in table_names:
- try:
- # Fetch metadata via SDK to get the comment and updated_at
- t = collector.client.tables.get(f"{catalog}.{schema}.{name}")
- tables_to_fetch.append(
- {
- "name": t.name,
- "updated_at": getattr(t, "updated_at", None),
- "comment": getattr(t, "comment", None),
- }
- )
- except Exception as e:
- logger.warning(f"Failed to fetch metadata for {catalog}.{schema}.{name}: {e}")
- failed_tables.append(
- DataSourceInfo(
- name=f"{catalog}.{schema}.{name}",
- error=f"Failed to fetch table metadata: {e}",
- )
- )
-
- if not tables_to_fetch and not failed_tables:
- return TableSchemaResult(catalog=catalog, schema_name=schema, tables=[])
-
- # Determine whether to collect stats
- collect_stats = table_stat_level != TableStatLevel.NONE
-
- # Fetch table info (with or without stats)
- logger.info(f"Fetching {len(tables_to_fetch)} tables with stat_level={table_stat_level.value}")
- table_infos = collector.get_tables_info_parallel(
- catalog=catalog,
- schema=schema,
- tables=tables_to_fetch,
- collect_stats=collect_stats,
- )
-
- # Append any tables that failed metadata lookup with their error info
- if failed_tables:
- table_infos.extend(failed_tables)
-
- # Build result
- result = TableSchemaResult(
- catalog=catalog,
- schema_name=schema,
- tables=table_infos,
- )
-
- # Apply stat level transformation
- if table_stat_level == TableStatLevel.SIMPLE:
- return result.keep_basic_stats()
- elif table_stat_level == TableStatLevel.NONE:
- return result.remove_stats()
- else:
- # DETAILED - return everything
- return result
-
-
-def _parse_volume_path(volume_path: str) -> str:
- """
- Parse volume path and return the full /Volumes/... path.
-
- Accepts:
- - catalog/schema/volume/path
- - /Volumes/catalog/schema/volume/path
-
- Returns:
- Full path in /Volumes/catalog/schema/volume/path format
- """
- path = volume_path.strip("/")
- if path.lower().startswith("volumes/"):
- return f"/{path}"
- return f"/Volumes/{path}"
-
-
-def _list_volume_files(volume_path: str) -> tuple[List[VolumeFileInfo], int, Optional[str]]:
- """
- List files in a volume folder using the Files API.
-
- Returns:
- Tuple of (files_list, total_size_bytes, error_message)
- """
- w = get_workspace_client()
- files = []
- total_size = 0
-
- try:
- for entry in w.files.list_directory_contents(volume_path):
- file_info = VolumeFileInfo(
- name=entry.name,
- path=entry.path,
- size_bytes=getattr(entry, "file_size", None),
- is_directory=entry.is_directory,
- modification_time=str(getattr(entry, "last_modified", None))
- if hasattr(entry, "last_modified")
- else None,
- )
- files.append(file_info)
- if file_info.size_bytes:
- total_size += file_info.size_bytes
-
- return files, total_size, None
-
- except Exception as e:
- error_msg = str(e)
- if "NOT_FOUND" in error_msg or "404" in error_msg:
- return (
- [],
- 0,
- f"Volume path not found: {volume_path}. Check that the catalog, schema, volume, and path exist.",
- )
- return [], 0, f"Failed to list volume path: {volume_path}. Error: {error_msg}"
-
-
-def _extract_catalog_schema_from_volume_path(volume_path: str) -> tuple[str, str]:
- """Extract catalog and schema from a volume path like /Volumes/catalog/schema/volume/..."""
- parts = volume_path.strip("/").split("/")
- if parts[0].lower() == "volumes" and len(parts) >= 3:
- return parts[1], parts[2]
- elif len(parts) >= 2:
- return parts[0], parts[1]
- return "volumes", "data"
-
-
-def get_volume_folder_details(
- volume_path: str,
- format: Literal["parquet", "csv", "json", "delta", "file"] = "parquet",
- table_stat_level: TableStatLevel = TableStatLevel.SIMPLE,
- warehouse_id: Optional[str] = None,
-) -> TableSchemaResult:
- """
- Get detailed information about data files in a Databricks Volume folder.
-
- Similar to get_table_stats_and_schema but for raw files stored in Volumes.
- Uses SQL warehouse to read volume data via read_files() function.
-
- Args:
- volume_path: Path to the volume folder. Can be:
- - "catalog/schema/volume/path" (e.g., "ai_dev_kit/demo/raw_data/customers")
- - "/Volumes/catalog/schema/volume/path"
- format: Data format:
- - "parquet", "csv", "json", "delta": Read data and compute stats
- - "file": Just list files without reading data (fast)
- table_stat_level: Level of statistics to collect:
- - NONE: Just schema, no stats
- - SIMPLE: Basic stats (default)
- - DETAILED: Full stats including samples
- warehouse_id: Optional warehouse ID. If not provided, auto-selects one.
-
- Returns:
- TableSchemaResult with a single DataSourceInfo containing file info, column stats, and sample data
-
- Examples:
- >>> # Get stats for parquet files
- >>> result = get_volume_folder_details(
- ... "ai_dev_kit/demo/raw_data/customers",
- ... format="parquet"
- ... )
- >>> info = result.tables[0]
- >>> print(f"Rows: {info.total_rows}, Columns: {len(info.column_details)}")
-
- >>> # Just list files (fast, no data reading)
- >>> result = get_volume_folder_details(
- ... "ai_dev_kit/demo/raw_data/customers",
- ... format="file"
- ... )
- >>> info = result.tables[0]
- >>> print(f"Files: {info.total_files}, Size: {info.total_size_bytes}")
- """
- full_path = _parse_volume_path(volume_path)
- logger.debug(f"Getting volume folder details for: {full_path}, format={format}")
-
- # Extract catalog/schema for the result
- catalog, schema = _extract_catalog_schema_from_volume_path(full_path)
-
- def _make_result(info: DataSourceInfo) -> TableSchemaResult:
- """Helper to wrap DataSourceInfo in TableSchemaResult."""
- return TableSchemaResult(catalog=catalog, schema_name=schema, tables=[info])
-
- # Step 1: List files to verify folder exists and get file info
- files, total_size, error = _list_volume_files(full_path)
-
- if error:
- return _make_result(
- DataSourceInfo(
- name=full_path,
- format=format,
- error=error,
- )
- )
-
- if not files:
- return _make_result(
- DataSourceInfo(
- name=full_path,
- format=format,
- total_files=0,
- error=f"Volume path exists but is empty: {full_path}",
- )
- )
-
- # Count data files (not directories)
- data_files = [f for f in files if not f.is_directory]
- directories = [f for f in files if f.is_directory]
- total_files = len(data_files) if data_files else len(directories)
-
- # Step 2: For format="file", just return file listing
- if format == "file":
- return _make_result(
- DataSourceInfo(
- name=full_path,
- format=format,
- total_files=len(files),
- total_size_bytes=total_size,
- files=files,
- )
- )
-
- # Step 3: For data formats, use TableStatsCollector to read and compute stats
- # Auto-select warehouse if not provided
- if not warehouse_id:
- logger.debug("No warehouse_id provided, selecting best available warehouse")
- warehouse_id = get_best_warehouse()
- if not warehouse_id:
- raise Exception(
- "No SQL warehouse available in the workspace. "
- "Please create a SQL warehouse or start an existing one, "
- "or provide a specific warehouse_id."
- )
- logger.debug(f"Auto-selected warehouse: {warehouse_id}")
-
- # Determine whether to collect stats
- collect_stats = table_stat_level != TableStatLevel.NONE
-
- if not collect_stats:
- # Just get schema without stats - use a simple query
- from .sql_utils.executor import SQLExecutor
-
- executor = SQLExecutor(warehouse_id=warehouse_id)
- volume_ref = f"read_files('{full_path}', format => '{format}')"
-
- try:
- # Get schema from first row
- sample_query = f"SELECT * FROM {volume_ref} LIMIT 1"
- sample_result = executor.execute(sql_query=sample_query, timeout=60)
-
- if not sample_result:
- return _make_result(
- DataSourceInfo(
- name=full_path,
- format=format,
- total_files=total_files,
- total_size_bytes=total_size,
- error="Failed to read volume data - no data returned",
- )
- )
-
- # Get row count
- count_result = executor.execute(
- sql_query=f"SELECT COUNT(*) as total_rows FROM {volume_ref}",
- timeout=60,
- )
- total_rows = count_result[0]["total_rows"] if count_result else 0
-
- # Build column details from first row
- column_details = {}
- for col_name, value in sample_result[0].items():
- if col_name == "_rescued_data":
- continue
- if isinstance(value, bool):
- data_type = "boolean"
- elif isinstance(value, int):
- data_type = "bigint"
- elif isinstance(value, float):
- data_type = "double"
- else:
- data_type = "string"
- column_details[col_name] = ColumnDetail(name=col_name, data_type=data_type)
-
- return _make_result(
- DataSourceInfo(
- name=full_path,
- format=format,
- total_files=total_files,
- total_size_bytes=total_size,
- total_rows=total_rows,
- column_details=column_details,
- )
- )
- except Exception as e:
- return _make_result(
- DataSourceInfo(
- name=full_path,
- format=format,
- total_files=total_files,
- total_size_bytes=total_size,
- error=f"Failed to read volume data: {str(e)}",
- )
- )
-
- # Use TableStatsCollector for full stats
- collector = TableStatsCollector(warehouse_id=warehouse_id)
-
- try:
- column_details, total_rows, sample_data = collector.collect_volume_stats(
- volume_path=full_path,
- format=format,
- )
-
- if not column_details:
- return _make_result(
- DataSourceInfo(
- name=full_path,
- format=format,
- total_files=total_files,
- total_size_bytes=total_size,
- error="Failed to collect volume stats - no columns found",
- )
- )
-
- volume_info = DataSourceInfo(
- name=full_path,
- format=format,
- total_files=total_files,
- total_size_bytes=total_size,
- total_rows=total_rows,
- column_details=column_details,
- sample_data=sample_data if table_stat_level == TableStatLevel.DETAILED else None,
- )
-
- result = _make_result(volume_info)
-
- # Apply stat level transformation
- if table_stat_level == TableStatLevel.SIMPLE:
- return result.keep_basic_stats()
- else:
- return result
-
- except Exception as e:
- return _make_result(
- DataSourceInfo(
- name=full_path,
- format=format,
- total_files=total_files,
- total_size_bytes=total_size,
- error=f"Failed to read volume data: {str(e)}",
- )
- )
diff --git a/databricks-tools-core/databricks_tools_core/sql/warehouse.py b/databricks-tools-core/databricks_tools_core/sql/warehouse.py
deleted file mode 100644
index 7dee2cbd..00000000
--- a/databricks-tools-core/databricks_tools_core/sql/warehouse.py
+++ /dev/null
@@ -1,176 +0,0 @@
-"""
-SQL Warehouse Operations
-
-Functions for listing and selecting SQL warehouses.
-"""
-
-import logging
-from typing import Any, Dict, List, Optional
-
-from databricks.sdk.service.sql import State
-
-from ..auth import get_workspace_client, get_current_username
-
-logger = logging.getLogger(__name__)
-
-
-def list_warehouses(limit: int = 20) -> List[Dict[str, Any]]:
- """
- List SQL warehouses, with online (RUNNING) warehouses first.
-
- Args:
- limit: Maximum number of warehouses to return (default: 20)
-
- Returns:
- List of warehouse dictionaries with keys:
- - id: Warehouse ID
- - name: Warehouse name
- - state: Current state (RUNNING, STOPPED, STARTING, etc.)
- - cluster_size: Size of the warehouse
- - auto_stop_mins: Auto-stop timeout in minutes
- - creator_name: Who created the warehouse
- - warehouse_type: Type of warehouse (PRO, CLASSIC)
- - enable_serverless_compute: Whether serverless compute is enabled
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- warehouses = list(client.warehouses.list())
- except Exception as e:
- raise Exception(f"Failed to list SQL warehouses: {str(e)}. Check that you have permission to view warehouses.")
-
- # Sort: RUNNING first, then by name
- def sort_key(w):
- # RUNNING = 0 (first), others = 1
- state_priority = 0 if w.state == State.RUNNING else 1
- return (state_priority, w.name.lower() if w.name else "")
-
- warehouses.sort(key=sort_key)
-
- # Convert to dicts and limit
- result = []
- for w in warehouses[:limit]:
- result.append(
- {
- "id": w.id,
- "name": w.name,
- "state": w.state.value if w.state else None,
- "cluster_size": w.cluster_size,
- "auto_stop_mins": w.auto_stop_mins,
- "creator_name": w.creator_name,
- "warehouse_type": getattr(w, "warehouse_type", None),
- "enable_serverless_compute": getattr(w, "enable_serverless_compute", None),
- }
- )
-
- return result
-
-
-def _sort_within_tier(warehouses: list, current_user: Optional[str]) -> list:
- """Sort warehouses within a tier: serverless first, then user-owned.
-
- This is a *soft* preference — no warehouses are removed. Within the same
- priority bucket, serverless warehouses are tried first, then user-owned.
-
- Args:
- warehouses: List of SDK warehouse objects.
- current_user: Current user's username/email, or None.
-
- Returns:
- Reordered list (serverless first, then user-owned, then the rest).
- """
- if not warehouses:
- return warehouses
-
- def sort_key(w):
- is_serverless = 0 if getattr(w, "enable_serverless_compute", False) else 1
- user_lower = (current_user or "").lower()
- is_owned = 0 if user_lower and (w.creator_name or "").lower() == user_lower else 1
- return (is_serverless, is_owned)
-
- return sorted(warehouses, key=sort_key)
-
-
-def get_best_warehouse() -> Optional[str]:
- """
- Select the best available SQL warehouse based on priority rules.
-
- Within each priority tier, serverless warehouses are preferred first
- (instant start, auto-scale, no idle costs), then warehouses created
- by the current user. No warehouses are excluded.
-
- Priority:
- 1. Running warehouse named "Shared endpoint" or "dbdemos-shared-endpoint"
- 2. Any running warehouse with 'shared' in name
- 3. Any running warehouse
- 4. Stopped warehouse with 'shared' in name
- 5. Any stopped warehouse
-
- Returns:
- Warehouse ID string, or None if no warehouses available
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
- current_user = get_current_username()
-
- try:
- warehouses = list(client.warehouses.list())
- except Exception as e:
- raise Exception(f"Failed to list SQL warehouses: {str(e)}. Check that you have permission to view warehouses.")
-
- if not warehouses:
- logger.warning("No SQL warehouses found in workspace")
- return None
-
- # Categorize warehouses
- standard_shared = [] # Specific shared endpoint names
- online_shared = [] # Running + 'shared' in name
- online_other = [] # Running, no 'shared'
- offline_shared = [] # Stopped + 'shared' in name
- offline_other = [] # Stopped, no 'shared'
-
- for warehouse in warehouses:
- is_running = warehouse.state == State.RUNNING
- name_lower = warehouse.name.lower() if warehouse.name else ""
- is_shared = "shared" in name_lower
-
- # Check for standard shared endpoint names
- if is_running and warehouse.name in ("Shared endpoint", "dbdemos-shared-endpoint"):
- standard_shared.append(warehouse)
- elif is_running and is_shared:
- online_shared.append(warehouse)
- elif is_running:
- online_other.append(warehouse)
- elif is_shared:
- offline_shared.append(warehouse)
- else:
- offline_other.append(warehouse)
-
- # Within each tier, prefer warehouses owned by the current user
- standard_shared = _sort_within_tier(standard_shared, current_user)
- online_shared = _sort_within_tier(online_shared, current_user)
- online_other = _sort_within_tier(online_other, current_user)
- offline_shared = _sort_within_tier(offline_shared, current_user)
- offline_other = _sort_within_tier(offline_other, current_user)
-
- # Select based on priority
- if standard_shared:
- selected = standard_shared[0]
- elif online_shared:
- selected = online_shared[0]
- elif online_other:
- selected = online_other[0]
- elif offline_shared:
- selected = offline_shared[0]
- elif offline_other:
- selected = offline_other[0]
- else:
- return None
-
- logger.debug(f"Selected warehouse: {selected.name} (state: {selected.state})")
- return selected.id
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/__init__.py b/databricks-tools-core/databricks_tools_core/unity_catalog/__init__.py
deleted file mode 100644
index d5ded19d..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/__init__.py
+++ /dev/null
@@ -1,261 +0,0 @@
-"""
-Unity Catalog Operations
-
-Functions for managing Unity Catalog objects, permissions, storage,
-governance metadata, monitors, and data sharing.
-"""
-
-# Catalogs
-from .catalogs import (
- list_catalogs,
- get_catalog,
- create_catalog,
- update_catalog,
- delete_catalog,
-)
-
-# Schemas
-from .schemas import (
- list_schemas,
- get_schema,
- create_schema,
- update_schema,
- delete_schema,
-)
-
-# Tables
-from .tables import (
- list_tables,
- get_table,
- create_table,
- delete_table,
-)
-
-# Volumes
-from .volumes import (
- list_volumes,
- get_volume,
- create_volume,
- update_volume,
- delete_volume,
-)
-
-# Volume Files
-from .volume_files import (
- VolumeFileInfo,
- VolumeUploadResult,
- VolumeFolderUploadResult,
- VolumeDownloadResult,
- VolumeDeleteResult,
- list_volume_files,
- upload_to_volume,
- download_from_volume,
- delete_from_volume,
- create_volume_directory,
- get_volume_file_metadata,
-)
-
-# Functions
-from .functions_uc import (
- list_functions,
- get_function,
- delete_function,
-)
-
-# Grants
-from .grants import (
- grant_privileges,
- revoke_privileges,
- get_grants,
- get_effective_grants,
-)
-
-# Storage credentials and external locations
-from .storage import (
- list_storage_credentials,
- get_storage_credential,
- create_storage_credential,
- update_storage_credential,
- delete_storage_credential,
- validate_storage_credential,
- list_external_locations,
- get_external_location,
- create_external_location,
- update_external_location,
- delete_external_location,
-)
-
-# Connections (Lakehouse Federation)
-from .connections import (
- list_connections,
- get_connection,
- create_connection,
- update_connection,
- delete_connection,
- create_foreign_catalog,
-)
-
-# Tags and comments
-from .tags import (
- set_tags,
- unset_tags,
- set_comment,
- query_table_tags,
- query_column_tags,
-)
-
-# Security policies (RLS, column masking)
-from .security_policies import (
- create_security_function,
- set_row_filter,
- drop_row_filter,
- set_column_mask,
- drop_column_mask,
-)
-
-# Quality monitors
-from .monitors import (
- create_monitor,
- get_monitor,
- run_monitor_refresh,
- list_monitor_refreshes,
- delete_monitor,
-)
-
-# Metric Views
-from .metric_views import (
- create_metric_view,
- alter_metric_view,
- drop_metric_view,
- describe_metric_view,
- query_metric_view,
- grant_metric_view,
-)
-
-# Delta Sharing
-from .sharing import (
- list_shares,
- get_share,
- create_share,
- add_table_to_share,
- remove_table_from_share,
- delete_share,
- grant_share_to_recipient,
- revoke_share_from_recipient,
- list_recipients,
- get_recipient,
- create_recipient,
- rotate_recipient_token,
- delete_recipient,
- list_providers,
- get_provider,
- list_provider_shares,
-)
-
-__all__ = [
- # Catalogs
- "list_catalogs",
- "get_catalog",
- "create_catalog",
- "update_catalog",
- "delete_catalog",
- # Schemas
- "list_schemas",
- "get_schema",
- "create_schema",
- "update_schema",
- "delete_schema",
- # Tables
- "list_tables",
- "get_table",
- "create_table",
- "delete_table",
- # Volumes
- "list_volumes",
- "get_volume",
- "create_volume",
- "update_volume",
- "delete_volume",
- # Volume Files
- "VolumeFileInfo",
- "VolumeUploadResult",
- "VolumeFolderUploadResult",
- "VolumeDownloadResult",
- "VolumeDeleteResult",
- "list_volume_files",
- "upload_to_volume",
- "download_from_volume",
- "delete_from_volume",
- "create_volume_directory",
- "get_volume_file_metadata",
- # Functions
- "list_functions",
- "get_function",
- "delete_function",
- # Grants
- "grant_privileges",
- "revoke_privileges",
- "get_grants",
- "get_effective_grants",
- # Storage
- "list_storage_credentials",
- "get_storage_credential",
- "create_storage_credential",
- "update_storage_credential",
- "delete_storage_credential",
- "validate_storage_credential",
- "list_external_locations",
- "get_external_location",
- "create_external_location",
- "update_external_location",
- "delete_external_location",
- # Connections
- "list_connections",
- "get_connection",
- "create_connection",
- "update_connection",
- "delete_connection",
- "create_foreign_catalog",
- # Tags and comments
- "set_tags",
- "unset_tags",
- "set_comment",
- "query_table_tags",
- "query_column_tags",
- # Security policies
- "create_security_function",
- "set_row_filter",
- "drop_row_filter",
- "set_column_mask",
- "drop_column_mask",
- # Quality monitors
- "create_monitor",
- "get_monitor",
- "run_monitor_refresh",
- "list_monitor_refreshes",
- "delete_monitor",
- # Metric Views
- "create_metric_view",
- "alter_metric_view",
- "drop_metric_view",
- "describe_metric_view",
- "query_metric_view",
- "grant_metric_view",
- # Sharing
- "list_shares",
- "get_share",
- "create_share",
- "add_table_to_share",
- "remove_table_from_share",
- "delete_share",
- "grant_share_to_recipient",
- "revoke_share_from_recipient",
- "list_recipients",
- "get_recipient",
- "create_recipient",
- "rotate_recipient_token",
- "delete_recipient",
- "list_providers",
- "get_provider",
- "list_provider_shares",
-]
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/catalogs.py b/databricks-tools-core/databricks_tools_core/unity_catalog/catalogs.py
deleted file mode 100644
index 583bbb03..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/catalogs.py
+++ /dev/null
@@ -1,131 +0,0 @@
-"""
-Unity Catalog - Catalog Operations
-
-Functions for managing catalogs in Unity Catalog.
-"""
-
-from typing import Dict, List, Optional
-from databricks.sdk.service.catalog import CatalogInfo, IsolationMode
-
-from ..auth import get_workspace_client
-
-
-def list_catalogs() -> List[CatalogInfo]:
- """
- List all catalogs in Unity Catalog.
-
- Returns:
- List of CatalogInfo objects with catalog metadata
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return list(w.catalogs.list())
-
-
-def get_catalog(catalog_name: str) -> CatalogInfo:
- """
- Get detailed information about a specific catalog.
-
- Args:
- catalog_name: Name of the catalog
-
- Returns:
- CatalogInfo object with catalog metadata including:
- - name, full_name, owner, comment
- - created_at, updated_at
- - storage_location
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.catalogs.get(name=catalog_name)
-
-
-def create_catalog(
- name: str,
- comment: Optional[str] = None,
- storage_root: Optional[str] = None,
- properties: Optional[Dict[str, str]] = None,
-) -> CatalogInfo:
- """
- Create a new catalog in Unity Catalog.
-
- Args:
- name: Name of the catalog to create
- comment: Optional description
- storage_root: Optional managed storage location (cloud URL)
- properties: Optional key-value properties
-
- Returns:
- CatalogInfo object with created catalog metadata
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- kwargs: Dict = {"name": name}
- if comment is not None:
- kwargs["comment"] = comment
- if storage_root is not None:
- kwargs["storage_root"] = storage_root
- if properties is not None:
- kwargs["properties"] = properties
- return w.catalogs.create(**kwargs)
-
-
-def update_catalog(
- catalog_name: str,
- new_name: Optional[str] = None,
- comment: Optional[str] = None,
- owner: Optional[str] = None,
- isolation_mode: Optional[str] = None,
-) -> CatalogInfo:
- """
- Update an existing catalog in Unity Catalog.
-
- Args:
- catalog_name: Current name of the catalog
- new_name: New name for the catalog
- comment: New comment/description
- owner: New owner (user or group)
- isolation_mode: Isolation mode ("OPEN" or "ISOLATED")
-
- Returns:
- CatalogInfo object with updated catalog metadata
-
- Raises:
- ValueError: If no fields are provided to update
- DatabricksError: If API request fails
- """
- if not any([new_name, comment, owner, isolation_mode]):
- raise ValueError("At least one field must be provided to update")
-
- w = get_workspace_client()
- kwargs: Dict = {"name": catalog_name}
- if new_name is not None:
- kwargs["new_name"] = new_name
- if comment is not None:
- kwargs["comment"] = comment
- if owner is not None:
- kwargs["owner"] = owner
- if isolation_mode is not None:
- kwargs["isolation_mode"] = IsolationMode(isolation_mode)
- return w.catalogs.update(**kwargs)
-
-
-def delete_catalog(catalog_name: str, force: bool = False) -> None:
- """
- Delete a catalog from Unity Catalog.
-
- Args:
- catalog_name: Name of the catalog to delete
- force: If True, force deletion even if catalog contains schemas
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.catalogs.delete(name=catalog_name, force=force)
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/connections.py b/databricks-tools-core/databricks_tools_core/unity_catalog/connections.py
deleted file mode 100644
index 06d492da..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/connections.py
+++ /dev/null
@@ -1,185 +0,0 @@
-"""
-Unity Catalog - Connection Operations
-
-Functions for managing Lakehouse Federation foreign connections.
-"""
-
-import re
-from typing import Any, Dict, List, Optional
-from databricks.sdk.service.catalog import ConnectionInfo, ConnectionType
-
-from ..auth import get_workspace_client
-
-_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z0-9_][a-zA-Z0-9_.\-]*$")
-
-
-def _validate_identifier(name: str) -> str:
- """Validate a SQL identifier to prevent injection."""
- if not _IDENTIFIER_PATTERN.match(name):
- raise ValueError(f"Invalid SQL identifier: '{name}'")
- return name
-
-
-def _execute_uc_sql(sql_query: str, warehouse_id: Optional[str] = None) -> List[Dict[str, Any]]:
- """Execute SQL using the existing execute_sql infrastructure."""
- from ..sql.sql import execute_sql
-
- return execute_sql(sql_query=sql_query, warehouse_id=warehouse_id)
-
-
-def list_connections() -> List[ConnectionInfo]:
- """
- List all foreign connections.
-
- Returns:
- List of ConnectionInfo objects
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return list(w.connections.list())
-
-
-def get_connection(name: str) -> ConnectionInfo:
- """
- Get a specific foreign connection.
-
- Args:
- name: Name of the connection
-
- Returns:
- ConnectionInfo with connection details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.connections.get(name=name)
-
-
-def create_connection(
- name: str,
- connection_type: str,
- options: Dict[str, str],
- comment: Optional[str] = None,
-) -> ConnectionInfo:
- """
- Create a foreign connection for Lakehouse Federation.
-
- Args:
- name: Name for the connection
- connection_type: Type of connection. Valid values:
- "SNOWFLAKE", "POSTGRESQL", "MYSQL", "SQLSERVER", "BIGQUERY",
- "REDSHIFT", "SQLDW"
- options: Connection options dict. Common keys:
- - host: Database hostname
- - port: Database port
- - user: Username
- - password: Password (use secret('scope', 'key') for security)
- - database: Database name
- - warehouse: Snowflake warehouse
- - httpPath: For some connectors
- comment: Optional description
-
- Returns:
- ConnectionInfo with created connection details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- kwargs: Dict[str, Any] = {
- "name": name,
- "connection_type": ConnectionType(connection_type.upper()),
- "options": options,
- }
- if comment is not None:
- kwargs["comment"] = comment
- return w.connections.create(**kwargs)
-
-
-def update_connection(
- name: str,
- options: Optional[Dict[str, str]] = None,
- new_name: Optional[str] = None,
- owner: Optional[str] = None,
-) -> ConnectionInfo:
- """
- Update a foreign connection.
-
- Args:
- name: Current name of the connection
- options: New connection options
- new_name: New name for the connection
- owner: New owner
-
- Returns:
- ConnectionInfo with updated details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- kwargs: Dict[str, Any] = {"name": name}
- if options is not None:
- kwargs["options"] = options
- if new_name is not None:
- kwargs["new_name"] = new_name
- if owner is not None:
- kwargs["owner"] = owner
- return w.connections.update(**kwargs)
-
-
-def delete_connection(name: str) -> None:
- """
- Delete a foreign connection.
-
- Args:
- name: Name of the connection to delete
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.connections.delete(name=name)
-
-
-def create_foreign_catalog(
- catalog_name: str,
- connection_name: str,
- catalog_options: Optional[Dict[str, str]] = None,
- comment: Optional[str] = None,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Create a foreign catalog using a connection (Lakehouse Federation).
-
- Args:
- catalog_name: Name for the new foreign catalog
- connection_name: Name of the connection to use
- catalog_options: Options (e.g., {"database": "my_db"})
- comment: Optional description
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- Dict with status and executed SQL
- """
- _validate_identifier(catalog_name)
- _validate_identifier(connection_name)
-
- sql = f"CREATE FOREIGN CATALOG {catalog_name} USING CONNECTION {connection_name}"
- if catalog_options:
- opts = ", ".join(f"'{k}' = '{v}'" for k, v in catalog_options.items())
- sql += f" OPTIONS ({opts})"
- if comment:
- escaped = comment.replace("'", "\\'")
- sql += f" COMMENT '{escaped}'"
-
- _execute_uc_sql(sql, warehouse_id=warehouse_id)
- return {
- "status": "created",
- "catalog_name": catalog_name,
- "connection_name": connection_name,
- "sql": sql,
- }
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/functions_uc.py b/databricks-tools-core/databricks_tools_core/unity_catalog/functions_uc.py
deleted file mode 100644
index abf52925..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/functions_uc.py
+++ /dev/null
@@ -1,70 +0,0 @@
-"""
-Unity Catalog - Function Operations
-
-Functions for managing UC functions (UDFs).
-Note: Creating functions requires SQL (CREATE FUNCTION statement).
-Use execute_sql or the security_policies module for function creation.
-"""
-
-from typing import List
-from databricks.sdk.service.catalog import FunctionInfo
-
-from ..auth import get_workspace_client
-
-
-def list_functions(catalog_name: str, schema_name: str) -> List[FunctionInfo]:
- """
- List all functions in a schema.
-
- Args:
- catalog_name: Name of the catalog
- schema_name: Name of the schema
-
- Returns:
- List of FunctionInfo objects with function metadata
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return list(
- w.functions.list(
- catalog_name=catalog_name,
- schema_name=schema_name,
- )
- )
-
-
-def get_function(full_function_name: str) -> FunctionInfo:
- """
- Get detailed information about a specific function.
-
- Args:
- full_function_name: Full function name (catalog.schema.function format)
-
- Returns:
- FunctionInfo object with function metadata including:
- - name, full_name, catalog_name, schema_name
- - input_params, return_params, routine_body
- - owner, comment, created_at
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.functions.get(name=full_function_name)
-
-
-def delete_function(full_function_name: str, force: bool = False) -> None:
- """
- Delete a function from Unity Catalog.
-
- Args:
- full_function_name: Full function name (catalog.schema.function format)
- force: If True, force deletion
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.functions.delete(name=full_function_name, force=force)
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/grants.py b/databricks-tools-core/databricks_tools_core/unity_catalog/grants.py
deleted file mode 100644
index b603b2a0..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/grants.py
+++ /dev/null
@@ -1,239 +0,0 @@
-"""
-Unity Catalog - Grant Operations
-
-Functions for managing permissions on Unity Catalog securables.
-"""
-
-from typing import Any, Dict, List, Optional
-from databricks.sdk.service.catalog import (
- Privilege,
- PermissionsChange,
-)
-
-from ..auth import get_workspace_client
-
-
-def _parse_securable_type(securable_type: str) -> str:
- """Parse securable type string to the API-expected string value.
-
- The GrantsAPI methods expect securable_type as a plain string,
- not a SecurableType enum instance.
- """
- valid_types = {
- "catalog",
- "schema",
- "table",
- "volume",
- "function",
- "storage_credential",
- "external_location",
- "connection",
- "share",
- "metastore",
- }
- key = securable_type.lower().replace("-", "_").replace(" ", "_")
- if key not in valid_types:
- raise ValueError(f"Invalid securable_type: '{securable_type}'. Valid types: {sorted(valid_types)}")
- return key
-
-
-def _parse_privileges(privileges: List[str]) -> List[Privilege]:
- """Parse privilege strings to SDK enum values."""
- result = []
- for p in privileges:
- try:
- result.append(Privilege(p.upper().replace(" ", "_")))
- except ValueError:
- raise ValueError(
- f"Invalid privilege: '{p}'. "
- f"Common privileges: SELECT, MODIFY, CREATE_TABLE, CREATE_SCHEMA, "
- f"USE_CATALOG, USE_SCHEMA, ALL_PRIVILEGES, EXECUTE, "
- f"READ_VOLUME, WRITE_VOLUME, CREATE_VOLUME, CREATE_FUNCTION"
- )
- return result
-
-
-def grant_privileges(
- securable_type: str,
- full_name: str,
- principal: str,
- privileges: List[str],
-) -> Dict[str, Any]:
- """
- Grant privileges to a principal on a UC securable.
-
- Args:
- securable_type: Type of object (catalog, schema, table, volume, function,
- storage_credential, external_location, connection, share)
- full_name: Full name of the securable object
- principal: User, group, or service principal to grant to
- privileges: List of privileges to grant (e.g., ["SELECT", "MODIFY"])
-
- Returns:
- Dict with grant result including privilege assignments
-
- Raises:
- ValueError: If securable_type or privileges are invalid
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- stype = _parse_securable_type(securable_type)
- privs = _parse_privileges(privileges)
-
- result = w.grants.update(
- securable_type=stype,
- full_name=full_name,
- changes=[
- PermissionsChange(
- principal=principal,
- add=privs,
- )
- ],
- )
- return {
- "status": "granted",
- "securable_type": securable_type,
- "full_name": full_name,
- "principal": principal,
- "privileges": privileges,
- "assignments": [
- {"principal": a.principal, "privileges": [p.value for p in (a.privileges or [])]}
- for a in (result.privilege_assignments or [])
- ],
- }
-
-
-def revoke_privileges(
- securable_type: str,
- full_name: str,
- principal: str,
- privileges: List[str],
-) -> Dict[str, Any]:
- """
- Revoke privileges from a principal on a UC securable.
-
- Args:
- securable_type: Type of object (catalog, schema, table, volume, function,
- storage_credential, external_location, connection, share)
- full_name: Full name of the securable object
- principal: User, group, or service principal to revoke from
- privileges: List of privileges to revoke (e.g., ["SELECT", "MODIFY"])
-
- Returns:
- Dict with revoke result
-
- Raises:
- ValueError: If securable_type or privileges are invalid
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- stype = _parse_securable_type(securable_type)
- privs = _parse_privileges(privileges)
-
- result = w.grants.update(
- securable_type=stype,
- full_name=full_name,
- changes=[
- PermissionsChange(
- principal=principal,
- remove=privs,
- )
- ],
- )
- return {
- "status": "revoked",
- "securable_type": securable_type,
- "full_name": full_name,
- "principal": principal,
- "privileges": privileges,
- "assignments": [
- {"principal": a.principal, "privileges": [p.value for p in (a.privileges or [])]}
- for a in (result.privilege_assignments or [])
- ],
- }
-
-
-def get_grants(
- securable_type: str,
- full_name: str,
- principal: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Get current permission grants on a UC securable.
-
- Args:
- securable_type: Type of object
- full_name: Full name of the securable object
- principal: Optional - filter grants for specific principal
-
- Returns:
- Dict with privilege assignments list
-
- Raises:
- ValueError: If securable_type is invalid
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- stype = _parse_securable_type(securable_type)
-
- result = w.grants.get(
- securable_type=stype,
- full_name=full_name,
- principal=principal,
- )
- return {
- "securable_type": securable_type,
- "full_name": full_name,
- "assignments": [
- {"principal": a.principal, "privileges": [p.value for p in (a.privileges or [])]}
- for a in (result.privilege_assignments or [])
- ],
- }
-
-
-def get_effective_grants(
- securable_type: str,
- full_name: str,
- principal: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Get effective (inherited + direct) permission grants on a UC securable.
-
- Args:
- securable_type: Type of object
- full_name: Full name of the securable object
- principal: Optional - filter grants for specific principal
-
- Returns:
- Dict with effective privilege assignments (includes inherited permissions)
-
- Raises:
- ValueError: If securable_type is invalid
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- stype = _parse_securable_type(securable_type)
-
- result = w.grants.get_effective(
- securable_type=stype,
- full_name=full_name,
- principal=principal,
- )
- return {
- "securable_type": securable_type,
- "full_name": full_name,
- "effective_assignments": [
- {
- "principal": a.principal,
- "privileges": [
- {
- "privilege": p.privilege.value if p.privilege else None,
- "inherited_from_name": p.inherited_from_name,
- "inherited_from_type": p.inherited_from_type.value if p.inherited_from_type else None,
- }
- for p in (a.privileges or [])
- ],
- }
- for a in (result.privilege_assignments or [])
- ],
- }
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/metric_views.py b/databricks-tools-core/databricks_tools_core/unity_catalog/metric_views.py
deleted file mode 100644
index 42238828..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/metric_views.py
+++ /dev/null
@@ -1,412 +0,0 @@
-"""
-Unity Catalog - Metric View Operations
-
-Functions for creating, altering, describing, dropping, and querying
-Unity Catalog metric views via SQL DDL.
-
-Metric views are defined in YAML and executed through the Statement Execution API
-since there is no dedicated REST API or Python SDK support for metric views.
-"""
-
-import logging
-import textwrap
-from typing import Any, Dict, List, Optional
-
-from ..sql.sql import execute_sql
-
-logger = logging.getLogger(__name__)
-
-
-def _build_yaml_block(
- source: str,
- dimensions: List[Dict[str, str]],
- measures: List[Dict[str, str]],
- version: str = "1.1",
- comment: Optional[str] = None,
- filter_expr: Optional[str] = None,
- joins: Optional[List[Dict[str, Any]]] = None,
- materialization: Optional[Dict[str, Any]] = None,
-) -> str:
- """Build the YAML definition block for a metric view.
-
- Args:
- source: Source table, view, or SQL query (three-level namespace).
- dimensions: List of dimension dicts with keys: name, expr, comment (optional).
- measures: List of measure dicts with keys: name, expr, comment (optional).
- version: YAML spec version (default "1.1" for DBR 17.2+).
- comment: Optional description of the metric view.
- filter_expr: Optional SQL boolean filter expression applied to all queries.
- joins: Optional list of join dicts with keys: name, source, on/using, joins (nested).
- materialization: Optional materialization config dict.
-
- Returns:
- The YAML string to embed in the SQL $$ block.
- """
- lines = [f"version: {version}"]
-
- if comment:
- lines.append(f'comment: "{comment}"')
-
- lines.append(f"source: {source}")
-
- if filter_expr:
- lines.append(f"filter: {filter_expr}")
-
- # Joins
- if joins:
- lines.append("joins:")
- _render_joins(lines, joins, indent=2)
-
- # Dimensions
- lines.append("dimensions:")
- for dim in dimensions:
- lines.append(f" - name: {dim['name']}")
- lines.append(f" expr: {dim['expr']}")
- if dim.get("comment"):
- lines.append(f' comment: "{dim["comment"]}"')
-
- # Measures
- lines.append("measures:")
- for measure in measures:
- lines.append(f" - name: {measure['name']}")
- lines.append(f" expr: {measure['expr']}")
- if measure.get("comment"):
- lines.append(f' comment: "{measure["comment"]}"')
-
- # Materialization
- if materialization:
- lines.append("materialization:")
- if materialization.get("schedule"):
- lines.append(f" schedule: {materialization['schedule']}")
- if materialization.get("mode"):
- lines.append(f" mode: {materialization['mode']}")
- if materialization.get("materialized_views"):
- lines.append(" materialized_views:")
- for mv in materialization["materialized_views"]:
- lines.append(f" - name: {mv['name']}")
- lines.append(f" type: {mv['type']}")
- if mv.get("dimensions"):
- lines.append(" dimensions:")
- for d in mv["dimensions"]:
- lines.append(f" - {d}")
- if mv.get("measures"):
- lines.append(" measures:")
- for m in mv["measures"]:
- lines.append(f" - {m}")
-
- return "\n".join(lines)
-
-
-def _render_joins(lines: List[str], joins: List[Dict[str, Any]], indent: int) -> None:
- """Recursively render join definitions into YAML lines."""
- prefix = " " * indent
- for join in joins:
- lines.append(f"{prefix}- name: {join['name']}")
- lines.append(f"{prefix} source: {join['source']}")
- if join.get("on"):
- lines.append(f"{prefix} on: {join['on']}")
- if join.get("using"):
- lines.append(f"{prefix} using:")
- for col in join["using"]:
- lines.append(f"{prefix} - {col}")
- # Nested joins (snowflake schema)
- if join.get("joins"):
- lines.append(f"{prefix} joins:")
- _render_joins(lines, join["joins"], indent + 4)
-
-
-def create_metric_view(
- full_name: str,
- source: str,
- dimensions: List[Dict[str, str]],
- measures: List[Dict[str, str]],
- version: str = "1.1",
- comment: Optional[str] = None,
- filter_expr: Optional[str] = None,
- joins: Optional[List[Dict[str, Any]]] = None,
- materialization: Optional[Dict[str, Any]] = None,
- or_replace: bool = False,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Create a metric view in Unity Catalog.
-
- Metric views define reusable, governed business metrics in YAML. They
- separate measure definitions from dimension groupings so metrics can
- be queried flexibly across any dimension at runtime.
-
- Requires Databricks Runtime 17.2+ and a SQL warehouse.
-
- Args:
- full_name: Three-level name (catalog.schema.metric_view_name).
- source: Source table, view, or SQL query.
- dimensions: List of dimension dicts. Each must have:
- - name: Display name for the dimension.
- - expr: SQL expression (column reference or transformation).
- - comment: (optional) Description of the dimension.
- measures: List of measure dicts. Each must have:
- - name: Display name for the measure.
- - expr: Aggregate SQL expression (e.g. "SUM(amount)").
- - comment: (optional) Description of the measure.
- version: YAML spec version ("1.1" for DBR 17.2+, "0.1" for DBR 16.4-17.1).
- comment: Optional description of the metric view.
- filter_expr: Optional SQL boolean filter applied to all queries.
- joins: Optional list of join definitions for star/snowflake schemas.
- Each dict: name, source, on (or using), joins (nested, optional).
- materialization: Optional materialization config dict with keys:
- - schedule: e.g. "every 6 hours"
- - mode: "relaxed" (only supported value)
- - materialized_views: list of {name, type, dimensions, measures}
- or_replace: If True, uses CREATE OR REPLACE (default: False).
- warehouse_id: Optional SQL warehouse ID.
-
- Returns:
- Dict with status, full_name, and the generated SQL.
-
- Raises:
- ValueError: If dimensions or measures are empty.
- SQLExecutionError: If query execution fails.
- """
- if not dimensions:
- raise ValueError("At least one dimension is required")
- if not measures:
- raise ValueError("At least one measure is required")
-
- yaml_block = _build_yaml_block(
- source=source,
- dimensions=dimensions,
- measures=measures,
- version=version,
- comment=comment,
- filter_expr=filter_expr,
- joins=joins,
- materialization=materialization,
- )
-
- create_keyword = "CREATE OR REPLACE" if or_replace else "CREATE"
- sql = textwrap.dedent(f"""\
- {create_keyword} VIEW {full_name}
- WITH METRICS
- LANGUAGE YAML
- AS $$
- {textwrap.indent(yaml_block, " ")}
- $$""")
-
- execute_sql(sql_query=sql, warehouse_id=warehouse_id)
-
- return {
- "status": "created",
- "full_name": full_name,
- "sql": sql,
- }
-
-
-def alter_metric_view(
- full_name: str,
- source: str,
- dimensions: List[Dict[str, str]],
- measures: List[Dict[str, str]],
- version: str = "1.1",
- comment: Optional[str] = None,
- filter_expr: Optional[str] = None,
- joins: Optional[List[Dict[str, Any]]] = None,
- materialization: Optional[Dict[str, Any]] = None,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Alter an existing metric view with a new YAML definition.
- Args:
- full_name: Three-level name (catalog.schema.metric_view_name).
- source: Source table, view, or SQL query.
- dimensions: List of dimension dicts (name, expr, comment).
- measures: List of measure dicts (name, expr, comment).
- version: YAML spec version.
- comment: Optional description.
- filter_expr: Optional SQL boolean filter.
- joins: Optional join definitions.
- materialization: Optional materialization config.
- warehouse_id: Optional SQL warehouse ID.
-
- Returns:
- Dict with status, full_name, and the generated SQL.
- """
- yaml_block = _build_yaml_block(
- source=source,
- dimensions=dimensions,
- measures=measures,
- version=version,
- comment=comment,
- filter_expr=filter_expr,
- joins=joins,
- materialization=materialization,
- )
-
- sql = textwrap.dedent(f"""\
- ALTER VIEW {full_name}
- AS $$
- {textwrap.indent(yaml_block, " ")}
- $$""")
-
- execute_sql(sql_query=sql, warehouse_id=warehouse_id)
-
- return {
- "status": "altered",
- "full_name": full_name,
- "sql": sql,
- }
-
-
-def drop_metric_view(
- full_name: str,
- if_exists: bool = True,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Drop a metric view.
-
- Args:
- full_name: Three-level name (catalog.schema.metric_view_name).
- if_exists: If True, does not error if the view doesn't exist.
- warehouse_id: Optional SQL warehouse ID.
-
- Returns:
- Dict with status and full_name.
- """
- exists_clause = " IF EXISTS" if if_exists else ""
- sql = f"DROP VIEW{exists_clause} {full_name}"
-
- execute_sql(sql_query=sql, warehouse_id=warehouse_id)
-
- return {
- "status": "dropped",
- "full_name": full_name,
- "sql": sql,
- }
-
-
-def describe_metric_view(
- full_name: str,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Get the definition and metadata of a metric view.
-
- Uses DESCRIBE TABLE EXTENDED ... AS JSON to retrieve the full
- YAML definition, dimensions, measures, and other metadata.
-
- Args:
- full_name: Three-level name (catalog.schema.metric_view_name).
- warehouse_id: Optional SQL warehouse ID.
-
- Returns:
- Dict with the metric view definition and metadata.
- """
- sql = f"DESCRIBE TABLE EXTENDED {full_name} AS JSON"
- result = execute_sql(sql_query=sql, warehouse_id=warehouse_id)
- return {"full_name": full_name, "definition": result}
-
-
-def query_metric_view(
- full_name: str,
- measures: List[str],
- dimensions: Optional[List[str]] = None,
- where: Optional[str] = None,
- order_by: Optional[str] = None,
- limit: Optional[int] = None,
- warehouse_id: Optional[str] = None,
-) -> List[Dict[str, Any]]:
- """
- Query a metric view by selecting dimensions and measures.
-
- Measures must be wrapped in MEASURE() aggregate function.
- Dimensions are used in GROUP BY.
-
- Args:
- full_name: Three-level name (catalog.schema.metric_view_name).
- measures: List of measure names to query (wrapped in MEASURE() automatically).
- dimensions: Optional list of dimension names/expressions to group by.
- where: Optional WHERE clause filter (without the WHERE keyword).
- order_by: Optional ORDER BY clause (without the ORDER BY keyword).
- Use "ALL" for ORDER BY ALL.
- limit: Optional row limit.
- warehouse_id: Optional SQL warehouse ID.
-
- Returns:
- List of row dicts with query results.
-
- Example:
- query_metric_view(
- full_name="catalog.schema.orders_metrics",
- measures=["Total Revenue", "Order Count"],
- dimensions=["Order Month", "Order Status"],
- where="extract(year FROM `Order Month`) = 2024",
- order_by="ALL",
- limit=100,
- )
- """
- select_parts = []
-
- if dimensions:
- for dim in dimensions:
- # Backtick-quote dimension names that contain spaces
- if " " in dim and not dim.startswith("`"):
- select_parts.append(f"`{dim}`")
- else:
- select_parts.append(dim)
-
- for measure in measures:
- # Backtick-quote measure names that contain spaces
- if " " in measure and not measure.startswith("`"):
- select_parts.append(f"MEASURE(`{measure}`)")
- else:
- select_parts.append(f"MEASURE({measure})")
-
- select_clause = ",\n ".join(select_parts)
- sql = f"SELECT\n {select_clause}\nFROM {full_name}"
-
- if where:
- sql += f"\nWHERE {where}"
-
- if dimensions:
- sql += "\nGROUP BY ALL"
-
- if order_by:
- sql += f"\nORDER BY {order_by}"
-
- if limit:
- sql += f"\nLIMIT {limit}"
-
- return execute_sql(sql_query=sql, warehouse_id=warehouse_id)
-
-
-def grant_metric_view(
- full_name: str,
- principal: str,
- privileges: Optional[List[str]] = None,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Grant privileges on a metric view.
-
- Args:
- full_name: Three-level name (catalog.schema.metric_view_name).
- principal: User, group, or service principal to grant to.
- privileges: List of privileges (default: ["SELECT"]).
- warehouse_id: Optional SQL warehouse ID.
-
- Returns:
- Dict with status and executed SQL.
- """
- privs = privileges or ["SELECT"]
- priv_str = ", ".join(privs)
- sql = f"GRANT {priv_str} ON {full_name} TO `{principal}`"
-
- execute_sql(sql_query=sql, warehouse_id=warehouse_id)
-
- return {
- "status": "granted",
- "full_name": full_name,
- "principal": principal,
- "privileges": privs,
- "sql": sql,
- }
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/monitors.py b/databricks-tools-core/databricks_tools_core/unity_catalog/monitors.py
deleted file mode 100644
index 0b146f3e..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/monitors.py
+++ /dev/null
@@ -1,155 +0,0 @@
-"""
-Unity Catalog - Quality Monitor Operations
-
-Functions for managing Lakehouse Monitors (data quality monitors).
-"""
-
-from typing import Any, Dict, List, Optional
-
-from ..auth import get_workspace_client
-
-
-def create_monitor(
- table_name: str,
- output_schema_name: str,
- monitor_type: str = "snapshot",
- assets_dir: Optional[str] = None,
- time_series_timestamp_col: Optional[str] = None,
- time_series_granularities: Optional[List[str]] = None,
- schedule_cron: Optional[str] = None,
- schedule_timezone: str = "UTC",
-) -> Dict[str, Any]:
- """
- Create a quality monitor on a table.
-
- Args:
- table_name: Full table name to monitor (catalog.schema.table)
- output_schema_name: Schema for monitor output tables (catalog.schema)
- monitor_type: Type of monitor: "snapshot" (default), "time_series", or "inference"
- assets_dir: Workspace path for monitor assets. If not provided,
- auto-generated under /Workspace/Users/{user}/databricks_lakehouse_monitoring/
- time_series_timestamp_col: Timestamp column (required for time_series type)
- time_series_granularities: Granularities list (e.g., ["1 day"]) for time_series
- schedule_cron: Quartz cron expression for schedule (e.g., "0 0 * * * ?")
- schedule_timezone: Timezone for schedule (default: "UTC")
-
- Returns:
- Dict with monitor details
-
- Raises:
- ValueError: If monitor_type is invalid or required params missing
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
-
- # assets_dir is required by the SDK; generate a default if not provided
- if assets_dir is None:
- user = w.current_user.me()
- safe_table = table_name.replace(".", "_")
- assets_dir = f"/Workspace/Users/{user.user_name}/databricks_lakehouse_monitoring/{safe_table}"
-
- kwargs: Dict[str, Any] = {
- "table_name": table_name,
- "output_schema_name": output_schema_name,
- "assets_dir": assets_dir,
- }
-
- # Configure monitor type (exactly one is required by the API)
- if monitor_type == "snapshot":
- from databricks.sdk.service.catalog import MonitorSnapshot
-
- kwargs["snapshot"] = MonitorSnapshot()
- elif monitor_type == "time_series":
- if not time_series_timestamp_col:
- raise ValueError("time_series_timestamp_col is required for time_series monitors")
- from databricks.sdk.service.catalog import MonitorTimeSeries
-
- kwargs["time_series"] = MonitorTimeSeries(
- timestamp_col=time_series_timestamp_col,
- granularities=time_series_granularities or ["1 day"],
- )
- elif monitor_type == "inference":
- from databricks.sdk.service.catalog import MonitorInferenceLog
-
- kwargs["inference_log"] = MonitorInferenceLog()
- else:
- raise ValueError(f"Invalid monitor_type: '{monitor_type}'. Valid types: snapshot, time_series, inference")
-
- if schedule_cron is not None:
- from databricks.sdk.service.catalog import MonitorCronSchedule
-
- kwargs["schedule"] = MonitorCronSchedule(
- quartz_cron_expression=schedule_cron,
- timezone_id=schedule_timezone,
- )
-
- result = w.quality_monitors.create(**kwargs)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def get_monitor(table_name: str) -> Dict[str, Any]:
- """
- Get the quality monitor on a table.
-
- Args:
- table_name: Full table name (catalog.schema.table)
-
- Returns:
- Dict with monitor details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- result = w.quality_monitors.get(table_name=table_name)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def run_monitor_refresh(table_name: str) -> Dict[str, Any]:
- """
- Trigger a refresh of a quality monitor.
-
- Args:
- table_name: Full table name (catalog.schema.table)
-
- Returns:
- Dict with refresh details including refresh_id
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- result = w.quality_monitors.run_refresh(table_name=table_name)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def list_monitor_refreshes(table_name: str) -> List[Dict[str, Any]]:
- """
- List refresh history for a quality monitor.
-
- Args:
- table_name: Full table name (catalog.schema.table)
-
- Returns:
- List of dicts with refresh history
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- result = w.quality_monitors.list_refreshes(table_name=table_name)
- return [r.as_dict() if hasattr(r, "as_dict") else vars(r) for r in (result.refreshes or [])]
-
-
-def delete_monitor(table_name: str) -> None:
- """
- Delete the quality monitor on a table.
-
- Args:
- table_name: Full table name (catalog.schema.table)
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.quality_monitors.delete(table_name=table_name)
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/schemas.py b/databricks-tools-core/databricks_tools_core/unity_catalog/schemas.py
deleted file mode 100644
index 1ee1697e..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/schemas.py
+++ /dev/null
@@ -1,109 +0,0 @@
-"""
-Unity Catalog - Schema Operations
-
-Functions for managing schemas (databases) in Unity Catalog.
-"""
-
-from typing import List, Optional
-from databricks.sdk.service.catalog import SchemaInfo
-
-from ..auth import get_workspace_client
-
-
-def list_schemas(catalog_name: str) -> List[SchemaInfo]:
- """
- List all schemas in a catalog.
-
- Args:
- catalog_name: Name of the catalog
-
- Returns:
- List of SchemaInfo objects with schema metadata
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return list(w.schemas.list(catalog_name=catalog_name))
-
-
-def get_schema(full_schema_name: str) -> SchemaInfo:
- """
- Get detailed information about a specific schema.
-
- Args:
- full_schema_name: Full schema name (catalog.schema format)
-
- Returns:
- SchemaInfo object with schema metadata including:
- - name, full_name, catalog_name, owner, comment
- - created_at, updated_at
- - storage_location
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.schemas.get(full_name=full_schema_name)
-
-
-def create_schema(catalog_name: str, schema_name: str, comment: Optional[str] = None) -> SchemaInfo:
- """
- Create a new schema in Unity Catalog.
-
- Args:
- catalog_name: Name of the catalog
- schema_name: Name of the schema to create
- comment: Optional description of the schema
-
- Returns:
- SchemaInfo object with created schema metadata
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.schemas.create(name=schema_name, catalog_name=catalog_name, comment=comment)
-
-
-def update_schema(
- full_schema_name: str,
- new_name: Optional[str] = None,
- comment: Optional[str] = None,
- owner: Optional[str] = None,
-) -> SchemaInfo:
- """
- Update an existing schema in Unity Catalog.
-
- Args:
- full_schema_name: Full schema name (catalog.schema format)
- new_name: New name for the schema
- comment: New comment/description
- owner: New owner
-
- Returns:
- SchemaInfo object with updated schema metadata
-
- Raises:
- ValueError: If no fields are provided to update
- DatabricksError: If API request fails
- """
- if not any([new_name, comment, owner]):
- raise ValueError("At least one field (new_name, comment, or owner) must be provided")
-
- w = get_workspace_client()
- return w.schemas.update(full_name=full_schema_name, new_name=new_name, comment=comment, owner=owner)
-
-
-def delete_schema(full_schema_name: str) -> None:
- """
- Delete a schema from Unity Catalog.
-
- Args:
- full_schema_name: Full schema name (catalog.schema format)
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.schemas.delete(full_name=full_schema_name)
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/security_policies.py b/databricks-tools-core/databricks_tools_core/unity_catalog/security_policies.py
deleted file mode 100644
index af77f130..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/security_policies.py
+++ /dev/null
@@ -1,184 +0,0 @@
-"""
-Unity Catalog - Security Policy Operations
-
-Functions for managing row-level security (row filters) and column masking.
-All operations are SQL-based via execute_sql.
-"""
-
-import re
-from typing import Any, Dict, List, Optional
-
-_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z0-9_][a-zA-Z0-9_.\-]*$")
-
-
-def _validate_identifier(name: str) -> str:
- """Validate a SQL identifier to prevent injection."""
- if not _IDENTIFIER_PATTERN.match(name):
- raise ValueError(f"Invalid SQL identifier: '{name}'")
- return name
-
-
-def _execute_uc_sql(sql_query: str, warehouse_id: Optional[str] = None) -> List[Dict[str, Any]]:
- """Execute SQL using the existing execute_sql infrastructure."""
- from ..sql.sql import execute_sql
-
- return execute_sql(sql_query=sql_query, warehouse_id=warehouse_id)
-
-
-def create_security_function(
- function_name: str,
- parameter_name: str,
- parameter_type: str,
- return_type: str,
- function_body: str,
- comment: Optional[str] = None,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Create a SQL function for use as a row filter or column mask.
-
- Args:
- function_name: Full function name (catalog.schema.function)
- parameter_name: Input parameter name (e.g., "val")
- parameter_type: Input parameter type (e.g., "STRING")
- return_type: Return type ("BOOLEAN" for row filters, data type for masks)
- function_body: SQL function body (e.g., "RETURN IF(IS_ACCOUNT_GROUP_MEMBER('admins'), val, '***')")
- comment: Optional function description
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- Dict with status and executed SQL
-
- Note:
- function_body accepts raw SQL. Only use with trusted input.
- """
- _validate_identifier(function_name)
- _validate_identifier(parameter_name)
-
- sql_parts = [
- f"CREATE OR REPLACE FUNCTION {function_name}({parameter_name} {parameter_type})",
- f"RETURNS {return_type}",
- ]
- if comment:
- escaped = comment.replace("'", "\\'")
- sql_parts.append(f"COMMENT '{escaped}'")
- sql_parts.append(function_body)
-
- sql = "\n".join(sql_parts)
- _execute_uc_sql(sql, warehouse_id=warehouse_id)
- return {"status": "created", "function_name": function_name, "sql": sql}
-
-
-def set_row_filter(
- table_name: str,
- filter_function: str,
- filter_columns: List[str],
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Apply a row filter function to a table.
-
- The filter function must return BOOLEAN and accept the specified columns
- as input parameters.
-
- Args:
- table_name: Full table name (catalog.schema.table)
- filter_function: Full function name (catalog.schema.function)
- filter_columns: Column names passed to the filter function
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- Dict with status and executed SQL
- """
- _validate_identifier(table_name)
- _validate_identifier(filter_function)
- validated_cols = [_validate_identifier(c) for c in filter_columns]
- cols_str = ", ".join(validated_cols)
-
- sql = f"ALTER TABLE {table_name} SET ROW FILTER {filter_function} ON ({cols_str})"
- _execute_uc_sql(sql, warehouse_id=warehouse_id)
- return {
- "status": "row_filter_set",
- "table": table_name,
- "function": filter_function,
- "sql": sql,
- }
-
-
-def drop_row_filter(
- table_name: str,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Remove the row filter from a table.
-
- Args:
- table_name: Full table name (catalog.schema.table)
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- Dict with status and executed SQL
- """
- _validate_identifier(table_name)
- sql = f"ALTER TABLE {table_name} DROP ROW FILTER"
- _execute_uc_sql(sql, warehouse_id=warehouse_id)
- return {"status": "row_filter_dropped", "table": table_name, "sql": sql}
-
-
-def set_column_mask(
- table_name: str,
- column_name: str,
- mask_function: str,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Apply a column mask function to a specific column.
-
- The mask function must accept the column value and return the same type.
-
- Args:
- table_name: Full table name (catalog.schema.table)
- column_name: Column to mask
- mask_function: Full function name (catalog.schema.function)
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- Dict with status and executed SQL
- """
- _validate_identifier(table_name)
- _validate_identifier(column_name)
- _validate_identifier(mask_function)
-
- sql = f"ALTER TABLE {table_name} ALTER COLUMN `{column_name}` SET MASK {mask_function}"
- _execute_uc_sql(sql, warehouse_id=warehouse_id)
- return {
- "status": "column_mask_set",
- "table": table_name,
- "column": column_name,
- "function": mask_function,
- "sql": sql,
- }
-
-
-def drop_column_mask(
- table_name: str,
- column_name: str,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Remove the column mask from a specific column.
-
- Args:
- table_name: Full table name (catalog.schema.table)
- column_name: Column to unmask
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- Dict with status and executed SQL
- """
- _validate_identifier(table_name)
- _validate_identifier(column_name)
-
- sql = f"ALTER TABLE {table_name} ALTER COLUMN `{column_name}` DROP MASK"
- _execute_uc_sql(sql, warehouse_id=warehouse_id)
- return {"status": "column_mask_dropped", "table": table_name, "column": column_name, "sql": sql}
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/sharing.py b/databricks-tools-core/databricks_tools_core/unity_catalog/sharing.py
deleted file mode 100644
index fc09cc2e..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/sharing.py
+++ /dev/null
@@ -1,407 +0,0 @@
-"""
-Unity Catalog - Delta Sharing Operations
-
-Functions for managing shares, recipients, and providers.
-"""
-
-from typing import Any, Dict, List, Optional
-
-from ..auth import get_workspace_client
-
-
-# --- Shares ---
-
-
-def list_shares() -> List[Dict[str, Any]]:
- """
- List all shares.
-
- Returns:
- List of share info dicts
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- shares = list(w.shares.list_shares())
- return [s.as_dict() if hasattr(s, "as_dict") else vars(s) for s in shares]
-
-
-def get_share(name: str, include_shared_data: bool = True) -> Dict[str, Any]:
- """
- Get details of a share including its shared objects.
-
- Args:
- name: Name of the share
- include_shared_data: Whether to include shared data objects (default: True)
-
- Returns:
- Dict with share details and objects
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- result = w.shares.get(name=name, include_shared_data=include_shared_data)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def create_share(
- name: str,
- comment: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Create a new share.
-
- Args:
- name: Name for the share
- comment: Optional description
-
- Returns:
- Dict with created share details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- kwargs: Dict[str, Any] = {"name": name}
- if comment is not None:
- kwargs["comment"] = comment
- result = w.shares.create(**kwargs)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def add_table_to_share(
- share_name: str,
- table_name: str,
- shared_as: Optional[str] = None,
- partition_spec: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Add a table to a share.
-
- Args:
- share_name: Name of the share
- table_name: Full table name (catalog.schema.table)
- shared_as: Alias for the shared table (hides internal naming)
- partition_spec: Partition filter (e.g., "(date = '2024-01-01')")
-
- Returns:
- Dict with updated share details
-
- Raises:
- DatabricksError: If API request fails
- """
- from databricks.sdk.service.sharing import (
- SharedDataObject,
- SharedDataObjectDataObjectType,
- SharedDataObjectUpdate,
- SharedDataObjectUpdateAction,
- )
-
- w = get_workspace_client()
- data_object = SharedDataObject(
- name=table_name,
- data_object_type=SharedDataObjectDataObjectType.TABLE,
- shared_as=shared_as,
- )
- if partition_spec:
- from databricks.sdk.service.sharing import Partition, PartitionValue
-
- data_object.partitions = [
- Partition(values=[PartitionValue(name="partition", op="EQUAL", value=partition_spec)])
- ]
-
- result = w.shares.update(
- name=share_name,
- updates=[
- SharedDataObjectUpdate(
- action=SharedDataObjectUpdateAction.ADD,
- data_object=data_object,
- )
- ],
- )
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def remove_table_from_share(
- share_name: str,
- table_name: str,
-) -> Dict[str, Any]:
- """
- Remove a table from a share.
-
- Args:
- share_name: Name of the share
- table_name: Full table name (catalog.schema.table)
-
- Returns:
- Dict with updated share details
-
- Raises:
- DatabricksError: If API request fails
- """
- from databricks.sdk.service.sharing import (
- SharedDataObject,
- SharedDataObjectDataObjectType,
- SharedDataObjectUpdate,
- SharedDataObjectUpdateAction,
- )
-
- w = get_workspace_client()
- result = w.shares.update(
- name=share_name,
- updates=[
- SharedDataObjectUpdate(
- action=SharedDataObjectUpdateAction.REMOVE,
- data_object=SharedDataObject(
- name=table_name,
- data_object_type=SharedDataObjectDataObjectType.TABLE,
- ),
- )
- ],
- )
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def delete_share(name: str) -> None:
- """
- Delete a share.
-
- Args:
- name: Name of the share to delete
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.shares.delete(name=name)
-
-
-def grant_share_to_recipient(
- share_name: str,
- recipient_name: str,
-) -> Dict[str, Any]:
- """
- Grant SELECT permission on a share to a recipient.
-
- Args:
- share_name: Name of the share
- recipient_name: Name of the recipient
-
- Returns:
- Dict with permission update result
-
- Raises:
- DatabricksError: If API request fails
- """
- from databricks.sdk.service.catalog import Privilege, PermissionsChange
-
- w = get_workspace_client()
- w.shares.update_permissions(
- name=share_name,
- changes=[
- PermissionsChange(
- principal=recipient_name,
- add=[Privilege.SELECT],
- )
- ],
- )
- return {"status": "granted", "share": share_name, "recipient": recipient_name}
-
-
-def revoke_share_from_recipient(
- share_name: str,
- recipient_name: str,
-) -> Dict[str, Any]:
- """
- Revoke SELECT permission on a share from a recipient.
-
- Args:
- share_name: Name of the share
- recipient_name: Name of the recipient
-
- Returns:
- Dict with permission update result
-
- Raises:
- DatabricksError: If API request fails
- """
- from databricks.sdk.service.catalog import Privilege, PermissionsChange
-
- w = get_workspace_client()
- w.shares.update_permissions(
- name=share_name,
- changes=[
- PermissionsChange(
- principal=recipient_name,
- remove=[Privilege.SELECT],
- )
- ],
- )
- return {"status": "revoked", "share": share_name, "recipient": recipient_name}
-
-
-# --- Recipients ---
-
-
-def list_recipients() -> List[Dict[str, Any]]:
- """
- List all sharing recipients.
-
- Returns:
- List of recipient info dicts
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- recipients = list(w.recipients.list())
- return [r.as_dict() if hasattr(r, "as_dict") else vars(r) for r in recipients]
-
-
-def get_recipient(name: str) -> Dict[str, Any]:
- """
- Get a specific recipient.
-
- Args:
- name: Name of the recipient
-
- Returns:
- Dict with recipient details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- result = w.recipients.get(name=name)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def create_recipient(
- name: str,
- authentication_type: str = "TOKEN",
- sharing_id: Optional[str] = None,
- comment: Optional[str] = None,
- ip_access_list: Optional[List[str]] = None,
-) -> Dict[str, Any]:
- """
- Create a sharing recipient.
-
- Args:
- name: Name for the recipient
- authentication_type: "TOKEN" (external) or "DATABRICKS" (D2D sharing)
- sharing_id: Required for DATABRICKS authentication (recipient's sharing identifier)
- comment: Optional description
- ip_access_list: Optional list of allowed IP addresses/CIDR ranges
-
- Returns:
- Dict with recipient details (includes activation_url for TOKEN type)
-
- Raises:
- DatabricksError: If API request fails
- """
- from databricks.sdk.service.sharing import AuthenticationType
-
- w = get_workspace_client()
- kwargs: Dict[str, Any] = {
- "name": name,
- "authentication_type": AuthenticationType(authentication_type.upper()),
- }
- if sharing_id is not None:
- kwargs["sharing_code"] = sharing_id
- if comment is not None:
- kwargs["comment"] = comment
- if ip_access_list is not None:
- from databricks.sdk.service.sharing import IpAccessList
-
- kwargs["ip_access_list"] = IpAccessList(allowed_ip_addresses=ip_access_list)
-
- result = w.recipients.create(**kwargs)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def rotate_recipient_token(name: str) -> Dict[str, Any]:
- """
- Rotate the authentication token for a recipient.
-
- Args:
- name: Name of the recipient
-
- Returns:
- Dict with new token details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- result = w.recipients.rotate_token(name=name, existing_token_expire_in_seconds=0)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def delete_recipient(name: str) -> None:
- """
- Delete a sharing recipient.
-
- Args:
- name: Name of the recipient to delete
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.recipients.delete(name=name)
-
-
-# --- Providers ---
-
-
-def list_providers() -> List[Dict[str, Any]]:
- """
- List all sharing providers.
-
- Returns:
- List of provider info dicts
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- providers = list(w.providers.list())
- return [p.as_dict() if hasattr(p, "as_dict") else vars(p) for p in providers]
-
-
-def get_provider(name: str) -> Dict[str, Any]:
- """
- Get a specific provider.
-
- Args:
- name: Name of the provider
-
- Returns:
- Dict with provider details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- result = w.providers.get(name=name)
- return result.as_dict() if hasattr(result, "as_dict") else vars(result)
-
-
-def list_provider_shares(name: str) -> List[Dict[str, Any]]:
- """
- List shares available from a provider.
-
- Args:
- name: Name of the provider
-
- Returns:
- List of share info dicts
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- shares = list(w.providers.list_shares(name=name))
- return [s.as_dict() if hasattr(s, "as_dict") else vars(s) for s in shares]
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/storage.py b/databricks-tools-core/databricks_tools_core/unity_catalog/storage.py
deleted file mode 100644
index b1f506a1..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/storage.py
+++ /dev/null
@@ -1,310 +0,0 @@
-"""
-Unity Catalog - Storage Operations
-
-Functions for managing storage credentials and external locations.
-"""
-
-from typing import Any, Dict, List, Optional
-from databricks.sdk.service.catalog import (
- StorageCredentialInfo,
- ExternalLocationInfo,
- AwsIamRoleRequest,
- AzureManagedIdentityRequest,
-)
-
-from ..auth import get_workspace_client
-
-
-# --- Storage Credentials ---
-
-
-def list_storage_credentials() -> List[StorageCredentialInfo]:
- """
- List all storage credentials.
-
- Returns:
- List of StorageCredentialInfo objects
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return list(w.storage_credentials.list())
-
-
-def get_storage_credential(name: str) -> StorageCredentialInfo:
- """
- Get a specific storage credential.
-
- Args:
- name: Name of the storage credential
-
- Returns:
- StorageCredentialInfo with credential details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.storage_credentials.get(name=name)
-
-
-def create_storage_credential(
- name: str,
- comment: Optional[str] = None,
- aws_iam_role_arn: Optional[str] = None,
- azure_access_connector_id: Optional[str] = None,
- read_only: bool = False,
-) -> StorageCredentialInfo:
- """
- Create a storage credential for accessing cloud storage.
-
- Provide exactly one of aws_iam_role_arn or azure_access_connector_id
- based on your cloud provider.
-
- Args:
- name: Name for the credential
- comment: Optional description
- aws_iam_role_arn: AWS IAM role ARN (for AWS)
- azure_access_connector_id: Azure Access Connector resource ID (for Azure)
- read_only: Whether the credential is read-only
-
- Returns:
- StorageCredentialInfo with created credential details
-
- Raises:
- ValueError: If no cloud credential is provided
- DatabricksError: If API request fails
- """
- if not aws_iam_role_arn and not azure_access_connector_id:
- raise ValueError("Provide aws_iam_role_arn (AWS) or azure_access_connector_id (Azure)")
-
- w = get_workspace_client()
- kwargs: Dict[str, Any] = {"name": name, "read_only": read_only}
- if comment is not None:
- kwargs["comment"] = comment
- if aws_iam_role_arn:
- kwargs["aws_iam_role"] = AwsIamRoleRequest(role_arn=aws_iam_role_arn)
- if azure_access_connector_id:
- kwargs["azure_managed_identity"] = AzureManagedIdentityRequest(access_connector_id=azure_access_connector_id)
- return w.storage_credentials.create(**kwargs)
-
-
-def update_storage_credential(
- name: str,
- new_name: Optional[str] = None,
- comment: Optional[str] = None,
- owner: Optional[str] = None,
- aws_iam_role_arn: Optional[str] = None,
- azure_access_connector_id: Optional[str] = None,
-) -> StorageCredentialInfo:
- """
- Update a storage credential.
-
- Args:
- name: Current name of the credential
- new_name: New name for the credential
- comment: New comment
- owner: New owner
- aws_iam_role_arn: New AWS IAM role ARN
- azure_access_connector_id: New Azure Access Connector ID
-
- Returns:
- StorageCredentialInfo with updated details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- kwargs: Dict[str, Any] = {"name": name}
- if new_name is not None:
- kwargs["new_name"] = new_name
- if comment is not None:
- kwargs["comment"] = comment
- if owner is not None:
- kwargs["owner"] = owner
- if aws_iam_role_arn:
- kwargs["aws_iam_role"] = AwsIamRoleRequest(role_arn=aws_iam_role_arn)
- if azure_access_connector_id:
- kwargs["azure_managed_identity"] = AzureManagedIdentityRequest(access_connector_id=azure_access_connector_id)
- return w.storage_credentials.update(**kwargs)
-
-
-def delete_storage_credential(name: str, force: bool = False) -> None:
- """
- Delete a storage credential.
-
- Args:
- name: Name of the credential to delete
- force: If True, force deletion
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.storage_credentials.delete(name=name, force=force)
-
-
-def validate_storage_credential(
- name: str,
- url: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Validate a storage credential against a cloud storage URL.
-
- Args:
- name: Name of the credential to validate
- url: Cloud storage URL to validate against
-
- Returns:
- Dict with validation results
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- kwargs: Dict[str, Any] = {"storage_credential_name": name}
- if url is not None:
- kwargs["url"] = url
- result = w.storage_credentials.validate(**kwargs)
- return {
- "is_valid": result.is_dir if hasattr(result, "is_dir") else None,
- "results": [
- {
- "operation": r.operation.value if r.operation else None,
- "result": r.result.value if r.result else None,
- "message": r.message,
- }
- for r in (result.results or [])
- ]
- if hasattr(result, "results") and result.results
- else [],
- }
-
-
-# --- External Locations ---
-
-
-def list_external_locations() -> List[ExternalLocationInfo]:
- """
- List all external locations.
-
- Returns:
- List of ExternalLocationInfo objects
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return list(w.external_locations.list())
-
-
-def get_external_location(name: str) -> ExternalLocationInfo:
- """
- Get a specific external location.
-
- Args:
- name: Name of the external location
-
- Returns:
- ExternalLocationInfo with location details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.external_locations.get(name=name)
-
-
-def create_external_location(
- name: str,
- url: str,
- credential_name: str,
- comment: Optional[str] = None,
- read_only: bool = False,
-) -> ExternalLocationInfo:
- """
- Create an external location pointing to cloud storage.
-
- Args:
- name: Name for the external location
- url: Cloud storage URL (s3://, abfss://, gs://)
- credential_name: Name of the storage credential to use
- comment: Optional description
- read_only: Whether location is read-only
-
- Returns:
- ExternalLocationInfo with created location details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- kwargs: Dict[str, Any] = {
- "name": name,
- "url": url,
- "credential_name": credential_name,
- "read_only": read_only,
- }
- if comment is not None:
- kwargs["comment"] = comment
- return w.external_locations.create(**kwargs)
-
-
-def update_external_location(
- name: str,
- new_name: Optional[str] = None,
- url: Optional[str] = None,
- credential_name: Optional[str] = None,
- comment: Optional[str] = None,
- owner: Optional[str] = None,
- read_only: Optional[bool] = None,
-) -> ExternalLocationInfo:
- """
- Update an external location.
-
- Args:
- name: Current name of the external location
- new_name: New name
- url: New cloud storage URL
- credential_name: New storage credential name
- comment: New comment
- owner: New owner
- read_only: New read-only setting
-
- Returns:
- ExternalLocationInfo with updated details
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- kwargs: Dict[str, Any] = {"name": name}
- if new_name is not None:
- kwargs["new_name"] = new_name
- if url is not None:
- kwargs["url"] = url
- if credential_name is not None:
- kwargs["credential_name"] = credential_name
- if comment is not None:
- kwargs["comment"] = comment
- if owner is not None:
- kwargs["owner"] = owner
- if read_only is not None:
- kwargs["read_only"] = read_only
- return w.external_locations.update(**kwargs)
-
-
-def delete_external_location(name: str, force: bool = False) -> None:
- """
- Delete an external location.
-
- Args:
- name: Name of the external location to delete
- force: If True, force deletion
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.external_locations.delete(name=name, force=force)
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/tables.py b/databricks-tools-core/databricks_tools_core/unity_catalog/tables.py
deleted file mode 100644
index e27e8179..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/tables.py
+++ /dev/null
@@ -1,131 +0,0 @@
-"""
-Unity Catalog - Table Operations
-
-Functions for managing tables in Unity Catalog.
-"""
-
-from typing import List, Optional
-from databricks.sdk.service.catalog import TableInfo, ColumnInfo, TableType, DataSourceFormat
-
-from ..auth import get_workspace_client
-
-
-def list_tables(catalog_name: str, schema_name: str) -> List[TableInfo]:
- """
- List all tables in a schema.
-
- Args:
- catalog_name: Name of the catalog
- schema_name: Name of the schema
-
- Returns:
- List of TableInfo objects with table metadata
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return list(w.tables.list(catalog_name=catalog_name, schema_name=schema_name))
-
-
-def get_table(full_table_name: str) -> TableInfo:
- """
- Get detailed information about a specific table.
-
- Args:
- full_table_name: Full table name (catalog.schema.table format)
-
- Returns:
- TableInfo object with table metadata including:
- - name, full_name, catalog_name, schema_name
- - table_type, owner, comment
- - created_at, updated_at
- - storage_location, columns
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.tables.get(full_name=full_table_name)
-
-
-def create_table(
- catalog_name: str,
- schema_name: str,
- table_name: str,
- columns: List[ColumnInfo],
- table_type: TableType = TableType.MANAGED,
- comment: Optional[str] = None,
- storage_location: Optional[str] = None,
-) -> TableInfo:
- """
- Create a new table in Unity Catalog.
-
- Args:
- catalog_name: Name of the catalog
- schema_name: Name of the schema
- table_name: Name of the table to create
- columns: List of ColumnInfo objects defining table columns
- Example: [ColumnInfo(name="id", type_name="INT"),
- ColumnInfo(name="value", type_name="STRING")]
- table_type: TableType.MANAGED or TableType.EXTERNAL (default: TableType.MANAGED)
- comment: Optional description of the table
- storage_location: Storage location for EXTERNAL tables
-
- Returns:
- TableInfo object with created table metadata
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
-
- # Build full table name for updates
- full_name = f"{catalog_name}.{schema_name}.{table_name}"
-
- # Build kwargs - name should be just the table name, not full path
- kwargs = {
- "name": table_name,
- "catalog_name": catalog_name,
- "schema_name": schema_name,
- "table_type": table_type,
- "columns": columns,
- "data_source_format": DataSourceFormat.DELTA,
- }
-
- # Add storage_location - required for all tables, use default for MANAGED
- if table_type == TableType.EXTERNAL:
- if not storage_location:
- raise ValueError("storage_location is required for EXTERNAL tables")
- kwargs["storage_location"] = storage_location
- else:
- # MANAGED tables don't need storage_location in newer SDK versions
- pass
-
- # Note: comment parameter removed as it's not supported in create()
- # Comments must be set via ALTER TABLE after creation
-
- table = w.tables.create(**kwargs)
-
- # Update comment if provided (via separate API call)
- if comment:
- try:
- w.tables.update(full_name=full_name, comment=comment)
- except Exception:
- pass # Ignore comment update failures
-
- return table
-
-
-def delete_table(full_table_name: str) -> None:
- """
- Delete a table from Unity Catalog.
-
- Args:
- full_table_name: Full table name (catalog.schema.table format)
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.tables.delete(full_name=full_table_name)
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/tags.py b/databricks-tools-core/databricks_tools_core/unity_catalog/tags.py
deleted file mode 100644
index fb169456..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/tags.py
+++ /dev/null
@@ -1,215 +0,0 @@
-"""
-Unity Catalog - Tag and Comment Operations
-
-Functions for managing tags and comments on UC objects.
-All operations are SQL-based via execute_sql.
-"""
-
-import re
-from typing import Any, Dict, List, Optional
-
-_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z0-9_][a-zA-Z0-9_.\-]*$")
-
-
-def _validate_identifier(name: str) -> str:
- """Validate a SQL identifier to prevent injection."""
- if not _IDENTIFIER_PATTERN.match(name):
- raise ValueError(f"Invalid SQL identifier: '{name}'")
- return name
-
-
-def _execute_uc_sql(sql_query: str, warehouse_id: Optional[str] = None) -> List[Dict[str, Any]]:
- """Execute SQL using the existing execute_sql infrastructure."""
- from ..sql.sql import execute_sql
-
- return execute_sql(sql_query=sql_query, warehouse_id=warehouse_id)
-
-
-def set_tags(
- object_type: str,
- full_name: str,
- tags: Dict[str, str],
- column_name: Optional[str] = None,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Set tags on a UC object or column.
-
- Args:
- object_type: Type of object ("catalog", "schema", "table", "column")
- full_name: Full object name (e.g., "catalog.schema.table")
- tags: Key-value tag pairs (e.g., {"pii": "true", "classification": "confidential"})
- column_name: Column name (required when object_type is "column")
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- Dict with status and executed SQL
-
- Raises:
- ValueError: If object_type is "column" but column_name not provided
- """
- _validate_identifier(full_name)
- tag_pairs = ", ".join(f"'{k}' = '{v}'" for k, v in tags.items())
-
- if object_type.lower() == "column":
- if not column_name:
- raise ValueError("column_name is required when object_type is 'column'")
- _validate_identifier(column_name)
- sql = f"ALTER TABLE {full_name} ALTER COLUMN `{column_name}` SET TAGS ({tag_pairs})"
- else:
- obj_keyword = object_type.upper()
- if obj_keyword not in ("CATALOG", "SCHEMA", "TABLE"):
- raise ValueError(f"object_type must be catalog, schema, table, or column. Got: {object_type}")
- sql = f"ALTER {obj_keyword} {full_name} SET TAGS ({tag_pairs})"
-
- _execute_uc_sql(sql, warehouse_id=warehouse_id)
- return {"status": "tags_set", "object": full_name, "tags": tags, "sql": sql}
-
-
-def unset_tags(
- object_type: str,
- full_name: str,
- tag_names: List[str],
- column_name: Optional[str] = None,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Remove tags from a UC object or column.
-
- Args:
- object_type: Type of object ("catalog", "schema", "table", "column")
- full_name: Full object name
- tag_names: List of tag keys to remove
- column_name: Column name (required when object_type is "column")
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- Dict with status and executed SQL
- """
- _validate_identifier(full_name)
- tag_keys = ", ".join(f"'{t}'" for t in tag_names)
-
- if object_type.lower() == "column":
- if not column_name:
- raise ValueError("column_name is required when object_type is 'column'")
- _validate_identifier(column_name)
- sql = f"ALTER TABLE {full_name} ALTER COLUMN `{column_name}` UNSET TAGS ({tag_keys})"
- else:
- obj_keyword = object_type.upper()
- if obj_keyword not in ("CATALOG", "SCHEMA", "TABLE"):
- raise ValueError(f"object_type must be catalog, schema, table, or column. Got: {object_type}")
- sql = f"ALTER {obj_keyword} {full_name} UNSET TAGS ({tag_keys})"
-
- _execute_uc_sql(sql, warehouse_id=warehouse_id)
- return {"status": "tags_unset", "object": full_name, "tag_names": tag_names, "sql": sql}
-
-
-def set_comment(
- object_type: str,
- full_name: str,
- comment_text: str,
- column_name: Optional[str] = None,
- warehouse_id: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Set a comment on a UC object or column.
-
- Args:
- object_type: Type of object ("catalog", "schema", "table", "column")
- full_name: Full object name
- comment_text: The comment text to set
- column_name: Column name (required when object_type is "column")
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- Dict with status and executed SQL
- """
- _validate_identifier(full_name)
- escaped_comment = comment_text.replace("'", "\\'")
-
- if object_type.lower() == "column":
- if not column_name:
- raise ValueError("column_name is required when object_type is 'column'")
- _validate_identifier(column_name)
- sql = f"ALTER TABLE {full_name} ALTER COLUMN `{column_name}` COMMENT '{escaped_comment}'"
- else:
- obj_keyword = object_type.upper()
- if obj_keyword not in ("CATALOG", "SCHEMA", "TABLE"):
- raise ValueError(f"object_type must be catalog, schema, table, or column. Got: {object_type}")
- sql = f"COMMENT ON {obj_keyword} {full_name} IS '{escaped_comment}'"
-
- _execute_uc_sql(sql, warehouse_id=warehouse_id)
- return {"status": "comment_set", "object": full_name, "sql": sql}
-
-
-def query_table_tags(
- catalog_filter: Optional[str] = None,
- tag_name: Optional[str] = None,
- tag_value: Optional[str] = None,
- limit: int = 100,
- warehouse_id: Optional[str] = None,
-) -> List[Dict[str, Any]]:
- """
- Query tags on tables from system.information_schema.table_tags.
-
- Args:
- catalog_filter: Filter by catalog name
- tag_name: Filter by tag name
- tag_value: Filter by tag value
- limit: Maximum rows to return (default: 100)
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- List of dicts with tag information
- """
- conditions = []
- if catalog_filter:
- _validate_identifier(catalog_filter)
- conditions.append(f"catalog_name = '{catalog_filter}'")
- if tag_name:
- conditions.append(f"tag_name = '{tag_name}'")
- if tag_value:
- conditions.append(f"tag_value = '{tag_value}'")
-
- where = f" WHERE {' AND '.join(conditions)}" if conditions else ""
- sql = f"SELECT * FROM system.information_schema.table_tags{where} LIMIT {limit}"
- return _execute_uc_sql(sql, warehouse_id=warehouse_id)
-
-
-def query_column_tags(
- catalog_filter: Optional[str] = None,
- table_name: Optional[str] = None,
- tag_name: Optional[str] = None,
- tag_value: Optional[str] = None,
- limit: int = 100,
- warehouse_id: Optional[str] = None,
-) -> List[Dict[str, Any]]:
- """
- Query tags on columns from system.information_schema.column_tags.
-
- Args:
- catalog_filter: Filter by catalog name
- table_name: Filter by table name
- tag_name: Filter by tag name
- tag_value: Filter by tag value
- limit: Maximum rows to return (default: 100)
- warehouse_id: Optional SQL warehouse ID
-
- Returns:
- List of dicts with column tag information
- """
- conditions = []
- if catalog_filter:
- _validate_identifier(catalog_filter)
- conditions.append(f"catalog_name = '{catalog_filter}'")
- if table_name:
- _validate_identifier(table_name)
- conditions.append(f"table_name = '{table_name}'")
- if tag_name:
- conditions.append(f"tag_name = '{tag_name}'")
- if tag_value:
- conditions.append(f"tag_value = '{tag_value}'")
-
- where = f" WHERE {' AND '.join(conditions)}" if conditions else ""
- sql = f"SELECT * FROM system.information_schema.column_tags{where} LIMIT {limit}"
- return _execute_uc_sql(sql, warehouse_id=warehouse_id)
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/volume_files.py b/databricks-tools-core/databricks_tools_core/unity_catalog/volume_files.py
deleted file mode 100644
index 3bae08fa..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/volume_files.py
+++ /dev/null
@@ -1,690 +0,0 @@
-"""
-Volume Files - Unity Catalog Volume File Operations
-
-Functions for working with files in Unity Catalog Volumes.
-Uses Databricks Files API via SDK (w.files).
-
-Volume paths use the format: /Volumes////
-"""
-
-import glob as glob_module
-import os
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from dataclasses import dataclass, field
-from pathlib import Path
-from typing import List, Optional
-
-from databricks.sdk import WorkspaceClient
-
-from ..auth import get_workspace_client
-
-
-@dataclass
-class VolumeFileInfo:
- """Information about a file or directory in a volume."""
-
- name: str
- path: str
- is_directory: bool
- file_size: Optional[int] = None
- last_modified: Optional[str] = None
-
-
-@dataclass
-class VolumeUploadResult:
- """Result from uploading a single file to a volume."""
-
- local_path: str
- volume_path: str
- success: bool
- error: Optional[str] = None
-
-
-@dataclass
-class VolumeFolderUploadResult:
- """Result from uploading multiple files to a volume."""
-
- local_folder: str
- remote_folder: str
- total_files: int
- successful: int
- failed: int
- results: List[VolumeUploadResult] = field(default_factory=list)
-
- @property
- def success(self) -> bool:
- """Returns True if all files were uploaded successfully"""
- return self.failed == 0
-
- def get_failed_uploads(self) -> List[VolumeUploadResult]:
- """Returns list of failed uploads"""
- return [r for r in self.results if not r.success]
-
-
-@dataclass
-class VolumeDownloadResult:
- """Result from downloading a file from a volume."""
-
- volume_path: str
- local_path: str
- success: bool
- error: Optional[str] = None
-
-
-@dataclass
-class VolumeDeleteResult:
- """Result from deleting a file or directory from a volume."""
-
- volume_path: str
- success: bool
- files_deleted: int = 0
- directories_deleted: int = 0
- error: Optional[str] = None
-
-
-def list_volume_files(volume_path: str, max_results: Optional[int] = None) -> List[VolumeFileInfo]:
- """
- List files and directories in a volume path.
-
- Args:
- volume_path: Path in volume (e.g., "/Volumes/catalog/schema/volume/folder")
- max_results: Optional maximum number of results to return (None = no limit)
-
- Returns:
- List of VolumeFileInfo objects
-
- Raises:
- Exception: If path doesn't exist or access denied
-
- Example:
- >>> files = list_volume_files("/Volumes/main/default/my_volume/data")
- >>> for f in files:
- ... print(f"{f.name}: {'dir' if f.is_directory else 'file'}")
- """
- w = get_workspace_client()
-
- # Ensure path ends with / for directory listing
- if not volume_path.endswith("/"):
- volume_path = volume_path + "/"
-
- results = []
- for entry in w.files.list_directory_contents(volume_path):
- # Handle last_modified - can be datetime, int (Unix timestamp), or None
- last_modified = None
- if entry.last_modified is not None:
- if isinstance(entry.last_modified, int):
- # Unix timestamp - convert to ISO format string
- last_modified = str(entry.last_modified)
- else:
- # datetime object - convert to ISO format string
- last_modified = entry.last_modified.isoformat()
-
- results.append(
- VolumeFileInfo(
- name=entry.name,
- path=entry.path,
- is_directory=entry.is_directory,
- file_size=entry.file_size,
- last_modified=last_modified,
- )
- )
- # Early exit if we've hit the limit
- if max_results is not None and len(results) >= max_results:
- break
-
- return results
-
-
-def _collect_local_files(local_folder: str) -> List[tuple]:
- """
- Collect all files in a folder recursively.
-
- Args:
- local_folder: Path to local folder
-
- Returns:
- List of (local_path, relative_path) tuples
- """
- files = []
- local_folder = os.path.abspath(local_folder)
-
- for dirpath, _, filenames in os.walk(local_folder):
- for filename in filenames:
- # Skip hidden files and __pycache__
- if filename.startswith(".") or "__pycache__" in dirpath:
- continue
-
- local_path = os.path.join(dirpath, filename)
- rel_path = os.path.relpath(local_path, local_folder)
- files.append((local_path, rel_path))
-
- return files
-
-
-def _collect_local_directories(local_folder: str) -> List[str]:
- """
- Collect all directories in a folder recursively.
-
- Args:
- local_folder: Path to local folder
-
- Returns:
- List of relative directory paths
- """
- directories = set()
- local_folder = os.path.abspath(local_folder)
-
- for dirpath, dirnames, _ in os.walk(local_folder):
- # Skip hidden directories and __pycache__
- dirnames[:] = [d for d in dirnames if not d.startswith(".") and d != "__pycache__"]
-
- for dirname in dirnames:
- full_path = os.path.join(dirpath, dirname)
- rel_path = os.path.relpath(full_path, local_folder)
- directories.add(rel_path)
- # Also add parent directories
- parent = Path(rel_path).parent
- while str(parent) != ".":
- directories.add(str(parent))
- parent = parent.parent
-
- return sorted(directories)
-
-
-def _upload_single_file_to_volume(
- w: WorkspaceClient, local_path: str, volume_path: str, overwrite: bool
-) -> VolumeUploadResult:
- """Upload a single file to volume using w.files API."""
- try:
- w.files.upload_from(file_path=volume_path, source_path=local_path, overwrite=overwrite)
- return VolumeUploadResult(local_path=local_path, volume_path=volume_path, success=True)
- except Exception as e:
- return VolumeUploadResult(local_path=local_path, volume_path=volume_path, success=False, error=str(e))
-
-
-def _create_volume_directory_safe(w: WorkspaceClient, volume_path: str) -> None:
- """Create a directory in volume, ignoring errors if it already exists."""
- try:
- w.files.create_directory(volume_path)
- except Exception:
- pass # Directory may already exist
-
-
-def upload_to_volume(
- local_path: str, volume_path: str, max_workers: int = 4, overwrite: bool = True
-) -> VolumeFolderUploadResult:
- """
- Upload local file(s) or folder(s) to a Unity Catalog volume.
-
- Works like the `cp` command - handles single files, folders, and glob patterns.
- Automatically creates parent directories in volume as needed.
-
- Args:
- local_path: Path to local file, folder, or glob pattern. Examples:
- - "/path/to/file.csv" - single file
- - "/path/to/folder" - entire folder (recursive)
- - "/path/to/folder/*" - all files/folders in folder
- - "/path/to/*.json" - glob pattern
- volume_path: Target path in Unity Catalog volume
- (e.g., "/Volumes/catalog/schema/volume/folder")
- max_workers: Maximum parallel upload threads (default: 4)
- overwrite: Whether to overwrite existing files (default: True)
-
- Returns:
- VolumeFolderUploadResult with upload statistics and individual results
-
- Example:
- >>> # Upload a single file
- >>> result = upload_to_volume(
- ... local_path="/tmp/data.csv",
- ... volume_path="/Volumes/main/default/my_volume/data.csv"
- ... )
-
- >>> # Upload a folder
- >>> result = upload_to_volume(
- ... local_path="/tmp/my_data",
- ... volume_path="/Volumes/main/default/my_volume/my_data"
- ... )
-
- >>> # Upload folder contents (not the folder itself)
- >>> result = upload_to_volume(
- ... local_path="/tmp/my_data/*",
- ... volume_path="/Volumes/main/default/my_volume/destination"
- ... )
- """
- local_path = os.path.expanduser(local_path)
- volume_path = volume_path.rstrip("/")
-
- w = get_workspace_client()
-
- # Determine what we're uploading
- has_glob = "*" in local_path or "?" in local_path
-
- if has_glob:
- return _upload_glob_to_volume(w, local_path, volume_path, max_workers, overwrite)
- elif os.path.isfile(local_path):
- return _upload_single_to_volume(w, local_path, volume_path, overwrite)
- elif os.path.isdir(local_path):
- return _upload_folder_to_volume(w, local_path, volume_path, max_workers, overwrite)
- else:
- return VolumeFolderUploadResult(
- local_folder=local_path,
- remote_folder=volume_path,
- total_files=0,
- successful=0,
- failed=1,
- results=[
- VolumeUploadResult(
- local_path=local_path,
- volume_path=volume_path,
- success=False,
- error=f"Path not found: {local_path}",
- )
- ],
- )
-
-
-def _upload_single_to_volume(
- w: WorkspaceClient, local_path: str, volume_path: str, overwrite: bool
-) -> VolumeFolderUploadResult:
- """Upload a single file to volume."""
- # Create parent directory if needed
- parent_dir = str(Path(volume_path).parent)
- _create_volume_directory_safe(w, parent_dir)
-
- result = _upload_single_file_to_volume(w, local_path, volume_path, overwrite)
- return VolumeFolderUploadResult(
- local_folder=os.path.dirname(local_path),
- remote_folder=os.path.dirname(volume_path),
- total_files=1,
- successful=1 if result.success else 0,
- failed=0 if result.success else 1,
- results=[result],
- )
-
-
-def _upload_glob_to_volume(
- w: WorkspaceClient, pattern: str, volume_path: str, max_workers: int, overwrite: bool
-) -> VolumeFolderUploadResult:
- """Upload files matching a glob pattern to volume."""
- matches = glob_module.glob(pattern)
- if not matches:
- return VolumeFolderUploadResult(
- local_folder=os.path.dirname(pattern),
- remote_folder=volume_path,
- total_files=0,
- successful=0,
- failed=1,
- results=[
- VolumeUploadResult(
- local_path=pattern,
- volume_path=volume_path,
- success=False,
- error=f"No files match pattern: {pattern}",
- )
- ],
- )
-
- # Get the base directory
- pattern_dir = os.path.dirname(pattern)
- if pattern_dir:
- base_dir = os.path.abspath(pattern_dir)
- else:
- base_dir = os.getcwd()
-
- # Create volume root directory
- _create_volume_directory_safe(w, volume_path)
-
- # Collect all files from all matches
- all_files = []
- all_dirs = set()
-
- for match in matches:
- match = os.path.abspath(match)
- if os.path.isfile(match):
- rel_path = os.path.basename(match)
- all_files.append((match, rel_path))
- elif os.path.isdir(match):
- folder_name = os.path.basename(match)
- for local_file, rel_in_folder in _collect_local_files(match):
- rel_path = os.path.join(folder_name, rel_in_folder)
- all_files.append((local_file, rel_path))
- parent = str(Path(rel_path).parent)
- while parent != ".":
- all_dirs.add(parent)
- parent = str(Path(parent).parent)
- for subdir in _collect_local_directories(match):
- all_dirs.add(os.path.join(folder_name, subdir))
- all_dirs.add(folder_name)
-
- # Create all directories
- for dir_path in sorted(all_dirs):
- _create_volume_directory_safe(w, f"{volume_path}/{dir_path}")
-
- if not all_files:
- return VolumeFolderUploadResult(
- local_folder=base_dir,
- remote_folder=volume_path,
- total_files=0,
- successful=0,
- failed=0,
- results=[],
- )
-
- return _parallel_upload_to_volume(w, all_files, base_dir, volume_path, max_workers, overwrite)
-
-
-def _upload_folder_to_volume(
- w: WorkspaceClient, local_folder: str, volume_folder: str, max_workers: int, overwrite: bool
-) -> VolumeFolderUploadResult:
- """Upload an entire folder to volume."""
- local_folder = os.path.abspath(local_folder)
-
- # Create root directory
- _create_volume_directory_safe(w, volume_folder)
-
- # Create all subdirectories
- directories = _collect_local_directories(local_folder)
- for dir_path in directories:
- _create_volume_directory_safe(w, f"{volume_folder}/{dir_path}")
-
- # Collect all files
- files = _collect_local_files(local_folder)
-
- if not files:
- return VolumeFolderUploadResult(
- local_folder=local_folder,
- remote_folder=volume_folder,
- total_files=0,
- successful=0,
- failed=0,
- results=[],
- )
-
- return _parallel_upload_to_volume(w, files, local_folder, volume_folder, max_workers, overwrite)
-
-
-def _parallel_upload_to_volume(
- w: WorkspaceClient,
- files: List[tuple],
- local_base: str,
- volume_base: str,
- max_workers: int,
- overwrite: bool,
-) -> VolumeFolderUploadResult:
- """Upload files in parallel to volume."""
- results = []
- successful = 0
- failed = 0
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_file = {}
- for local_path, rel_path in files:
- remote_path = f"{volume_base}/{rel_path.replace(os.sep, '/')}"
- future = executor.submit(_upload_single_file_to_volume, w, local_path, remote_path, overwrite)
- future_to_file[future] = (local_path, remote_path)
-
- for future in as_completed(future_to_file):
- result = future.result()
- results.append(result)
- if result.success:
- successful += 1
- else:
- failed += 1
-
- return VolumeFolderUploadResult(
- local_folder=local_base,
- remote_folder=volume_base,
- total_files=len(files),
- successful=successful,
- failed=failed,
- results=results,
- )
-
-
-def download_from_volume(volume_path: str, local_path: str, overwrite: bool = True) -> VolumeDownloadResult:
- """
- Download a file from a Unity Catalog volume to local path.
-
- Args:
- volume_path: Path in volume (e.g., "/Volumes/catalog/schema/volume/file.csv")
- local_path: Target local file path
- overwrite: Whether to overwrite existing local file (default: True)
-
- Returns:
- VolumeDownloadResult with success status
-
- Example:
- >>> result = download_from_volume(
- ... volume_path="/Volumes/main/default/my_volume/data.csv",
- ... local_path="/tmp/downloaded.csv"
- ... )
- >>> if result.success:
- ... print("Download complete")
- """
- # Check if local file exists and overwrite is False
- if os.path.exists(local_path) and not overwrite:
- return VolumeDownloadResult(
- volume_path=volume_path,
- local_path=local_path,
- success=False,
- error=f"Local file already exists: {local_path}",
- )
-
- try:
- w = get_workspace_client()
-
- # Create parent directory if needed
- parent_dir = str(Path(local_path).parent)
- if parent_dir and not os.path.exists(parent_dir):
- os.makedirs(parent_dir)
-
- # Use download_to for direct volume-to-file download
- w.files.download_to(file_path=volume_path, destination=local_path, overwrite=overwrite)
-
- return VolumeDownloadResult(volume_path=volume_path, local_path=local_path, success=True)
-
- except Exception as e:
- return VolumeDownloadResult(volume_path=volume_path, local_path=local_path, success=False, error=str(e))
-
-
-def _delete_single_file(w: WorkspaceClient, volume_path: str) -> bool:
- """Delete a single file, returns True if successful."""
- try:
- w.files.delete(volume_path)
- return True
- except Exception:
- return False
-
-
-def _delete_single_directory(w: WorkspaceClient, volume_path: str) -> bool:
- """Delete a single empty directory, returns True if successful."""
- try:
- w.files.delete_directory(volume_path)
- return True
- except Exception:
- return False
-
-
-def _collect_volume_contents(w: WorkspaceClient, volume_path: str) -> tuple[List[str], List[str]]:
- """
- Recursively collect all files and directories in a volume path.
-
- Returns:
- Tuple of (files, directories) where directories are sorted deepest-first
- for proper deletion order.
- """
- files = []
- directories = []
-
- def _scan(path: str):
- try:
- if not path.endswith("/"):
- path = path + "/"
- for entry in w.files.list_directory_contents(path):
- if entry.is_directory:
- directories.append(entry.path)
- _scan(entry.path)
- else:
- files.append(entry.path)
- except Exception:
- pass
-
- _scan(volume_path)
-
- # Sort directories by depth (deepest first) for proper deletion order
- directories.sort(key=lambda x: x.count("/"), reverse=True)
-
- return files, directories
-
-
-def delete_from_volume(volume_path: str, recursive: bool = False, max_workers: int = 4) -> VolumeDeleteResult:
- """
- Delete a file or directory from a Unity Catalog volume.
-
- For files, deletes the file directly.
- For directories, requires recursive=True to delete non-empty directories.
- When recursive=True, deletes all files in parallel, then directories deepest-first.
-
- Args:
- volume_path: Path to file or directory in volume
- (e.g., "/Volumes/catalog/schema/volume/folder")
- recursive: If True, delete directory and all contents. Required for non-empty directories.
- (default: False)
- max_workers: Maximum parallel delete threads (default: 4)
-
- Returns:
- VolumeDeleteResult with success status and counts
-
- Example:
- >>> # Delete a single file
- >>> result = delete_from_volume("/Volumes/main/default/my_volume/old_file.csv")
-
- >>> # Delete a folder and all contents
- >>> result = delete_from_volume(
- ... "/Volumes/main/default/my_volume/old_folder",
- ... recursive=True
- ... )
- """
- volume_path = volume_path.rstrip("/")
- w = get_workspace_client()
-
- # Check if path is a file or directory
- try:
- w.files.get_metadata(volume_path)
- is_directory = False
- except Exception:
- # get_metadata fails for directories, try listing
- try:
- list(w.files.list_directory_contents(volume_path + "/"))
- is_directory = True
- except Exception as e:
- return VolumeDeleteResult(
- volume_path=volume_path,
- success=False,
- error=f"Path not found or access denied: {volume_path}. {str(e)}",
- )
-
- if not is_directory:
- # Simple file deletion
- try:
- w.files.delete(volume_path)
- return VolumeDeleteResult(volume_path=volume_path, success=True, files_deleted=1)
- except Exception as e:
- return VolumeDeleteResult(volume_path=volume_path, success=False, error=str(e))
-
- # It's a directory
- if not recursive:
- # Try to delete empty directory
- try:
- w.files.delete_directory(volume_path)
- return VolumeDeleteResult(volume_path=volume_path, success=True, directories_deleted=1)
- except Exception as e:
- return VolumeDeleteResult(
- volume_path=volume_path,
- success=False,
- error=f"Directory not empty. Use recursive=True to delete all contents. {str(e)}",
- )
-
- # Recursive deletion
- files, directories = _collect_volume_contents(w, volume_path)
-
- files_deleted = 0
- directories_deleted = 0
-
- # Delete all files in parallel
- if files:
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- futures = {executor.submit(_delete_single_file, w, f): f for f in files}
- for future in as_completed(futures):
- if future.result():
- files_deleted += 1
-
- # Delete directories sequentially (deepest first)
- for dir_path in directories:
- if _delete_single_directory(w, dir_path):
- directories_deleted += 1
-
- # Finally delete the root directory
- try:
- w.files.delete_directory(volume_path)
- directories_deleted += 1
- success = True
- error = None
- except Exception as e:
- success = False
- error = f"Failed to delete root directory: {str(e)}"
-
- return VolumeDeleteResult(
- volume_path=volume_path,
- success=success,
- files_deleted=files_deleted,
- directories_deleted=directories_deleted,
- error=error,
- )
-
-
-def create_volume_directory(volume_path: str) -> None:
- """
- Create a directory in a Unity Catalog volume.
-
- Creates parent directories as needed (like mkdir -p).
- Idempotent - succeeds if directory already exists.
-
- Args:
- volume_path: Path for new directory (e.g., "/Volumes/catalog/schema/volume/new_folder")
-
- Example:
- >>> create_volume_directory("/Volumes/main/default/my_volume/data/2024/01")
- """
- w = get_workspace_client()
- w.files.create_directory(volume_path)
-
-
-def get_volume_file_metadata(volume_path: str) -> VolumeFileInfo:
- """
- Get metadata for a file in a Unity Catalog volume.
-
- Args:
- volume_path: Path to file in volume
-
- Returns:
- VolumeFileInfo with file metadata
-
- Raises:
- Exception: If file doesn't exist or access denied
-
- Example:
- >>> info = get_volume_file_metadata("/Volumes/main/default/my_volume/data.csv")
- >>> print(f"Size: {info.file_size} bytes")
- """
- w = get_workspace_client()
- metadata = w.files.get_metadata(volume_path)
-
- return VolumeFileInfo(
- name=Path(volume_path).name,
- path=volume_path,
- is_directory=False,
- file_size=metadata.content_length,
- last_modified=metadata.last_modified.isoformat() if metadata.last_modified else None,
- )
diff --git a/databricks-tools-core/databricks_tools_core/unity_catalog/volumes.py b/databricks-tools-core/databricks_tools_core/unity_catalog/volumes.py
deleted file mode 100644
index 36378787..00000000
--- a/databricks-tools-core/databricks_tools_core/unity_catalog/volumes.py
+++ /dev/null
@@ -1,148 +0,0 @@
-"""
-Unity Catalog - Volume Operations
-
-Functions for managing volumes in Unity Catalog.
-"""
-
-from typing import Dict, List, Optional
-from databricks.sdk.service.catalog import VolumeInfo, VolumeType
-
-from ..auth import get_workspace_client
-
-
-def list_volumes(catalog_name: str, schema_name: str) -> List[VolumeInfo]:
- """
- List all volumes in a schema.
-
- Args:
- catalog_name: Name of the catalog
- schema_name: Name of the schema
-
- Returns:
- List of VolumeInfo objects with volume metadata
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return list(
- w.volumes.list(
- catalog_name=catalog_name,
- schema_name=schema_name,
- )
- )
-
-
-def get_volume(full_volume_name: str) -> VolumeInfo:
- """
- Get detailed information about a specific volume.
-
- Args:
- full_volume_name: Full volume name (catalog.schema.volume format)
-
- Returns:
- VolumeInfo object with volume metadata including:
- - name, full_name, catalog_name, schema_name
- - volume_type, owner, comment
- - created_at, updated_at
- - storage_location
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- return w.volumes.read(name=full_volume_name)
-
-
-def create_volume(
- catalog_name: str,
- schema_name: str,
- name: str,
- volume_type: str = "MANAGED",
- comment: Optional[str] = None,
- storage_location: Optional[str] = None,
-) -> VolumeInfo:
- """
- Create a new volume in Unity Catalog.
-
- Args:
- catalog_name: Name of the catalog
- schema_name: Name of the schema
- name: Name of the volume to create
- volume_type: "MANAGED" or "EXTERNAL" (default: "MANAGED")
- comment: Optional description
- storage_location: Required for EXTERNAL volumes (cloud storage URL)
-
- Returns:
- VolumeInfo object with created volume metadata
-
- Raises:
- ValueError: If EXTERNAL volume without storage_location
- DatabricksError: If API request fails
- """
- vtype = VolumeType(volume_type)
- if vtype == VolumeType.EXTERNAL and not storage_location:
- raise ValueError("storage_location is required for EXTERNAL volumes")
-
- w = get_workspace_client()
- kwargs: Dict = {
- "catalog_name": catalog_name,
- "schema_name": schema_name,
- "name": name,
- "volume_type": vtype,
- }
- if comment is not None:
- kwargs["comment"] = comment
- if storage_location is not None:
- kwargs["storage_location"] = storage_location
- return w.volumes.create(**kwargs)
-
-
-def update_volume(
- full_volume_name: str,
- new_name: Optional[str] = None,
- comment: Optional[str] = None,
- owner: Optional[str] = None,
-) -> VolumeInfo:
- """
- Update an existing volume.
-
- Args:
- full_volume_name: Full volume name (catalog.schema.volume format)
- new_name: New name for the volume
- comment: New comment/description
- owner: New owner
-
- Returns:
- VolumeInfo object with updated volume metadata
-
- Raises:
- ValueError: If no fields are provided to update
- DatabricksError: If API request fails
- """
- if not any([new_name, comment, owner]):
- raise ValueError("At least one field must be provided to update")
-
- w = get_workspace_client()
- kwargs: Dict = {"name": full_volume_name}
- if new_name is not None:
- kwargs["new_name"] = new_name
- if comment is not None:
- kwargs["comment"] = comment
- if owner is not None:
- kwargs["owner"] = owner
- return w.volumes.update(**kwargs)
-
-
-def delete_volume(full_volume_name: str) -> None:
- """
- Delete a volume from Unity Catalog.
-
- Args:
- full_volume_name: Full volume name (catalog.schema.volume format)
-
- Raises:
- DatabricksError: If API request fails
- """
- w = get_workspace_client()
- w.volumes.delete(name=full_volume_name)
diff --git a/databricks-tools-core/databricks_tools_core/vector_search/__init__.py b/databricks-tools-core/databricks_tools_core/vector_search/__init__.py
deleted file mode 100644
index c35557df..00000000
--- a/databricks-tools-core/databricks_tools_core/vector_search/__init__.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""
-Vector Search Operations
-
-Functions for managing Databricks Vector Search endpoints, indexes,
-and performing similarity queries.
-"""
-
-from .endpoints import (
- create_vs_endpoint,
- get_vs_endpoint,
- list_vs_endpoints,
- delete_vs_endpoint,
-)
-from .indexes import (
- create_vs_index,
- get_vs_index,
- list_vs_indexes,
- delete_vs_index,
- sync_vs_index,
- query_vs_index,
- upsert_vs_data,
- delete_vs_data,
- scan_vs_index,
-)
-
-__all__ = [
- # Endpoints
- "create_vs_endpoint",
- "get_vs_endpoint",
- "list_vs_endpoints",
- "delete_vs_endpoint",
- # Indexes
- "create_vs_index",
- "get_vs_index",
- "list_vs_indexes",
- "delete_vs_index",
- "sync_vs_index",
- "query_vs_index",
- "upsert_vs_data",
- "delete_vs_data",
- "scan_vs_index",
-]
diff --git a/databricks-tools-core/databricks_tools_core/vector_search/endpoints.py b/databricks-tools-core/databricks_tools_core/vector_search/endpoints.py
deleted file mode 100644
index 157398fb..00000000
--- a/databricks-tools-core/databricks_tools_core/vector_search/endpoints.py
+++ /dev/null
@@ -1,209 +0,0 @@
-"""
-Vector Search Endpoint Operations
-
-Functions for creating, managing, and deleting Databricks Vector Search endpoints.
-"""
-
-import logging
-from typing import Any, Dict, List
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def create_vs_endpoint(
- name: str,
- endpoint_type: str = "STANDARD",
-) -> Dict[str, Any]:
- """
- Create a Vector Search endpoint.
-
- Endpoint creation is asynchronous. Use get_vs_endpoint() to check status.
-
- Args:
- name: Endpoint name (unique within workspace)
- endpoint_type: "STANDARD" (low-latency, <100ms) or
- "STORAGE_OPTIMIZED" (cost-effective, ~250ms, 1B+ vectors)
-
- Returns:
- Dictionary with:
- - name: Endpoint name
- - endpoint_type: Type of endpoint created
- - status: Creation status
-
- Raises:
- Exception: If creation fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.vectorsearch import EndpointType
-
- ep_type = EndpointType(endpoint_type)
- client.vector_search_endpoints.create_endpoint(
- name=name,
- endpoint_type=ep_type,
- )
-
- return {
- "name": name,
- "endpoint_type": endpoint_type,
- "status": "CREATING",
- "message": f"Endpoint '{name}' creation initiated. Use get_vs_endpoint('{name}') to check status.",
- }
- except Exception as e:
- error_msg = str(e)
- if "ALREADY_EXISTS" in error_msg or "already exists" in error_msg.lower():
- return {
- "name": name,
- "endpoint_type": endpoint_type,
- "status": "ALREADY_EXISTS",
- "error": f"Endpoint '{name}' already exists",
- }
- raise Exception(f"Failed to create vector search endpoint '{name}': {error_msg}")
-
-
-def get_vs_endpoint(name: str) -> Dict[str, Any]:
- """
- Get Vector Search endpoint status and details.
-
- Args:
- name: Endpoint name
-
- Returns:
- Dictionary with:
- - name: Endpoint name
- - endpoint_type: STANDARD or STORAGE_OPTIMIZED
- - state: Current state (e.g., ONLINE, PROVISIONING, OFFLINE)
- - creation_timestamp: When endpoint was created
- - last_updated_timestamp: When endpoint was last updated
- - num_indexes: Number of indexes on this endpoint
- - error: Error message if endpoint is in error state
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- endpoint = client.vector_search_endpoints.get_endpoint(endpoint_name=name)
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "state": "NOT_FOUND",
- "error": f"Endpoint '{name}' not found",
- }
- raise Exception(f"Failed to get vector search endpoint '{name}': {error_msg}")
-
- result: Dict[str, Any] = {
- "name": endpoint.name,
- "state": (
- endpoint.endpoint_status.state.value
- if endpoint.endpoint_status and endpoint.endpoint_status.state
- else None
- ),
- "error": None,
- }
-
- if endpoint.endpoint_type:
- result["endpoint_type"] = endpoint.endpoint_type.value
-
- if endpoint.endpoint_status and endpoint.endpoint_status.message:
- result["message"] = endpoint.endpoint_status.message
-
- if endpoint.creation_timestamp:
- result["creation_timestamp"] = endpoint.creation_timestamp
-
- if endpoint.last_updated_timestamp:
- result["last_updated_timestamp"] = endpoint.last_updated_timestamp
-
- if endpoint.num_indexes is not None:
- result["num_indexes"] = endpoint.num_indexes
-
- return result
-
-
-def list_vs_endpoints() -> List[Dict[str, Any]]:
- """
- List all Vector Search endpoints in the workspace.
-
- Returns:
- List of endpoint dictionaries with:
- - name: Endpoint name
- - endpoint_type: STANDARD or STORAGE_OPTIMIZED
- - state: Current state
- - num_indexes: Number of indexes on endpoint
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- response = client.vector_search_endpoints.list_endpoints()
- except Exception as e:
- raise Exception(f"Failed to list vector search endpoints: {str(e)}")
-
- result = []
- # SDK may return a generator or an object with .endpoints attribute
- if hasattr(response, "endpoints"):
- endpoints = response.endpoints if response.endpoints else []
- else:
- endpoints = list(response) if response else []
- for ep in endpoints:
- entry: Dict[str, Any] = {"name": ep.name}
-
- if ep.endpoint_type:
- entry["endpoint_type"] = ep.endpoint_type.value
-
- if ep.endpoint_status and ep.endpoint_status.state:
- entry["state"] = ep.endpoint_status.state.value
-
- if ep.num_indexes is not None:
- entry["num_indexes"] = ep.num_indexes
-
- if ep.creation_timestamp:
- entry["creation_timestamp"] = ep.creation_timestamp
-
- result.append(entry)
-
- return result
-
-
-def delete_vs_endpoint(name: str) -> Dict[str, Any]:
- """
- Delete a Vector Search endpoint.
-
- All indexes on the endpoint must be deleted first.
-
- Args:
- name: Endpoint name to delete
-
- Returns:
- Dictionary with:
- - name: Endpoint name
- - status: "deleted" or error info
-
- Raises:
- Exception: If deletion fails
- """
- client = get_workspace_client()
-
- try:
- client.vector_search_endpoints.delete_endpoint(endpoint_name=name)
- return {
- "name": name,
- "status": "deleted",
- }
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": name,
- "status": "NOT_FOUND",
- "error": f"Endpoint '{name}' not found",
- }
- raise Exception(f"Failed to delete vector search endpoint '{name}': {error_msg}")
diff --git a/databricks-tools-core/databricks_tools_core/vector_search/indexes.py b/databricks-tools-core/databricks_tools_core/vector_search/indexes.py
deleted file mode 100644
index 0c7cf55f..00000000
--- a/databricks-tools-core/databricks_tools_core/vector_search/indexes.py
+++ /dev/null
@@ -1,602 +0,0 @@
-"""
-Vector Search Index Operations
-
-Functions for creating, managing, querying, and syncing Vector Search indexes.
-"""
-
-import json
-import logging
-from typing import Any, Dict, List, Optional
-
-from ..auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-def create_vs_index(
- name: str,
- endpoint_name: str,
- primary_key: str,
- index_type: str = "DELTA_SYNC",
- delta_sync_index_spec: Optional[Dict[str, Any]] = None,
- direct_access_index_spec: Optional[Dict[str, Any]] = None,
-) -> Dict[str, Any]:
- """
- Create a Vector Search index.
-
- For DELTA_SYNC indexes, provide delta_sync_index_spec with either:
- - embedding_source_columns (managed embeddings): Databricks computes embeddings
- - embedding_vector_columns (self-managed): you provide pre-computed embeddings
-
- For DIRECT_ACCESS indexes, provide direct_access_index_spec with:
- - embedding_vector_columns and schema_json
-
- Args:
- name: Fully qualified index name (catalog.schema.index_name)
- endpoint_name: Vector Search endpoint to host this index
- primary_key: Column name for the primary key
- index_type: "DELTA_SYNC" or "DIRECT_ACCESS"
- delta_sync_index_spec: Config for Delta Sync index. Keys:
- - source_table (str): Fully qualified source table name
- - embedding_source_columns (list): For managed embeddings,
- e.g. [{"name": "content", "embedding_model_endpoint_name": "databricks-gte-large-en"}]
- - embedding_vector_columns (list): For self-managed embeddings,
- e.g. [{"name": "embedding", "embedding_dimension": 768}]
- - pipeline_type (str): "TRIGGERED" or "CONTINUOUS"
- - columns_to_sync (list, optional): Column names to include
- direct_access_index_spec: Config for Direct Access index. Keys:
- - embedding_vector_columns (list): e.g. [{"name": "embedding", "embedding_dimension": 768}]
- - schema_json (str): JSON schema string
- - embedding_model_endpoint_name (str, optional): For query-time embedding
-
- Returns:
- Dictionary with index creation details
-
- Raises:
- Exception: If creation fails
- """
- client = get_workspace_client()
-
- try:
- from databricks.sdk.service.vectorsearch import (
- DeltaSyncVectorIndexSpecRequest,
- DirectAccessVectorIndexSpec,
- EmbeddingSourceColumn,
- EmbeddingVectorColumn,
- VectorIndexType,
- )
-
- kwargs: Dict[str, Any] = {
- "name": name,
- "endpoint_name": endpoint_name,
- "primary_key": primary_key,
- "index_type": VectorIndexType(index_type),
- }
-
- if index_type == "DELTA_SYNC" and delta_sync_index_spec:
- spec = delta_sync_index_spec
- ds_kwargs: Dict[str, Any] = {}
-
- if "source_table" in spec:
- ds_kwargs["source_table"] = spec["source_table"]
-
- if "pipeline_type" in spec:
- from databricks.sdk.service.vectorsearch import PipelineType
-
- ds_kwargs["pipeline_type"] = PipelineType(spec["pipeline_type"])
-
- if "embedding_source_columns" in spec:
- ds_kwargs["embedding_source_columns"] = [
- EmbeddingSourceColumn(**col) for col in spec["embedding_source_columns"]
- ]
-
- if "embedding_vector_columns" in spec:
- ds_kwargs["embedding_vector_columns"] = [
- EmbeddingVectorColumn(**col) for col in spec["embedding_vector_columns"]
- ]
-
- if "columns_to_sync" in spec:
- ds_kwargs["columns_to_sync"] = spec["columns_to_sync"]
-
- kwargs["delta_sync_index_spec"] = DeltaSyncVectorIndexSpecRequest(**ds_kwargs)
-
- elif index_type == "DIRECT_ACCESS" and direct_access_index_spec:
- spec = direct_access_index_spec
- da_kwargs: Dict[str, Any] = {}
-
- if "embedding_vector_columns" in spec:
- da_kwargs["embedding_vector_columns"] = [
- EmbeddingVectorColumn(**col) for col in spec["embedding_vector_columns"]
- ]
-
- if "schema_json" in spec:
- da_kwargs["schema_json"] = spec["schema_json"]
-
- if "embedding_model_endpoint_name" in spec:
- da_kwargs["embedding_source_columns"] = [
- EmbeddingSourceColumn(
- name="__query__",
- embedding_model_endpoint_name=spec["embedding_model_endpoint_name"],
- )
- ]
-
- kwargs["direct_access_index_spec"] = DirectAccessVectorIndexSpec(**da_kwargs)
-
- client.vector_search_indexes.create_index(**kwargs)
-
- return {
- "name": name,
- "endpoint_name": endpoint_name,
- "index_type": index_type,
- "primary_key": primary_key,
- "status": "CREATING",
- "message": f"Index '{name}' creation initiated. Use get_vs_index('{name}') to check status.",
- }
- except Exception as e:
- error_msg = str(e)
- if "ALREADY_EXISTS" in error_msg or "already exists" in error_msg.lower():
- return {
- "name": name,
- "status": "ALREADY_EXISTS",
- "error": f"Index '{name}' already exists",
- }
- raise Exception(f"Failed to create vector search index '{name}': {error_msg}")
-
-
-def get_vs_index(index_name: str) -> Dict[str, Any]:
- """
- Get Vector Search index status and details.
-
- Args:
- index_name: Fully qualified index name (catalog.schema.index_name)
-
- Returns:
- Dictionary with:
- - name: Index name
- - endpoint_name: Hosting endpoint
- - index_type: DELTA_SYNC or DIRECT_ACCESS
- - primary_key: Primary key column
- - state: Index state (ONLINE, PROVISIONING, etc.)
- - delta_sync_index_spec: Sync config (if DELTA_SYNC)
- - direct_access_index_spec: Config (if DIRECT_ACCESS)
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- index = client.vector_search_indexes.get_index(index_name=index_name)
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": index_name,
- "state": "NOT_FOUND",
- "error": f"Index '{index_name}' not found",
- }
- raise Exception(f"Failed to get vector search index '{index_name}': {error_msg}")
-
- result: Dict[str, Any] = {
- "name": index.name,
- "endpoint_name": index.endpoint_name,
- "primary_key": index.primary_key,
- }
-
- if index.index_type:
- result["index_type"] = index.index_type.value
-
- if index.status:
- if index.status.ready:
- result["state"] = "ONLINE" if index.status.ready else "NOT_READY"
- if index.status.message:
- result["message"] = index.status.message
- if index.status.index_url:
- result["index_url"] = index.status.index_url
-
- if index.delta_sync_index_spec:
- spec = index.delta_sync_index_spec
- result["delta_sync_index_spec"] = {
- "source_table": spec.source_table,
- "pipeline_type": spec.pipeline_type.value if spec.pipeline_type else None,
- }
- if spec.pipeline_id:
- result["delta_sync_index_spec"]["pipeline_id"] = spec.pipeline_id
-
- return result
-
-
-def list_vs_indexes(endpoint_name: str) -> List[Dict[str, Any]]:
- """
- List all Vector Search indexes on an endpoint.
-
- Args:
- endpoint_name: Endpoint name to list indexes for
-
- Returns:
- List of index dictionaries with:
- - name: Index name
- - index_type: DELTA_SYNC or DIRECT_ACCESS
- - primary_key: Primary key column
- - state: Index state
-
- Raises:
- Exception: If API request fails
- """
- client = get_workspace_client()
-
- try:
- response = client.vector_search_indexes.list_indexes(
- endpoint_name=endpoint_name,
- )
- except Exception as e:
- raise Exception(f"Failed to list indexes on endpoint '{endpoint_name}': {str(e)}")
-
- result = []
- # SDK may return an object with .vector_indexes or a generator directly
- if hasattr(response, "vector_indexes") and response.vector_indexes:
- indexes = response.vector_indexes
- elif response:
- indexes = list(response)
- else:
- indexes = []
- for idx in indexes:
- entry: Dict[str, Any] = {
- "name": idx.name,
- }
-
- # primary_key may not exist on MiniVectorIndex
- try:
- if idx.primary_key:
- entry["primary_key"] = idx.primary_key
- except (AttributeError, KeyError):
- pass
-
- try:
- if idx.index_type:
- entry["index_type"] = idx.index_type.value
- except (AttributeError, KeyError):
- pass
-
- # status may not exist on MiniVectorIndex (from generator response)
- try:
- if idx.status and idx.status.ready is not None:
- entry["state"] = "ONLINE" if idx.status.ready else "NOT_READY"
- except (AttributeError, KeyError):
- pass
-
- result.append(entry)
-
- return result
-
-
-def delete_vs_index(index_name: str) -> Dict[str, Any]:
- """
- Delete a Vector Search index.
-
- Args:
- index_name: Fully qualified index name (catalog.schema.index_name)
-
- Returns:
- Dictionary with:
- - name: Index name
- - status: "deleted" or error info
-
- Raises:
- Exception: If deletion fails
- """
- client = get_workspace_client()
-
- try:
- client.vector_search_indexes.delete_index(index_name=index_name)
- return {
- "name": index_name,
- "status": "deleted",
- }
- except Exception as e:
- error_msg = str(e)
- if "not found" in error_msg.lower() or "does not exist" in error_msg.lower() or "404" in error_msg:
- return {
- "name": index_name,
- "status": "NOT_FOUND",
- "error": f"Index '{index_name}' not found",
- }
- raise Exception(f"Failed to delete vector search index '{index_name}': {error_msg}")
-
-
-def sync_vs_index(index_name: str) -> Dict[str, Any]:
- """
- Trigger a sync for a TRIGGERED Delta Sync index.
-
- Only applicable for Delta Sync indexes with pipeline_type=TRIGGERED.
- Continuous indexes sync automatically.
-
- Args:
- index_name: Fully qualified index name (catalog.schema.index_name)
-
- Returns:
- Dictionary with sync status
-
- Raises:
- Exception: If sync trigger fails
- """
- client = get_workspace_client()
-
- try:
- client.vector_search_indexes.sync_index(index_name=index_name)
- return {
- "name": index_name,
- "status": "SYNC_TRIGGERED",
- "message": f"Sync triggered for index '{index_name}'. Use get_vs_index() to check progress.",
- }
- except Exception as e:
- raise Exception(f"Failed to sync index '{index_name}': {str(e)}")
-
-
-def query_vs_index(
- index_name: str,
- columns: List[str],
- query_text: Optional[str] = None,
- query_vector: Optional[List[float]] = None,
- num_results: int = 5,
- filters_json: Optional[str] = None,
- filter_string: Optional[str] = None,
- query_type: Optional[str] = None,
-) -> Dict[str, Any]:
- """
- Query a Vector Search index for similar documents.
-
- Provide either query_text (for indexes with managed or attached embeddings)
- or query_vector (for pre-computed query embeddings).
-
- For filters:
- - Standard endpoints use filters_json (dict format): '{"category": "ai"}'
- - Storage-Optimized endpoints use filter_string (SQL syntax): "category = 'ai'"
-
- Args:
- index_name: Fully qualified index name (catalog.schema.index_name)
- columns: List of column names to return in results
- query_text: Text query (for managed/attached embedding models)
- query_vector: Pre-computed query embedding vector
- num_results: Number of results to return (default: 5)
- filters_json: JSON string of filters for Standard endpoints
- filter_string: SQL-like filter for Storage-Optimized endpoints
- query_type: Search algorithm: "ANN" (default) or "HYBRID" (vector + keyword)
-
- Returns:
- Dictionary with:
- - columns: Column names in results
- - data: List of result rows (similarity score is appended as last column)
- - num_results: Number of results returned
-
- Raises:
- Exception: If query fails
- """
- client = get_workspace_client()
-
- kwargs: Dict[str, Any] = {
- "index_name": index_name,
- "columns": columns,
- "num_results": num_results,
- }
-
- if query_text is not None:
- kwargs["query_text"] = query_text
- elif query_vector is not None:
- kwargs["query_vector"] = query_vector
- else:
- raise ValueError("Must provide either query_text or query_vector")
-
- if filters_json is not None:
- # Ensure filters_json is a string — callers may pass a dict
- if isinstance(filters_json, dict):
- filters_json = json.dumps(filters_json)
- kwargs["filters_json"] = filters_json
-
- if filter_string is not None:
- kwargs["filter_string"] = filter_string
-
- if query_type is not None:
- kwargs["query_type"] = query_type
-
- try:
- response = client.vector_search_indexes.query_index(**kwargs)
- except Exception as e:
- raise Exception(f"Failed to query index '{index_name}': {str(e)}")
-
- result: Dict[str, Any] = {}
-
- # Column names from manifest (SDK doesn't put them on result directly)
- try:
- if response.manifest and response.manifest.columns:
- result["columns"] = [c.name for c in response.manifest.columns]
- except (AttributeError, KeyError):
- pass
-
- if response.result:
- if response.result.data_array:
- result["data"] = response.result.data_array
- result["num_results"] = len(response.result.data_array)
- else:
- result["data"] = []
- result["num_results"] = 0
-
- if response.manifest:
- result["manifest"] = {
- "column_count": response.manifest.column_count,
- }
-
- return result
-
-
-def upsert_vs_data(
- index_name: str,
- inputs_json: str,
-) -> Dict[str, Any]:
- """
- Upsert data into a Direct Access Vector Search index.
-
- Args:
- index_name: Fully qualified index name (catalog.schema.index_name)
- inputs_json: JSON string of records to upsert. Each record must include
- the primary key and embedding vector columns.
- Example: '[{"id": "1", "text": "hello", "embedding": [0.1, 0.2, ...]}]'
-
- Returns:
- Dictionary with:
- - name: Index name
- - status: Upsert result status
- - num_records: Number of records upserted
-
- Raises:
- Exception: If upsert fails
- """
- client = get_workspace_client()
-
- try:
- # Ensure inputs_json is a string — callers may pass a list/dict
- if isinstance(inputs_json, (dict, list)):
- records = inputs_json
- inputs_json = json.dumps(inputs_json)
- else:
- records = json.loads(inputs_json)
- num_records = len(records) if isinstance(records, list) else 1
-
- response = client.vector_search_indexes.upsert_data_vector_index(
- index_name=index_name,
- inputs_json=inputs_json,
- )
-
- result: Dict[str, Any] = {
- "name": index_name,
- "status": "SUCCESS",
- "num_records": num_records,
- }
-
- if response and response.status:
- result["status"] = response.status.value if hasattr(response.status, "value") else str(response.status)
-
- return result
- except Exception as e:
- raise Exception(f"Failed to upsert data into index '{index_name}': {str(e)}")
-
-
-def delete_vs_data(
- index_name: str,
- primary_keys: List[str],
-) -> Dict[str, Any]:
- """
- Delete data from a Direct Access Vector Search index.
-
- Args:
- index_name: Fully qualified index name (catalog.schema.index_name)
- primary_keys: List of primary key values to delete
-
- Returns:
- Dictionary with:
- - name: Index name
- - status: Delete result status
- - num_deleted: Number of records requested for deletion
-
- Raises:
- Exception: If delete fails
- """
- client = get_workspace_client()
-
- try:
- response = client.vector_search_indexes.delete_data_vector_index(
- index_name=index_name,
- primary_keys=primary_keys,
- )
-
- result: Dict[str, Any] = {
- "name": index_name,
- "status": "SUCCESS",
- "num_deleted": len(primary_keys),
- }
-
- if response and response.status:
- result["status"] = response.status.value if hasattr(response.status, "value") else str(response.status)
-
- return result
- except Exception as e:
- raise Exception(f"Failed to delete data from index '{index_name}': {str(e)}")
-
-
-def scan_vs_index(
- index_name: str,
- num_results: int = 100,
-) -> Dict[str, Any]:
- """
- Scan a Vector Search index to retrieve all entries.
-
- Useful for debugging, exporting, or verifying index contents.
-
- Args:
- index_name: Fully qualified index name (catalog.schema.index_name)
- num_results: Maximum number of entries to return (default: 100)
-
- Returns:
- Dictionary with:
- - columns: Column names
- - data: List of index entries
- - num_results: Number of entries returned
-
- Raises:
- Exception: If scan fails
- """
- client = get_workspace_client()
-
- try:
- response = client.vector_search_indexes.scan_index(
- index_name=index_name,
- num_results=num_results,
- )
- except Exception as e:
- raise Exception(f"Failed to scan index '{index_name}': {str(e)}")
-
- result: Dict[str, Any] = {}
-
- # ScanVectorIndexResponse has .data (list of entries) and .last_primary_key
- # not .result like QueryVectorIndexResponse
- try:
- data = response.data
- if data:
- # data is a list of Struct/dict objects
- if isinstance(data, list) and len(data) > 0:
- # Extract column names from first entry
- first = data[0]
- if hasattr(first, "as_dict"):
- rows = [d.as_dict() for d in data]
- elif isinstance(first, dict):
- rows = data
- else:
- rows = data
-
- if rows and isinstance(rows[0], dict):
- result["columns"] = list(rows[0].keys())
- result["data"] = rows
- result["num_results"] = len(rows)
- else:
- result["data"] = []
- result["num_results"] = 0
- else:
- result["data"] = []
- result["num_results"] = 0
- except (AttributeError, KeyError):
- # Fallback: try the old .result pattern in case SDK changes
- try:
- if response.result:
- if hasattr(response.result, "column_names") and response.result.column_names:
- result["columns"] = response.result.column_names
- if response.result.data_array:
- result["data"] = response.result.data_array
- result["num_results"] = len(response.result.data_array)
- else:
- result["data"] = []
- result["num_results"] = 0
- except (AttributeError, KeyError):
- result["data"] = []
- result["num_results"] = 0
-
- return result
diff --git a/databricks-tools-core/docs/architecture.svg b/databricks-tools-core/docs/architecture.svg
deleted file mode 100644
index eafc80b8..00000000
--- a/databricks-tools-core/docs/architecture.svg
+++ /dev/null
@@ -1,113 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- Build Cool Stuff on Databricks
-
-
-
-
-
-
-
- SQL
-
-
-
-
- Jobs
-
-
-
-
- Pipelines
-
-
-
-
- Catalog
-
-
-
-
- Genie
-
-
-
-
- AI/BI
-
-
-
-
- + more
-
-
-
-
-
-
-
-
-
-
-
- enables
-
-
-
-
- Skills
- Knowledge & Patterns
- 17+ Databricks skills
-
-
-
-
- MCP Tools
- Executable Actions
- 45+ Databricks tools
-
-
-
-
-
-
-
-
-
-
- provides
-
-
-
-
- AI Dev Kit
- Trusted sources for your AI coding assistant
-
-
\ No newline at end of file
diff --git a/databricks-tools-core/pyproject.toml b/databricks-tools-core/pyproject.toml
deleted file mode 100644
index 9dbd6e37..00000000
--- a/databricks-tools-core/pyproject.toml
+++ /dev/null
@@ -1,48 +0,0 @@
-[build-system]
-requires = ["setuptools>=61.0", "wheel"]
-build-backend = "setuptools.build_meta"
-
-[project]
-name = "databricks-tools-core"
-version = "0.1.0"
-description = "High-level, AI-assistant-friendly functions for building Databricks projects"
-readme = "README.md"
-requires-python = ">=3.10"
-license = {file = "LICENSE.md"}
-authors = [
- {name = "Databricks"},
-]
-classifiers = [
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- "Programming Language :: Python :: 3.12",
-]
-dependencies = [
- "requests>=2.31.0",
- "pydantic>=2.0.0",
- "databricks-sdk>=0.81.0",
- "pyyaml>=6.0",
- "sqlglot>=20.0.0",
- "sqlfluff>=3.0.0",
- "plutoprint==0.19.0",
-]
-
-[project.optional-dependencies]
-dev = [
- "pytest>=7.0.0",
- "pytest-timeout>=2.0.0",
- "black>=23.0.0",
- "ruff>=0.1.0",
-]
-
-[tool.pytest.ini_options]
-testpaths = ["tests"]
-markers = [
- "integration: marks tests as integration tests (require Databricks connection)",
- "slow: marks tests as slow",
-]
-
-[tool.setuptools.packages.find]
-where = ["."]
-include = ["databricks_tools_core*"]
diff --git a/databricks-tools-core/pytest.ini b/databricks-tools-core/pytest.ini
deleted file mode 100644
index 012188f0..00000000
--- a/databricks-tools-core/pytest.ini
+++ /dev/null
@@ -1,11 +0,0 @@
-[pytest]
-testpaths = tests
-python_files = test_*.py
-python_classes = Test*
-python_functions = test_*
-addopts = -v --tb=short
-markers =
- integration: marks tests as integration tests (require Databricks connection)
- slow: marks tests as slow (may take a while to run)
-filterwarnings =
- ignore::DeprecationWarning
diff --git a/databricks-tools-core/requirements.txt b/databricks-tools-core/requirements.txt
deleted file mode 100644
index 3ec4ad94..00000000
--- a/databricks-tools-core/requirements.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-databricks-sdk>=0.81.0
-pydantic>=2.0.0
diff --git a/databricks-tools-core/tests/__init__.py b/databricks-tools-core/tests/__init__.py
deleted file mode 100644
index 900ff4af..00000000
--- a/databricks-tools-core/tests/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Tests for databricks-tools-core."""
diff --git a/databricks-tools-core/tests/conftest.py b/databricks-tools-core/tests/conftest.py
deleted file mode 100644
index ab854a2b..00000000
--- a/databricks-tools-core/tests/conftest.py
+++ /dev/null
@@ -1,331 +0,0 @@
-"""
-Pytest fixtures for databricks-tools-core integration tests.
-
-These fixtures set up and tear down test resources in Databricks.
-Requires a valid Databricks connection (via env vars or ~/.databrickscfg).
-"""
-
-import logging
-import os
-from pathlib import Path
-import pytest
-from databricks.sdk import WorkspaceClient
-
-# Load .env.test file if it exists
-_env_file = Path(__file__).parent.parent / ".env.test"
-if _env_file.exists():
- from dotenv import load_dotenv
-
- load_dotenv(_env_file)
- logging.getLogger(__name__).info(f"Loaded environment from {_env_file}")
-
-# Test catalog and schema names (configurable via env vars)
-TEST_CATALOG = os.environ.get("TEST_CATALOG", "ai_dev_kit_test")
-TEST_SCHEMA = os.environ.get("TEST_SCHEMA", "test_schema")
-TEST_VOLUME = os.environ.get("TEST_VOLUME", "test_volume")
-
-# Test data directory
-TEST_DATA_DIR = Path(__file__).parent / "integration" / "sql" / "test_data"
-
-logger = logging.getLogger(__name__)
-
-
-def pytest_configure(config):
- """Configure pytest with custom markers."""
- config.addinivalue_line("markers", "integration: mark test as integration test requiring Databricks")
-
-
-@pytest.fixture(scope="session")
-def workspace_client() -> WorkspaceClient:
- """
- Create a WorkspaceClient for the test session.
-
- Uses standard Databricks authentication:
- 1. DATABRICKS_HOST + DATABRICKS_TOKEN env vars
- 2. ~/.databrickscfg profile
- """
- try:
- client = WorkspaceClient()
- # Verify connection works
- client.current_user.me()
- logger.info(f"Connected to Databricks: {client.config.host}")
- return client
- except Exception as e:
- pytest.skip(f"Could not connect to Databricks: {e}")
-
-
-@pytest.fixture(scope="session")
-def test_catalog(workspace_client: WorkspaceClient) -> str:
- """
- Ensure test catalog exists.
-
- Returns the catalog name.
- """
- try:
- workspace_client.catalogs.get(TEST_CATALOG)
- logger.info(f"Using existing catalog: {TEST_CATALOG}")
- except Exception:
- logger.info(f"Creating catalog: {TEST_CATALOG}")
- workspace_client.catalogs.create(name=TEST_CATALOG)
-
- return TEST_CATALOG
-
-
-@pytest.fixture(scope="session")
-def test_schema(workspace_client: WorkspaceClient, test_catalog: str) -> str:
- """
- Create a fresh test schema (drops if exists).
-
- This ensures a clean state for each test run.
- Returns the schema name.
- """
- full_schema_name = f"{test_catalog}.{TEST_SCHEMA}"
-
- # Drop schema if exists (cascade to remove all objects)
- try:
- logger.info(f"Dropping existing schema: {full_schema_name}")
- workspace_client.schemas.delete(full_schema_name, force=True)
- except Exception as e:
- logger.debug(f"Schema delete failed (may not exist): {e}")
-
- # Create fresh schema
- logger.info(f"Creating schema: {full_schema_name}")
- try:
- workspace_client.schemas.create(
- name=TEST_SCHEMA,
- catalog_name=test_catalog,
- )
- except Exception as e:
- if "already exists" in str(e):
- logger.info(f"Schema already exists, reusing: {full_schema_name}")
- else:
- raise
-
- yield TEST_SCHEMA
-
- # Cleanup after all tests (optional - comment out to inspect test data)
- # try:
- # logger.info(f"Cleaning up schema: {full_schema_name}")
- # workspace_client.schemas.delete(full_schema_name)
- # except Exception as e:
- # logger.warning(f"Failed to cleanup schema: {e}")
-
-
-@pytest.fixture(scope="session")
-def warehouse_id(workspace_client: WorkspaceClient) -> str:
- """
- Get a running SQL warehouse for tests.
-
- Prefers shared endpoints, falls back to any running warehouse.
- """
- from databricks.sdk.service.sql import State
-
- warehouses = list(workspace_client.warehouses.list())
-
- # Priority: running shared endpoint
- for w in warehouses:
- if w.state == State.RUNNING and "shared" in (w.name or "").lower():
- logger.info(f"Using warehouse: {w.name} ({w.id})")
- return w.id
-
- # Fallback: any running warehouse
- for w in warehouses:
- if w.state == State.RUNNING:
- logger.info(f"Using warehouse: {w.name} ({w.id})")
- return w.id
-
- # No running warehouse found
- pytest.skip("No running SQL warehouse available for tests")
-
-
-@pytest.fixture(scope="module")
-def test_tables(
- workspace_client: WorkspaceClient,
- test_catalog: str,
- test_schema: str,
- warehouse_id: str,
-) -> dict:
- """
- Create test tables with sample data.
-
- Creates:
- - customers: Basic customer table
- - orders: Orders with foreign key to customers
- - products: Product catalog with various data types
-
- Returns dict with table names.
- """
- from databricks_tools_core.sql import execute_sql
-
- tables = {
- "customers": f"{test_catalog}.{test_schema}.customers",
- "orders": f"{test_catalog}.{test_schema}.orders",
- "products": f"{test_catalog}.{test_schema}.products",
- }
-
- # Create customers table
- execute_sql(
- sql_query=f"""
- CREATE OR REPLACE TABLE {tables["customers"]} (
- customer_id BIGINT,
- name STRING,
- email STRING,
- country STRING,
- created_at TIMESTAMP,
- is_active BOOLEAN
- )
- """,
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=test_schema,
- )
-
- # Insert customer data
- execute_sql(
- sql_query=f"""
- INSERT INTO {tables["customers"]} VALUES
- (1, 'Alice Smith', 'alice@example.com', 'USA', '2024-01-15 10:30:00', true),
- (2, 'Bob Johnson', 'bob@example.com', 'Canada', '2024-02-20 14:45:00', true),
- (3, 'Charlie Brown', 'charlie@example.com', 'UK', '2024-03-10 09:00:00', false),
- (4, 'Diana Ross', 'diana@example.com', 'USA', '2024-04-05 16:20:00', true),
- (5, 'Eve Wilson', 'eve@example.com', 'Germany', '2024-05-12 11:15:00', true)
- """,
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=test_schema,
- )
-
- # Create orders table
- execute_sql(
- sql_query=f"""
- CREATE OR REPLACE TABLE {tables["orders"]} (
- order_id BIGINT,
- customer_id BIGINT,
- amount DECIMAL(10, 2),
- status STRING,
- order_date DATE
- )
- """,
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=test_schema,
- )
-
- # Insert order data
- execute_sql(
- sql_query=f"""
- INSERT INTO {tables["orders"]} VALUES
- (101, 1, 150.00, 'completed', '2024-06-01'),
- (102, 1, 75.50, 'completed', '2024-06-15'),
- (103, 2, 200.00, 'pending', '2024-06-20'),
- (104, 3, 50.00, 'cancelled', '2024-06-22'),
- (105, 4, 300.00, 'completed', '2024-06-25'),
- (106, 5, 125.75, 'pending', '2024-06-28'),
- (107, 1, 89.99, 'completed', '2024-07-01'),
- (108, 2, 175.00, 'completed', '2024-07-05')
- """,
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=test_schema,
- )
-
- # Create products table with various data types
- execute_sql(
- sql_query=f"""
- CREATE OR REPLACE TABLE {tables["products"]} (
- product_id BIGINT,
- name STRING,
- category STRING,
- price DOUBLE,
- stock_quantity INT,
- rating FLOAT,
- tags ARRAY,
- created_at TIMESTAMP
- )
- """,
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=test_schema,
- )
-
- # Insert product data
- execute_sql(
- sql_query=f"""
- INSERT INTO {tables["products"]} VALUES
- (1, 'Laptop Pro', 'Electronics', 1299.99, 50, 4.5, ARRAY('tech', 'computer'), '2024-01-01 00:00:00'),
- (2, 'Wireless Mouse', 'Electronics', 29.99, 200, 4.2, ARRAY('tech', 'accessory'), '2024-01-15 00:00:00'),
- (3, 'Coffee Maker', 'Kitchen', 79.99, 75, 4.8, ARRAY('home', 'appliance'), '2024-02-01 00:00:00'),
- (4, 'Running Shoes', 'Sports', 119.99, 100, 4.3, ARRAY('fitness', 'footwear'), '2024-02-15 00:00:00'),
- (5, 'Desk Lamp', 'Office', 45.00, 150, 4.0, ARRAY('home', 'lighting'), '2024-03-01 00:00:00')
- """,
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=test_schema,
- )
-
- logger.info(f"Created test tables: {list(tables.keys())}")
- return tables
-
-
-@pytest.fixture(scope="module")
-def test_volume(
- workspace_client: WorkspaceClient,
- test_catalog: str,
- test_schema: str,
-) -> str:
- """
- Create a test volume and upload test files.
-
- Creates:
- - parquet_data/: Parquet files for testing
- - txt_files/: Text files for file listing tests
-
- Returns the volume name.
- """
- from databricks.sdk.service.catalog import VolumeType
-
- full_volume_name = f"{test_catalog}.{test_schema}.{TEST_VOLUME}"
- volume_path = f"/Volumes/{test_catalog}/{test_schema}/{TEST_VOLUME}"
-
- # Delete volume if exists (fresh start)
- try:
- logger.info(f"Deleting existing volume: {full_volume_name}")
- workspace_client.volumes.delete(full_volume_name)
- except Exception:
- pass # Volume doesn't exist, that's fine
-
- # Create the volume
- logger.info(f"Creating volume: {full_volume_name}")
- workspace_client.volumes.create(
- catalog_name=test_catalog,
- schema_name=test_schema,
- name=TEST_VOLUME,
- volume_type=VolumeType.MANAGED,
- )
-
- # Upload parquet files
- parquet_dir = TEST_DATA_DIR / "parquet"
- if parquet_dir.exists():
- for file_path in parquet_dir.glob("*.parquet"):
- remote_path = f"{volume_path}/parquet_data/{file_path.name}"
- logger.info(f"Uploading {file_path.name} to {remote_path}")
- with open(file_path, "rb") as f:
- workspace_client.files.upload(remote_path, f, overwrite=True)
-
- # Upload txt files
- txt_dir = TEST_DATA_DIR / "txt_files"
- if txt_dir.exists():
- for file_path in txt_dir.glob("*.txt"):
- remote_path = f"{volume_path}/txt_files/{file_path.name}"
- logger.info(f"Uploading {file_path.name} to {remote_path}")
- with open(file_path, "rb") as f:
- workspace_client.files.upload(remote_path, f, overwrite=True)
-
- logger.info(f"Created test volume with files: {TEST_VOLUME}")
- yield TEST_VOLUME
-
- # Cleanup (optional)
- # try:
- # workspace_client.volumes.delete(full_volume_name)
- # except Exception as e:
- # logger.warning(f"Failed to cleanup volume: {e}")
diff --git a/databricks-tools-core/tests/integration/__init__.py b/databricks-tools-core/tests/integration/__init__.py
deleted file mode 100644
index bae4707e..00000000
--- a/databricks-tools-core/tests/integration/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Integration tests requiring Databricks connection."""
diff --git a/databricks-tools-core/tests/integration/compute/__init__.py b/databricks-tools-core/tests/integration/compute/__init__.py
deleted file mode 100644
index 0c7a3f86..00000000
--- a/databricks-tools-core/tests/integration/compute/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Compute module integration tests."""
diff --git a/databricks-tools-core/tests/integration/compute/test_execution.py b/databricks-tools-core/tests/integration/compute/test_execution.py
deleted file mode 100644
index 0f9edb44..00000000
--- a/databricks-tools-core/tests/integration/compute/test_execution.py
+++ /dev/null
@@ -1,466 +0,0 @@
-"""
-Integration tests for compute execution functions.
-
-Tests execute_databricks_command and run_file_on_databricks (with language detection,
-workspace_path persistence).
-"""
-
-import logging
-import tempfile
-import pytest
-from pathlib import Path
-
-from databricks_tools_core.compute import (
- execute_databricks_command,
- run_file_on_databricks,
- list_clusters,
- get_best_cluster,
- destroy_context,
- NoRunningClusterError,
- ExecutionResult,
-)
-from databricks_tools_core.auth import get_workspace_client, get_current_username
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture(scope="module")
-def shared_context():
- """
- Create a shared execution context for tests that need cluster execution.
-
- This speeds up tests by reusing the same context instead of creating
- a new one for each test (context creation takes ~5-10s).
- """
- # Get a running cluster
- cluster_id = get_best_cluster()
- if cluster_id is None:
- pytest.skip("No running cluster available")
-
- # Create context with first execution
- result = execute_databricks_command(
- code='print("Context initialized")',
- cluster_id=cluster_id,
- timeout=120,
- )
-
- if not result.success:
- pytest.fail(f"Failed to create shared context: {result.error}")
-
- yield {
- "cluster_id": result.cluster_id,
- "context_id": result.context_id,
- }
-
- # Cleanup
- try:
- destroy_context(result.cluster_id, result.context_id)
- except Exception:
- pass # Ignore cleanup errors
-
-
-@pytest.mark.integration
-class TestListClusters:
- """Tests for list_clusters function."""
-
- def test_list_clusters_running_only(self):
- """Should list running clusters quickly."""
- clusters = list_clusters(include_terminated=False)
-
- print("\n=== List Running Clusters ===")
- print(f"Found {len(clusters)} running clusters:")
- for c in clusters[:5]:
- print(f" - {c['cluster_name']} ({c['cluster_id']}) - {c['state']}")
-
- assert isinstance(clusters, list)
- # All should be running/pending states
- for c in clusters:
- assert c["state"] in ["RUNNING", "PENDING", "RESIZING", "RESTARTING"]
-
- def test_list_clusters_with_limit(self):
- """Should respect limit parameter."""
- clusters = list_clusters(limit=5)
-
- print("\n=== List Clusters (limit=5) ===")
- print(f"Found {len(clusters)} clusters")
-
- assert isinstance(clusters, list)
- assert len(clusters) <= 5
-
-
-@pytest.mark.integration
-class TestGetBestCluster:
- """Tests for get_best_cluster function."""
-
- def test_get_best_cluster(self):
- """Should return a running cluster ID or None."""
- cluster_id = get_best_cluster()
-
- print("\n=== Get Best Cluster ===")
- print(f"Best cluster ID: {cluster_id}")
-
- # Result can be None if no running clusters
- if cluster_id is not None:
- assert isinstance(cluster_id, str)
- assert len(cluster_id) > 0
-
-
-@pytest.mark.integration
-class TestExecuteDatabricksCommand:
- """Tests for execute_databricks_command function."""
-
- def test_simple_code_with_shared_context(self, shared_context):
- """Should execute simple code with shared context."""
- result = execute_databricks_command(
- code='print("Hello from shared context!")',
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- timeout=120,
- )
-
- print("\n=== Shared Context Execution ===")
- print(f"Success: {result.success}")
- print(f"Output: {result.output}")
-
- assert result.success, f"Execution failed: {result.error}"
- assert "Hello" in result.output
- assert result.context_id == shared_context["context_id"]
-
- def test_context_variable_persistence(self, shared_context):
- """Should persist variables across executions in same context."""
- # Set a variable
- result1 = execute_databricks_command(
- code='test_var = 42\nprint(f"Set test_var = {test_var}")',
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- timeout=120,
- )
-
- print("\n=== First Execution ===")
- print(f"Success: {result1.success}")
-
- assert result1.success, f"First execution failed: {result1.error}"
-
- # Read the variable back
- result2 = execute_databricks_command(
- code='print(f"test_var is still {test_var}")',
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- timeout=120,
- )
-
- print("\n=== Second Execution ===")
- print(f"Success: {result2.success}")
- print(f"Output: {result2.output}")
-
- assert result2.success, f"Second execution failed: {result2.error}"
- assert "test_var is still 42" in result2.output
-
- def test_sql_execution(self, shared_context):
- """Should execute SQL queries."""
- result = execute_databricks_command(
- code="SELECT 1 + 1 as result",
- cluster_id=shared_context["cluster_id"],
- language="sql",
- timeout=120,
- )
-
- print("\n=== SQL Execution ===")
- print(f"Success: {result.success}")
- print(f"Output: {result.output}")
-
- assert result.success, f"SQL execution failed: {result.error}"
-
- def test_destroy_context_on_completion(self):
- """Should destroy context when requested."""
- try:
- result = execute_databricks_command(
- code='print("Destroying context after this")',
- timeout=120,
- destroy_context_on_completion=True,
- )
-
- print("\n=== Destroy Context On Completion ===")
- print(f"Success: {result.success}")
- print(f"Context Destroyed: {result.context_destroyed}")
-
- assert result.success, f"Execution failed: {result.error}"
- assert result.context_destroyed is True
- assert "destroyed" in result.message.lower()
-
- except NoRunningClusterError as e:
- pytest.skip(f"No running cluster available: {e}")
-
-
-@pytest.mark.integration
-class TestRunFileOnDatabricksBasic:
- """Basic tests for run_file_on_databricks function."""
-
- def test_simple_file_execution(self, shared_context):
- """Should execute a simple Python file."""
- code = 'print("Hello from file!")\nprint(2 + 2)'
-
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write(code)
- f.flush()
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(
- file_path=temp_path,
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- timeout=120,
- )
-
- print("\n=== File Execution Result ===")
- print(f"Success: {result.success}")
- print(f"Output: {result.output}")
-
- assert result.success, f"Execution failed: {result.error}"
- assert "Hello from file!" in result.output
-
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_spark_code(self, shared_context):
- """Should execute Spark code."""
- code = """
-from pyspark.sql import SparkSession
-spark = SparkSession.builder.getOrCreate()
-df = spark.range(5)
-print(f"Row count: {df.count()}")
-"""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write(code)
- f.flush()
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(
- file_path=temp_path,
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- timeout=120,
- )
-
- print("\n=== Spark Execution Result ===")
- print(f"Success: {result.success}")
- print(f"Output: {result.output}")
-
- assert result.success, f"Spark execution failed: {result.error}"
- assert "Row count: 5" in result.output
-
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_error_handling(self, shared_context):
- """Should capture Python errors with details."""
- code = "x = 1 / 0 # This will raise ZeroDivisionError"
-
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write(code)
- f.flush()
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(
- file_path=temp_path,
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- timeout=120,
- )
-
- print("\n=== Error Handling Result ===")
- print(f"Success: {result.success}")
- print(f"Error: {result.error[:200] if result.error else None}...")
-
- assert not result.success, "Should have failed with division by zero"
- assert result.error is not None
- assert "ZeroDivisionError" in result.error
-
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_file_not_found(self):
- """Should handle missing file gracefully (no cluster needed)."""
- result = run_file_on_databricks(file_path="/nonexistent/path/to/file.py", timeout=120)
-
- print("\n=== File Not Found Result ===")
- print(f"Success: {result.success}")
- print(f"Error: {result.error}")
-
- assert not result.success
- assert "not found" in result.error.lower()
-
-
-@pytest.mark.integration
-class TestRunFileOnDatabricks:
- """Tests for run_file_on_databricks.
-
- Covers: language auto-detection, multi-language support, workspace_path persistence.
- """
-
- def test_python_auto_detect(self, shared_context):
- """Should auto-detect Python from .py extension."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write('print("auto-detected python")')
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(
- file_path=temp_path,
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- timeout=120,
- )
- assert result.success, f"Execution failed: {result.error}"
- assert "auto-detected python" in result.output
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_sql_auto_detect(self, shared_context):
- """Should auto-detect SQL from .sql extension."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as f:
- f.write("SELECT 42 as answer")
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(
- file_path=temp_path,
- cluster_id=shared_context["cluster_id"],
- language=None, # should auto-detect
- timeout=120,
- )
-
- logger.info(f"SQL auto-detect: success={result.success}, output={result.output}")
-
- assert result.success, f"SQL execution failed: {result.error}"
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_explicit_language_override(self, shared_context):
- """Should use explicit language even if extension differs."""
- # File has .txt extension but we specify python
- with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
- f.write('print("explicit python")')
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(
- file_path=temp_path,
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- language="python",
- timeout=120,
- )
- assert result.success, f"Execution failed: {result.error}"
- assert "explicit python" in result.output
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_empty_file(self):
- """Should reject empty files gracefully."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write("")
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(file_path=temp_path, timeout=120)
- assert not result.success
- assert "empty" in result.error.lower()
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_returns_execution_result(self, shared_context):
- """Should return ExecutionResult type."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write('print("type check")')
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(
- file_path=temp_path,
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- timeout=120,
- )
- assert isinstance(result, ExecutionResult)
- assert result.success
- d = result.to_dict()
- assert isinstance(d, dict)
- assert "success" in d
- assert "cluster_id" in d
- assert "context_id" in d
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
-
-@pytest.mark.integration
-class TestRunFileWorkspacePath:
- """Tests for run_file_on_databricks with workspace_path (persistent mode)."""
-
- @pytest.fixture(autouse=True)
- def _setup_cleanup(self):
- """Track workspace paths for cleanup."""
- self._paths_to_cleanup = []
- yield
- try:
- w = get_workspace_client()
- for path in self._paths_to_cleanup:
- try:
- w.workspace.delete(path=path, recursive=False)
- logger.info(f"Cleaned up: {path}")
- except Exception:
- pass
- except Exception:
- pass
-
- def test_workspace_path_uploads_notebook(self, shared_context):
- """Should upload file as notebook when workspace_path is provided."""
- username = get_current_username()
- ws_path = f"/Workspace/Users/{username}/.ai_dev_kit_test/file_persist_test"
- self._paths_to_cleanup.append(ws_path)
-
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write('print("persisted via run_file")')
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(
- file_path=temp_path,
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- workspace_path=ws_path,
- timeout=120,
- )
-
- logger.info(f"Workspace path result: success={result.success}")
-
- assert result.success, f"Execution failed: {result.error}"
- assert "persisted via run_file" in result.output
-
- # Verify notebook exists in workspace
- w = get_workspace_client()
- status = w.workspace.get_status(ws_path)
- assert status is not None
- finally:
- Path(temp_path).unlink(missing_ok=True)
-
- def test_workspace_path_none_no_upload(self, shared_context):
- """Without workspace_path, no notebook should be uploaded (ephemeral)."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write('print("ephemeral file")')
- temp_path = f.name
-
- try:
- result = run_file_on_databricks(
- file_path=temp_path,
- cluster_id=shared_context["cluster_id"],
- context_id=shared_context["context_id"],
- timeout=120,
- )
- assert result.success, f"Execution failed: {result.error}"
- # No workspace_path on ExecutionResult — just verify execution worked
- finally:
- Path(temp_path).unlink(missing_ok=True)
diff --git a/databricks-tools-core/tests/integration/compute/test_manage.py b/databricks-tools-core/tests/integration/compute/test_manage.py
deleted file mode 100644
index f5c7e1a7..00000000
--- a/databricks-tools-core/tests/integration/compute/test_manage.py
+++ /dev/null
@@ -1,326 +0,0 @@
-"""
-Integration tests for compute management functions.
-
-Tests create_cluster, modify_cluster, terminate_cluster, delete_cluster,
-list_node_types, list_spark_versions, create_sql_warehouse, modify_sql_warehouse,
-and delete_sql_warehouse.
-
-Requires a valid Databricks connection (e.g. DATABRICKS_CONFIG_PROFILE=E2-Demo).
-"""
-
-import logging
-import time
-import pytest
-
-from databricks_tools_core.compute import (
- create_cluster,
- modify_cluster,
- terminate_cluster,
- delete_cluster,
- list_node_types,
- list_spark_versions,
- create_sql_warehouse,
- modify_sql_warehouse,
- delete_sql_warehouse,
- get_cluster_status,
-)
-from databricks_tools_core.auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture(scope="module")
-def managed_cluster():
- """Create a test cluster and clean it up after all tests.
-
- Yields the cluster_id. The cluster is permanently deleted after the
- test module completes.
- """
- result = create_cluster(
- name="ai-dev-kit-test-manage",
- num_workers=0,
- autotermination_minutes=10,
- )
-
- assert result["cluster_id"] is not None
- cluster_id = result["cluster_id"]
-
- logger.info(f"Created test cluster: {cluster_id}")
-
- yield cluster_id
-
- # Cleanup: permanently delete
- try:
- delete_cluster(cluster_id)
- logger.info(f"Cleaned up test cluster: {cluster_id}")
- except Exception as e:
- logger.warning(f"Failed to cleanup test cluster {cluster_id}: {e}")
-
-
-@pytest.fixture(scope="module")
-def managed_warehouse():
- """Create a test SQL warehouse and clean it up after all tests.
-
- Yields the warehouse_id. The warehouse is permanently deleted after the
- test module completes.
- """
- result = create_sql_warehouse(
- name="ai-dev-kit-test-manage",
- size="2X-Small",
- auto_stop_mins=10,
- enable_serverless=True,
- )
-
- assert result["warehouse_id"] is not None
- warehouse_id = result["warehouse_id"]
-
- logger.info(f"Created test warehouse: {warehouse_id}")
-
- yield warehouse_id
-
- # Cleanup: permanently delete
- try:
- delete_sql_warehouse(warehouse_id)
- logger.info(f"Cleaned up test warehouse: {warehouse_id}")
- except Exception as e:
- logger.warning(f"Failed to cleanup test warehouse {warehouse_id}: {e}")
-
-
-@pytest.mark.integration
-class TestListNodeTypes:
- """Tests for list_node_types function."""
-
- def test_list_node_types(self):
- """Should return a non-empty list of node types."""
- node_types = list_node_types()
-
- print(f"\n=== List Node Types ===")
- print(f"Found {len(node_types)} node types")
- for nt in node_types[:5]:
- print(f" - {nt['node_type_id']} ({nt['memory_mb']}MB, {nt['num_cores']} cores)")
-
- assert isinstance(node_types, list)
- assert len(node_types) > 0
- assert "node_type_id" in node_types[0]
- assert "memory_mb" in node_types[0]
-
- def test_node_type_has_expected_fields(self):
- """Each node type should have expected fields."""
- node_types = list_node_types()
- nt = node_types[0]
-
- assert "node_type_id" in nt
- assert "memory_mb" in nt
- assert "num_gpus" in nt
- assert "description" in nt
-
-
-@pytest.mark.integration
-class TestListSparkVersions:
- """Tests for list_spark_versions function."""
-
- def test_list_spark_versions(self):
- """Should return a non-empty list of spark versions."""
- versions = list_spark_versions()
-
- print(f"\n=== List Spark Versions ===")
- print(f"Found {len(versions)} versions")
- for v in versions[:5]:
- print(f" - {v['key']}: {v['name']}")
-
- assert isinstance(versions, list)
- assert len(versions) > 0
- assert "key" in versions[0]
- assert "name" in versions[0]
-
- def test_has_lts_versions(self):
- """Should include at least one LTS version."""
- versions = list_spark_versions()
- lts = [v for v in versions if "LTS" in (v["name"] or "")]
- assert len(lts) > 0, "No LTS versions found"
-
-
-@pytest.mark.integration
-class TestCreateCluster:
- """Tests for create_cluster function."""
-
- def test_create_cluster_returns_expected_fields(self, managed_cluster):
- """managed_cluster fixture validates create_cluster returns cluster_id.
-
- This test just verifies the cluster exists.
- """
- status = get_cluster_status(managed_cluster)
-
- print(f"\n=== Created Cluster Status ===")
- print(f"Cluster ID: {status['cluster_id']}")
- print(f"Name: {status['cluster_name']}")
- print(f"State: {status['state']}")
-
- assert status["cluster_id"] == managed_cluster
- assert status["cluster_name"] == "ai-dev-kit-test-manage"
-
-
-@pytest.mark.integration
-class TestTerminateCluster:
- """Tests for terminate_cluster function."""
-
- def test_terminate_cluster(self, managed_cluster):
- """Should terminate the cluster (reversible)."""
- result = terminate_cluster(managed_cluster)
-
- print(f"\n=== Terminate Cluster ===")
- print(f"Result: {result}")
-
- assert result["cluster_id"] == managed_cluster
- assert result["state"] in ("TERMINATING", "TERMINATED")
- assert "reversible" in result["message"].lower() or "terminated" in result["message"].lower()
-
-
-@pytest.mark.integration
-class TestModifyCluster:
- """Tests for modify_cluster function.
-
- Runs after TestTerminateCluster so the cluster is in a stable (TERMINATED/TERMINATING)
- state — the edit API rejects edits on PENDING clusters.
- """
-
- def _wait_for_terminated(self, cluster_id, timeout=120):
- """Wait until cluster reaches TERMINATED state."""
- import time
- start = time.time()
- while time.time() - start < timeout:
- status = get_cluster_status(cluster_id)
- if status["state"] == "TERMINATED":
- return
- time.sleep(5)
- pytest.fail(f"Cluster did not terminate within {timeout}s")
-
- def test_modify_cluster_name(self, managed_cluster):
- """Should modify cluster name."""
- self._wait_for_terminated(managed_cluster)
-
- result = modify_cluster(
- cluster_id=managed_cluster,
- name="ai-dev-kit-test-manage-renamed",
- )
-
- print(f"\n=== Modify Cluster ===")
- print(f"Result: {result}")
-
- assert result["cluster_id"] == managed_cluster
- assert result["cluster_name"] == "ai-dev-kit-test-manage-renamed"
- assert "updated" in result["message"].lower()
-
- # Rename back for other tests
- modify_cluster(
- cluster_id=managed_cluster,
- name="ai-dev-kit-test-manage",
- )
-
-
-@pytest.mark.integration
-class TestCreateSqlWarehouse:
- """Tests for create_sql_warehouse function."""
-
- def test_create_warehouse_returns_expected_fields(self, managed_warehouse):
- """managed_warehouse fixture validates create returns warehouse_id."""
- w = get_workspace_client()
- wh = w.warehouses.get(managed_warehouse)
-
- print(f"\n=== Created Warehouse ===")
- print(f"Warehouse ID: {wh.id}")
- print(f"Name: {wh.name}")
- print(f"State: {wh.state}")
-
- assert wh.id == managed_warehouse
- assert wh.name == "ai-dev-kit-test-manage"
-
-
-@pytest.mark.integration
-class TestModifySqlWarehouse:
- """Tests for modify_sql_warehouse function."""
-
- def test_modify_warehouse_name(self, managed_warehouse):
- """Should modify warehouse name."""
- result = modify_sql_warehouse(
- warehouse_id=managed_warehouse,
- name="ai-dev-kit-test-manage-renamed",
- )
-
- print(f"\n=== Modify Warehouse ===")
- print(f"Result: {result}")
-
- assert result["warehouse_id"] == managed_warehouse
- assert result["name"] == "ai-dev-kit-test-manage-renamed"
- assert "updated" in result["message"].lower()
-
- # Rename back
- modify_sql_warehouse(
- warehouse_id=managed_warehouse,
- name="ai-dev-kit-test-manage",
- )
-
-
-@pytest.mark.integration
-class TestDeleteCluster:
- """Tests for delete_cluster function."""
-
- def test_delete_cluster_warning_message(self):
- """Should include a permanent deletion warning in the response."""
- # Create a throwaway cluster for deletion test
- result = create_cluster(
- name="ai-dev-kit-test-delete",
- num_workers=0,
- autotermination_minutes=10,
- )
- cluster_id = result["cluster_id"]
-
- try:
- delete_result = delete_cluster(cluster_id)
-
- print(f"\n=== Delete Cluster ===")
- print(f"Result: {delete_result}")
-
- assert delete_result["cluster_id"] == cluster_id
- assert delete_result["state"] == "DELETED"
- assert "permanent" in delete_result["message"].lower()
- assert "warning" in delete_result["message"].lower()
- except Exception:
- # Best-effort cleanup if delete fails
- try:
- delete_cluster(cluster_id)
- except Exception:
- pass
- raise
-
-
-@pytest.mark.integration
-class TestDeleteSqlWarehouse:
- """Tests for delete_sql_warehouse function."""
-
- def test_delete_warehouse_warning_message(self):
- """Should include a permanent deletion warning in the response."""
- # Create a throwaway warehouse for deletion test
- result = create_sql_warehouse(
- name="ai-dev-kit-test-delete",
- size="2X-Small",
- auto_stop_mins=10,
- )
- warehouse_id = result["warehouse_id"]
-
- try:
- delete_result = delete_sql_warehouse(warehouse_id)
-
- print(f"\n=== Delete Warehouse ===")
- print(f"Result: {delete_result}")
-
- assert delete_result["warehouse_id"] == warehouse_id
- assert delete_result["state"] == "DELETED"
- assert "permanent" in delete_result["message"].lower()
- assert "warning" in delete_result["message"].lower()
- except Exception:
- try:
- delete_sql_warehouse(warehouse_id)
- except Exception:
- pass
- raise
diff --git a/databricks-tools-core/tests/integration/compute/test_serverless.py b/databricks-tools-core/tests/integration/compute/test_serverless.py
deleted file mode 100644
index 7f68bb42..00000000
--- a/databricks-tools-core/tests/integration/compute/test_serverless.py
+++ /dev/null
@@ -1,226 +0,0 @@
-"""
-Integration tests for serverless compute execution (run_code_on_serverless).
-
-Tests serverless Python/SQL execution, ephemeral vs persistent modes,
-workspace_path, error handling, and input validation.
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.compute import (
- run_code_on_serverless,
- ServerlessRunResult,
-)
-from databricks_tools_core.auth import get_workspace_client, get_current_username
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestServerlessInputValidation:
- """Tests for input validation (no cluster/serverless needed)."""
-
- def test_empty_code(self):
- """Should reject empty code without submitting a run."""
- result = run_code_on_serverless(code="")
- assert not result.success
- assert result.state == "INVALID_INPUT"
- assert "empty" in result.error.lower()
-
- def test_whitespace_only_code(self):
- """Should reject whitespace-only code."""
- result = run_code_on_serverless(code=" \n\n ")
- assert not result.success
- assert result.state == "INVALID_INPUT"
-
- def test_unsupported_language(self):
- """Should reject unsupported languages."""
- result = run_code_on_serverless(code="println('hi')", language="scala")
- assert not result.success
- assert result.state == "INVALID_INPUT"
- assert "scala" in result.error.lower()
-
- def test_result_is_serverless_run_result(self):
- """Should return ServerlessRunResult type even on validation errors."""
- result = run_code_on_serverless(code="", language="python")
- assert isinstance(result, ServerlessRunResult)
-
- def test_to_dict(self):
- """Should serialize to dict properly."""
- result = run_code_on_serverless(code="")
- d = result.to_dict()
- assert isinstance(d, dict)
- assert "success" in d
- assert "output" in d
- assert "error" in d
- assert "run_id" in d
- assert "state" in d
-
-
-@pytest.mark.integration
-class TestServerlessPythonExecution:
- """Tests for Python code execution on serverless compute."""
-
- def test_simple_python_dbutils_exit(self):
- """Should capture output from dbutils.notebook.exit()."""
- code = 'dbutils.notebook.exit("hello from serverless")'
- result = run_code_on_serverless(code=code, language="python")
-
- logger.info(f"Result: success={result.success}, output={result.output}, "
- f"duration={result.duration_seconds}s")
-
- assert result.success, f"Execution failed: {result.error}"
- assert "hello from serverless" in result.output
- assert result.run_id is not None
- assert result.run_url is not None
- assert result.duration_seconds is not None
- assert result.state == "SUCCESS"
-
- def test_python_computation(self):
- """Should execute computation and return result via dbutils.notebook.exit()."""
- code = """
-import math
-result = sum(math.factorial(i) for i in range(10))
-dbutils.notebook.exit(str(result))
-"""
- result = run_code_on_serverless(code=code, language="python")
-
- assert result.success, f"Execution failed: {result.error}"
- assert "409114" in result.output # sum of 0! through 9!
-
- def test_python_error_handling(self):
- """Should capture Python errors with traceback."""
- code = """
-x = 1 / 0
-"""
- result = run_code_on_serverless(code=code, language="python")
-
- logger.info(f"Error result: success={result.success}, error={result.error[:200] if result.error else None}")
-
- assert not result.success
- assert result.error is not None
- assert "ZeroDivisionError" in result.error
- assert result.state == "FAILED"
-
- def test_python_with_spark(self):
- """Should have access to Spark on serverless."""
- code = """
-df = spark.range(10)
-count = df.count()
-dbutils.notebook.exit(str(count))
-"""
- result = run_code_on_serverless(code=code, language="python")
-
- assert result.success, f"Execution failed: {result.error}"
- assert "10" in result.output
-
- def test_custom_run_name(self):
- """Should accept custom run name."""
- result = run_code_on_serverless(
- code='dbutils.notebook.exit("named run")',
- run_name="test_custom_name_integration",
- )
-
- assert result.success, f"Execution failed: {result.error}"
- assert "named run" in result.output
-
-
-@pytest.mark.integration
-class TestServerlessSQLExecution:
- """Tests for SQL execution on serverless compute."""
-
- def test_sql_ddl(self):
- """Should execute SQL DDL statements."""
- code = """
-CREATE DATABASE IF NOT EXISTS ai_dev_kit_serverless_test;
-"""
- result = run_code_on_serverless(code=code, language="sql")
-
- logger.info(f"SQL DDL result: success={result.success}, state={result.state}")
-
- assert result.success, f"SQL DDL failed: {result.error}"
-
-
-@pytest.mark.integration
-class TestServerlessEphemeralMode:
- """Tests for ephemeral mode (default - temp notebook cleaned up)."""
-
- def test_ephemeral_no_workspace_path_in_result(self):
- """Ephemeral mode should not include workspace_path in result."""
- result = run_code_on_serverless(
- code='dbutils.notebook.exit("ephemeral")',
- )
-
- assert result.success, f"Execution failed: {result.error}"
- assert result.workspace_path is None
-
- def test_ephemeral_to_dict_no_workspace_path(self):
- """Ephemeral mode should not include workspace_path in dict."""
- result = run_code_on_serverless(
- code='dbutils.notebook.exit("ephemeral dict")',
- )
-
- assert result.success, f"Execution failed: {result.error}"
- d = result.to_dict()
- assert "workspace_path" not in d
-
-
-@pytest.mark.integration
-class TestServerlessPersistentMode:
- """Tests for persistent mode (workspace_path provided, notebook saved)."""
-
- @pytest.fixture(autouse=True)
- def _setup_cleanup(self):
- """Track workspace paths for cleanup after each test."""
- self._paths_to_cleanup = []
- yield
- # Cleanup persisted notebooks
- try:
- w = get_workspace_client()
- for path in self._paths_to_cleanup:
- try:
- w.workspace.delete(path=path, recursive=False)
- logger.info(f"Cleaned up: {path}")
- except Exception:
- pass
- except Exception:
- pass
-
- def test_persistent_saves_notebook(self):
- """Persistent mode should save notebook at workspace_path."""
- username = get_current_username()
- ws_path = f"/Workspace/Users/{username}/.ai_dev_kit_test/persistent_test"
- self._paths_to_cleanup.append(ws_path)
-
- result = run_code_on_serverless(
- code='dbutils.notebook.exit("persisted!")',
- workspace_path=ws_path,
- )
-
- logger.info(f"Persistent result: success={result.success}, "
- f"workspace_path={result.workspace_path}")
-
- assert result.success, f"Execution failed: {result.error}"
- assert result.workspace_path == ws_path
- assert "persisted!" in result.output
-
- # Verify notebook exists in workspace
- w = get_workspace_client()
- status = w.workspace.get_status(ws_path)
- assert status is not None
-
- def test_persistent_to_dict_includes_workspace_path(self):
- """Persistent mode should include workspace_path in dict."""
- username = get_current_username()
- ws_path = f"/Workspace/Users/{username}/.ai_dev_kit_test/persistent_dict_test"
- self._paths_to_cleanup.append(ws_path)
-
- result = run_code_on_serverless(
- code='dbutils.notebook.exit("dict test")',
- workspace_path=ws_path,
- )
-
- assert result.success, f"Execution failed: {result.error}"
- d = result.to_dict()
- assert d["workspace_path"] == ws_path
diff --git a/databricks-tools-core/tests/integration/jobs/__init__.py b/databricks-tools-core/tests/integration/jobs/__init__.py
deleted file mode 100644
index 9d4deb08..00000000
--- a/databricks-tools-core/tests/integration/jobs/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""
-Integration tests for Jobs module.
-
-These tests verify job creation, management, and run operations
-using a real Databricks workspace.
-"""
diff --git a/databricks-tools-core/tests/integration/jobs/conftest.py b/databricks-tools-core/tests/integration/jobs/conftest.py
deleted file mode 100644
index 263aba77..00000000
--- a/databricks-tools-core/tests/integration/jobs/conftest.py
+++ /dev/null
@@ -1,118 +0,0 @@
-"""
-Pytest fixtures for jobs integration tests.
-
-Provides fixtures for job creation, cleanup, and test data.
-Uses serverless compute for all job executions.
-"""
-
-import base64
-import logging
-import pytest
-import uuid
-
-from databricks.sdk.service.workspace import ImportFormat, Language
-
-from databricks_tools_core.auth import get_workspace_client
-from databricks_tools_core.jobs import delete_job
-
-logger = logging.getLogger(__name__)
-
-# Fixed test job name prefix for easy identification and cleanup
-TEST_JOB_PREFIX = "ai_dev_kit_test_job"
-
-
-@pytest.fixture(scope="module")
-def test_job_name() -> str:
- """
- Generate a unique test job name for this test session.
-
- Uses UUID suffix to avoid conflicts if multiple test runs
- happen simultaneously.
- """
- unique_suffix = str(uuid.uuid4())[:8]
- return f"{TEST_JOB_PREFIX}_{unique_suffix}"
-
-
-@pytest.fixture(scope="module")
-def test_notebook_path() -> str:
- """
- Create a simple test notebook in the workspace.
-
- Returns the workspace path to the notebook.
- """
- w = get_workspace_client()
- user = w.current_user.me()
- notebook_path = f"/Users/{user.user_name}/test_jobs/test_notebook"
-
- # Create notebook with simple Python code
- # Keep it short to minimize serverless execution time
- notebook_content = """# Databricks notebook source
-# Test notebook for jobs integration tests
-print("Test job executed successfully")
-dbutils.notebook.exit("success")
-"""
-
- logger.info(f"Creating test notebook: {notebook_path}")
-
- try:
- # Create parent folder first
- parent_folder = "/".join(notebook_path.split("/")[:-1])
- w.workspace.mkdirs(parent_folder)
- logger.info(f"Created parent folder: {parent_folder}")
-
- # Import notebook (creates it if doesn't exist)
- # Content must be base64 encoded
- content_b64 = base64.b64encode(notebook_content.encode("utf-8")).decode("utf-8")
- w.workspace.import_(
- path=notebook_path,
- format=ImportFormat.SOURCE,
- language=Language.PYTHON,
- content=content_b64,
- overwrite=True,
- )
- logger.info(f"Test notebook created: {notebook_path}")
- except Exception as e:
- logger.error(f"Failed to create test notebook: {e}")
- raise
-
- yield notebook_path
-
- # Cleanup: Delete notebook and folder after tests
- try:
- logger.info(f"Cleaning up test notebook: {notebook_path}")
- w.workspace.delete(notebook_path)
- # Also try to delete the parent folder
- parent_folder = "/".join(notebook_path.split("/")[:-1])
- w.workspace.delete(parent_folder, recursive=True)
- except Exception as e:
- logger.warning(f"Failed to cleanup test notebook: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_job():
- """
- Fixture to track and cleanup test jobs created during tests.
-
- Usage:
- def test_create_job(cleanup_job):
- job = create_job(...)
- cleanup_job(job["job_id"]) # Register for cleanup
- """
- job_ids_to_cleanup = []
-
- def register_job(job_id: int):
- """Register a job ID for cleanup after the test."""
- if job_id and job_id not in job_ids_to_cleanup:
- job_ids_to_cleanup.append(job_id)
- logger.info(f"Registered job {job_id} for cleanup")
-
- yield register_job
-
- # Cleanup all registered jobs
- for job_id in job_ids_to_cleanup:
- try:
- logger.info(f"Cleaning up job: {job_id}")
- delete_job(job_id=job_id)
- logger.info(f"Job {job_id} deleted successfully")
- except Exception as e:
- logger.warning(f"Failed to cleanup job {job_id}: {e}")
diff --git a/databricks-tools-core/tests/integration/jobs/test_jobs.py b/databricks-tools-core/tests/integration/jobs/test_jobs.py
deleted file mode 100644
index cc40d5d6..00000000
--- a/databricks-tools-core/tests/integration/jobs/test_jobs.py
+++ /dev/null
@@ -1,343 +0,0 @@
-"""
-Integration tests for job CRUD operations.
-
-Tests the databricks_tools_core.jobs functions:
-- list_jobs
-- find_job_by_name
-- create_job
-- get_job
-- update_job
-- delete_job
-
-All tests use serverless compute for job execution.
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.jobs import (
- list_jobs,
- find_job_by_name,
- create_job,
- get_job,
- update_job,
- delete_job,
- JobError,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestListJobs:
- """Tests for listing jobs."""
-
- def test_list_jobs(self):
- """Should list jobs successfully using our core function."""
- logger.info("Testing list_jobs")
-
- # List jobs with limit
- jobs = list_jobs(limit=10)
-
- logger.info(f"Found {len(jobs)} jobs")
- for job in jobs[:3]:
- logger.info(f" - {job.get('name')} (ID: {job.get('job_id')})")
-
- # Verify response
- assert isinstance(jobs, list)
- # Verify dict structure
- if len(jobs) > 0:
- assert "job_id" in jobs[0]
- assert "name" in jobs[0]
-
- def test_list_jobs_with_name_filter(self):
- """Should filter jobs by name."""
- logger.info("Testing list_jobs with name filter")
-
- # This may or may not find jobs depending on workspace
- jobs = list_jobs(name="test", limit=5)
-
- logger.info(f"Found {len(jobs)} jobs matching 'test'")
- assert isinstance(jobs, list)
-
-
-@pytest.mark.integration
-class TestFindJobByName:
- """Tests for finding jobs by name."""
-
- def test_find_job_by_name_not_found(self):
- """Should return None when job doesn't exist."""
- non_existent_name = "this_job_definitely_does_not_exist_12345"
- job_id = find_job_by_name(name=non_existent_name)
-
- logger.info(f"Search for non-existent job '{non_existent_name}': {job_id}")
- assert job_id is None, "Should return None for non-existent job"
-
- def test_find_job_by_name_exists(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should find job when it exists."""
- # Create a test job using our core function (serverless)
- job_name = "test_find_by_name_job"
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name=job_name, tasks=tasks)
- cleanup_job(job["job_id"])
-
- # Try to find it using our core function
- found_job_id = find_job_by_name(name=job_name)
-
- logger.info(f"Search for '{job_name}': {found_job_id}")
- assert found_job_id is not None, "Should find the created job"
- assert found_job_id == job["job_id"], "Should return correct job ID"
-
-
-@pytest.mark.integration
-class TestCreateJob:
- """Tests for creating jobs."""
-
- def test_create_job_basic(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should create a simple notebook task job with serverless."""
- job_name = "test_create_job_basic"
- logger.info(f"Creating job: {job_name}")
-
- # Create job using our core function (serverless - no cluster specified)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- "timeout_seconds": 3600,
- }
- ]
- job = create_job(name=job_name, tasks=tasks, max_concurrent_runs=1)
-
- # Register for cleanup
- cleanup_job(job["job_id"])
-
- logger.info(f"Job created: {job_name} (ID: {job['job_id']})")
-
- # Verify creation
- assert job["job_id"] is not None, "Job ID should be set"
-
- def test_create_job_with_multiple_tasks(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should create a job with multiple tasks."""
- job_name = "test_create_job_multi_task"
- logger.info(f"Creating multi-task job: {job_name}")
-
- # Create job with two tasks using our core function (serverless)
- tasks = [
- {
- "task_key": "task_1",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- },
- {
- "task_key": "task_2",
- "depends_on": [{"task_key": "task_1"}],
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- },
- ]
- job = create_job(name=job_name, tasks=tasks, max_concurrent_runs=1)
-
- # Register for cleanup
- cleanup_job(job["job_id"])
-
- logger.info(f"Multi-task job created: {job['job_id']}")
-
- # Verify creation by getting the job
- job_details = get_job(job_id=job["job_id"])
- assert job_details["job_id"] is not None
- assert len(job_details["settings"]["tasks"]) == 2, "Should have two tasks"
-
-
-@pytest.mark.integration
-class TestGetJob:
- """Tests for getting job details."""
-
- def test_get_job(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should get job details by ID."""
- # Create a test job
- job_name = "test_get_job"
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- created_job = create_job(name=job_name, tasks=tasks)
- cleanup_job(created_job["job_id"])
-
- logger.info(f"Getting job details for {created_job['job_id']}")
-
- # Get job details using our core function
- job = get_job(job_id=created_job["job_id"])
-
- logger.info(f"Retrieved job: {job['settings']['name']}")
-
- # Verify details
- assert job["job_id"] == created_job["job_id"], "Job ID should match"
- assert job["settings"]["name"] == job_name, "Job name should match"
- assert job["settings"]["tasks"] is not None, "Should have tasks"
-
- def test_get_job_not_found(self):
- """Should raise JobError when job doesn't exist."""
- non_existent_id = 999999999
-
- logger.info(f"Attempting to get non-existent job: {non_existent_id}")
-
- with pytest.raises(JobError) as exc_info:
- get_job(job_id=non_existent_id)
-
- logger.info(f"Expected JobError: {exc_info.value}")
- assert exc_info.value.job_id == non_existent_id
-
-
-@pytest.mark.integration
-class TestUpdateJob:
- """Tests for updating jobs."""
-
- def test_update_job_name(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should update job name successfully."""
- # Create a test job
- original_name = "test_update_job_original"
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name=original_name, tasks=tasks)
- cleanup_job(job["job_id"])
-
- new_name = "test_update_job_renamed"
- logger.info(f"Updating job {job['job_id']} to '{new_name}'")
-
- # Update job name using our core function
- update_job(job_id=job["job_id"], name=new_name)
-
- # Verify update
- updated_job = get_job(job_id=job["job_id"])
- logger.info(f"Job updated: {updated_job['settings']['name']}")
-
- assert updated_job["settings"]["name"] == new_name, "Name should be updated"
- assert updated_job["job_id"] == job["job_id"], "Job ID should stay same"
-
- def test_update_job_max_concurrent_runs(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should update max concurrent runs."""
- # Create a test job with max_concurrent_runs=1
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(
- name="test_update_concurrent",
- tasks=tasks,
- max_concurrent_runs=1,
- )
- cleanup_job(job["job_id"])
-
- logger.info(f"Updating max_concurrent_runs for job {job['job_id']}")
-
- # Update using our core function
- update_job(job_id=job["job_id"], max_concurrent_runs=5)
-
- # Verify update
- updated_job = get_job(job_id=job["job_id"])
- max_runs = updated_job["settings"]["max_concurrent_runs"]
- logger.info(f"Max concurrent runs updated to: {max_runs}")
-
- assert max_runs == 5, "Max concurrent runs should be updated"
-
-
-@pytest.mark.integration
-class TestDeleteJob:
- """Tests for deleting jobs."""
-
- def test_delete_job(
- self,
- test_notebook_path: str,
- ):
- """Should delete a job successfully."""
- # Create a test job
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name="test_delete_job", tasks=tasks)
- job_id = job["job_id"]
-
- logger.info(f"Deleting job: {job_id}")
-
- # Delete the job using our core function
- delete_job(job_id=job_id)
-
- logger.info(f"Job {job_id} deleted")
-
- # Verify deletion - try to get the job (should fail)
- with pytest.raises(JobError):
- get_job(job_id=job_id)
-
- def test_delete_job_not_found(self):
- """Should raise JobError when deleting non-existent job."""
- non_existent_id = 999999999
-
- logger.info(f"Deleting non-existent job: {non_existent_id}")
-
- with pytest.raises(JobError) as exc_info:
- delete_job(job_id=non_existent_id)
-
- logger.info(f"Expected JobError: {exc_info.value}")
- assert exc_info.value.job_id == non_existent_id
diff --git a/databricks-tools-core/tests/integration/jobs/test_runs.py b/databricks-tools-core/tests/integration/jobs/test_runs.py
deleted file mode 100644
index b435fb82..00000000
--- a/databricks-tools-core/tests/integration/jobs/test_runs.py
+++ /dev/null
@@ -1,412 +0,0 @@
-"""
-Integration tests for job run operations.
-
-Tests the databricks_tools_core.jobs functions:
-- run_job_now
-- get_run
-- cancel_run
-- list_runs
-- wait_for_run
-
-All tests use serverless compute for job execution.
-"""
-
-import logging
-import pytest
-import time
-
-from databricks_tools_core.jobs import (
- create_job,
- run_job_now,
- get_run,
- cancel_run,
- list_runs,
- wait_for_run,
- JobError,
- RunLifecycleState,
- RunResultState,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestRunJobNow:
- """Tests for triggering job runs."""
-
- def test_run_job_now(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should trigger a job run successfully."""
- # Create a test job (serverless)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name="test_run_job_now", tasks=tasks)
- cleanup_job(job["job_id"])
-
- logger.info(f"Triggering run for job: {job['job_id']}")
-
- # Run the job using our core function
- run_id = run_job_now(job_id=job["job_id"])
-
- logger.info(f"Run triggered: {run_id}")
-
- # Verify run was created
- assert run_id is not None, "Run ID should be returned"
- assert isinstance(run_id, int), "Run ID should be an integer"
-
- # Get run details to verify state
- run_details = get_run(run_id=run_id)
- logger.info(f"Run state: {run_details.get('state', {})}")
-
- assert run_details["run_id"] == run_id, "Run ID should match"
-
- def test_run_job_with_parameters(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should run job with notebook parameters."""
- # Create a test job (serverless)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name="test_run_with_params", tasks=tasks)
- cleanup_job(job["job_id"])
-
- logger.info(f"Running job with parameters: {job['job_id']}")
-
- # Run with parameters using our core function
- params = {"param1": "value1", "param2": "value2"}
- run_id = run_job_now(job_id=job["job_id"], notebook_params=params)
-
- logger.info(f"Run triggered with params: {run_id}")
-
- assert run_id is not None, "Run ID should be returned"
-
-
-@pytest.mark.integration
-class TestGetRun:
- """Tests for getting run details."""
-
- def test_get_run(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should get run status and details."""
- # Create and run a job (serverless)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name="test_get_run", tasks=tasks)
- cleanup_job(job["job_id"])
- run_id = run_job_now(job_id=job["job_id"])
-
- logger.info(f"Getting run details for: {run_id}")
-
- # Get run details using our core function
- run_details = get_run(run_id=run_id)
-
- logger.info(f"Run details: {run_details.get('state', {})}")
-
- # Verify details
- assert run_details["run_id"] == run_id, "Run ID should match"
- assert run_details["job_id"] == job["job_id"], "Job ID should match"
- assert "state" in run_details, "Should have state"
-
- def test_get_run_not_found(self):
- """Should raise JobError when run doesn't exist."""
- non_existent_run_id = 999999999
-
- logger.info(f"Attempting to get non-existent run: {non_existent_run_id}")
-
- with pytest.raises(JobError) as exc_info:
- get_run(run_id=non_existent_run_id)
-
- logger.info(f"Expected JobError: {exc_info.value}")
- assert exc_info.value.run_id == non_existent_run_id
-
-
-@pytest.mark.integration
-class TestCancelRun:
- """Tests for canceling job runs."""
-
- def test_cancel_run(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should cancel a running job."""
- # Create and run a job (serverless)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name="test_cancel_run", tasks=tasks)
- cleanup_job(job["job_id"])
- run_id = run_job_now(job_id=job["job_id"])
-
- logger.info(f"Canceling run: {run_id}")
-
- # Wait a moment to ensure run has started
- time.sleep(3)
-
- # Cancel the run using our core function
- cancel_run(run_id=run_id)
-
- logger.info(f"Run {run_id} cancel requested")
-
- # Poll for cancellation to take effect (serverless may take longer)
- max_wait = 30
- poll_interval = 3
- elapsed = 0
- lifecycle_state = None
-
- while elapsed < max_wait:
- time.sleep(poll_interval)
- elapsed += poll_interval
-
- run_details = get_run(run_id=run_id)
- state = run_details.get("state", {})
- lifecycle_state = state.get("life_cycle_state")
- logger.info(f"Run state after {elapsed}s: {lifecycle_state}")
-
- if lifecycle_state in ["TERMINATING", "TERMINATED"]:
- break
-
- # Run should be terminating or terminated
- assert lifecycle_state in [
- "TERMINATING",
- "TERMINATED",
- ], f"Run should be terminating or terminated, got {lifecycle_state}"
-
-
-@pytest.mark.integration
-class TestListRuns:
- """Tests for listing job runs."""
-
- def test_list_runs_for_job(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should list runs for a specific job."""
- # Create a job and trigger multiple runs (serverless)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(
- name="test_list_runs",
- tasks=tasks,
- max_concurrent_runs=5, # Allow concurrent runs for this test
- )
- cleanup_job(job["job_id"])
-
- logger.info(f"Triggering 3 runs for job: {job['job_id']}")
-
- # Trigger 3 runs using our core function
- run_ids = []
- for i in range(3):
- run_id = run_job_now(job_id=job["job_id"])
- run_ids.append(run_id)
- logger.info(f"Run {i + 1} triggered: {run_id}")
- time.sleep(1) # Small delay between runs
-
- # List runs for this job using our core function
- logger.info(f"Listing runs for job: {job['job_id']}")
- runs = list_runs(job_id=job["job_id"], limit=10)
-
- logger.info(f"Found {len(runs)} runs for job {job['job_id']}")
-
- # Verify we got the runs
- assert len(runs) >= 3, f"Should have at least 3 runs, got {len(runs)}"
-
- # Verify our run IDs are in the list
- found_run_ids = {run["run_id"] for run in runs}
- for run_id in run_ids:
- assert run_id in found_run_ids, f"Run {run_id} should be in the list"
-
- def test_list_runs_with_limit(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should respect limit parameter when listing runs."""
- # Create a job (serverless)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(
- name="test_list_runs_limit",
- tasks=tasks,
- max_concurrent_runs=5,
- )
- cleanup_job(job["job_id"])
-
- # Trigger 5 runs
- for _ in range(5):
- run_job_now(job_id=job["job_id"])
- time.sleep(1)
-
- # List with limit=3 using our core function
- runs = list_runs(job_id=job["job_id"], limit=3)
-
- logger.info(f"Listed runs with limit=3: found {len(runs)}")
-
- assert len(runs) <= 3, f"Should have at most 3 runs, got {len(runs)}"
-
-
-@pytest.mark.integration
-class TestWaitForRun:
- """Tests for waiting for run completion."""
-
- def test_wait_for_run_success(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should wait for run to complete successfully."""
- # Create and run a job (serverless)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name="test_wait_for_run", tasks=tasks)
- cleanup_job(job["job_id"])
- run_id = run_job_now(job_id=job["job_id"])
-
- logger.info(f"Waiting for run {run_id} to complete")
-
- # Wait for run using our core function (timeout after 5 minutes)
- result = wait_for_run(run_id=run_id, timeout=300, poll_interval=10)
-
- logger.info(f"Run completed: lifecycle={result.lifecycle_state}, result={result.result_state}")
-
- # Verify completion using our JobRunResult model
- # lifecycle_state is stored as string in JobRunResult
- assert result.lifecycle_state == RunLifecycleState.TERMINATED.value
- assert result.result_state in [
- RunResultState.SUCCESS.value,
- RunResultState.FAILED.value,
- RunResultState.CANCELED.value,
- ], "Run should have a result state"
-
- def test_wait_for_run_with_cancellation(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should wait for canceled run to complete."""
- # Create and run a job (serverless)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name="test_wait_canceled", tasks=tasks)
- cleanup_job(job["job_id"])
- run_id = run_job_now(job_id=job["job_id"])
-
- logger.info(f"Starting run {run_id} and canceling it")
-
- # Wait a moment, then cancel
- time.sleep(2)
- cancel_run(run_id=run_id)
-
- logger.info("Waiting for canceled run to complete")
-
- # Wait for run using our core function
- result = wait_for_run(run_id=run_id, timeout=300, poll_interval=10)
-
- logger.info(f"Canceled run completed: lifecycle={result.lifecycle_state}, result={result.result_state}")
-
- # Verify completion (states stored as strings in JobRunResult)
- assert result.lifecycle_state == RunLifecycleState.TERMINATED.value
- assert result.result_state == RunResultState.CANCELED.value
-
- def test_wait_for_run_result_object(
- self,
- test_notebook_path: str,
- cleanup_job,
- ):
- """Should return JobRunResult with all expected fields."""
- # Create and run a job (serverless)
- tasks = [
- {
- "task_key": "test_task",
- "notebook_task": {
- "notebook_path": test_notebook_path,
- "source": "WORKSPACE",
- },
- }
- ]
- job = create_job(name="test_wait_result", tasks=tasks)
- cleanup_job(job["job_id"])
- run_id = run_job_now(job_id=job["job_id"])
-
- logger.info(f"Waiting for run {run_id}")
-
- # Wait for run using our core function
- result = wait_for_run(run_id=run_id, timeout=300, poll_interval=10)
-
- # Verify JobRunResult fields
- assert result.run_id == run_id, "run_id should match"
- assert result.job_id == job["job_id"], "job_id should match"
- assert result.lifecycle_state is not None, "Should have lifecycle_state"
- assert result.run_page_url is not None, "Should have run_page_url"
-
- # Test to_dict() method
- result_dict = result.to_dict()
- assert "run_id" in result_dict
- assert "success" in result_dict
- assert "lifecycle_state" in result_dict
-
- logger.info(f"JobRunResult.to_dict(): {result_dict}")
diff --git a/databricks-tools-core/tests/integration/lakebase/__init__.py b/databricks-tools-core/tests/integration/lakebase/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/databricks-tools-core/tests/integration/lakebase/conftest.py b/databricks-tools-core/tests/integration/lakebase/conftest.py
deleted file mode 100644
index 98ada193..00000000
--- a/databricks-tools-core/tests/integration/lakebase/conftest.py
+++ /dev/null
@@ -1,204 +0,0 @@
-"""
-Pytest fixtures for Lakebase Provisioned integration tests.
-
-Provides session-scoped fixtures for:
-- Lakebase instance (CU_1, created once, shared across all Lakebase tests)
-- Source Delta table for synced table tests
-- Cleanup helpers
-"""
-
-import logging
-import time
-import uuid
-
-import pytest
-
-logger = logging.getLogger(__name__)
-
-# Unique suffix to avoid collisions across test runs
-_RUN_ID = str(uuid.uuid4())[:8]
-
-LB_INSTANCE_NAME = f"lb-test-{_RUN_ID}"
-LB_CATALOG_PREFIX = "lb_test_cat"
-LB_SYNCED_TABLE_PREFIX = "lb_test_sync"
-
-
-def _wait_for_instance_running(name: str, timeout: int = 1200, poll: int = 20):
- """Poll instance until RUNNING or timeout."""
- from databricks_tools_core.lakebase import get_lakebase_instance
-
- deadline = time.time() + timeout
- while time.time() < deadline:
- inst = get_lakebase_instance(name)
- state = inst.get("state", "")
- stopped = inst.get("stopped")
- logger.info(f"Instance '{name}' state: {state}, stopped: {stopped}")
-
- # State varies by SDK version (e.g. "DatabaseInstanceState.RUNNING",
- # "DatabaseInstanceState.AVAILABLE")
- state_upper = state.upper()
- if any(s in state_upper for s in ("RUNNING", "AVAILABLE", "ACTIVE")):
- return inst
- if state == "NOT_FOUND":
- raise RuntimeError(f"Instance '{name}' not found")
- if "FAILED" in state.upper() or "ERROR" in state.upper():
- raise RuntimeError(f"Instance '{name}' in terminal state: {state}")
- time.sleep(poll)
- raise TimeoutError(f"Instance '{name}' not RUNNING within {timeout}s")
-
-
-@pytest.fixture(scope="session")
-def lakebase_instance_name():
- """
- Create a CU_1 Lakebase instance for the test session.
-
- Waits for RUNNING state before yielding.
- Deletes the instance after all tests complete.
- """
- from databricks_tools_core.lakebase import (
- create_lakebase_instance,
- delete_lakebase_instance,
- get_lakebase_instance,
- )
-
- name = LB_INSTANCE_NAME
- logger.info(f"Creating Lakebase instance: {name}")
-
- # Check if it already exists (from a previous failed run)
- existing = get_lakebase_instance(name)
- if existing.get("state") != "NOT_FOUND":
- logger.info(f"Instance '{name}' already exists (state: {existing.get('state')}), reusing")
- # If it's stopped, we still proceed - tests that need it running will handle it
- else:
- result = create_lakebase_instance(
- name=name,
- capacity="CU_1",
- stopped=False,
- )
- logger.info(f"Instance creation result: {result}")
-
- # Wait for RUNNING
- _wait_for_instance_running(name)
- logger.info(f"Instance '{name}' is RUNNING")
-
- yield name
-
- # Teardown
- logger.info(f"Deleting Lakebase instance: {name}")
- try:
- delete_lakebase_instance(name=name, force=False, purge=True)
- logger.info(f"Instance '{name}' deleted")
- except Exception as e:
- logger.warning(f"Failed to delete instance '{name}': {e}")
-
-
-@pytest.fixture(scope="module")
-def source_delta_table(test_catalog, test_schema, warehouse_id):
- """
- Create a small Delta table for synced table tests.
-
- Returns the fully qualified table name.
- """
- from databricks_tools_core.sql import execute_sql
-
- table_name = f"{test_catalog}.{test_schema}.lb_source_{_RUN_ID}"
- logger.info(f"Creating source Delta table: {table_name}")
-
- execute_sql(
- sql_query=f"""
- CREATE OR REPLACE TABLE {table_name} (
- id BIGINT,
- name STRING,
- email STRING,
- score DOUBLE
- )
- """,
- warehouse_id=warehouse_id,
- )
-
- execute_sql(
- sql_query=f"""
- INSERT INTO {table_name} VALUES
- (1, 'Alice', 'alice@test.com', 95.5),
- (2, 'Bob', 'bob@test.com', 87.3),
- (3, 'Charlie', 'charlie@test.com', 92.1)
- """,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Source Delta table created: {table_name}")
- yield table_name
-
- # Cleanup handled by schema teardown
-
-
-@pytest.fixture(scope="function")
-def unique_name() -> str:
- """Generate a unique name suffix for test resources."""
- return str(uuid.uuid4())[:8]
-
-
-@pytest.fixture(scope="function")
-def cleanup_instances():
- """Track and cleanup Lakebase instances created during a test."""
- from databricks_tools_core.lakebase import delete_lakebase_instance
-
- instances_to_cleanup = []
-
- def register(name: str):
- if name not in instances_to_cleanup:
- instances_to_cleanup.append(name)
- logger.info(f"Registered instance for cleanup: {name}")
-
- yield register
-
- for inst_name in instances_to_cleanup:
- try:
- logger.info(f"Cleaning up instance: {inst_name}")
- delete_lakebase_instance(name=inst_name, force=False, purge=True)
- except Exception as e:
- logger.warning(f"Failed to cleanup instance {inst_name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_catalogs():
- """Track and cleanup Lakebase catalogs created during a test."""
- from databricks_tools_core.lakebase import delete_lakebase_catalog
-
- catalogs_to_cleanup = []
-
- def register(name: str):
- if name not in catalogs_to_cleanup:
- catalogs_to_cleanup.append(name)
- logger.info(f"Registered catalog for cleanup: {name}")
-
- yield register
-
- for cat_name in catalogs_to_cleanup:
- try:
- logger.info(f"Cleaning up Lakebase catalog: {cat_name}")
- delete_lakebase_catalog(cat_name)
- except Exception as e:
- logger.warning(f"Failed to cleanup catalog {cat_name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_synced_tables():
- """Track and cleanup synced tables created during a test."""
- from databricks_tools_core.lakebase import delete_synced_table
-
- tables_to_cleanup = []
-
- def register(table_name: str):
- if table_name not in tables_to_cleanup:
- tables_to_cleanup.append(table_name)
- logger.info(f"Registered synced table for cleanup: {table_name}")
-
- yield register
-
- for tbl_name in tables_to_cleanup:
- try:
- logger.info(f"Cleaning up synced table: {tbl_name}")
- delete_synced_table(tbl_name)
- except Exception as e:
- logger.warning(f"Failed to cleanup synced table {tbl_name}: {e}")
diff --git a/databricks-tools-core/tests/integration/lakebase/test_catalogs.py b/databricks-tools-core/tests/integration/lakebase/test_catalogs.py
deleted file mode 100644
index b91ba67d..00000000
--- a/databricks-tools-core/tests/integration/lakebase/test_catalogs.py
+++ /dev/null
@@ -1,156 +0,0 @@
-"""
-Integration tests for Lakebase Unity Catalog registration.
-
-Tests:
-- create_lakebase_catalog
-- get_lakebase_catalog
-- delete_lakebase_catalog
-
-Requires a running Lakebase instance.
-"""
-
-import logging
-
-import pytest
-
-from databricks_tools_core.lakebase import (
- create_lakebase_catalog,
- delete_lakebase_catalog,
- get_lakebase_catalog,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestCreateCatalog:
- """Tests for registering a Lakebase instance as a UC catalog."""
-
- def test_create_catalog(self, lakebase_instance_name: str, unique_name, cleanup_catalogs):
- """Should register a Lakebase instance as a UC catalog."""
- catalog_name = f"lb_test_cat_{unique_name}"
- cleanup_catalogs(catalog_name)
-
- try:
- result = create_lakebase_catalog(
- name=catalog_name,
- instance_name=lakebase_instance_name,
- database_name="databricks_postgres",
- create_database_if_not_exists=True,
- )
- except Exception as e:
- if "CREATE CATALOG" in str(e) or "permission" in str(e).lower():
- pytest.skip("User lacks CREATE CATALOG permission on this metastore")
- raise
-
- logger.info(f"Create catalog result: {result}")
-
- assert result["name"] == catalog_name
- assert result["instance_name"] == lakebase_instance_name
- assert result["status"] in ("created", "ALREADY_EXISTS")
-
- def test_create_duplicate_catalog(self, lakebase_instance_name: str, unique_name, cleanup_catalogs):
- """Should return ALREADY_EXISTS for duplicate catalog."""
- catalog_name = f"lb_test_cat_dup_{unique_name}"
- cleanup_catalogs(catalog_name)
-
- # Create first
- try:
- create_lakebase_catalog(
- name=catalog_name,
- instance_name=lakebase_instance_name,
- database_name="databricks_postgres",
- create_database_if_not_exists=True,
- )
- except Exception as e:
- if "CREATE CATALOG" in str(e) or "permission" in str(e).lower():
- pytest.skip("User lacks CREATE CATALOG permission on this metastore")
- raise
-
- # Create again
- result = create_lakebase_catalog(
- name=catalog_name,
- instance_name=lakebase_instance_name,
- )
-
- assert result["status"] == "ALREADY_EXISTS"
-
-
-@pytest.mark.integration
-class TestGetCatalog:
- """Tests for getting Lakebase catalog details."""
-
- def test_get_catalog(self, lakebase_instance_name: str, unique_name, cleanup_catalogs):
- """Should return catalog details."""
- catalog_name = f"lb_test_cat_get_{unique_name}"
- cleanup_catalogs(catalog_name)
-
- # Create
- try:
- create_lakebase_catalog(
- name=catalog_name,
- instance_name=lakebase_instance_name,
- database_name="databricks_postgres",
- create_database_if_not_exists=True,
- )
- except Exception as e:
- if "CREATE CATALOG" in str(e) or "permission" in str(e).lower():
- pytest.skip("User lacks CREATE CATALOG permission on this metastore")
- raise
-
- # Get
- result = get_lakebase_catalog(catalog_name)
-
- logger.info(f"Get catalog result: {result}")
-
- assert result["name"] == catalog_name
- assert result.get("instance_name") == lakebase_instance_name
-
- def test_get_catalog_not_found(self):
- """Should return NOT_FOUND for non-existent catalog."""
- result = get_lakebase_catalog("nonexistent_lb_catalog_99999")
-
- assert result["status"] == "NOT_FOUND"
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestDeleteCatalog:
- """Tests for deleting Lakebase catalogs."""
-
- def test_delete_catalog(self, lakebase_instance_name: str, unique_name, cleanup_catalogs):
- """Should delete a Lakebase catalog."""
- catalog_name = f"lb_test_cat_del_{unique_name}"
- # Don't register for cleanup since we're testing delete
-
- # Create
- try:
- create_lakebase_catalog(
- name=catalog_name,
- instance_name=lakebase_instance_name,
- database_name="databricks_postgres",
- create_database_if_not_exists=True,
- )
- except Exception as e:
- if "CREATE CATALOG" in str(e) or "permission" in str(e).lower():
- pytest.skip("User lacks CREATE CATALOG permission on this metastore")
- raise
-
- # Delete
- result = delete_lakebase_catalog(catalog_name)
-
- logger.info(f"Delete catalog result: {result}")
-
- assert result["name"] == catalog_name
- assert result["status"] == "deleted"
-
- # Verify it's gone
- verify = get_lakebase_catalog(catalog_name)
- assert verify["status"] == "NOT_FOUND"
-
- def test_delete_catalog_not_found(self):
- """Should return NOT_FOUND for non-existent catalog."""
- result = delete_lakebase_catalog("nonexistent_lb_catalog_99999")
-
- assert result["status"] == "NOT_FOUND"
- assert "error" in result
diff --git a/databricks-tools-core/tests/integration/lakebase/test_instances.py b/databricks-tools-core/tests/integration/lakebase/test_instances.py
deleted file mode 100644
index 2b60d4e3..00000000
--- a/databricks-tools-core/tests/integration/lakebase/test_instances.py
+++ /dev/null
@@ -1,149 +0,0 @@
-"""
-Integration tests for Lakebase Provisioned instance operations.
-
-Tests:
-- create_lakebase_instance
-- get_lakebase_instance
-- list_lakebase_instances
-- update_lakebase_instance
-- delete_lakebase_instance
-- generate_lakebase_credential
-"""
-
-import logging
-import time
-
-import pytest
-
-from databricks_tools_core.lakebase import (
- create_lakebase_instance,
- delete_lakebase_instance,
- generate_lakebase_credential,
- get_lakebase_instance,
- list_lakebase_instances,
- update_lakebase_instance,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestGetInstance:
- """Tests for getting instance details."""
-
- def test_get_instance_running(self, lakebase_instance_name: str):
- """Should return details for a running instance."""
- result = get_lakebase_instance(lakebase_instance_name)
-
- logger.info(f"Instance details: {result}")
-
- assert result["name"] == lakebase_instance_name
- assert "state" in result
- # Instance should be ready (state string varies by SDK)
- state = result["state"].upper()
- assert any(s in state for s in ("RUNNING", "AVAILABLE", "ACTIVE")), f"Unexpected state: {result['state']}"
- assert "capacity" in result
-
- def test_get_instance_not_found(self):
- """Should return NOT_FOUND for non-existent instance."""
- result = get_lakebase_instance("nonexistent-instance-xyz-99999")
-
- assert result["state"] == "NOT_FOUND"
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestListInstances:
- """Tests for listing instances."""
-
- def test_list_instances(self, lakebase_instance_name: str):
- """Should list instances including the test instance."""
- instances = list_lakebase_instances()
-
- logger.info(f"Found {len(instances)} instances")
-
- assert isinstance(instances, list)
- assert len(instances) > 0
-
- names = [inst["name"] for inst in instances]
- assert lakebase_instance_name in names, f"Test instance '{lakebase_instance_name}' not in: {names}"
-
-
-@pytest.mark.integration
-class TestCreateInstance:
- """Tests for creating instances."""
-
- def test_create_instance(self, cleanup_instances, unique_name):
- """Should create a new instance."""
- name = f"lb-test-create-{unique_name}"
- cleanup_instances(name)
-
- result = create_lakebase_instance(
- name=name,
- capacity="CU_1",
- stopped=False,
- )
-
- logger.info(f"Create result: {result}")
-
- assert result["name"] == name
- assert result.get("capacity") == "CU_1"
- assert result["status"] in ("CREATING", "ALREADY_EXISTS")
-
- def test_create_duplicate_instance(self, lakebase_instance_name: str):
- """Should return ALREADY_EXISTS for duplicate instance."""
- result = create_lakebase_instance(
- name=lakebase_instance_name,
- capacity="CU_1",
- )
-
- assert result["status"] == "ALREADY_EXISTS"
-
-
-@pytest.mark.integration
-class TestUpdateInstance:
- """Tests for updating instances."""
-
- def test_update_instance_stop(self, lakebase_instance_name: str):
- """Should stop the session instance (it will be restarted by other tests)."""
- result = update_lakebase_instance(name=lakebase_instance_name, stopped=True)
- logger.info(f"Stop result: {result}")
- assert result["status"] == "UPDATED"
- assert result.get("stopped") is True
-
- # Give it a moment then restart
- time.sleep(5)
-
- # Start it back up
- result = update_lakebase_instance(name=lakebase_instance_name, stopped=False)
- logger.info(f"Start result: {result}")
- assert result["status"] == "UPDATED"
- assert result.get("stopped") is False
-
-
-@pytest.mark.integration
-class TestGenerateCredential:
- """Tests for generating database credentials."""
-
- def test_generate_credential(self, lakebase_instance_name: str):
- """Should generate an OAuth token for connecting."""
- result = generate_lakebase_credential(instance_names=[lakebase_instance_name])
-
- logger.info(f"Credential result keys: {list(result.keys())}")
-
- assert "token" in result, "Expected OAuth token in result"
- assert len(result["token"]) > 0
- assert result["instance_names"] == [lakebase_instance_name]
- assert "message" in result
-
-
-@pytest.mark.integration
-class TestDeleteInstance:
- """Tests for deleting instances."""
-
- def test_delete_instance_not_found(self):
- """Should return NOT_FOUND when deleting non-existent instance."""
- result = delete_lakebase_instance("nonexistent-instance-xyz-99999")
-
- assert result["status"] == "NOT_FOUND"
- assert "error" in result
diff --git a/databricks-tools-core/tests/integration/lakebase/test_synced_tables.py b/databricks-tools-core/tests/integration/lakebase/test_synced_tables.py
deleted file mode 100644
index d44b2a67..00000000
--- a/databricks-tools-core/tests/integration/lakebase/test_synced_tables.py
+++ /dev/null
@@ -1,168 +0,0 @@
-"""
-Integration tests for Lakebase synced table (reverse ETL) operations.
-
-Tests:
-- create_synced_table
-- get_synced_table
-- delete_synced_table
-
-Requires a running Lakebase instance and a source Delta table.
-"""
-
-import logging
-import time
-
-import pytest
-
-from databricks_tools_core.lakebase import (
- create_lakebase_catalog,
- create_synced_table,
- delete_synced_table,
- get_synced_table,
-)
-
-logger = logging.getLogger(__name__)
-
-
-def _create_catalog_or_skip(catalog_name: str, instance_name: str):
- """Create a Lakebase catalog, skipping test if permissions are lacking."""
- try:
- create_lakebase_catalog(
- name=catalog_name,
- instance_name=instance_name,
- database_name="databricks_postgres",
- create_database_if_not_exists=True,
- )
- except Exception as e:
- err = str(e)
- if "CREATE CATALOG" in err or "permission" in err.lower() or "storage root" in err.lower():
- pytest.skip(f"Cannot create catalog on this metastore: {err[:120]}")
- raise
-
-
-@pytest.mark.integration
-class TestCreateSyncedTable:
- """Tests for creating synced tables."""
-
- def test_create_synced_table(
- self,
- lakebase_instance_name: str,
- source_delta_table: str,
- unique_name,
- cleanup_synced_tables,
- cleanup_catalogs,
- ):
- """Should create a synced table from a source Delta table."""
- catalog_name = f"lb_sync_cat_{unique_name}"
- cleanup_catalogs(catalog_name)
-
- _create_catalog_or_skip(catalog_name, lakebase_instance_name)
- time.sleep(5)
-
- target_table = f"{catalog_name}.public.synced_{unique_name}"
- cleanup_synced_tables(target_table)
-
- result = create_synced_table(
- instance_name=lakebase_instance_name,
- source_table_name=source_delta_table,
- target_table_name=target_table,
- primary_key_columns=["id"],
- scheduling_policy="TRIGGERED",
- )
-
- logger.info(f"Create synced table result: {result}")
-
- assert result["instance_name"] == lakebase_instance_name
- assert result["source_table_name"] == source_delta_table
- assert result["target_table_name"] == target_table
- assert result["status"] in ("CREATING", "ALREADY_EXISTS")
-
-
-@pytest.mark.integration
-class TestGetSyncedTable:
- """Tests for getting synced table details."""
-
- def test_get_synced_table(
- self,
- lakebase_instance_name: str,
- source_delta_table: str,
- unique_name,
- cleanup_synced_tables,
- cleanup_catalogs,
- ):
- """Should return synced table details."""
- catalog_name = f"lb_sync_cat_get_{unique_name}"
- cleanup_catalogs(catalog_name)
-
- _create_catalog_or_skip(catalog_name, lakebase_instance_name)
- time.sleep(5)
-
- target_table = f"{catalog_name}.public.synced_get_{unique_name}"
- cleanup_synced_tables(target_table)
-
- create_synced_table(
- instance_name=lakebase_instance_name,
- source_table_name=source_delta_table,
- target_table_name=target_table,
- primary_key_columns=["id"],
- )
- time.sleep(5)
-
- result = get_synced_table(target_table)
-
- logger.info(f"Get synced table result: {result}")
-
- assert result["table_name"] == target_table
- assert result.get("instance_name") == lakebase_instance_name
- assert result.get("source_table_name") == source_delta_table
-
- def test_get_synced_table_not_found(self):
- """Should return NOT_FOUND for non-existent synced table."""
- result = get_synced_table("nonexistent_cat.nonexistent_schema.nonexistent_table")
-
- assert result["status"] == "NOT_FOUND"
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestDeleteSyncedTable:
- """Tests for deleting synced tables."""
-
- def test_delete_synced_table(
- self,
- lakebase_instance_name: str,
- source_delta_table: str,
- unique_name,
- cleanup_catalogs,
- ):
- """Should delete a synced table."""
- catalog_name = f"lb_sync_cat_del_{unique_name}"
- cleanup_catalogs(catalog_name)
-
- _create_catalog_or_skip(catalog_name, lakebase_instance_name)
- time.sleep(5)
-
- target_table = f"{catalog_name}.public.synced_del_{unique_name}"
-
- create_synced_table(
- instance_name=lakebase_instance_name,
- source_table_name=source_delta_table,
- target_table_name=target_table,
- primary_key_columns=["id"],
- )
- time.sleep(5)
-
- # Delete
- result = delete_synced_table(target_table)
-
- logger.info(f"Delete synced table result: {result}")
-
- assert result["table_name"] == target_table
- assert result["status"] == "deleted"
-
- def test_delete_synced_table_not_found(self):
- """Should return NOT_FOUND for non-existent synced table."""
- result = delete_synced_table("nonexistent_cat.nonexistent_schema.nonexistent_table")
-
- assert result["status"] == "NOT_FOUND"
- assert "error" in result
diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/__init__.py b/databricks-tools-core/tests/integration/lakebase_autoscale/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/conftest.py b/databricks-tools-core/tests/integration/lakebase_autoscale/conftest.py
deleted file mode 100644
index edef3450..00000000
--- a/databricks-tools-core/tests/integration/lakebase_autoscale/conftest.py
+++ /dev/null
@@ -1,157 +0,0 @@
-"""
-Shared fixtures for Lakebase Autoscaling integration tests.
-
-Strategy:
- - If an existing project with prefix "lb-auto-test-" exists, reuse it.
- - Otherwise, create a new project and wait for it to become READY.
- - The session-scoped project is NOT deleted at teardown so it can be
- reused across runs (manual cleanup is expected).
- - Branch / endpoint cleanup helpers are provided for tests that create
- child resources.
-"""
-
-import logging
-import time
-import uuid
-
-import pytest
-
-from databricks_tools_core.lakebase_autoscale import (
- create_project,
- delete_branch,
- delete_endpoint,
- delete_project,
- get_project,
- list_branches,
- list_projects,
-)
-
-logger = logging.getLogger(__name__)
-
-TEST_PREFIX = "lb-auto-test-"
-PROJECT_READY_TIMEOUT = 600 # seconds
-
-
-def _wait_for_project_ready(project_name: str, timeout: int = PROJECT_READY_TIMEOUT):
- """Poll until the project reaches a READY state."""
- deadline = time.time() + timeout
- while time.time() < deadline:
- proj = get_project(project_name)
- state = proj.get("state", "")
- if "READY" in str(state).upper():
- logger.info(f"Project {project_name} is READY")
- return proj
- logger.info(f"Project {project_name} state={state}, waiting...")
- time.sleep(15)
- raise TimeoutError(f"Project {project_name} did not become READY within {timeout}s")
-
-
-@pytest.fixture(scope="session")
-def lakebase_project_name():
- """
- Return a reusable test project name.
-
- Looks for an existing project with prefix ``lb-auto-test-``.
- If none exists, creates one and waits for it to be ready.
- """
- projects = list_projects()
- for p in projects:
- name = p.get("name", "")
- project_id = name.split("/")[-1] if "/" in name else name
- if project_id.startswith(TEST_PREFIX):
- state = p.get("state", "")
- if "READY" in str(state).upper() or state == "":
- logger.info(f"Reusing existing project: {name}")
- return name
-
- # No reusable project – create one
- suffix = uuid.uuid4().hex[:8]
- project_id = f"{TEST_PREFIX}{suffix}"
- logger.info(f"Creating new test project: {project_id}")
-
- result = create_project(
- project_id=project_id,
- display_name=f"Integration Test {suffix}",
- )
- project_name = result["name"]
- _wait_for_project_ready(project_name)
- return project_name
-
-
-@pytest.fixture(scope="session")
-def lakebase_default_branch(lakebase_project_name: str):
- """
- Return the default branch name for the test project.
-
- The default branch is created automatically with the project.
- Its ID is auto-generated by the service.
- """
- branches = list_branches(lakebase_project_name)
- default_branches = [b for b in branches if b.get("is_default") is True]
- if default_branches:
- branch_name = default_branches[0]["name"]
- else:
- branch_name = branches[0]["name"]
-
- logger.info(f"Default branch: {branch_name}")
- return branch_name
-
-
-@pytest.fixture(scope="session")
-def unique_name():
- """Return a short unique suffix for naming test resources."""
- return uuid.uuid4().hex[:6]
-
-
-@pytest.fixture
-def cleanup_branches():
- """Fixture that returns a helper to register branches for cleanup."""
- branches_to_delete: list[str] = []
-
- def _register(branch_name: str):
- branches_to_delete.append(branch_name)
-
- yield _register
-
- for name in reversed(branches_to_delete):
- try:
- delete_branch(name)
- logger.info(f"Cleaned up branch: {name}")
- except Exception as exc:
- logger.warning(f"Branch cleanup failed for {name}: {exc}")
-
-
-@pytest.fixture
-def cleanup_projects():
- """Fixture that returns a helper to register projects for cleanup."""
- projects_to_delete: list[str] = []
-
- def _register(project_name: str):
- projects_to_delete.append(project_name)
-
- yield _register
-
- for name in reversed(projects_to_delete):
- try:
- delete_project(name)
- logger.info(f"Cleaned up project: {name}")
- except Exception as exc:
- logger.warning(f"Project cleanup failed for {name}: {exc}")
-
-
-@pytest.fixture
-def cleanup_endpoints():
- """Fixture that returns a helper to register endpoints for cleanup."""
- endpoints_to_delete: list[str] = []
-
- def _register(endpoint_name: str):
- endpoints_to_delete.append(endpoint_name)
-
- yield _register
-
- for name in reversed(endpoints_to_delete):
- try:
- delete_endpoint(name)
- logger.info(f"Cleaned up endpoint: {name}")
- except Exception as exc:
- logger.warning(f"Endpoint cleanup failed for {name}: {exc}")
diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/test_branches.py b/databricks-tools-core/tests/integration/lakebase_autoscale/test_branches.py
deleted file mode 100644
index f0f24635..00000000
--- a/databricks-tools-core/tests/integration/lakebase_autoscale/test_branches.py
+++ /dev/null
@@ -1,157 +0,0 @@
-"""
-Integration tests for Lakebase Autoscaling branch operations.
-
-Tests:
-- List branches on a project
-- Get branch details
-- Create and delete a branch
-- Protect and unprotect a branch
-- Reset a branch from its parent
-"""
-
-import logging
-import time
-
-import pytest
-
-from databricks_tools_core.lakebase_autoscale import (
- create_branch,
- delete_branch,
- get_branch,
- list_branches,
- update_branch,
-)
-
-logger = logging.getLogger(__name__)
-
-
-def _wait_for_branch_ready(name: str, timeout: int = 300, poll: int = 10):
- """Poll branch until READY or timeout."""
- deadline = time.time() + timeout
- while time.time() < deadline:
- branch = get_branch(name)
- state = branch.get("state", "")
- logger.info(f"Branch '{name}' state: {state}")
- state_upper = state.upper()
- if "READY" in state_upper or "ACTIVE" in state_upper:
- return branch
- if state == "NOT_FOUND":
- raise RuntimeError(f"Branch '{name}' not found")
- if "FAILED" in state_upper or "ERROR" in state_upper:
- raise RuntimeError(f"Branch '{name}' in terminal state: {state}")
- time.sleep(poll)
- raise TimeoutError(f"Branch '{name}' not READY within {timeout}s")
-
-
-class TestListBranches:
- """Test listing branches."""
-
- def test_list_branches_returns_list(self, lakebase_project_name):
- """Listing branches should return at least the production branch."""
- branches = list_branches(lakebase_project_name)
- assert isinstance(branches, list)
- assert len(branches) > 0, "Expected at least one branch (production)"
-
- def test_default_branch_exists(self, lakebase_project_name):
- """A default branch should exist."""
- branches = list_branches(lakebase_project_name)
- default_branches = [b for b in branches if b.get("is_default") is True]
- assert len(default_branches) > 0, f"No default branch found in: {branches}"
-
-
-class TestGetBranch:
- """Test getting branch details."""
-
- def test_get_default_branch(self, lakebase_default_branch):
- """Getting the default branch should return its details."""
- branch = get_branch(lakebase_default_branch)
- assert branch["name"] == lakebase_default_branch
-
- def test_get_nonexistent_branch(self, lakebase_project_name):
- """Getting a non-existent branch should return NOT_FOUND."""
- result = get_branch(f"{lakebase_project_name}/branches/nonexistent")
- assert result["state"] == "NOT_FOUND"
-
-
-class TestBranchLifecycle:
- """Test branch create, update, and delete lifecycle."""
-
- @pytest.mark.slow
- def test_create_and_delete_branch(self, lakebase_project_name, cleanup_branches, unique_name):
- """Create a branch from production and then delete it."""
- branch_id = f"test-br-{unique_name}"
- branch_full_name = f"{lakebase_project_name}/branches/{branch_id}"
- cleanup_branches(branch_full_name)
-
- # Create
- result = create_branch(
- project_name=lakebase_project_name,
- branch_id=branch_id,
- ttl_seconds=7200, # 2 hours
- )
- assert result["name"] == branch_full_name
- assert result["status"] in ("CREATED", "ALREADY_EXISTS")
-
- # Wait for ready
- _wait_for_branch_ready(branch_full_name)
-
- # Verify
- branch = get_branch(branch_full_name)
- assert branch["name"] == branch_full_name
-
- # Delete
- del_result = delete_branch(branch_full_name)
- assert del_result["status"] == "deleted"
-
- @pytest.mark.slow
- def test_create_branch_no_expiry(self, lakebase_project_name, cleanup_branches, unique_name):
- """Create a permanent branch (no expiry)."""
- branch_id = f"test-perm-{unique_name}"
- branch_full_name = f"{lakebase_project_name}/branches/{branch_id}"
- cleanup_branches(branch_full_name)
-
- result = create_branch(
- project_name=lakebase_project_name,
- branch_id=branch_id,
- no_expiry=True,
- )
- assert result["name"] == branch_full_name
- assert result["status"] in ("CREATED", "ALREADY_EXISTS")
-
- _wait_for_branch_ready(branch_full_name)
-
- # Branch should not have an expire_time
- branch = get_branch(branch_full_name)
- # expire_time might be absent or None for permanent branches
- assert branch.get("expire_time") is None or "expire_time" not in branch
-
- @pytest.mark.slow
- def test_protect_and_unprotect_branch(self, lakebase_project_name, cleanup_branches, unique_name):
- """Create a branch, protect it, then unprotect it."""
- branch_id = f"test-prot-{unique_name}"
- branch_full_name = f"{lakebase_project_name}/branches/{branch_id}"
- cleanup_branches(branch_full_name)
-
- # Create
- create_branch(
- project_name=lakebase_project_name,
- branch_id=branch_id,
- no_expiry=True,
- )
- _wait_for_branch_ready(branch_full_name)
-
- # Protect
- result = update_branch(branch_full_name, is_protected=True)
- assert result["status"] == "UPDATED"
- assert result["is_protected"] is True
-
- # Verify protection
- branch = get_branch(branch_full_name)
- assert branch.get("is_protected") is True
-
- # Unprotect
- result = update_branch(branch_full_name, is_protected=False)
- assert result["status"] == "UPDATED"
-
- # NOTE: reset_branch is not yet available in the Databricks SDK.
- # Test will be added when the SDK supports it.
diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/test_computes.py b/databricks-tools-core/tests/integration/lakebase_autoscale/test_computes.py
deleted file mode 100644
index b4317718..00000000
--- a/databricks-tools-core/tests/integration/lakebase_autoscale/test_computes.py
+++ /dev/null
@@ -1,162 +0,0 @@
-"""
-Integration tests for Lakebase Autoscaling compute (endpoint) operations.
-
-Tests:
-- List endpoints on a branch
-- Get endpoint details
-- Create and delete endpoints
-- Update endpoint (resize)
-"""
-
-import logging
-import time
-
-import pytest
-
-from databricks_tools_core.lakebase_autoscale import (
- create_endpoint,
- delete_endpoint,
- get_endpoint,
- list_endpoints,
- update_endpoint,
-)
-
-logger = logging.getLogger(__name__)
-
-
-def _wait_for_endpoint_ready(name: str, timeout: int = 600, poll: int = 15):
- """Poll endpoint until ACTIVE or timeout."""
- deadline = time.time() + timeout
- while time.time() < deadline:
- ep = get_endpoint(name)
- state = ep.get("state", "")
- logger.info(f"Endpoint '{name}' state: {state}")
- state_upper = state.upper()
- if "ACTIVE" in state_upper or "READY" in state_upper:
- return ep
- if state == "NOT_FOUND":
- raise RuntimeError(f"Endpoint '{name}' not found")
- if "FAILED" in state_upper or "ERROR" in state_upper:
- raise RuntimeError(f"Endpoint '{name}' in terminal state: {state}")
- time.sleep(poll)
- raise TimeoutError(f"Endpoint '{name}' not ACTIVE within {timeout}s")
-
-
-class TestListEndpoints:
- """Test listing endpoints on a branch."""
-
- def test_list_endpoints_returns_list(self, lakebase_default_branch):
- """Listing endpoints should return at least the default primary endpoint."""
- endpoints = list_endpoints(lakebase_default_branch)
- assert isinstance(endpoints, list)
- assert len(endpoints) > 0, "Expected at least one endpoint on production"
-
- def test_default_endpoint_is_read_write(self, lakebase_default_branch):
- """The default endpoint should be read-write."""
- endpoints = list_endpoints(lakebase_default_branch)
- rw_endpoints = [ep for ep in endpoints if "READ_WRITE" in ep.get("endpoint_type", "")]
- assert len(rw_endpoints) > 0, "Expected a read-write endpoint on production"
-
-
-class TestGetEndpoint:
- """Test getting endpoint details."""
-
- def test_get_default_endpoint(self, lakebase_default_branch):
- """Getting the default endpoint should return details including host."""
- endpoints = list_endpoints(lakebase_default_branch)
- assert len(endpoints) > 0
-
- ep = get_endpoint(endpoints[0]["name"])
- assert "name" in ep
- assert ep.get("state") != "NOT_FOUND"
-
- def test_get_nonexistent_endpoint(self, lakebase_default_branch):
- """Getting a non-existent endpoint should return NOT_FOUND."""
- result = get_endpoint(f"{lakebase_default_branch}/endpoints/nonexistent")
- assert result["state"] == "NOT_FOUND"
-
-
-class TestEndpointLifecycle:
- """Test endpoint create, update, and delete lifecycle."""
-
- @pytest.mark.slow
- def test_create_and_delete_read_only_endpoint(
- self,
- lakebase_default_branch,
- cleanup_endpoints,
- unique_name,
- ):
- """Create a read-only endpoint, verify it, then delete it."""
- endpoint_id = f"test-ro-{unique_name}"
- ep_full_name = f"{lakebase_default_branch}/endpoints/{endpoint_id}"
- cleanup_endpoints(ep_full_name)
-
- # Create read-only endpoint
- result = create_endpoint(
- branch_name=lakebase_default_branch,
- endpoint_id=endpoint_id,
- endpoint_type="ENDPOINT_TYPE_READ_ONLY",
- autoscaling_limit_min_cu=0.5,
- autoscaling_limit_max_cu=2.0,
- )
- assert result["name"] == ep_full_name
- assert result["status"] in ("CREATED", "ALREADY_EXISTS")
-
- # Wait for active
- _wait_for_endpoint_ready(ep_full_name)
-
- # Verify
- ep = get_endpoint(ep_full_name)
- assert ep["name"] == ep_full_name
- assert "READ_ONLY" in ep.get("endpoint_type", "")
-
- # Delete
- del_result = delete_endpoint(ep_full_name)
- assert del_result["status"] == "deleted"
-
-
-class TestUpdateEndpoint:
- """Test resizing / updating endpoints."""
-
- @pytest.mark.slow
- def test_resize_default_endpoint(self, lakebase_default_branch):
- """Resize the default endpoint and verify the update."""
- endpoints = list_endpoints(lakebase_default_branch)
- assert len(endpoints) > 0
-
- ep_name = endpoints[0]["name"]
-
- # Get current state
- ep = get_endpoint(ep_name)
- current_min = ep.get("min_cu")
- current_max = ep.get("max_cu")
- logger.info(f"Current CU range: {current_min}-{current_max}")
-
- # Update to a known range
- target_min = 4.0
- target_max = 8.0
-
- result = update_endpoint(
- name=ep_name,
- autoscaling_limit_min_cu=target_min,
- autoscaling_limit_max_cu=target_max,
- )
- assert result["status"] == "UPDATED"
- assert result.get("min_cu") == target_min
- assert result.get("max_cu") == target_max
-
- # Restore original if known
- if current_min is not None and current_max is not None:
- update_endpoint(
- name=ep_name,
- autoscaling_limit_min_cu=current_min,
- autoscaling_limit_max_cu=current_max,
- )
-
- def test_update_no_changes(self, lakebase_default_branch):
- """Updating with no fields should return NO_CHANGES."""
- endpoints = list_endpoints(lakebase_default_branch)
- assert len(endpoints) > 0
-
- result = update_endpoint(name=endpoints[0]["name"])
- assert result["status"] == "NO_CHANGES"
diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/test_credentials.py b/databricks-tools-core/tests/integration/lakebase_autoscale/test_credentials.py
deleted file mode 100644
index f2a7e465..00000000
--- a/databricks-tools-core/tests/integration/lakebase_autoscale/test_credentials.py
+++ /dev/null
@@ -1,51 +0,0 @@
-"""
-Integration tests for Lakebase Autoscaling credential operations.
-
-Tests:
-- Generate database credentials (OAuth token) scoped to an endpoint
-"""
-
-import logging
-
-
-from databricks_tools_core.lakebase_autoscale import (
- generate_credential,
- list_endpoints,
-)
-
-logger = logging.getLogger(__name__)
-
-
-class TestGenerateCredential:
- """Test credential generation (requires a project with an endpoint)."""
-
- def test_generate_credential_returns_token(self, lakebase_default_branch):
- """Generating credentials should return a token."""
- endpoints = list_endpoints(lakebase_default_branch)
- assert len(endpoints) > 0, "Expected at least one endpoint on production"
-
- ep_name = endpoints[0]["name"]
- result = generate_credential(endpoint=ep_name)
- assert "token" in result, "Expected 'token' in credential response"
- assert isinstance(result["token"], str)
- assert len(result["token"]) > 0, "Token should not be empty"
-
- def test_generate_credential_has_message(self, lakebase_default_branch):
- """Generated credential response should include usage instructions."""
- endpoints = list_endpoints(lakebase_default_branch)
- assert len(endpoints) > 0
-
- ep_name = endpoints[0]["name"]
- result = generate_credential(endpoint=ep_name)
- assert "message" in result
- assert "sslmode" in result["message"].lower() or "password" in result["message"].lower()
-
- def test_generate_credential_token_is_nontrivial(self, lakebase_default_branch):
- """The token should be a non-trivial string."""
- endpoints = list_endpoints(lakebase_default_branch)
- assert len(endpoints) > 0
-
- ep_name = endpoints[0]["name"]
- result = generate_credential(endpoint=ep_name)
- token = result["token"]
- assert len(token) > 20, f"Token seems too short ({len(token)} chars)"
diff --git a/databricks-tools-core/tests/integration/lakebase_autoscale/test_projects.py b/databricks-tools-core/tests/integration/lakebase_autoscale/test_projects.py
deleted file mode 100644
index 38077e53..00000000
--- a/databricks-tools-core/tests/integration/lakebase_autoscale/test_projects.py
+++ /dev/null
@@ -1,113 +0,0 @@
-"""
-Integration tests for Lakebase Autoscaling project operations.
-
-Tests:
-- List projects
-- Get project details
-- Update project display name
-- Create and delete a standalone project (full lifecycle)
-"""
-
-import logging
-import uuid
-
-import pytest
-
-from databricks_tools_core.lakebase_autoscale import (
- create_project,
- delete_project,
- get_project,
- list_projects,
- update_project,
-)
-
-logger = logging.getLogger(__name__)
-
-
-class TestListProjects:
- """Test listing Lakebase Autoscaling projects."""
-
- def test_list_projects_returns_list(self, lakebase_project_name):
- """Listing projects should return a non-empty list."""
- projects = list_projects()
- assert isinstance(projects, list)
- assert len(projects) > 0, "Expected at least one project"
-
- def test_list_projects_contains_test_project(self, lakebase_project_name):
- """The test project should appear in the list."""
- projects = list_projects()
- names = [p["name"] for p in projects]
- assert lakebase_project_name in names, f"Test project '{lakebase_project_name}' not found in: {names}"
-
-
-class TestGetProject:
- """Test getting Lakebase Autoscaling project details."""
-
- def test_get_existing_project(self, lakebase_project_name):
- """Getting an existing project should return its details."""
- project = get_project(lakebase_project_name)
- assert project["name"] == lakebase_project_name
- assert "state" in project or "display_name" in project
-
- def test_get_project_without_prefix(self, lakebase_project_name):
- """Getting a project by ID without 'projects/' prefix should work."""
- project_id = lakebase_project_name.replace("projects/", "")
- project = get_project(project_id)
- assert project["name"] == lakebase_project_name
-
- def test_get_nonexistent_project(self):
- """Getting a non-existent project should return NOT_FOUND."""
- result = get_project("projects/nonexistent-project-xyz")
- assert result["state"] == "NOT_FOUND"
-
-
-class TestUpdateProject:
- """Test updating Lakebase Autoscaling project properties."""
-
- def test_update_display_name(self, lakebase_project_name):
- """Updating display name should succeed."""
- new_name = f"Updated Test {uuid.uuid4().hex[:6]}"
- result = update_project(lakebase_project_name, display_name=new_name)
- assert result["status"] == "UPDATED"
- assert result["display_name"] == new_name
-
- def test_update_no_changes(self, lakebase_project_name):
- """Updating with no fields should return NO_CHANGES."""
- result = update_project(lakebase_project_name)
- assert result["status"] == "NO_CHANGES"
-
-
-class TestProjectLifecycle:
- """Test full project create-delete lifecycle."""
-
- @pytest.mark.slow
- def test_create_and_delete_project(self, cleanup_projects):
- """Create a project, verify it exists, then delete it.
-
- NOTE: Project creation can take 10-15 minutes due to LRO provisioning.
- The SDK's operation.wait() blocks until the project is fully created.
- """
- project_id = f"lb-test-lifecycle-{uuid.uuid4().hex[:8]}"
- cleanup_projects(f"projects/{project_id}")
-
- # Create (operation.wait() blocks until provisioning completes)
- logger.info(f"Creating project '{project_id}' -- this may take 10-15 min...")
- result = create_project(
- project_id=project_id,
- display_name=f"Lifecycle Test {project_id}",
- pg_version="17",
- )
- assert result["name"] == f"projects/{project_id}"
- assert result["status"] in ("CREATED", "ALREADY_EXISTS")
- logger.info(f"Project '{project_id}' created: {result}")
-
- # Verify it's accessible
- proj = get_project(f"projects/{project_id}")
- assert proj["name"] == f"projects/{project_id}"
- assert "display_name" in proj
-
- # Delete (operation.wait() blocks until deletion completes)
- logger.info(f"Deleting project '{project_id}'...")
- del_result = delete_project(f"projects/{project_id}")
- assert del_result["status"] == "deleted"
- logger.info(f"Project '{project_id}' deleted")
diff --git a/databricks-tools-core/tests/integration/pdf/test_pdf_generation.py b/databricks-tools-core/tests/integration/pdf/test_pdf_generation.py
deleted file mode 100644
index 8db09214..00000000
--- a/databricks-tools-core/tests/integration/pdf/test_pdf_generation.py
+++ /dev/null
@@ -1,107 +0,0 @@
-"""Integration tests for PDF generation."""
-
-import pytest
-
-from databricks_tools_core.pdf import generate_and_upload_pdf
-from databricks_tools_core.pdf.generator import _convert_html_to_pdf
-
-
-@pytest.fixture
-def sample_html():
- """Sample HTML document for testing."""
- return """
-
-
-
-
-
- Test Document
- This is a simple test paragraph.
-
-
This is highlighted content.
-
-
- Item 1
- Item 2
- Item 3
-
-
-"""
-
-
-@pytest.fixture
-def test_config():
- """Test configuration using ai_dev_kit catalog."""
- return {
- "catalog": "ai_dev_kit",
- "schema": "test_pdf_generation",
- "volume": "raw_data",
- }
-
-
-@pytest.mark.integration
-class TestHTMLToPDF:
- """Test HTML to PDF conversion (local only, no Databricks connection)."""
-
- def test_convert_simple_html(self, sample_html, tmp_path):
- """Test converting HTML to PDF locally."""
- output_path = str(tmp_path / "test.pdf")
- success = _convert_html_to_pdf(sample_html, output_path)
-
- assert success, "HTML to PDF conversion failed"
- assert (tmp_path / "test.pdf").exists()
- assert (tmp_path / "test.pdf").stat().st_size > 0
-
-
-@pytest.mark.integration
-class TestGenerateAndUploadPDF:
- """Test PDF generation and upload to Unity Catalog volume."""
-
- def test_generate_and_upload_pdf(self, sample_html, test_config):
- """Test generating PDF from HTML and uploading to volume."""
- result = generate_and_upload_pdf(
- html_content=sample_html,
- filename="test_document.pdf",
- catalog=test_config["catalog"],
- schema=test_config["schema"],
- volume=test_config["volume"],
- )
-
- assert result.success, f"PDF generation failed: {result.error}"
- assert result.volume_path is not None
- assert result.volume_path.endswith(".pdf")
- assert test_config["catalog"] in result.volume_path
-
- def test_generate_and_upload_pdf_with_folder(self, sample_html, test_config):
- """Test generating PDF and uploading to a subfolder."""
- result = generate_and_upload_pdf(
- html_content=sample_html,
- filename="subfolder_test", # Without .pdf extension
- catalog=test_config["catalog"],
- schema=test_config["schema"],
- volume=test_config["volume"],
- folder="test_folder",
- )
-
- assert result.success, f"PDF generation failed: {result.error}"
- assert result.volume_path is not None
- assert result.volume_path.endswith(".pdf")
- assert "test_folder" in result.volume_path
-
- def test_generate_pdf_invalid_volume(self, sample_html, test_config):
- """Test error handling for invalid volume."""
- result = generate_and_upload_pdf(
- html_content=sample_html,
- filename="test.pdf",
- catalog=test_config["catalog"],
- schema=test_config["schema"],
- volume="nonexistent_volume",
- )
-
- assert not result.success
- assert result.error is not None
- assert "does not exist" in result.error
diff --git a/databricks-tools-core/tests/integration/sdp/__init__.py b/databricks-tools-core/tests/integration/sdp/__init__.py
deleted file mode 100644
index f2bbfa7f..00000000
--- a/databricks-tools-core/tests/integration/sdp/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Integration tests for Spark Declarative Pipelines (SDP)."""
diff --git a/databricks-tools-core/tests/integration/sdp/pipelines/nyctaxi_bronze.sql b/databricks-tools-core/tests/integration/sdp/pipelines/nyctaxi_bronze.sql
deleted file mode 100644
index 983361d4..00000000
--- a/databricks-tools-core/tests/integration/sdp/pipelines/nyctaxi_bronze.sql
+++ /dev/null
@@ -1,10 +0,0 @@
--- Bronze layer: Read raw NYC taxi trip data
-CREATE OR REFRESH STREAMING TABLE bronze_trips
-AS SELECT
- tpep_pickup_datetime,
- tpep_dropoff_datetime,
- trip_distance,
- fare_amount,
- pickup_zip,
- dropoff_zip
-FROM STREAM samples.nyctaxi.trips
diff --git a/databricks-tools-core/tests/integration/sdp/pipelines/nyctaxi_silver.sql b/databricks-tools-core/tests/integration/sdp/pipelines/nyctaxi_silver.sql
deleted file mode 100644
index e402b9c9..00000000
--- a/databricks-tools-core/tests/integration/sdp/pipelines/nyctaxi_silver.sql
+++ /dev/null
@@ -1,16 +0,0 @@
--- Silver layer: Enriched NYC taxi trip data with calculated fields
-CREATE OR REFRESH MATERIALIZED VIEW silver_trips
-AS SELECT
- tpep_pickup_datetime,
- tpep_dropoff_datetime,
- trip_distance,
- fare_amount,
- pickup_zip,
- dropoff_zip,
- -- Calculated fields
- TIMESTAMPDIFF(MINUTE, tpep_pickup_datetime, tpep_dropoff_datetime) AS trip_duration_minutes,
- CASE
- WHEN fare_amount > 0 THEN fare_amount / NULLIF(trip_distance, 0)
- ELSE 0
- END AS fare_per_mile
-FROM LIVE.bronze_trips
diff --git a/databricks-tools-core/tests/integration/sdp/test_pipelines.py b/databricks-tools-core/tests/integration/sdp/test_pipelines.py
deleted file mode 100644
index ed63058b..00000000
--- a/databricks-tools-core/tests/integration/sdp/test_pipelines.py
+++ /dev/null
@@ -1,448 +0,0 @@
-"""
-Integration tests for SDP pipeline management functions.
-
-Tests:
-- create_or_update_pipeline (create, update, run, wait)
-- delete_pipeline
-"""
-
-import logging
-import pytest
-from pathlib import Path
-
-from databricks_tools_core.spark_declarative_pipelines.pipelines import (
- create_or_update_pipeline,
- delete_pipeline,
- find_pipeline_by_name,
-)
-from databricks_tools_core.file.workspace import upload_to_workspace
-
-
-logger = logging.getLogger(__name__)
-
-# Path to test pipeline files
-PIPELINES_DIR = Path(__file__).parent / "pipelines"
-
-# Fixed pipeline name for consistent cleanup
-TEST_PIPELINE_NAME = "ai_dev_kit_test_sdp_pipeline"
-
-# Dedicated schema for SDP tests (separate from SQL tests)
-SDP_TEST_SCHEMA = "test_sdp_schema"
-
-
-@pytest.fixture(scope="module")
-def pipeline_name() -> str:
- """Return the fixed test pipeline name."""
- return TEST_PIPELINE_NAME
-
-
-@pytest.fixture(scope="module")
-def sdp_test_schema(workspace_client, test_catalog: str) -> str:
- """
- Create a dedicated schema for SDP tests.
-
- Uses a separate schema from other tests to avoid conflicts.
- """
- full_schema_name = f"{test_catalog}.{SDP_TEST_SCHEMA}"
-
- # Drop schema if exists (with force to cascade)
- try:
- logger.info(f"Dropping existing SDP test schema: {full_schema_name}")
- workspace_client.schemas.delete(full_schema_name, force=True)
- except Exception:
- pass # Schema doesn't exist, that's fine
-
- # Create fresh schema
- logger.info(f"Creating SDP test schema: {full_schema_name}")
- workspace_client.schemas.create(
- name=SDP_TEST_SCHEMA,
- catalog_name=test_catalog,
- )
-
- yield SDP_TEST_SCHEMA
-
- # Cleanup - drop schema after tests
- try:
- logger.info(f"Cleaning up SDP test schema: {full_schema_name}")
- workspace_client.schemas.delete(full_schema_name, force=True)
- except Exception as e:
- logger.warning(f"Failed to cleanup SDP test schema: {e}")
-
-
-@pytest.fixture(scope="module")
-def workspace_path(workspace_client) -> str:
- """
- Get workspace path for test pipeline files.
-
- Uses the current user's home folder.
- """
- user = workspace_client.current_user.me()
- return f"/Workspace/Users/{user.user_name}/test_sdp/{TEST_PIPELINE_NAME}"
-
-
-@pytest.fixture(scope="module")
-def clean_pipeline(pipeline_name: str):
- """
- Ensure pipeline doesn't exist before tests start.
-
- This cleans up any leftover pipeline from a previous failed run.
- """
- # Check if pipeline exists and delete it
- existing_id = find_pipeline_by_name(pipeline_name)
- if existing_id:
- logger.info(f"Cleaning up existing pipeline: {pipeline_name} ({existing_id})")
- try:
- delete_pipeline(existing_id)
- logger.info("Existing pipeline deleted")
- except Exception as e:
- logger.warning(f"Failed to delete existing pipeline: {e}")
-
- yield pipeline_name
-
- # Cleanup after all tests - delete pipeline if it still exists
- existing_id = find_pipeline_by_name(pipeline_name)
- if existing_id:
- logger.info(f"Final cleanup of pipeline: {pipeline_name}")
- try:
- delete_pipeline(existing_id)
- except Exception as e:
- logger.warning(f"Failed to cleanup pipeline: {e}")
-
-
-@pytest.fixture(scope="module")
-def uploaded_pipeline_files(workspace_client, workspace_path: str):
- """Upload test pipeline files to workspace."""
- logger.info(f"Uploading pipeline files to {workspace_path}")
-
- # Upload the pipelines folder to workspace
- result = upload_to_workspace(
- local_path=str(PIPELINES_DIR),
- workspace_path=workspace_path,
- overwrite=True,
- )
-
- assert result.success, f"Failed to upload pipeline files: {result.get_failed_uploads()}"
- logger.info(f"Uploaded {result.successful} files successfully")
-
- yield workspace_path
-
- # Cleanup: delete uploaded files after tests
- try:
- logger.info(f"Cleaning up workspace files: {workspace_path}")
- workspace_client.workspace.delete(workspace_path, recursive=True)
- except Exception as e:
- logger.warning(f"Failed to cleanup workspace files: {e}")
-
-
-@pytest.mark.integration
-class TestCreateOrUpdatePipeline:
- """Tests for create_or_update_pipeline function."""
-
- def test_create_pipeline_with_bronze_only(
- self,
- test_catalog: str,
- sdp_test_schema: str,
- clean_pipeline: str,
- uploaded_pipeline_files: str,
- ):
- """Should create a new pipeline with bronze layer only."""
- pipeline_name = clean_pipeline
- workspace_path = uploaded_pipeline_files
- bronze_path = f"{workspace_path}/nyctaxi_bronze.sql"
-
- logger.info(f"Creating pipeline: {pipeline_name}")
- logger.info(f"Catalog: {test_catalog}, Schema: {sdp_test_schema}")
- logger.info(f"Bronze path: {bronze_path}")
-
- result = create_or_update_pipeline(
- name=pipeline_name,
- root_path=workspace_path,
- catalog=test_catalog,
- schema=sdp_test_schema,
- workspace_file_paths=[bronze_path],
- start_run=True,
- wait_for_completion=True,
- full_refresh=True,
- timeout=600, # 10 minutes
- )
-
- logger.info(f"Pipeline result: {result.to_dict()}")
-
- # Verify creation
- assert result.pipeline_id is not None, "Pipeline ID should be set"
- assert result.pipeline_name == pipeline_name
- assert result.created is True, "Pipeline should be newly created"
- assert result.success is True, f"Pipeline run failed: {result.error_message}. Errors: {result.errors}"
- assert result.state == "COMPLETED", f"Expected COMPLETED, got {result.state}"
- assert result.duration_seconds is not None
- assert result.duration_seconds > 0
-
- def test_update_pipeline_with_silver_layer(
- self,
- test_catalog: str,
- sdp_test_schema: str,
- clean_pipeline: str,
- uploaded_pipeline_files: str,
- ):
- """Should update existing pipeline by adding silver layer."""
- pipeline_name = clean_pipeline
- workspace_path = uploaded_pipeline_files
- bronze_path = f"{workspace_path}/nyctaxi_bronze.sql"
- silver_path = f"{workspace_path}/nyctaxi_silver.sql"
-
- logger.info(f"Updating pipeline with silver layer: {pipeline_name}")
-
- # Update pipeline with both bronze and silver
- result = create_or_update_pipeline(
- name=pipeline_name,
- root_path=workspace_path,
- catalog=test_catalog,
- schema=sdp_test_schema,
- workspace_file_paths=[bronze_path, silver_path],
- start_run=True,
- wait_for_completion=True,
- full_refresh=True,
- timeout=600,
- )
-
- logger.info(f"Pipeline update result: {result.to_dict()}")
-
- # Verify update (not creation)
- assert result.pipeline_id is not None
- assert result.pipeline_name == pipeline_name
- assert result.created is False, "Pipeline should be updated, not created"
- assert result.success is True, f"Pipeline run failed: {result.error_message}. Errors: {result.errors}"
- assert result.state == "COMPLETED", f"Expected COMPLETED, got {result.state}"
-
- def test_find_pipeline_by_name(self, clean_pipeline: str):
- """Should find existing pipeline by name."""
- pipeline_name = clean_pipeline
- pipeline_id = find_pipeline_by_name(pipeline_name)
-
- assert pipeline_id is not None, f"Pipeline '{pipeline_name}' not found"
-
- def test_delete_pipeline(self, clean_pipeline: str):
- """Should delete the test pipeline."""
- pipeline_name = clean_pipeline
-
- # First find the pipeline
- pipeline_id = find_pipeline_by_name(pipeline_name)
- assert pipeline_id is not None, "Pipeline should exist before deletion"
-
- logger.info(f"Deleting pipeline: {pipeline_name} ({pipeline_id})")
-
- # Delete it
- delete_pipeline(pipeline_id)
-
- # Verify deletion
- found_id = find_pipeline_by_name(pipeline_name)
- assert found_id is None, "Pipeline should not exist after deletion"
-
- logger.info("Pipeline deleted successfully")
-
-
-# Separate test pipeline name for extra_settings tests
-TEST_PIPELINE_NAME_EXTRA = "ai_dev_kit_test_sdp_extra_settings"
-
-
-@pytest.fixture(scope="module")
-def clean_pipeline_extra(pipeline_name_extra: str):
- """
- Ensure pipeline for extra_settings tests doesn't exist before tests start.
- """
- # Check if pipeline exists and delete it
- existing_id = find_pipeline_by_name(pipeline_name_extra)
- if existing_id:
- logger.info(f"Cleaning up existing pipeline: {pipeline_name_extra} ({existing_id})")
- try:
- delete_pipeline(existing_id)
- logger.info("Existing pipeline deleted")
- except Exception as e:
- logger.warning(f"Failed to delete existing pipeline: {e}")
-
- yield pipeline_name_extra
-
- # Cleanup after all tests - delete pipeline if it still exists
- existing_id = find_pipeline_by_name(pipeline_name_extra)
- if existing_id:
- logger.info(f"Final cleanup of pipeline: {pipeline_name_extra}")
- try:
- delete_pipeline(existing_id)
- except Exception as e:
- logger.warning(f"Failed to cleanup pipeline: {e}")
-
-
-@pytest.fixture(scope="module")
-def pipeline_name_extra() -> str:
- """Return the fixed test pipeline name for extra_settings tests."""
- return TEST_PIPELINE_NAME_EXTRA
-
-
-@pytest.mark.integration
-class TestPipelineExtraSettings:
- """Tests for extra_settings parameter in pipeline functions."""
-
- def test_create_pipeline_with_development_mode(
- self,
- test_catalog: str,
- sdp_test_schema: str,
- clean_pipeline_extra: str,
- uploaded_pipeline_files: str,
- ):
- """Should create a pipeline with development mode enabled via extra_settings."""
- pipeline_name = clean_pipeline_extra
- workspace_path = uploaded_pipeline_files
- bronze_path = f"{workspace_path}/nyctaxi_bronze.sql"
-
- logger.info(f"Creating pipeline with extra_settings: {pipeline_name}")
-
- # Create with development=True and custom tags
- extra_settings = {
- "development": True,
- "tags": {"test": "extra_settings", "environment": "test"},
- }
-
- result = create_or_update_pipeline(
- name=pipeline_name,
- root_path=workspace_path,
- catalog=test_catalog,
- schema=sdp_test_schema,
- workspace_file_paths=[bronze_path],
- start_run=False, # Don't run, just create
- extra_settings=extra_settings,
- )
-
- logger.info(f"Pipeline result: {result.to_dict()}")
-
- # Verify creation
- assert result.pipeline_id is not None, "Pipeline ID should be set"
- assert result.pipeline_name == pipeline_name
- assert result.created is True, "Pipeline should be newly created"
- assert result.success is True, f"Pipeline creation failed: {result.error_message}"
-
- # Verify the extra settings were applied by fetching the pipeline
- from databricks_tools_core.spark_declarative_pipelines.pipelines import get_pipeline
-
- pipeline_details = get_pipeline(result.pipeline_id)
-
- # Check development mode is set
- assert pipeline_details.spec.development is True, "Development mode should be True"
-
- # Check tags are set
- assert pipeline_details.spec.tags is not None, "Tags should be set"
- assert pipeline_details.spec.tags.get("test") == "extra_settings"
-
- def test_update_pipeline_with_extra_settings(
- self,
- test_catalog: str,
- sdp_test_schema: str,
- clean_pipeline_extra: str,
- uploaded_pipeline_files: str,
- ):
- """Should update a pipeline with new extra_settings."""
- pipeline_name = clean_pipeline_extra
- workspace_path = uploaded_pipeline_files
- bronze_path = f"{workspace_path}/nyctaxi_bronze.sql"
-
- logger.info(f"Updating pipeline with new extra_settings: {pipeline_name}")
-
- # Update with new tags
- extra_settings = {
- "development": True,
- "tags": {"test": "updated", "version": "2"},
- }
-
- result = create_or_update_pipeline(
- name=pipeline_name,
- root_path=workspace_path,
- catalog=test_catalog,
- schema=sdp_test_schema,
- workspace_file_paths=[bronze_path],
- start_run=False,
- extra_settings=extra_settings,
- )
-
- logger.info(f"Pipeline update result: {result.to_dict()}")
-
- # Verify update (not creation)
- assert result.created is False, "Pipeline should be updated, not created"
- assert result.success is True, f"Pipeline update failed: {result.error_message}"
-
- # Verify the updated settings
- from databricks_tools_core.spark_declarative_pipelines.pipelines import get_pipeline
-
- pipeline_details = get_pipeline(result.pipeline_id)
-
- assert pipeline_details.spec.tags.get("test") == "updated"
- assert pipeline_details.spec.tags.get("version") == "2"
-
- def test_create_pipeline_with_configuration(
- self,
- test_catalog: str,
- sdp_test_schema: str,
- clean_pipeline_extra: str,
- uploaded_pipeline_files: str,
- ):
- """Should create a pipeline with custom configuration dict."""
- pipeline_name = clean_pipeline_extra
- workspace_path = uploaded_pipeline_files
- bronze_path = f"{workspace_path}/nyctaxi_bronze.sql"
-
- # First delete the existing pipeline to test creation
- existing_id = find_pipeline_by_name(pipeline_name)
- if existing_id:
- delete_pipeline(existing_id)
-
- logger.info(f"Creating pipeline with configuration: {pipeline_name}")
-
- # Create with custom configuration
- extra_settings = {
- "configuration": {
- "spark.sql.shuffle.partitions": "10",
- "pipelines.numRetries": "3",
- },
- }
-
- result = create_or_update_pipeline(
- name=pipeline_name,
- root_path=workspace_path,
- catalog=test_catalog,
- schema=sdp_test_schema,
- workspace_file_paths=[bronze_path],
- start_run=False,
- extra_settings=extra_settings,
- )
-
- logger.info(f"Pipeline result: {result.to_dict()}")
-
- assert result.created is True, "Pipeline should be newly created"
- assert result.success is True, f"Pipeline creation failed: {result.error_message}"
-
- # Verify configuration
- from databricks_tools_core.spark_declarative_pipelines.pipelines import get_pipeline
-
- pipeline_details = get_pipeline(result.pipeline_id)
-
- assert pipeline_details.spec.configuration is not None
- assert pipeline_details.spec.configuration.get("spark.sql.shuffle.partitions") == "10"
-
- def test_delete_extra_settings_pipeline(self, clean_pipeline_extra: str):
- """Should delete the extra_settings test pipeline."""
- pipeline_name = clean_pipeline_extra
-
- # First find the pipeline
- pipeline_id = find_pipeline_by_name(pipeline_name)
- if pipeline_id is None:
- logger.info(f"Pipeline '{pipeline_name}' already deleted, skipping")
- return
-
- logger.info(f"Deleting pipeline: {pipeline_name} ({pipeline_id})")
-
- # Delete it
- delete_pipeline(pipeline_id)
-
- # Verify deletion
- found_id = find_pipeline_by_name(pipeline_name)
- assert found_id is None, "Pipeline should not exist after deletion"
-
- logger.info("Extra settings pipeline deleted successfully")
diff --git a/databricks-tools-core/tests/integration/sql/__init__.py b/databricks-tools-core/tests/integration/sql/__init__.py
deleted file mode 100644
index 85eb6e7e..00000000
--- a/databricks-tools-core/tests/integration/sql/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""SQL module integration tests."""
diff --git a/databricks-tools-core/tests/integration/sql/test_data/generate_test_files.py b/databricks-tools-core/tests/integration/sql/test_data/generate_test_files.py
deleted file mode 100644
index e6bdd178..00000000
--- a/databricks-tools-core/tests/integration/sql/test_data/generate_test_files.py
+++ /dev/null
@@ -1,63 +0,0 @@
-"""
-Generate test files for volume folder integration tests.
-
-Creates:
-- parquet/: Parquet files with sample data
-- txt_files/: Simple text files
-"""
-
-from pathlib import Path
-
-import pandas as pd
-
-# Get directory of this script
-TEST_DATA_DIR = Path(__file__).parent
-
-
-def generate_parquet_data():
- """Generate parquet files with sample data."""
- parquet_dir = TEST_DATA_DIR / "parquet"
- parquet_dir.mkdir(exist_ok=True)
-
- # Create sample data
- data = {
- "id": [1, 2, 3, 4, 5],
- "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"],
- "age": [25, 30, 35, 28, 32],
- "salary": [50000.0, 60000.0, 75000.0, 55000.0, 80000.0],
- "department": ["Engineering", "Sales", "Engineering", "Marketing", "Sales"],
- }
- df = pd.DataFrame(data)
-
- # Save as parquet
- output_path = parquet_dir / "employees.parquet"
- df.to_parquet(output_path, index=False)
- print(f"Created: {output_path}")
-
- return parquet_dir
-
-
-def generate_txt_files():
- """Generate text files for file listing tests."""
- txt_dir = TEST_DATA_DIR / "txt_files"
- txt_dir.mkdir(exist_ok=True)
-
- files = [
- ("readme.txt", "This is a test readme file.\nIt has multiple lines."),
- ("data.txt", "id,name,value\n1,foo,100\n2,bar,200\n3,baz,300"),
- ("notes.txt", "Some random notes for testing."),
- ]
-
- for filename, content in files:
- file_path = txt_dir / filename
- file_path.write_text(content)
- print(f"Created: {file_path}")
-
- return txt_dir
-
-
-if __name__ == "__main__":
- print("Generating test data files...")
- generate_parquet_data()
- generate_txt_files()
- print("Done!")
diff --git a/databricks-tools-core/tests/integration/sql/test_data/parquet/employees.parquet b/databricks-tools-core/tests/integration/sql/test_data/parquet/employees.parquet
deleted file mode 100644
index 2190a43b..00000000
Binary files a/databricks-tools-core/tests/integration/sql/test_data/parquet/employees.parquet and /dev/null differ
diff --git a/databricks-tools-core/tests/integration/sql/test_data/txt_files/data.txt b/databricks-tools-core/tests/integration/sql/test_data/txt_files/data.txt
deleted file mode 100644
index 8ae44a2b..00000000
--- a/databricks-tools-core/tests/integration/sql/test_data/txt_files/data.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-id,name,value
-1,foo,100
-2,bar,200
-3,baz,300
\ No newline at end of file
diff --git a/databricks-tools-core/tests/integration/sql/test_data/txt_files/notes.txt b/databricks-tools-core/tests/integration/sql/test_data/txt_files/notes.txt
deleted file mode 100644
index c023b7d9..00000000
--- a/databricks-tools-core/tests/integration/sql/test_data/txt_files/notes.txt
+++ /dev/null
@@ -1 +0,0 @@
-Some random notes for testing.
\ No newline at end of file
diff --git a/databricks-tools-core/tests/integration/sql/test_data/txt_files/readme.txt b/databricks-tools-core/tests/integration/sql/test_data/txt_files/readme.txt
deleted file mode 100644
index 2c4ddfb0..00000000
--- a/databricks-tools-core/tests/integration/sql/test_data/txt_files/readme.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-This is a test readme file.
-It has multiple lines.
\ No newline at end of file
diff --git a/databricks-tools-core/tests/integration/sql/test_sql.py b/databricks-tools-core/tests/integration/sql/test_sql.py
deleted file mode 100644
index 665bee92..00000000
--- a/databricks-tools-core/tests/integration/sql/test_sql.py
+++ /dev/null
@@ -1,289 +0,0 @@
-"""
-Integration tests for SQL execution functions.
-
-Tests:
-- execute_sql
-- execute_sql_multi
-"""
-
-import pytest
-from databricks_tools_core.sql import execute_sql, execute_sql_multi, SQLExecutionError
-
-
-@pytest.mark.integration
-class TestExecuteSQL:
- """Tests for execute_sql function."""
-
- def test_simple_select(self, warehouse_id):
- """Should execute a simple SELECT statement."""
- result = execute_sql(
- sql_query="SELECT 1 as num, 'hello' as greeting",
- warehouse_id=warehouse_id,
- )
-
- assert isinstance(result, list)
- assert len(result) == 1
- assert result[0]["num"] == "1" or result[0]["num"] == 1
- assert result[0]["greeting"] == "hello"
-
- def test_select_with_multiple_rows(self, warehouse_id):
- """Should return multiple rows correctly."""
- result = execute_sql(
- sql_query="""
- SELECT * FROM (
- VALUES (1, 'a'), (2, 'b'), (3, 'c')
- ) AS t(id, letter)
- """,
- warehouse_id=warehouse_id,
- )
-
- assert len(result) == 3
- assert all("id" in row and "letter" in row for row in result)
-
- def test_select_from_test_table(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should query data from test tables."""
- result = execute_sql(
- sql_query=f"SELECT * FROM {test_tables['customers']} ORDER BY customer_id",
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=test_schema,
- )
-
- assert len(result) == 5
- assert result[0]["name"] == "Alice Smith"
- assert result[4]["name"] == "Eve Wilson"
-
- def test_select_with_catalog_schema_context(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should use catalog/schema context for unqualified table names."""
- result = execute_sql(
- sql_query="SELECT COUNT(*) as cnt FROM customers",
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=test_schema,
- )
-
- assert len(result) == 1
- count = int(result[0]["cnt"])
- assert count == 5
-
- def test_aggregate_query(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should handle aggregate functions."""
- result = execute_sql(
- sql_query=f"""
- SELECT
- status,
- COUNT(*) as order_count,
- SUM(amount) as total_amount
- FROM {test_tables["orders"]}
- GROUP BY status
- ORDER BY status
- """,
- warehouse_id=warehouse_id,
- )
-
- assert len(result) == 3 # cancelled, completed, pending
- statuses = [row["status"] for row in result]
- assert "completed" in statuses
- assert "pending" in statuses
-
- def test_join_query(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should handle JOIN operations."""
- result = execute_sql(
- sql_query=f"""
- SELECT
- c.name,
- COUNT(o.order_id) as order_count
- FROM {test_tables["customers"]} c
- LEFT JOIN {test_tables["orders"]} o ON c.customer_id = o.customer_id
- GROUP BY c.name
- ORDER BY order_count DESC
- """,
- warehouse_id=warehouse_id,
- )
-
- assert len(result) == 5
- # Alice has the most orders (3)
- assert result[0]["name"] == "Alice Smith"
-
- def test_auto_selects_warehouse(self, test_catalog, test_schema, test_tables):
- """Should auto-select warehouse if not provided."""
- result = execute_sql(
- sql_query=f"SELECT COUNT(*) as cnt FROM {test_tables['customers']}",
- # warehouse_id not provided
- )
-
- assert len(result) == 1
- assert int(result[0]["cnt"]) == 5
-
- def test_invalid_sql_raises_error(self, warehouse_id):
- """Should raise SQLExecutionError for invalid SQL."""
- with pytest.raises(SQLExecutionError) as exc_info:
- execute_sql(
- sql_query="SELECT * FROM nonexistent_table_xyz",
- warehouse_id=warehouse_id,
- )
-
- assert "TABLE_OR_VIEW_NOT_FOUND" in str(exc_info.value).upper() or "NOT FOUND" in str(exc_info.value).upper()
-
- def test_syntax_error_raises_error(self, warehouse_id):
- """Should raise SQLExecutionError for syntax errors."""
- with pytest.raises(SQLExecutionError) as exc_info:
- execute_sql(
- sql_query="SELEC * FORM table", # typos
- warehouse_id=warehouse_id,
- )
-
- error_msg = str(exc_info.value).upper()
- assert "SYNTAX" in error_msg or "PARSE" in error_msg or "ERROR" in error_msg
-
-
-@pytest.mark.integration
-class TestExecuteSQLMulti:
- """Tests for execute_sql_multi function."""
-
- def test_multiple_independent_statements(self, warehouse_id, test_catalog, test_schema):
- """Should execute multiple independent statements."""
- result = execute_sql_multi(
- sql_content=f"""
- CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.multi_test_1
- AS SELECT 1 as id;
- CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.multi_test_2
- AS SELECT 2 as id;
- SELECT * FROM {test_catalog}.{test_schema}.multi_test_1;
- SELECT * FROM {test_catalog}.{test_schema}.multi_test_2;
- """,
- warehouse_id=warehouse_id,
- catalog=test_catalog,
- schema=test_schema,
- )
-
- assert "results" in result
- assert "execution_summary" in result
-
- # All 4 statements should succeed
- results = result["results"]
- assert len(results) == 4
- assert all(r["status"] == "success" for r in results.values())
-
- def test_dependency_analysis(self, warehouse_id, test_catalog, test_schema):
- """Should analyze dependencies and execute in correct order."""
- result = execute_sql_multi(
- sql_content=f"""
- CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.dep_base
- AS SELECT 1 as id, 'a' as val;
- CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.dep_derived
- AS SELECT * FROM {test_catalog}.{test_schema}.dep_base WHERE id = 1;
- SELECT COUNT(*) FROM {test_catalog}.{test_schema}.dep_derived;
- """,
- warehouse_id=warehouse_id,
- )
-
- summary = result["execution_summary"]
-
- # Should have multiple groups due to dependencies
- assert summary["total_groups"] >= 2
-
- # All should succeed
- assert all(r["status"] == "success" for r in result["results"].values())
-
- def test_parallel_execution_info(self, warehouse_id, test_catalog, test_schema):
- """Should indicate which queries ran in parallel."""
- result = execute_sql_multi(
- sql_content=f"""
- CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.para_1 AS SELECT 1 as id;
- CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.para_2 AS SELECT 2 as id;
- CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.para_3 AS SELECT 3 as id;
- """,
- warehouse_id=warehouse_id,
- )
-
- summary = result["execution_summary"]
-
- # These are independent, should be in same group (parallel)
- groups = summary["groups"]
- assert len(groups) >= 1
-
- # First group should have all 3 (parallel)
- first_group = groups[0]
- assert first_group["group_size"] == 3
- assert first_group["is_parallel"] is True
-
- def test_stops_on_error(self, warehouse_id, test_catalog, test_schema):
- """Should stop execution when a query fails."""
- result = execute_sql_multi(
- sql_content=f"""
- CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.stop_test AS SELECT 1 as id;
- SELECT * FROM nonexistent_table_xyz_123;
- CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.should_not_run AS SELECT 2 as id;
- """,
- warehouse_id=warehouse_id,
- )
-
- summary = result["execution_summary"]
-
- # Should have stopped after error
- assert summary["stopped_after_group"] is not None
-
- # Check that we have at least one error
- has_error = any(r["status"] == "error" for r in result["results"].values())
- assert has_error
-
- def test_execution_summary_structure(self, warehouse_id, test_catalog, test_schema):
- """Should return proper execution summary."""
- result = execute_sql_multi(
- sql_content="""
- SELECT 1 as a;
- SELECT 2 as b;
- """,
- warehouse_id=warehouse_id,
- )
-
- summary = result["execution_summary"]
-
- assert "total_queries" in summary
- assert "total_groups" in summary
- assert "total_time" in summary
- assert "groups" in summary
-
- assert summary["total_queries"] == 2
- assert isinstance(summary["total_time"], float)
-
- def test_result_contains_query_details(self, warehouse_id, test_catalog, test_schema):
- """Each result should contain query details."""
- result = execute_sql_multi(
- sql_content="SELECT 1 as num; SELECT 2 as num;",
- warehouse_id=warehouse_id,
- )
-
- for _idx, query_result in result["results"].items():
- assert "query_index" in query_result
- assert "status" in query_result
- assert "execution_time" in query_result
- assert "query_preview" in query_result
- assert "group_number" in query_result
-
- def test_handles_comments(self, warehouse_id):
- """Should handle SQL comments correctly."""
- result = execute_sql_multi(
- sql_content="""
- -- This is a comment
- SELECT 1 as first;
- /* Multi-line
- comment */
- SELECT 2 as second;
- """,
- warehouse_id=warehouse_id,
- )
-
- assert len(result["results"]) == 2
- assert all(r["status"] == "success" for r in result["results"].values())
-
- def test_auto_selects_warehouse(self, test_catalog, test_schema):
- """Should auto-select warehouse if not provided."""
- result = execute_sql_multi(
- sql_content="SELECT 1 as num;",
- # warehouse_id not provided
- )
-
- assert len(result["results"]) == 1
- assert result["results"][0]["status"] == "success"
diff --git a/databricks-tools-core/tests/integration/sql/test_table_stats.py b/databricks-tools-core/tests/integration/sql/test_table_stats.py
deleted file mode 100644
index b53e6ca4..00000000
--- a/databricks-tools-core/tests/integration/sql/test_table_stats.py
+++ /dev/null
@@ -1,471 +0,0 @@
-"""
-Integration tests for table statistics functions.
-
-Tests:
-- get_table_stats_and_schema
-- TableStatLevel (NONE, SIMPLE, DETAILED)
-- GLOB pattern matching
-- Caching behavior
-"""
-
-import pytest
-from databricks_tools_core.sql import (
- get_table_stats_and_schema,
- TableStatLevel,
- TableSchemaResult,
-)
-
-
-@pytest.mark.integration
-class TestGetTableDetails:
- """Tests for get_table_stats_and_schema function."""
-
- def test_get_all_tables(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should list all tables when table_names is empty."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=[],
- warehouse_id=warehouse_id,
- )
-
- assert isinstance(result, TableSchemaResult)
- assert result.catalog == test_catalog
- assert result.schema_name == test_schema
- assert len(result.tables) >= 3 # customers, orders, products
-
- table_names = [t.name.split(".")[-1] for t in result.tables]
- assert "customers" in table_names
- assert "orders" in table_names
- assert "products" in table_names
-
- def test_get_specific_tables(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should get specific tables by exact name."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers", "orders"],
- warehouse_id=warehouse_id,
- )
-
- assert len(result.tables) == 2
- table_names = [t.name.split(".")[-1] for t in result.tables]
- assert "customers" in table_names
- assert "orders" in table_names
- assert "products" not in table_names
-
- def test_glob_pattern_star(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should filter tables using * glob pattern."""
- # First create tables with a common prefix
- from databricks_tools_core.sql import execute_sql
-
- execute_sql(
- sql_query=f"CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.raw_sales AS SELECT 1 as id",
- warehouse_id=warehouse_id,
- )
- execute_sql(
- sql_query=f"CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.raw_inventory AS SELECT 2 as id",
- warehouse_id=warehouse_id,
- )
-
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["raw_*"],
- warehouse_id=warehouse_id,
- )
-
- assert len(result.tables) >= 2
- for table in result.tables:
- assert table.name.split(".")[-1].startswith("raw_")
-
- def test_glob_pattern_question_mark(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should filter tables using ? glob pattern."""
- from databricks_tools_core.sql import execute_sql
-
- execute_sql(
- sql_query=f"CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.dim_a AS SELECT 1 as id",
- warehouse_id=warehouse_id,
- )
- execute_sql(
- sql_query=f"CREATE OR REPLACE TABLE {test_catalog}.{test_schema}.dim_b AS SELECT 2 as id",
- warehouse_id=warehouse_id,
- )
-
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["dim_?"],
- warehouse_id=warehouse_id,
- )
-
- assert len(result.tables) >= 2
- for table in result.tables:
- name = table.name.split(".")[-1]
- assert name.startswith("dim_") and len(name) == 5
-
- def test_mixed_patterns_and_exact(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should handle mix of glob patterns and exact names."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers", "raw_*"],
- warehouse_id=warehouse_id,
- )
-
- table_names = [t.name.split(".")[-1] for t in result.tables]
- assert "customers" in table_names
- # Should also include raw_* tables if they exist
-
-
-@pytest.mark.integration
-class TestTableStatLevelNone:
- """Tests for TableStatLevel.NONE (DDL only, no stats)."""
-
- def test_stat_level_none_returns_ddl(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should return DDL without column stats."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers"],
- table_stat_level=TableStatLevel.NONE,
- warehouse_id=warehouse_id,
- )
-
- assert len(result.tables) == 1
- table = result.tables[0]
-
- # Should have DDL
- assert table.ddl is not None
- assert "CREATE" in table.ddl.upper()
-
- # Should NOT have column details
- assert table.column_details is None
-
- # Should NOT have row count
- assert table.total_rows is None
-
- def test_stat_level_none_is_fast(self, warehouse_id, test_catalog, test_schema, test_tables):
- """NONE level should be faster than SIMPLE/DETAILED."""
- import time
-
- # NONE level
- start = time.time()
- get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers"],
- table_stat_level=TableStatLevel.NONE,
- warehouse_id=warehouse_id,
- )
- none_time = time.time() - start
-
- # SIMPLE level (should be slower due to stats collection)
- start = time.time()
- get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["orders"], # Different table to avoid cache
- table_stat_level=TableStatLevel.SIMPLE,
- warehouse_id=warehouse_id,
- )
- simple_time = time.time() - start
-
- # NONE should generally be faster (though not guaranteed due to caching)
- # Just ensure both complete without error
- assert none_time >= 0
- assert simple_time >= 0
-
-
-@pytest.mark.integration
-class TestTableStatLevelSimple:
- """Tests for TableStatLevel.SIMPLE (basic stats with caching)."""
-
- def test_stat_level_simple_has_basic_stats(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should return basic column statistics."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers"],
- table_stat_level=TableStatLevel.SIMPLE,
- warehouse_id=warehouse_id,
- )
-
- assert len(result.tables) == 1
- table = result.tables[0]
-
- # Should have DDL
- assert table.ddl is not None
-
- # Should have column details
- assert table.column_details is not None
- assert len(table.column_details) > 0
-
- # Check a known column
- if "name" in table.column_details:
- name_col = table.column_details["name"]
- assert name_col.name == "name"
- assert name_col.data_type is not None
-
- def test_stat_level_simple_excludes_heavy_stats(self, warehouse_id, test_catalog, test_schema, test_tables):
- """SIMPLE level should not include histograms/percentiles."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["orders"],
- table_stat_level=TableStatLevel.SIMPLE,
- warehouse_id=warehouse_id,
- )
-
- table = result.tables[0]
- for col_name, col in table.column_details.items():
- # Heavy stats should be None in SIMPLE mode
- assert col.histogram is None, f"Column {col_name} should not have histogram"
- assert col.stddev is None, f"Column {col_name} should not have stddev"
- assert col.q1 is None, f"Column {col_name} should not have q1"
- assert col.median is None, f"Column {col_name} should not have median"
- assert col.q3 is None, f"Column {col_name} should not have q3"
-
- def test_caching_works(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Second call should use cache and be faster."""
- import time
-
- # First call (cache miss)
- start = time.time()
- result1 = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["products"],
- table_stat_level=TableStatLevel.SIMPLE,
- warehouse_id=warehouse_id,
- )
-
- # Second call (should hit cache)
- start = time.time()
- result2 = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["products"],
- table_stat_level=TableStatLevel.SIMPLE,
- warehouse_id=warehouse_id,
- )
- second_time = time.time() - start
-
- # Results should be equivalent
- assert len(result1.tables) == len(result2.tables)
-
- # Second call should generally be faster (cache hit)
- # Note: Not always guaranteed due to other factors
- assert second_time >= 0
-
-
-@pytest.mark.integration
-class TestTableStatLevelDetailed:
- """Tests for TableStatLevel.DETAILED (full stats)."""
-
- def test_stat_level_detailed_has_all_stats(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Should return all column statistics including heavy ones."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["orders"],
- table_stat_level=TableStatLevel.DETAILED,
- warehouse_id=warehouse_id,
- )
-
- assert len(result.tables) == 1
- table = result.tables[0]
-
- # Should have column details
- assert table.column_details is not None
-
- # Check numeric column (amount) has detailed stats
- if "amount" in table.column_details:
- amount_col = table.column_details["amount"]
- # These should be present for numeric columns
- assert amount_col.min is not None or amount_col.max is not None
- # Histogram may or may not be present depending on data
-
- def test_stat_level_detailed_has_row_count(self, warehouse_id, test_catalog, test_schema, test_tables):
- """DETAILED level should include total row count."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers"],
- table_stat_level=TableStatLevel.DETAILED,
- warehouse_id=warehouse_id,
- )
-
- table = result.tables[0]
- assert table.total_rows is not None
- assert table.total_rows == 5 # We inserted 5 customers
-
- def test_stat_level_detailed_has_sample_data(self, warehouse_id, test_catalog, test_schema, test_tables):
- """DETAILED level should include sample data."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers"],
- table_stat_level=TableStatLevel.DETAILED,
- warehouse_id=warehouse_id,
- )
-
- table = result.tables[0]
- assert table.sample_data is not None
- assert len(table.sample_data) > 0
- assert len(table.sample_data) <= 10 # SAMPLE_ROW_COUNT
-
- def test_categorical_columns_have_value_counts(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Categorical columns with low cardinality should have value_counts."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["orders"],
- table_stat_level=TableStatLevel.DETAILED,
- warehouse_id=warehouse_id,
- )
-
- table = result.tables[0]
-
- # status column is categorical with 3 values
- if "status" in table.column_details:
- status_col = table.column_details["status"]
- # Should have value_counts for low-cardinality categorical
- if status_col.value_counts:
- assert "completed" in status_col.value_counts
- assert "pending" in status_col.value_counts
-
-
-@pytest.mark.integration
-class TestTableInfoStructure:
- """Tests for TableInfo structure and content."""
-
- def test_table_info_has_full_name(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Table name should be fully qualified."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers"],
- warehouse_id=warehouse_id,
- )
-
- table = result.tables[0]
- expected_full_name = f"{test_catalog}.{test_schema}.customers"
- assert table.name == expected_full_name
-
- def test_ddl_contains_columns(self, warehouse_id, test_catalog, test_schema, test_tables):
- """DDL should contain column definitions."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers"],
- table_stat_level=TableStatLevel.NONE,
- warehouse_id=warehouse_id,
- )
-
- table = result.tables[0]
- ddl_upper = table.ddl.upper()
-
- assert "CUSTOMER_ID" in ddl_upper
- assert "NAME" in ddl_upper
- assert "EMAIL" in ddl_upper
-
- def test_column_types_detected(self, warehouse_id, test_catalog, test_schema, test_tables):
- """Column types should be correctly detected."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["products"],
- table_stat_level=TableStatLevel.DETAILED,
- warehouse_id=warehouse_id,
- )
-
- table = result.tables[0]
- cols = table.column_details
-
- # Check various column types are detected
- if "price" in cols:
- assert cols["price"].data_type in ["numeric", "double", "float"]
-
- if "tags" in cols:
- assert cols["tags"].data_type == "array"
-
- if "created_at" in cols:
- assert cols["created_at"].data_type == "timestamp"
-
-
-@pytest.mark.integration
-class TestTableSchemaResult:
- """Tests for TableSchemaResult methods."""
-
- def test_table_count_property(self, warehouse_id, test_catalog, test_schema, test_tables):
- """table_count property should return correct count."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers", "orders"],
- warehouse_id=warehouse_id,
- )
-
- assert result.table_count == 2
-
- def test_keep_basic_stats_method(self, warehouse_id, test_catalog, test_schema, test_tables):
- """keep_basic_stats() should remove heavy stats."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["orders"],
- table_stat_level=TableStatLevel.DETAILED,
- warehouse_id=warehouse_id,
- )
-
- basic_result = result.keep_basic_stats()
-
- # Should still have tables
- assert len(basic_result.tables) == len(result.tables)
-
- # But without heavy stats
- for table in basic_result.tables:
- if table.column_details:
- for col in table.column_details.values():
- assert col.histogram is None
- assert col.stddev is None
-
- def test_remove_stats_method(self, warehouse_id, test_catalog, test_schema, test_tables):
- """remove_stats() should remove all column details."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers"],
- table_stat_level=TableStatLevel.DETAILED,
- warehouse_id=warehouse_id,
- )
-
- minimal_result = result.remove_stats()
-
- assert len(minimal_result.tables) == 1
- table = minimal_result.tables[0]
-
- # DDL should still be present
- assert table.ddl is not None
-
- # But no column details
- assert table.column_details is None
- assert table.total_rows is None
-
-
-@pytest.mark.integration
-class TestAutoWarehouseSelection:
- """Tests for automatic warehouse selection."""
-
- def test_auto_selects_warehouse(self, test_catalog, test_schema, test_tables):
- """Should auto-select warehouse if not provided."""
- result = get_table_stats_and_schema(
- catalog=test_catalog,
- schema=test_schema,
- table_names=["customers"],
- # warehouse_id not provided
- )
-
- assert len(result.tables) == 1
- assert result.tables[0].ddl is not None
diff --git a/databricks-tools-core/tests/integration/sql/test_volume_folder_stats.py b/databricks-tools-core/tests/integration/sql/test_volume_folder_stats.py
deleted file mode 100644
index a4c6ec2a..00000000
--- a/databricks-tools-core/tests/integration/sql/test_volume_folder_stats.py
+++ /dev/null
@@ -1,309 +0,0 @@
-"""
-Integration tests for volume folder statistics functions.
-
-Tests:
-- get_volume_folder_details
-- format="parquet" - read parquet data and compute stats
-- format="file" - list files only
-- TableStatLevel variants
-"""
-
-import pytest
-from databricks_tools_core.sql import (
- get_volume_folder_details,
- TableStatLevel,
- TableSchemaResult,
- DataSourceInfo,
- VolumeFileInfo,
-)
-
-
-@pytest.mark.integration
-class TestGetVolumeFolderDetailsParquet:
- """Tests for get_volume_folder_details with parquet format."""
-
- def test_parquet_basic_info(self, warehouse_id, test_catalog, test_schema, test_volume):
- """Should read parquet files and return basic info."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/parquet_data"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="parquet",
- table_stat_level=TableStatLevel.SIMPLE,
- warehouse_id=warehouse_id,
- )
-
- assert isinstance(result, TableSchemaResult)
- assert len(result.tables) == 1
- info = result.tables[0]
- assert isinstance(info, DataSourceInfo)
- assert info.error is None, f"Unexpected error: {info.error}"
- assert info.format == "parquet"
- assert info.total_files >= 1
- assert info.total_rows == 5 # We have 5 rows in test data
-
- def test_parquet_column_details(self, warehouse_id, test_catalog, test_schema, test_volume):
- """Should return column statistics for parquet data."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/parquet_data"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="parquet",
- table_stat_level=TableStatLevel.SIMPLE,
- warehouse_id=warehouse_id,
- )
-
- info = result.tables[0]
- assert info.error is None
- assert info.column_details is not None
- assert len(info.column_details) == 5 # id, name, age, salary, department
-
- # Check specific columns
- assert "id" in info.column_details
- assert "name" in info.column_details
- assert "age" in info.column_details
- assert "salary" in info.column_details
- assert "department" in info.column_details
-
- # Check age column has numeric stats
- age_col = info.column_details["age"]
- assert age_col.name == "age"
- assert age_col.total_count == 5
-
- def test_parquet_numeric_stats(self, warehouse_id, test_catalog, test_schema, test_volume):
- """Should compute min/max/avg for numeric columns."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/parquet_data"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="parquet",
- table_stat_level=TableStatLevel.SIMPLE,
- warehouse_id=warehouse_id,
- )
-
- info = result.tables[0]
- assert info.error is None
-
- # Check salary column stats - may be inferred as numeric or string
- # depending on how read_files returns the data
- salary_col = info.column_details["salary"]
- assert salary_col.total_count == 5
- assert salary_col.unique_count == 5 # All unique salaries
-
- # If recognized as numeric, should have min/max/avg
- # If recognized as string, should have samples
- if salary_col.min is not None:
- # Verify values (50000, 60000, 75000, 55000, 80000)
- assert salary_col.min == 50000.0
- assert salary_col.max == 80000.0
- else:
- # String type - should have samples
- assert salary_col.samples is not None
- assert len(salary_col.samples) > 0
-
- def test_parquet_detailed_has_samples(self, warehouse_id, test_catalog, test_schema, test_volume):
- """DETAILED level should include sample data."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/parquet_data"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="parquet",
- table_stat_level=TableStatLevel.DETAILED,
- warehouse_id=warehouse_id,
- )
-
- info = result.tables[0]
- assert info.error is None
- assert info.sample_data is not None
- assert len(info.sample_data) > 0
- assert len(info.sample_data) <= 5
-
- # Check sample data has expected columns
- first_row = info.sample_data[0]
- assert "id" in first_row
- assert "name" in first_row
-
- def test_parquet_stat_level_none(self, warehouse_id, test_catalog, test_schema, test_volume):
- """NONE level should return schema without stats."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/parquet_data"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="parquet",
- table_stat_level=TableStatLevel.NONE,
- warehouse_id=warehouse_id,
- )
-
- info = result.tables[0]
- assert info.error is None
- assert info.total_rows == 5 # Row count still present
- # Column details should have schema but minimal stats
- assert info.column_details is not None
-
-
-@pytest.mark.integration
-class TestGetVolumeFolderDetailsFile:
- """Tests for get_volume_folder_details with format='file' (listing only)."""
-
- def test_file_listing(self, test_catalog, test_schema, test_volume):
- """format='file' should list files without reading data."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/txt_files"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="file",
- )
-
- assert isinstance(result, TableSchemaResult)
- info = result.tables[0]
- assert info.error is None, f"Unexpected error: {info.error}"
- assert info.format == "file"
- assert info.total_files == 3 # readme.txt, data.txt, notes.txt
-
- # Should have file list
- assert info.files is not None
- assert len(info.files) == 3
-
- # Check file info structure
- file_names = [f.name for f in info.files]
- assert "readme.txt" in file_names
- assert "data.txt" in file_names
- assert "notes.txt" in file_names
-
- def test_file_listing_has_size(self, test_catalog, test_schema, test_volume):
- """File listing should include file sizes."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/txt_files"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="file",
- )
-
- info = result.tables[0]
- assert info.error is None
- assert info.files is not None
-
- # All files should have size
- for file_info in info.files:
- assert isinstance(file_info, VolumeFileInfo)
- assert file_info.size_bytes is not None or file_info.is_directory
-
- def test_file_listing_no_data_reading(self, test_catalog, test_schema, test_volume):
- """format='file' should not read data (no column_details, no rows)."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/txt_files"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="file",
- )
-
- info = result.tables[0]
- assert info.error is None
- # Should not have data-related fields
- assert info.column_details is None
- assert info.total_rows is None
- assert info.sample_data is None
-
-
-@pytest.mark.integration
-class TestVolumeFolderPathFormats:
- """Tests for different volume path formats."""
-
- def test_short_path_format(self, warehouse_id, test_catalog, test_schema, test_volume):
- """Should accept catalog/schema/volume/path format."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/parquet_data"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="parquet",
- warehouse_id=warehouse_id,
- )
-
- info = result.tables[0]
- assert info.error is None
- assert "/Volumes/" in info.name
-
- def test_full_volumes_path_format(self, warehouse_id, test_catalog, test_schema, test_volume):
- """Should accept /Volumes/catalog/schema/volume/path format."""
- volume_path = f"/Volumes/{test_catalog}/{test_schema}/{test_volume}/parquet_data"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="parquet",
- warehouse_id=warehouse_id,
- )
-
- info = result.tables[0]
- assert info.error is None
- assert info.name == volume_path
-
-
-@pytest.mark.integration
-class TestVolumeFolderErrors:
- """Tests for error handling."""
-
- def test_nonexistent_path(self, test_catalog, test_schema, test_volume):
- """Should return error for non-existent path."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/nonexistent_folder"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="file",
- )
-
- info = result.tables[0]
- assert info.error is not None
- assert "not found" in info.error.lower() or "empty" in info.error.lower()
-
- def test_empty_folder(self, workspace_client, test_catalog, test_schema, test_volume):
- """Should handle empty folders gracefully."""
- # Create an empty folder by uploading and deleting a file
- volume_path = f"/Volumes/{test_catalog}/{test_schema}/{test_volume}/empty_folder"
-
- # Upload a temp file to create the folder
- from io import BytesIO
-
- temp_path = f"{volume_path}/temp.txt"
- workspace_client.files.upload(temp_path, BytesIO(b"temp"), overwrite=True)
- workspace_client.files.delete(temp_path)
-
- result = get_volume_folder_details(
- volume_path=f"{test_catalog}/{test_schema}/{test_volume}/empty_folder",
- format="file",
- )
-
- info = result.tables[0]
- # Should handle empty gracefully (either error or empty list)
- if info.error is None:
- assert info.total_files == 0
- else:
- assert "empty" in info.error.lower()
-
-
-@pytest.mark.integration
-class TestVolumeFolderResultMethods:
- """Tests for TableSchemaResult methods with volume folder data."""
-
- def test_keep_basic_stats(self, warehouse_id, test_catalog, test_schema, test_volume):
- """keep_basic_stats() should remove heavy stats."""
- volume_path = f"{test_catalog}/{test_schema}/{test_volume}/parquet_data"
-
- result = get_volume_folder_details(
- volume_path=volume_path,
- format="parquet",
- table_stat_level=TableStatLevel.DETAILED,
- warehouse_id=warehouse_id,
- )
-
- basic_result = result.keep_basic_stats()
-
- info = basic_result.tables[0]
- assert info.error is None
- assert info.column_details is not None
-
- # Should still have basic column info
- salary_col = info.column_details["salary"]
- assert salary_col.total_count == 5
-
- # Sample data should be removed
- assert info.sample_data is None
diff --git a/databricks-tools-core/tests/integration/sql/test_warehouse.py b/databricks-tools-core/tests/integration/sql/test_warehouse.py
deleted file mode 100644
index 146863a6..00000000
--- a/databricks-tools-core/tests/integration/sql/test_warehouse.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""
-Integration tests for SQL warehouse functions.
-
-Tests:
-- list_warehouses
-- get_best_warehouse
-"""
-
-import pytest
-from databricks_tools_core.sql import list_warehouses, get_best_warehouse
-
-
-@pytest.mark.integration
-class TestListWarehouses:
- """Tests for list_warehouses function."""
-
- def test_list_warehouses_returns_list(self):
- """Should return a list of warehouses."""
- warehouses = list_warehouses()
-
- assert isinstance(warehouses, list)
- assert len(warehouses) > 0, "Expected at least one warehouse"
-
- def test_list_warehouses_structure(self):
- """Each warehouse should have expected fields."""
- warehouses = list_warehouses()
-
- for w in warehouses:
- assert "id" in w, "Warehouse should have 'id'"
- assert "name" in w, "Warehouse should have 'name'"
- assert "state" in w, "Warehouse should have 'state'"
- assert w["id"] is not None
- assert w["name"] is not None
-
- def test_list_warehouses_running_first(self):
- """Running warehouses should be listed first."""
- warehouses = list_warehouses()
-
- # Find first non-running warehouse
- first_non_running_idx = None
- for i, w in enumerate(warehouses):
- if w["state"] != "RUNNING":
- first_non_running_idx = i
- break
-
- # If we found a non-running warehouse, all before it should be running
- if first_non_running_idx is not None:
- for i in range(first_non_running_idx):
- assert warehouses[i]["state"] == "RUNNING", f"Warehouse at index {i} should be RUNNING"
-
- def test_list_warehouses_with_limit(self):
- """Should respect the limit parameter."""
- warehouses_5 = list_warehouses(limit=5)
- warehouses_2 = list_warehouses(limit=2)
-
- assert len(warehouses_2) <= 2
- assert len(warehouses_5) <= 5
-
-
-@pytest.mark.integration
-class TestGetBestWarehouse:
- """Tests for get_best_warehouse function."""
-
- def test_get_best_warehouse_returns_string(self):
- """Should return a warehouse ID string."""
- warehouse_id = get_best_warehouse()
-
- # May be None if no warehouses available
- if warehouse_id is not None:
- assert isinstance(warehouse_id, str)
- assert len(warehouse_id) > 0
-
- def test_get_best_warehouse_returns_valid_id(self):
- """Returned ID should be in the warehouse list."""
- warehouse_id = get_best_warehouse()
-
- if warehouse_id is not None:
- warehouses = list_warehouses()
- warehouse_ids = [w["id"] for w in warehouses]
- assert warehouse_id in warehouse_ids, f"Warehouse ID {warehouse_id} not found in list"
-
- def test_get_best_warehouse_prefers_running(self):
- """Should prefer running warehouses."""
- warehouse_id = get_best_warehouse()
-
- if warehouse_id is not None:
- warehouses = list_warehouses()
- selected = next((w for w in warehouses if w["id"] == warehouse_id), None)
-
- # If there are any running warehouses, selected should be running
- running_exists = any(w["state"] == "RUNNING" for w in warehouses)
- if running_exists:
- assert selected["state"] == "RUNNING", "Should select a running warehouse when available"
diff --git a/databricks-tools-core/tests/integration/unity_catalog/__init__.py b/databricks-tools-core/tests/integration/unity_catalog/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/databricks-tools-core/tests/integration/unity_catalog/conftest.py b/databricks-tools-core/tests/integration/unity_catalog/conftest.py
deleted file mode 100644
index 5f495c30..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/conftest.py
+++ /dev/null
@@ -1,232 +0,0 @@
-"""
-Pytest fixtures for Unity Catalog integration tests.
-
-Provides fixtures for:
-- UC test schema (dedicated for UC tests to avoid conflicts)
-- Test tables for tag/security/grant tests
-- Cleanup helpers for catalogs, volumes, shares, recipients, monitors
-"""
-
-import logging
-import uuid
-import pytest
-
-from databricks_tools_core.auth import get_workspace_client
-from databricks_tools_core.sql import execute_sql
-
-logger = logging.getLogger(__name__)
-
-# Dedicated UC test schema (separate from root conftest test_schema)
-UC_TEST_SCHEMA = "uc_test_schema"
-UC_TEST_PREFIX = "uc_test"
-
-
-@pytest.fixture(scope="module")
-def uc_test_schema(test_catalog: str, warehouse_id: str) -> str:
- """
- Create a dedicated schema for UC integration tests.
-
- Drops and recreates to ensure clean state.
- Yields the schema name.
- """
- full_schema_name = f"{test_catalog}.{UC_TEST_SCHEMA}"
-
- # Drop if exists
- try:
- logger.info(f"Dropping existing UC test schema: {full_schema_name}")
- w = get_workspace_client()
- w.schemas.delete(full_schema_name, force=True)
- except Exception as e:
- logger.debug(f"Schema drop failed (may not exist): {e}")
-
- # Create fresh
- logger.info(f"Creating UC test schema: {full_schema_name}")
- try:
- w = get_workspace_client()
- w.schemas.create(name=UC_TEST_SCHEMA, catalog_name=test_catalog)
- except Exception as e:
- if "already exists" in str(e).lower():
- logger.info(f"Schema already exists, reusing: {full_schema_name}")
- else:
- raise
-
- yield UC_TEST_SCHEMA
-
- # Cleanup
- try:
- logger.info(f"Cleaning up UC test schema: {full_schema_name}")
- w = get_workspace_client()
- w.schemas.delete(full_schema_name, force=True)
- except Exception as e:
- logger.warning(f"Failed to cleanup UC test schema: {e}")
-
-
-@pytest.fixture(scope="module")
-def uc_test_table(
- test_catalog: str,
- uc_test_schema: str,
- warehouse_id: str,
-) -> str:
- """
- Create a test table for UC operations (tags, grants, security policies).
-
- Returns the full table name.
- """
- table_name = f"{UC_TEST_PREFIX}_employees"
- full_name = f"{test_catalog}.{uc_test_schema}.{table_name}"
-
- logger.info(f"Creating UC test table: {full_name}")
- execute_sql(
- sql_query=f"""
- CREATE OR REPLACE TABLE {full_name} (
- employee_id BIGINT,
- name STRING,
- email STRING,
- department STRING,
- salary DOUBLE,
- hire_date DATE,
- is_active BOOLEAN
- )
- """,
- warehouse_id=warehouse_id,
- )
-
- execute_sql(
- sql_query=f"""
- INSERT INTO {full_name} VALUES
- (1, 'Alice Smith', 'alice@company.com', 'Engineering', 120000.00, '2022-01-15', true),
- (2, 'Bob Johnson', 'bob@company.com', 'Marketing', 95000.00, '2022-03-20', true),
- (3, 'Charlie Brown', 'charlie@company.com', 'Engineering', 110000.00, '2022-06-10', false),
- (4, 'Diana Ross', 'diana@company.com', 'Finance', 105000.00, '2023-01-05', true),
- (5, 'Eve Wilson', 'eve@company.com', 'Engineering', 130000.00, '2023-04-12', true)
- """,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"UC test table created: {full_name}")
- yield full_name
-
- # Cleanup handled by schema drop
-
-
-@pytest.fixture(scope="function")
-def unique_name() -> str:
- """Generate a unique name suffix for test resources."""
- return str(uuid.uuid4())[:8]
-
-
-@pytest.fixture(scope="function")
-def cleanup_volumes():
- """
- Track and cleanup volumes created during tests.
-
- Usage:
- def test_create_volume(cleanup_volumes):
- vol = create_volume(...)
- cleanup_volumes(full_volume_name)
- """
- from databricks_tools_core.unity_catalog import delete_volume
-
- volumes_to_cleanup = []
-
- def register(full_volume_name: str):
- if full_volume_name not in volumes_to_cleanup:
- volumes_to_cleanup.append(full_volume_name)
- logger.info(f"Registered volume for cleanup: {full_volume_name}")
-
- yield register
-
- for vol_name in volumes_to_cleanup:
- try:
- logger.info(f"Cleaning up volume: {vol_name}")
- delete_volume(vol_name)
- except Exception as e:
- logger.warning(f"Failed to cleanup volume {vol_name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_shares():
- """Track and cleanup shares created during tests."""
- from databricks_tools_core.unity_catalog import delete_share
-
- shares_to_cleanup = []
-
- def register(share_name: str):
- if share_name not in shares_to_cleanup:
- shares_to_cleanup.append(share_name)
- logger.info(f"Registered share for cleanup: {share_name}")
-
- yield register
-
- for name in shares_to_cleanup:
- try:
- logger.info(f"Cleaning up share: {name}")
- delete_share(name)
- except Exception as e:
- logger.warning(f"Failed to cleanup share {name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_recipients():
- """Track and cleanup recipients created during tests."""
- from databricks_tools_core.unity_catalog import delete_recipient
-
- recipients_to_cleanup = []
-
- def register(recipient_name: str):
- if recipient_name not in recipients_to_cleanup:
- recipients_to_cleanup.append(recipient_name)
- logger.info(f"Registered recipient for cleanup: {recipient_name}")
-
- yield register
-
- for name in recipients_to_cleanup:
- try:
- logger.info(f"Cleaning up recipient: {name}")
- delete_recipient(name)
- except Exception as e:
- logger.warning(f"Failed to cleanup recipient {name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_monitors():
- """Track and cleanup monitors created during tests."""
- from databricks_tools_core.unity_catalog import delete_monitor
-
- monitors_to_cleanup = []
-
- def register(table_name: str):
- if table_name not in monitors_to_cleanup:
- monitors_to_cleanup.append(table_name)
- logger.info(f"Registered monitor for cleanup: {table_name}")
-
- yield register
-
- for tbl in monitors_to_cleanup:
- try:
- logger.info(f"Cleaning up monitor on: {tbl}")
- delete_monitor(tbl)
- except Exception as e:
- logger.warning(f"Failed to cleanup monitor {tbl}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_functions():
- """Track and cleanup UC functions created during tests."""
- from databricks_tools_core.unity_catalog import delete_function
-
- functions_to_cleanup = []
-
- def register(full_function_name: str):
- if full_function_name not in functions_to_cleanup:
- functions_to_cleanup.append(full_function_name)
- logger.info(f"Registered function for cleanup: {full_function_name}")
-
- yield register
-
- for fn_name in functions_to_cleanup:
- try:
- logger.info(f"Cleaning up function: {fn_name}")
- delete_function(fn_name, force=True)
- except Exception as e:
- logger.warning(f"Failed to cleanup function {fn_name}: {e}")
diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_catalogs.py b/databricks-tools-core/tests/integration/unity_catalog/test_catalogs.py
deleted file mode 100644
index 61e1c25a..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/test_catalogs.py
+++ /dev/null
@@ -1,144 +0,0 @@
-"""
-Integration tests for Unity Catalog - Catalog operations.
-
-Tests:
-- list_catalogs
-- get_catalog
-- create_catalog
-- update_catalog
-- delete_catalog
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.unity_catalog import (
- list_catalogs,
- get_catalog,
- create_catalog,
- update_catalog,
- delete_catalog,
-)
-
-logger = logging.getLogger(__name__)
-
-TEST_CATALOG_PREFIX = "uc_test_catalog"
-
-
-@pytest.mark.integration
-class TestListCatalogs:
- """Tests for listing catalogs."""
-
- def test_list_catalogs(self):
- """Should list catalogs in the metastore."""
- catalogs = list_catalogs()
-
- logger.info(f"Found {len(catalogs)} catalogs")
- for cat in catalogs[:5]:
- logger.info(f" - {cat.name}")
-
- assert isinstance(catalogs, list)
- assert len(catalogs) > 0, "Should have at least one catalog"
-
- def test_list_catalogs_contains_main(self):
- """Should contain the main catalog."""
- catalogs = list_catalogs()
- catalog_names = [c.name for c in catalogs]
-
- # Most workspaces have a 'main' catalog
- assert any(name in catalog_names for name in ["main", "hive_metastore"]), (
- f"Expected 'main' or 'hive_metastore' in catalogs: {catalog_names[:10]}"
- )
-
-
-@pytest.mark.integration
-class TestGetCatalog:
- """Tests for getting catalog details."""
-
- def test_get_catalog(self, test_catalog: str):
- """Should get catalog details by name."""
- catalog = get_catalog(test_catalog)
-
- logger.info(f"Got catalog: {catalog.name} (owner: {catalog.owner})")
-
- assert catalog.name == test_catalog
- assert catalog.owner is not None
-
- def test_get_catalog_not_found(self):
- """Should raise error for non-existent catalog."""
- with pytest.raises(Exception) as exc_info:
- get_catalog("nonexistent_catalog_xyz_12345")
-
- error_msg = str(exc_info.value).lower()
- logger.info(f"Expected error: {exc_info.value}")
- assert "not found" in error_msg or "does not exist" in error_msg
-
-
-@pytest.mark.integration
-class TestCatalogCRUD:
- """Tests for catalog create, update, delete lifecycle."""
-
- def test_create_and_delete_catalog(self, unique_name: str):
- """Should create and delete a catalog."""
- catalog_name = f"{TEST_CATALOG_PREFIX}_{unique_name}"
-
- try:
- # Create
- logger.info(f"Creating catalog: {catalog_name}")
- catalog = create_catalog(
- name=catalog_name,
- comment="Integration test catalog",
- )
-
- assert catalog.name == catalog_name
- assert catalog.comment == "Integration test catalog"
- logger.info(f"Catalog created: {catalog.name}")
-
- # Verify via get
- fetched = get_catalog(catalog_name)
- assert fetched.name == catalog_name
-
- finally:
- # Cleanup
- try:
- logger.info(f"Deleting catalog: {catalog_name}")
- delete_catalog(catalog_name, force=True)
- logger.info(f"Catalog deleted: {catalog_name}")
- except Exception as e:
- logger.warning(f"Failed to cleanup catalog: {e}")
-
- def test_update_catalog_comment(self, unique_name: str):
- """Should update catalog comment."""
- catalog_name = f"{TEST_CATALOG_PREFIX}_{unique_name}"
-
- try:
- # Create
- create_catalog(name=catalog_name, comment="Original comment")
-
- # Update comment
- logger.info(f"Updating catalog comment: {catalog_name}")
- updated = update_catalog(
- catalog_name=catalog_name,
- comment="Updated comment",
- )
-
- assert updated.comment == "Updated comment"
- logger.info(f"Catalog comment updated: {updated.comment}")
-
- finally:
- try:
- delete_catalog(catalog_name, force=True)
- except Exception as e:
- logger.warning(f"Failed to cleanup catalog: {e}")
-
- def test_update_catalog_no_fields_raises(self, test_catalog: str):
- """Should raise ValueError when no fields provided."""
- with pytest.raises(ValueError) as exc_info:
- update_catalog(catalog_name=test_catalog)
-
- assert "at least one field" in str(exc_info.value).lower()
-
- def test_delete_catalog_not_found(self):
- """Should raise error when deleting non-existent catalog."""
- with pytest.raises(Exception):
- delete_catalog("nonexistent_catalog_xyz_12345")
diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_connections.py b/databricks-tools-core/tests/integration/unity_catalog/test_connections.py
deleted file mode 100644
index 904b96e2..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/test_connections.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""
-Integration tests for Unity Catalog - Connection operations.
-
-Tests:
-- list_connections
-- get_connection
-
-Note: Creating connections requires external database credentials
-(Snowflake, PostgreSQL, etc.) and is not tested in CI.
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.unity_catalog import (
- list_connections,
- get_connection,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestListConnections:
- """Tests for listing Lakehouse Federation connections."""
-
- def test_list_connections(self):
- """Should list foreign connections (may be empty)."""
- connections = list_connections()
-
- logger.info(f"Found {len(connections)} connections")
- for conn in connections[:5]:
- logger.info(f" - {conn.name} (type: {conn.connection_type})")
-
- assert isinstance(connections, list)
-
- def test_list_connections_structure(self):
- """Should return ConnectionInfo objects."""
- connections = list_connections()
-
- if len(connections) > 0:
- conn = connections[0]
- assert hasattr(conn, "name")
- assert hasattr(conn, "connection_type")
- logger.info(f"First connection: {conn.name} ({conn.connection_type})")
-
-
-@pytest.mark.integration
-class TestGetConnection:
- """Tests for getting a specific connection."""
-
- def test_get_existing_connection(self):
- """Should get details of an existing connection."""
- connections = list_connections()
- if not connections:
- pytest.skip("No connections in workspace")
-
- conn_name = connections[0].name
- logger.info(f"Getting connection: {conn_name}")
-
- conn = get_connection(conn_name)
- assert conn.name == conn_name
- logger.info(f"Got connection: {conn.name} (type: {conn.connection_type})")
-
- def test_get_nonexistent_connection(self):
- """Should raise error for non-existent connection."""
- with pytest.raises(Exception):
- get_connection("nonexistent_connection_xyz_12345")
diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_grants.py b/databricks-tools-core/tests/integration/unity_catalog/test_grants.py
deleted file mode 100644
index ddbf52db..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/test_grants.py
+++ /dev/null
@@ -1,188 +0,0 @@
-"""
-Integration tests for Unity Catalog - Grant operations.
-
-Tests:
-- grant_privileges
-- revoke_privileges
-- get_grants
-- get_effective_grants
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.unity_catalog import (
- grant_privileges,
- revoke_privileges,
- get_grants,
- get_effective_grants,
-)
-from databricks_tools_core.auth import get_workspace_client
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture(scope="module")
-def current_user() -> str:
- """Get the current user's email for grant tests."""
- w = get_workspace_client()
- return w.current_user.me().user_name
-
-
-@pytest.mark.integration
-class TestGetGrants:
- """Tests for getting grants on objects."""
-
- def test_get_grants_on_catalog(self, test_catalog: str):
- """Should get grants on a catalog."""
- result = get_grants(
- securable_type="catalog",
- full_name=test_catalog,
- )
-
- logger.info(f"Grants on catalog {test_catalog}: {len(result['assignments'])} assignments")
- for a in result["assignments"]:
- logger.info(f" - {a['principal']}: {a['privileges']}")
-
- assert result["securable_type"] == "catalog"
- assert result["full_name"] == test_catalog
- assert isinstance(result["assignments"], list)
-
- def test_get_grants_on_schema(self, test_catalog: str, uc_test_schema: str):
- """Should get grants on a schema."""
- full_name = f"{test_catalog}.{uc_test_schema}"
- result = get_grants(
- securable_type="schema",
- full_name=full_name,
- )
-
- logger.info(f"Grants on schema {full_name}: {len(result['assignments'])} assignments")
- assert result["securable_type"] == "schema"
- assert isinstance(result["assignments"], list)
-
- def test_get_grants_filtered_by_principal(self, test_catalog: str, current_user: str):
- """Should filter grants by principal."""
- result = get_grants(
- securable_type="catalog",
- full_name=test_catalog,
- principal=current_user,
- )
-
- logger.info(f"Grants for {current_user}: {result['assignments']}")
- assert isinstance(result["assignments"], list)
-
-
-@pytest.mark.integration
-class TestGetEffectiveGrants:
- """Tests for getting effective (inherited) grants."""
-
- def test_get_effective_grants_on_table(self, uc_test_table: str):
- """Should get effective grants on a table (including inherited)."""
- result = get_effective_grants(
- securable_type="table",
- full_name=uc_test_table,
- )
-
- logger.info(f"Effective grants on {uc_test_table}: {len(result['effective_assignments'])} assignments")
- for a in result["effective_assignments"][:3]:
- logger.info(f" - {a['principal']}: {len(a['privileges'])} privileges")
-
- assert result["securable_type"] == "table"
- assert isinstance(result["effective_assignments"], list)
-
-
-@pytest.mark.integration
-class TestGrantRevoke:
- """Tests for granting and revoking privileges."""
-
- def test_grant_and_revoke_on_schema(self, test_catalog: str, uc_test_schema: str):
- """Should grant and then revoke SELECT on a schema."""
- full_name = f"{test_catalog}.{uc_test_schema}"
- principal = "account users"
-
- # Grant
- logger.info(f"Granting SELECT to '{principal}' on {full_name}")
- grant_result = grant_privileges(
- securable_type="schema",
- full_name=full_name,
- principal=principal,
- privileges=["SELECT"],
- )
-
- assert grant_result["status"] == "granted"
- assert grant_result["principal"] == principal
- logger.info(f"Grant result: {grant_result['status']}")
-
- # Verify grant exists
- grants = get_grants(
- securable_type="schema",
- full_name=full_name,
- principal=principal,
- )
- principals = [a["principal"] for a in grants["assignments"]]
- assert principal in principals, f"Expected '{principal}' in grants"
-
- # Revoke
- logger.info(f"Revoking SELECT from '{principal}' on {full_name}")
- revoke_result = revoke_privileges(
- securable_type="schema",
- full_name=full_name,
- principal=principal,
- privileges=["SELECT"],
- )
-
- assert revoke_result["status"] == "revoked"
- logger.info(f"Revoke result: {revoke_result['status']}")
-
- def test_grant_multiple_privileges(self, test_catalog: str, uc_test_schema: str):
- """Should grant multiple privileges at once."""
- full_name = f"{test_catalog}.{uc_test_schema}"
- principal = "account users"
-
- try:
- result = grant_privileges(
- securable_type="schema",
- full_name=full_name,
- principal=principal,
- privileges=["SELECT", "MODIFY"],
- )
-
- assert result["status"] == "granted"
- assert len(result["privileges"]) == 2
- logger.info(f"Multiple privileges granted: {result['privileges']}")
-
- finally:
- # Cleanup
- try:
- revoke_privileges(
- securable_type="schema",
- full_name=full_name,
- principal=principal,
- privileges=["SELECT", "MODIFY"],
- )
- except Exception:
- pass
-
- def test_invalid_securable_type_raises(self):
- """Should raise ValueError for invalid securable type."""
- with pytest.raises(ValueError) as exc_info:
- grant_privileges(
- securable_type="invalid_type",
- full_name="test",
- principal="user",
- privileges=["SELECT"],
- )
-
- assert "invalid securable_type" in str(exc_info.value).lower()
-
- def test_invalid_privilege_raises(self, test_catalog: str):
- """Should raise ValueError for invalid privilege."""
- with pytest.raises(ValueError) as exc_info:
- grant_privileges(
- securable_type="catalog",
- full_name=test_catalog,
- principal="account users",
- privileges=["INVALID_PRIVILEGE_XYZ"],
- )
-
- assert "invalid privilege" in str(exc_info.value).lower()
diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_monitors.py b/databricks-tools-core/tests/integration/unity_catalog/test_monitors.py
deleted file mode 100644
index 5134baae..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/test_monitors.py
+++ /dev/null
@@ -1,111 +0,0 @@
-"""
-Integration tests for Unity Catalog - Quality Monitor operations.
-
-Tests:
-- create_monitor
-- get_monitor
-- run_monitor_refresh
-- list_monitor_refreshes
-- delete_monitor
-
-Note: Quality monitors require Lakehouse Monitoring to be enabled.
-Tests skip if the feature is not available.
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.unity_catalog import (
- create_monitor,
- get_monitor,
- list_monitor_refreshes,
-)
-
-logger = logging.getLogger(__name__)
-
-UC_TEST_PREFIX = "uc_test"
-
-
-@pytest.mark.integration
-@pytest.mark.slow
-class TestMonitorCRUD:
- """Tests for monitor create, get, refresh, delete lifecycle."""
-
- def test_create_and_delete_monitor(
- self,
- test_catalog: str,
- uc_test_schema: str,
- uc_test_table: str,
- warehouse_id: str,
- cleanup_monitors,
- ):
- """Should create and delete a quality monitor."""
- output_schema = f"{test_catalog}.{uc_test_schema}"
-
- try:
- logger.info(f"Creating monitor on: {uc_test_table}")
- monitor = create_monitor(
- table_name=uc_test_table,
- output_schema_name=output_schema,
- )
- cleanup_monitors(uc_test_table)
-
- assert monitor is not None
- logger.info(f"Monitor created: {monitor}")
-
- # Get monitor
- fetched = get_monitor(uc_test_table)
- assert fetched is not None
- logger.info(f"Monitor fetched: {fetched}")
-
- except Exception as e:
- if (
- "FEATURE_NOT_ENABLED" in str(e).upper()
- or "not enabled" in str(e).lower()
- or "NOT_FOUND" in str(e).upper()
- ):
- pytest.skip(f"Quality monitors not available: {e}")
- raise
-
- def test_list_monitor_refreshes(
- self,
- test_catalog: str,
- uc_test_schema: str,
- uc_test_table: str,
- warehouse_id: str,
- cleanup_monitors,
- ):
- """Should list refresh history for a monitor."""
- output_schema = f"{test_catalog}.{uc_test_schema}"
-
- try:
- # Create monitor
- create_monitor(
- table_name=uc_test_table,
- output_schema_name=output_schema,
- )
- cleanup_monitors(uc_test_table)
-
- # List refreshes (may be empty if no refresh has run)
- refreshes = list_monitor_refreshes(uc_test_table)
- assert isinstance(refreshes, list)
- logger.info(f"Monitor refreshes: {len(refreshes)}")
-
- except Exception as e:
- if (
- "FEATURE_NOT_ENABLED" in str(e).upper()
- or "not enabled" in str(e).lower()
- or "NOT_FOUND" in str(e).upper()
- ):
- pytest.skip(f"Quality monitors not available: {e}")
- raise
-
- def test_get_nonexistent_monitor(self):
- """Should raise error for table without monitor."""
- try:
- with pytest.raises(Exception):
- get_monitor("nonexistent_catalog.nonexistent_schema.nonexistent_table")
- except Exception as e:
- if "FEATURE_NOT_ENABLED" in str(e).upper():
- pytest.skip("Quality monitors not available")
- raise
diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_security_policies.py b/databricks-tools-core/tests/integration/unity_catalog/test_security_policies.py
deleted file mode 100644
index 3f2d661d..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/test_security_policies.py
+++ /dev/null
@@ -1,210 +0,0 @@
-"""
-Integration tests for Unity Catalog - Security Policy operations.
-
-Tests:
-- create_security_function
-- set_row_filter
-- drop_row_filter
-- set_column_mask
-- drop_column_mask
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.unity_catalog import (
- create_security_function,
- set_row_filter,
- drop_row_filter,
- set_column_mask,
- drop_column_mask,
-)
-
-logger = logging.getLogger(__name__)
-
-UC_TEST_PREFIX = "uc_test"
-
-
-@pytest.mark.integration
-class TestCreateSecurityFunction:
- """Tests for creating security functions."""
-
- def test_create_row_filter_function(
- self,
- test_catalog: str,
- uc_test_schema: str,
- unique_name: str,
- warehouse_id: str,
- cleanup_functions,
- ):
- """Should create a row filter function."""
- fn_name = f"{test_catalog}.{uc_test_schema}.{UC_TEST_PREFIX}_row_filter_{unique_name}"
- cleanup_functions(fn_name)
-
- result = create_security_function(
- function_name=fn_name,
- parameter_name="dept",
- parameter_type="STRING",
- return_type="BOOLEAN",
- function_body="RETURN dept = 'Engineering'",
- comment="Test row filter function",
- warehouse_id=warehouse_id,
- )
-
- assert result["status"] == "created"
- assert result["function_name"] == fn_name
- logger.info(f"Row filter function created: {fn_name}")
-
- def test_create_column_mask_function(
- self,
- test_catalog: str,
- uc_test_schema: str,
- unique_name: str,
- warehouse_id: str,
- cleanup_functions,
- ):
- """Should create a column mask function."""
- fn_name = f"{test_catalog}.{uc_test_schema}.{UC_TEST_PREFIX}_col_mask_{unique_name}"
- cleanup_functions(fn_name)
-
- result = create_security_function(
- function_name=fn_name,
- parameter_name="val",
- parameter_type="STRING",
- return_type="STRING",
- function_body="RETURN CASE WHEN is_account_group_member('admins') THEN val ELSE '***' END",
- warehouse_id=warehouse_id,
- )
-
- assert result["status"] == "created"
- logger.info(f"Column mask function created: {fn_name}")
-
-
-@pytest.mark.integration
-class TestRowFilter:
- """Tests for row filter operations."""
-
- def test_set_and_drop_row_filter(
- self,
- test_catalog: str,
- uc_test_schema: str,
- uc_test_table: str,
- unique_name: str,
- warehouse_id: str,
- cleanup_functions,
- ):
- """Should set and then drop a row filter on a table."""
- fn_name = f"{test_catalog}.{uc_test_schema}.{UC_TEST_PREFIX}_rf_{unique_name}"
- cleanup_functions(fn_name)
-
- # Create filter function
- create_security_function(
- function_name=fn_name,
- parameter_name="dept",
- parameter_type="STRING",
- return_type="BOOLEAN",
- function_body="RETURN dept = 'Engineering'",
- warehouse_id=warehouse_id,
- )
-
- # Set row filter
- logger.info(f"Setting row filter on {uc_test_table}")
- set_result = set_row_filter(
- table_name=uc_test_table,
- filter_function=fn_name,
- filter_columns=["department"],
- warehouse_id=warehouse_id,
- )
-
- assert set_result["status"] == "row_filter_set"
- assert set_result["function"] == fn_name
- logger.info(f"Row filter set: {set_result['sql']}")
-
- # Drop row filter
- logger.info(f"Dropping row filter from {uc_test_table}")
- drop_result = drop_row_filter(
- table_name=uc_test_table,
- warehouse_id=warehouse_id,
- )
-
- assert drop_result["status"] == "row_filter_dropped"
- logger.info(f"Row filter dropped: {drop_result['sql']}")
-
-
-@pytest.mark.integration
-class TestColumnMask:
- """Tests for column mask operations."""
-
- def test_set_and_drop_column_mask(
- self,
- test_catalog: str,
- uc_test_schema: str,
- uc_test_table: str,
- unique_name: str,
- warehouse_id: str,
- cleanup_functions,
- ):
- """Should set and then drop a column mask."""
- fn_name = f"{test_catalog}.{uc_test_schema}.{UC_TEST_PREFIX}_cm_{unique_name}"
- cleanup_functions(fn_name)
-
- # Create mask function
- create_security_function(
- function_name=fn_name,
- parameter_name="val",
- parameter_type="STRING",
- return_type="STRING",
- function_body="RETURN CASE WHEN is_account_group_member('admins') THEN val ELSE '***@***.com' END",
- warehouse_id=warehouse_id,
- )
-
- # Set column mask
- logger.info(f"Setting column mask on {uc_test_table}.email")
- set_result = set_column_mask(
- table_name=uc_test_table,
- column_name="email",
- mask_function=fn_name,
- warehouse_id=warehouse_id,
- )
-
- assert set_result["status"] == "column_mask_set"
- assert set_result["column"] == "email"
- logger.info(f"Column mask set: {set_result['sql']}")
-
- # Drop column mask
- logger.info(f"Dropping column mask from {uc_test_table}.email")
- drop_result = drop_column_mask(
- table_name=uc_test_table,
- column_name="email",
- warehouse_id=warehouse_id,
- )
-
- assert drop_result["status"] == "column_mask_dropped"
- logger.info(f"Column mask dropped: {drop_result['sql']}")
-
-
-@pytest.mark.integration
-class TestSecurityPolicyValidation:
- """Tests for input validation in security policy functions."""
-
- def test_invalid_identifier_raises(self):
- """Should raise ValueError for invalid SQL identifiers."""
- with pytest.raises(ValueError) as exc_info:
- set_row_filter(
- table_name="DROP TABLE; --",
- filter_function="fn",
- filter_columns=["col"],
- )
-
- assert "invalid sql identifier" in str(exc_info.value).lower()
-
- def test_invalid_column_identifier_raises(self):
- """Should raise ValueError for invalid column in filter."""
- with pytest.raises(ValueError) as exc_info:
- set_row_filter(
- table_name="catalog.schema.table",
- filter_function="catalog.schema.fn",
- filter_columns=["col; DROP TABLE--"],
- )
-
- assert "invalid sql identifier" in str(exc_info.value).lower()
diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_sharing.py b/databricks-tools-core/tests/integration/unity_catalog/test_sharing.py
deleted file mode 100644
index 4eb1dfc9..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/test_sharing.py
+++ /dev/null
@@ -1,245 +0,0 @@
-"""
-Integration tests for Unity Catalog - Delta Sharing operations.
-
-Tests:
-- list_shares, create_share, get_share, delete_share
-- add_table_to_share, remove_table_from_share
-- grant_share_to_recipient, revoke_share_from_recipient
-- list_recipients, create_recipient, get_recipient, delete_recipient
-- list_providers
-
-Note: Delta Sharing must be enabled in the workspace.
-Tests skip if the feature is not available.
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.unity_catalog import (
- list_shares,
- get_share,
- create_share,
- add_table_to_share,
- remove_table_from_share,
- grant_share_to_recipient,
- revoke_share_from_recipient,
- list_recipients,
- get_recipient,
- create_recipient,
- list_providers,
-)
-
-logger = logging.getLogger(__name__)
-
-UC_TEST_PREFIX = "uc_test"
-
-
-def _is_sharing_error(e: Exception) -> bool:
- """Check if error indicates sharing is not available."""
- msg = str(e).upper()
- return any(
- kw in msg
- for kw in [
- "FEATURE_NOT_ENABLED",
- "NOT_FOUND",
- "FORBIDDEN",
- "PERMISSION_DENIED",
- "DELTA_SHARING",
- "NOT_AVAILABLE",
- ]
- )
-
-
-@pytest.mark.integration
-class TestListShares:
- """Tests for listing shares."""
-
- def test_list_shares(self):
- """Should list shares (may be empty)."""
- try:
- shares = list_shares()
- logger.info(f"Found {len(shares)} shares")
- assert isinstance(shares, list)
- except Exception as e:
- if _is_sharing_error(e):
- pytest.skip(f"Delta Sharing not available: {e}")
- raise
-
-
-@pytest.mark.integration
-class TestShareCRUD:
- """Tests for share create, get, delete lifecycle."""
-
- def test_create_and_delete_share(self, unique_name: str, cleanup_shares):
- """Should create and delete a share."""
- share_name = f"{UC_TEST_PREFIX}_share_{unique_name}"
-
- try:
- logger.info(f"Creating share: {share_name}")
- share = create_share(
- name=share_name,
- comment="Integration test share",
- )
- cleanup_shares(share_name)
-
- assert share["name"] == share_name
- logger.info(f"Share created: {share['name']}")
-
- # Get share
- fetched = get_share(share_name)
- assert fetched["name"] == share_name
- logger.info(f"Share fetched: {fetched['name']}")
-
- except Exception as e:
- if _is_sharing_error(e):
- pytest.skip(f"Delta Sharing not available: {e}")
- raise
-
- def test_add_and_remove_table_from_share(
- self,
- uc_test_table: str,
- unique_name: str,
- cleanup_shares,
- ):
- """Should add and remove a table from a share."""
- share_name = f"{UC_TEST_PREFIX}_share_tbl_{unique_name}"
-
- try:
- # Create share
- create_share(name=share_name)
- cleanup_shares(share_name)
-
- # Add table
- logger.info(f"Adding table {uc_test_table} to share {share_name}")
- result = add_table_to_share(
- share_name=share_name,
- table_name=uc_test_table,
- )
- assert result is not None
- logger.info("Table added to share")
-
- # Verify table is in share
- share = get_share(share_name, include_shared_data=True)
- objects = share.get("objects", [])
- logger.info(f"Share has {len(objects)} objects")
-
- # Remove table
- logger.info("Removing table from share")
- remove_result = remove_table_from_share(
- share_name=share_name,
- table_name=uc_test_table,
- )
- assert remove_result is not None
- logger.info("Table removed from share")
-
- except Exception as e:
- if _is_sharing_error(e):
- pytest.skip(f"Delta Sharing not available: {e}")
- raise
-
-
-@pytest.mark.integration
-class TestRecipientCRUD:
- """Tests for recipient create, get, delete lifecycle."""
-
- def test_create_and_delete_recipient(self, unique_name: str, cleanup_recipients):
- """Should create and delete a sharing recipient."""
- recipient_name = f"{UC_TEST_PREFIX}_recipient_{unique_name}"
-
- try:
- logger.info(f"Creating recipient: {recipient_name}")
- recipient = create_recipient(
- name=recipient_name,
- authentication_type="TOKEN",
- comment="Integration test recipient",
- )
- cleanup_recipients(recipient_name)
-
- assert recipient["name"] == recipient_name
- logger.info(f"Recipient created: {recipient['name']}")
-
- # Get recipient
- fetched = get_recipient(recipient_name)
- assert fetched["name"] == recipient_name
- logger.info(f"Recipient fetched: {fetched['name']}")
-
- except Exception as e:
- if _is_sharing_error(e):
- pytest.skip(f"Delta Sharing not available: {e}")
- raise
-
- def test_list_recipients(self):
- """Should list recipients."""
- try:
- recipients = list_recipients()
- logger.info(f"Found {len(recipients)} recipients")
- assert isinstance(recipients, list)
- except Exception as e:
- if _is_sharing_error(e):
- pytest.skip(f"Delta Sharing not available: {e}")
- raise
-
-
-@pytest.mark.integration
-class TestSharePermissions:
- """Tests for share grant/revoke to recipients."""
-
- def test_grant_and_revoke_share(
- self,
- uc_test_table: str,
- unique_name: str,
- cleanup_shares,
- cleanup_recipients,
- ):
- """Should grant and revoke share permissions to a recipient."""
- share_name = f"{UC_TEST_PREFIX}_share_perm_{unique_name}"
- recipient_name = f"{UC_TEST_PREFIX}_recip_perm_{unique_name}"
-
- try:
- # Create share with table
- create_share(name=share_name)
- cleanup_shares(share_name)
- add_table_to_share(share_name=share_name, table_name=uc_test_table)
-
- # Create recipient
- create_recipient(name=recipient_name, authentication_type="TOKEN")
- cleanup_recipients(recipient_name)
-
- # Grant
- logger.info(f"Granting share '{share_name}' to recipient '{recipient_name}'")
- grant_result = grant_share_to_recipient(
- share_name=share_name,
- recipient_name=recipient_name,
- )
- assert grant_result["status"] == "granted"
- logger.info(f"Share granted: {grant_result}")
-
- # Revoke
- logger.info("Revoking share from recipient")
- revoke_result = revoke_share_from_recipient(
- share_name=share_name,
- recipient_name=recipient_name,
- )
- assert revoke_result["status"] == "revoked"
- logger.info(f"Share revoked: {revoke_result}")
-
- except Exception as e:
- if _is_sharing_error(e):
- pytest.skip(f"Delta Sharing not available: {e}")
- raise
-
-
-@pytest.mark.integration
-class TestListProviders:
- """Tests for listing sharing providers."""
-
- def test_list_providers(self):
- """Should list sharing providers (may be empty)."""
- try:
- providers = list_providers()
- logger.info(f"Found {len(providers)} providers")
- assert isinstance(providers, list)
- except Exception as e:
- if _is_sharing_error(e):
- pytest.skip(f"Delta Sharing not available: {e}")
- raise
diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_storage.py b/databricks-tools-core/tests/integration/unity_catalog/test_storage.py
deleted file mode 100644
index 0fb23d01..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/test_storage.py
+++ /dev/null
@@ -1,166 +0,0 @@
-"""
-Integration tests for Unity Catalog - Storage operations.
-
-Tests:
-- list_storage_credentials
-- list_external_locations
-- validate_storage_credential
-
-Note: Create/update/delete operations for storage credentials and external
-locations require cloud-specific resources (IAM roles, access connectors, etc.)
-and are tested as read-only list operations in CI. Full CRUD requires manual
-setup with proper cloud credentials.
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.unity_catalog import (
- list_storage_credentials,
- get_storage_credential,
- create_storage_credential,
- list_external_locations,
- get_external_location,
- validate_storage_credential,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestListStorageCredentials:
- """Tests for listing storage credentials."""
-
- def test_list_storage_credentials(self):
- """Should list storage credentials in the metastore."""
- credentials = list_storage_credentials()
-
- logger.info(f"Found {len(credentials)} storage credentials")
- for cred in credentials[:5]:
- logger.info(f" - {cred.name} (read_only: {cred.read_only})")
-
- assert isinstance(credentials, list)
- # Most workspaces have at least the default credential
- assert len(credentials) >= 0
-
- def test_list_storage_credentials_structure(self):
- """Should return StorageCredentialInfo objects."""
- credentials = list_storage_credentials()
-
- if len(credentials) > 0:
- cred = credentials[0]
- assert hasattr(cred, "name")
- assert hasattr(cred, "owner")
- logger.info(f"First credential: {cred.name}")
-
-
-@pytest.mark.integration
-class TestGetStorageCredential:
- """Tests for getting a specific storage credential."""
-
- def test_get_existing_credential(self):
- """Should get details of an existing credential."""
- credentials = list_storage_credentials()
- if not credentials:
- pytest.skip("No storage credentials in workspace")
-
- cred_name = credentials[0].name
- logger.info(f"Getting credential: {cred_name}")
-
- cred = get_storage_credential(cred_name)
- assert cred.name == cred_name
- logger.info(f"Got credential: {cred.name} (owner: {cred.owner})")
-
- def test_get_nonexistent_credential(self):
- """Should raise error for non-existent credential."""
- with pytest.raises(Exception):
- get_storage_credential("nonexistent_credential_xyz_12345")
-
-
-@pytest.mark.integration
-class TestValidateStorageCredential:
- """Tests for validating storage credentials."""
-
- def test_validate_existing_credential(self):
- """Should validate an existing storage credential."""
- credentials = list_storage_credentials()
- if not credentials:
- pytest.skip("No storage credentials in workspace")
-
- cred_name = credentials[0].name
- logger.info(f"Validating credential: {cred_name}")
-
- try:
- result = validate_storage_credential(name=cred_name)
- assert isinstance(result, dict)
- assert "results" in result
- logger.info(f"Validation result: {result}")
- except Exception as e:
- # Some credentials may not support validation without a URL
- logger.info(f"Validation requires URL: {e}")
-
-
-@pytest.mark.integration
-class TestCreateStorageCredentialValidation:
- """Tests for storage credential creation validation."""
-
- def test_create_without_cloud_credentials_raises(self):
- """Should raise ValueError when no cloud credentials provided."""
- with pytest.raises(ValueError) as exc_info:
- create_storage_credential(
- name="test_bad_credential",
- )
-
- assert (
- "aws_iam_role_arn" in str(exc_info.value).lower()
- or "azure_access_connector_id" in str(exc_info.value).lower()
- )
-
-
-@pytest.mark.integration
-class TestListExternalLocations:
- """Tests for listing external locations."""
-
- def test_list_external_locations(self):
- """Should list external locations."""
- locations = list_external_locations()
-
- logger.info(f"Found {len(locations)} external locations")
- for loc in locations[:5]:
- logger.info(f" - {loc.name} -> {loc.url}")
-
- assert isinstance(locations, list)
-
- def test_list_external_locations_structure(self):
- """Should return ExternalLocationInfo objects."""
- locations = list_external_locations()
-
- if len(locations) > 0:
- loc = locations[0]
- assert hasattr(loc, "name")
- assert hasattr(loc, "url")
- assert hasattr(loc, "credential_name")
- logger.info(f"First location: {loc.name} -> {loc.url}")
-
-
-@pytest.mark.integration
-class TestGetExternalLocation:
- """Tests for getting a specific external location."""
-
- def test_get_existing_location(self):
- """Should get details of an existing external location."""
- locations = list_external_locations()
- if not locations:
- pytest.skip("No external locations in workspace")
-
- loc_name = locations[0].name
- logger.info(f"Getting external location: {loc_name}")
-
- loc = get_external_location(loc_name)
- assert loc.name == loc_name
- logger.info(f"Got location: {loc.name} (url: {loc.url})")
-
- def test_get_nonexistent_location(self):
- """Should raise error for non-existent location."""
- with pytest.raises(Exception):
- get_external_location("nonexistent_location_xyz_12345")
diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_tags.py b/databricks-tools-core/tests/integration/unity_catalog/test_tags.py
deleted file mode 100644
index 1532da69..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/test_tags.py
+++ /dev/null
@@ -1,243 +0,0 @@
-"""
-Integration tests for Unity Catalog - Tag and Comment operations.
-
-Tests:
-- set_tags
-- unset_tags
-- set_comment
-- query_table_tags
-- query_column_tags
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.unity_catalog import (
- set_tags,
- unset_tags,
- set_comment,
- query_table_tags,
- query_column_tags,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestSetTags:
- """Tests for setting tags on UC objects."""
-
- def test_set_tags_on_table(self, uc_test_table: str, warehouse_id: str):
- """Should set tags on a table."""
- # Use unique prefixed tag names to avoid workspace governed tag policies
- result = set_tags(
- object_type="table",
- full_name=uc_test_table,
- tags={"uc_test_tier": "silver", "uc_test_owner": "data-eng"},
- warehouse_id=warehouse_id,
- )
-
- assert result["status"] == "tags_set"
- assert result["tags"]["uc_test_tier"] == "silver"
- logger.info(f"Tags set on {uc_test_table}: {result['tags']}")
-
- def test_set_tags_on_column(self, uc_test_table: str, warehouse_id: str):
- """Should set tags on a column."""
- # Use unique prefixed tag names to avoid workspace governed tag policies
- result = set_tags(
- object_type="column",
- full_name=uc_test_table,
- column_name="email",
- tags={"uc_test_col_class": "sensitive", "uc_test_col_level": "high"},
- warehouse_id=warehouse_id,
- )
-
- assert result["status"] == "tags_set"
- logger.info(f"Column tags set: {result['tags']}")
-
- def test_set_tags_on_schema(self, test_catalog: str, uc_test_schema: str, warehouse_id: str):
- """Should set tags on a schema."""
- full_name = f"{test_catalog}.{uc_test_schema}"
- # Use unique prefixed tag names to avoid workspace governed tag policies
- result = set_tags(
- object_type="schema",
- full_name=full_name,
- tags={"uc_test_deploy": "testing"},
- warehouse_id=warehouse_id,
- )
-
- assert result["status"] == "tags_set"
- logger.info(f"Schema tags set: {result['tags']}")
-
- def test_set_tags_column_without_column_name_raises(self, uc_test_table: str):
- """Should raise ValueError when column_name missing for column type."""
- with pytest.raises(ValueError) as exc_info:
- set_tags(
- object_type="column",
- full_name=uc_test_table,
- tags={"test": "value"},
- )
-
- assert "column_name" in str(exc_info.value).lower()
-
- def test_set_tags_invalid_object_type_raises(self, uc_test_table: str):
- """Should raise ValueError for invalid object type."""
- with pytest.raises(ValueError) as exc_info:
- set_tags(
- object_type="invalid_type",
- full_name=uc_test_table,
- tags={"test": "value"},
- )
-
- assert "object_type" in str(exc_info.value).lower()
-
-
-@pytest.mark.integration
-class TestUnsetTags:
- """Tests for removing tags from UC objects."""
-
- def test_unset_tags_on_table(self, uc_test_table: str, warehouse_id: str):
- """Should remove tags from a table."""
- # Set tags first
- set_tags(
- object_type="table",
- full_name=uc_test_table,
- tags={"temp_tag": "to_remove"},
- warehouse_id=warehouse_id,
- )
-
- # Unset
- result = unset_tags(
- object_type="table",
- full_name=uc_test_table,
- tag_names=["temp_tag"],
- warehouse_id=warehouse_id,
- )
-
- assert result["status"] == "tags_unset"
- assert "temp_tag" in result["tag_names"]
- logger.info(f"Tags unset: {result['tag_names']}")
-
- def test_unset_tags_on_column(self, uc_test_table: str, warehouse_id: str):
- """Should remove tags from a column."""
- # Set first
- set_tags(
- object_type="column",
- full_name=uc_test_table,
- column_name="salary",
- tags={"temp_col_tag": "remove_me"},
- warehouse_id=warehouse_id,
- )
-
- # Unset
- result = unset_tags(
- object_type="column",
- full_name=uc_test_table,
- column_name="salary",
- tag_names=["temp_col_tag"],
- warehouse_id=warehouse_id,
- )
-
- assert result["status"] == "tags_unset"
- logger.info(f"Column tags unset: {result['tag_names']}")
-
-
-@pytest.mark.integration
-class TestSetComment:
- """Tests for setting comments on UC objects."""
-
- def test_set_comment_on_table(self, uc_test_table: str, warehouse_id: str):
- """Should set a comment on a table."""
- result = set_comment(
- object_type="table",
- full_name=uc_test_table,
- comment_text="Employee records for UC testing",
- warehouse_id=warehouse_id,
- )
-
- assert result["status"] == "comment_set"
- logger.info(f"Comment set on {uc_test_table}")
-
- def test_set_comment_on_column(self, uc_test_table: str, warehouse_id: str):
- """Should set a comment on a column."""
- result = set_comment(
- object_type="column",
- full_name=uc_test_table,
- column_name="salary",
- comment_text="Annual salary in USD",
- warehouse_id=warehouse_id,
- )
-
- assert result["status"] == "comment_set"
- logger.info("Column comment set on salary")
-
- def test_set_comment_column_without_name_raises(self, uc_test_table: str):
- """Should raise ValueError when column_name missing for column type."""
- with pytest.raises(ValueError) as exc_info:
- set_comment(
- object_type="column",
- full_name=uc_test_table,
- comment_text="test",
- )
-
- assert "column_name" in str(exc_info.value).lower()
-
-
-@pytest.mark.integration
-class TestQueryTags:
- """Tests for querying tags from information_schema."""
-
- def test_query_table_tags(self, test_catalog: str, uc_test_table: str, warehouse_id: str):
- """Should query table tags from information_schema."""
- # Ensure tags exist
- set_tags(
- object_type="table",
- full_name=uc_test_table,
- tags={"query_test": "yes"},
- warehouse_id=warehouse_id,
- )
-
- results = query_table_tags(
- catalog_filter=test_catalog,
- tag_name="query_test",
- limit=10,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Found {len(results)} table tag entries")
- assert isinstance(results, list)
- # May take time for tags to propagate to information_schema
- if len(results) > 0:
- assert "tag_name" in results[0] or "TAG_NAME" in str(results[0].keys())
-
- def test_query_column_tags(self, test_catalog: str, uc_test_table: str, warehouse_id: str):
- """Should query column tags from information_schema."""
- # Ensure column tags exist
- set_tags(
- object_type="column",
- full_name=uc_test_table,
- column_name="email",
- tags={"col_query_test": "yes"},
- warehouse_id=warehouse_id,
- )
-
- results = query_column_tags(
- catalog_filter=test_catalog,
- tag_name="col_query_test",
- limit=10,
- warehouse_id=warehouse_id,
- )
-
- logger.info(f"Found {len(results)} column tag entries")
- assert isinstance(results, list)
-
- def test_query_table_tags_no_results(self, warehouse_id: str):
- """Should return empty list for non-matching filter."""
- results = query_table_tags(
- tag_name="nonexistent_tag_xyz_12345",
- limit=10,
- warehouse_id=warehouse_id,
- )
-
- assert isinstance(results, list)
- assert len(results) == 0
diff --git a/databricks-tools-core/tests/integration/unity_catalog/test_volumes.py b/databricks-tools-core/tests/integration/unity_catalog/test_volumes.py
deleted file mode 100644
index 604666f6..00000000
--- a/databricks-tools-core/tests/integration/unity_catalog/test_volumes.py
+++ /dev/null
@@ -1,196 +0,0 @@
-"""
-Integration tests for Unity Catalog - Volume operations.
-
-Tests:
-- list_volumes
-- get_volume
-- create_volume
-- update_volume
-- delete_volume
-"""
-
-import logging
-import pytest
-
-from databricks_tools_core.unity_catalog import (
- list_volumes,
- get_volume,
- create_volume,
- update_volume,
- delete_volume,
-)
-
-logger = logging.getLogger(__name__)
-
-UC_TEST_PREFIX = "uc_test"
-
-
-@pytest.mark.integration
-class TestListVolumes:
- """Tests for listing volumes."""
-
- def test_list_volumes(self, test_catalog: str, uc_test_schema: str):
- """Should list volumes in schema (may be empty)."""
- volumes = list_volumes(
- catalog_name=test_catalog,
- schema_name=uc_test_schema,
- )
-
- logger.info(f"Found {len(volumes)} volumes in {test_catalog}.{uc_test_schema}")
- assert isinstance(volumes, list)
-
-
-@pytest.mark.integration
-class TestVolumeCRUD:
- """Tests for volume create, get, update, delete lifecycle."""
-
- def test_create_managed_volume(
- self,
- test_catalog: str,
- uc_test_schema: str,
- unique_name: str,
- cleanup_volumes,
- ):
- """Should create a managed volume."""
- vol_name = f"{UC_TEST_PREFIX}_vol_{unique_name}"
- full_name = f"{test_catalog}.{uc_test_schema}.{vol_name}"
-
- logger.info(f"Creating managed volume: {full_name}")
- vol = create_volume(
- catalog_name=test_catalog,
- schema_name=uc_test_schema,
- name=vol_name,
- volume_type="MANAGED",
- comment="Test managed volume",
- )
- cleanup_volumes(full_name)
-
- assert vol.name == vol_name
- logger.info(f"Volume created: {vol.full_name}")
-
- def test_get_volume(
- self,
- test_catalog: str,
- uc_test_schema: str,
- unique_name: str,
- cleanup_volumes,
- ):
- """Should get volume details."""
- vol_name = f"{UC_TEST_PREFIX}_vol_{unique_name}"
- full_name = f"{test_catalog}.{uc_test_schema}.{vol_name}"
-
- create_volume(
- catalog_name=test_catalog,
- schema_name=uc_test_schema,
- name=vol_name,
- comment="Get test volume",
- )
- cleanup_volumes(full_name)
-
- fetched = get_volume(full_name)
-
- logger.info(f"Got volume: {fetched.full_name} (type: {fetched.volume_type})")
- assert fetched.name == vol_name
- assert fetched.catalog_name == test_catalog
- assert fetched.schema_name == uc_test_schema
-
- def test_update_volume_comment(
- self,
- test_catalog: str,
- uc_test_schema: str,
- unique_name: str,
- cleanup_volumes,
- ):
- """Should update volume comment."""
- vol_name = f"{UC_TEST_PREFIX}_vol_{unique_name}"
- full_name = f"{test_catalog}.{uc_test_schema}.{vol_name}"
-
- create_volume(
- catalog_name=test_catalog,
- schema_name=uc_test_schema,
- name=vol_name,
- comment="Original",
- )
- cleanup_volumes(full_name)
-
- updated = update_volume(
- full_volume_name=full_name,
- comment="Updated comment",
- )
-
- logger.info(f"Updated volume comment: {updated.comment}")
- assert updated.comment == "Updated comment"
-
- def test_update_volume_no_fields_raises(self):
- """Should raise ValueError when no update fields provided."""
- with pytest.raises(ValueError) as exc_info:
- update_volume(full_volume_name="cat.sch.vol")
-
- assert "at least one field" in str(exc_info.value).lower()
-
- def test_create_external_volume_without_location_raises(
- self,
- test_catalog: str,
- uc_test_schema: str,
- ):
- """Should raise ValueError for EXTERNAL volume without storage_location."""
- with pytest.raises(ValueError) as exc_info:
- create_volume(
- catalog_name=test_catalog,
- schema_name=uc_test_schema,
- name="bad_external",
- volume_type="EXTERNAL",
- )
-
- assert "storage_location" in str(exc_info.value).lower()
-
- def test_delete_volume(
- self,
- test_catalog: str,
- uc_test_schema: str,
- unique_name: str,
- ):
- """Should delete a volume."""
- vol_name = f"{UC_TEST_PREFIX}_vol_del_{unique_name}"
- full_name = f"{test_catalog}.{uc_test_schema}.{vol_name}"
-
- create_volume(
- catalog_name=test_catalog,
- schema_name=uc_test_schema,
- name=vol_name,
- )
-
- logger.info(f"Deleting volume: {full_name}")
- delete_volume(full_name)
- logger.info(f"Volume deleted: {full_name}")
-
- # Verify deletion
- with pytest.raises(Exception):
- get_volume(full_name)
-
- def test_list_volumes_after_create(
- self,
- test_catalog: str,
- uc_test_schema: str,
- unique_name: str,
- cleanup_volumes,
- ):
- """Should list volumes and include newly created one."""
- vol_name = f"{UC_TEST_PREFIX}_vol_list_{unique_name}"
- full_name = f"{test_catalog}.{uc_test_schema}.{vol_name}"
-
- create_volume(
- catalog_name=test_catalog,
- schema_name=uc_test_schema,
- name=vol_name,
- )
- cleanup_volumes(full_name)
-
- volumes = list_volumes(
- catalog_name=test_catalog,
- schema_name=uc_test_schema,
- )
- volume_names = [v.name for v in volumes]
-
- logger.info(f"Volumes in schema: {volume_names}")
- assert vol_name in volume_names
diff --git a/databricks-tools-core/tests/integration/vector_search/__init__.py b/databricks-tools-core/tests/integration/vector_search/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/databricks-tools-core/tests/integration/vector_search/conftest.py b/databricks-tools-core/tests/integration/vector_search/conftest.py
deleted file mode 100644
index 298f3212..00000000
--- a/databricks-tools-core/tests/integration/vector_search/conftest.py
+++ /dev/null
@@ -1,221 +0,0 @@
-"""
-Pytest fixtures for Vector Search integration tests.
-
-Provides session-scoped fixtures for:
-- VS endpoint (created once, shared across all VS tests)
-- Direct Access index (for query/upsert/scan tests)
-- Cleanup helpers
-"""
-
-import json
-import logging
-import time
-import uuid
-
-import pytest
-
-logger = logging.getLogger(__name__)
-
-# Unique suffix to avoid collisions across test runs
-_RUN_ID = str(uuid.uuid4())[:8]
-
-VS_ENDPOINT_NAME = f"vs_test_ep_{_RUN_ID}"
-VS_INDEX_NAME_TEMPLATE = "{catalog}.{schema}.vs_test_idx_{suffix}"
-EMBEDDING_DIM = 8 # small dimension for fast tests
-
-# Schema for Direct Access index
-DA_SCHEMA = {
- "id": "int",
- "text": "string",
- "embedding": "array",
-}
-
-
-def _wait_for_endpoint_online(name: str, timeout: int = 1200, poll: int = 15):
- """Poll endpoint until ONLINE or timeout."""
- from databricks_tools_core.vector_search import get_vs_endpoint
-
- deadline = time.time() + timeout
- while time.time() < deadline:
- ep = get_vs_endpoint(name)
- state = ep.get("state")
- logger.info(f"Endpoint '{name}' state: {state}")
- if state == "ONLINE":
- return ep
- if state in ("OFFLINE", "NOT_FOUND"):
- raise RuntimeError(f"Endpoint '{name}' in terminal state: {state}")
- time.sleep(poll)
- raise TimeoutError(f"Endpoint '{name}' not ONLINE within {timeout}s")
-
-
-def _wait_for_index_online(index_name: str, timeout: int = 600, poll: int = 15):
- """Poll index until ONLINE or timeout.
-
- Handles transient InternalError from the VS service gracefully
- by treating them as 'still provisioning'.
- """
- from databricks_tools_core.vector_search import get_vs_index
-
- deadline = time.time() + timeout
- while time.time() < deadline:
- try:
- idx = get_vs_index(index_name)
- except Exception as e:
- # Transient server errors during provisioning - keep retrying
- logger.warning(f"Transient error polling index '{index_name}': {e}")
- time.sleep(poll)
- continue
- state = idx.get("state")
- logger.info(f"Index '{index_name}' state: {state}")
- if state == "ONLINE":
- return idx
- if state == "NOT_FOUND":
- raise RuntimeError(f"Index '{index_name}' not found")
- time.sleep(poll)
- raise TimeoutError(f"Index '{index_name}' not ONLINE within {timeout}s")
-
-
-@pytest.fixture(scope="session")
-def vs_endpoint_name():
- """
- Create a Standard VS endpoint for the test session.
-
- Waits for ONLINE state before yielding.
- Deletes the endpoint after all tests complete.
- """
- from databricks_tools_core.vector_search import (
- create_vs_endpoint,
- delete_vs_endpoint,
- get_vs_endpoint,
- )
-
- name = VS_ENDPOINT_NAME
- logger.info(f"Creating VS endpoint: {name}")
-
- # Check if it already exists (from a previous failed run)
- existing = get_vs_endpoint(name)
- if existing.get("state") != "NOT_FOUND":
- logger.info(f"Endpoint '{name}' already exists, reusing")
- else:
- result = create_vs_endpoint(name=name, endpoint_type="STANDARD")
- logger.info(f"Endpoint creation result: {result}")
-
- # Wait for ONLINE
- _wait_for_endpoint_online(name)
- logger.info(f"Endpoint '{name}' is ONLINE")
-
- yield name
-
- # Teardown
- logger.info(f"Deleting VS endpoint: {name}")
- try:
- delete_vs_endpoint(name)
- logger.info(f"Endpoint '{name}' deleted")
- except Exception as e:
- logger.warning(f"Failed to delete endpoint '{name}': {e}")
-
-
-@pytest.fixture(scope="module")
-def vs_direct_index_name(vs_endpoint_name, test_catalog, test_schema):
- """
- Create a Direct Access index for query/data tests.
-
- Depends on test_catalog/test_schema to ensure the catalog exists.
- Uses a small embedding dimension for speed.
- Yields the fully qualified index name.
- """
- from databricks_tools_core.vector_search import (
- create_vs_index,
- delete_vs_index,
- get_vs_index,
- )
-
- suffix = str(uuid.uuid4())[:8]
- index_name = VS_INDEX_NAME_TEMPLATE.format(catalog=test_catalog, schema=test_schema, suffix=suffix)
- logger.info(f"Creating Direct Access index: {index_name}")
-
- # Check if it already exists
- existing = get_vs_index(index_name)
- if existing.get("state") != "NOT_FOUND":
- logger.info(f"Index '{index_name}' already exists, reusing")
- else:
- result = create_vs_index(
- name=index_name,
- endpoint_name=vs_endpoint_name,
- primary_key="id",
- index_type="DIRECT_ACCESS",
- direct_access_index_spec={
- "embedding_vector_columns": [
- {
- "name": "embedding",
- "embedding_dimension": EMBEDDING_DIM,
- }
- ],
- "schema_json": json.dumps(DA_SCHEMA),
- },
- )
- logger.info(f"Index creation result: {result}")
-
- # Wait for ONLINE
- _wait_for_index_online(index_name)
- logger.info(f"Index '{index_name}' is ONLINE")
-
- yield index_name
-
- # Teardown
- logger.info(f"Deleting Direct Access index: {index_name}")
- try:
- delete_vs_index(index_name)
- logger.info(f"Index '{index_name}' deleted")
- except Exception as e:
- logger.warning(f"Failed to delete index '{index_name}': {e}")
-
-
-@pytest.fixture(scope="function")
-def unique_suffix() -> str:
- """Generate a unique suffix for throwaway resources."""
- return str(uuid.uuid4())[:8]
-
-
-@pytest.fixture(scope="function")
-def cleanup_indexes():
- """Track and cleanup indexes created during a test."""
- from databricks_tools_core.vector_search import delete_vs_index
-
- indexes_to_cleanup = []
-
- def register(index_name: str):
- if index_name not in indexes_to_cleanup:
- indexes_to_cleanup.append(index_name)
- logger.info(f"Registered index for cleanup: {index_name}")
-
- yield register
-
- for idx_name in indexes_to_cleanup:
- try:
- logger.info(f"Cleaning up index: {idx_name}")
- delete_vs_index(idx_name)
- except Exception as e:
- logger.warning(f"Failed to cleanup index {idx_name}: {e}")
-
-
-@pytest.fixture(scope="function")
-def cleanup_endpoints():
- """Track and cleanup endpoints created during a test."""
- from databricks_tools_core.vector_search import delete_vs_endpoint
-
- endpoints_to_cleanup = []
-
- def register(name: str):
- if name not in endpoints_to_cleanup:
- endpoints_to_cleanup.append(name)
- logger.info(f"Registered endpoint for cleanup: {name}")
-
- yield register
-
- for ep_name in endpoints_to_cleanup:
- try:
- logger.info(f"Cleaning up endpoint: {ep_name}")
- delete_vs_endpoint(ep_name)
- except Exception as e:
- logger.warning(f"Failed to cleanup endpoint {ep_name}: {e}")
diff --git a/databricks-tools-core/tests/integration/vector_search/test_endpoints.py b/databricks-tools-core/tests/integration/vector_search/test_endpoints.py
deleted file mode 100644
index 608df59d..00000000
--- a/databricks-tools-core/tests/integration/vector_search/test_endpoints.py
+++ /dev/null
@@ -1,100 +0,0 @@
-"""
-Integration tests for Vector Search endpoint operations.
-
-Tests:
-- create_vs_endpoint
-- get_vs_endpoint
-- list_vs_endpoints
-- delete_vs_endpoint
-"""
-
-import logging
-import uuid
-
-import pytest
-
-from databricks_tools_core.vector_search import (
- create_vs_endpoint,
- delete_vs_endpoint,
- get_vs_endpoint,
- list_vs_endpoints,
-)
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.mark.integration
-class TestGetEndpoint:
- """Tests for getting endpoint details."""
-
- def test_get_endpoint_online(self, vs_endpoint_name: str):
- """Should return details for a running endpoint."""
- result = get_vs_endpoint(vs_endpoint_name)
-
- logger.info(f"Endpoint details: {result}")
-
- assert result["name"] == vs_endpoint_name
- assert result["state"] == "ONLINE"
- assert result.get("endpoint_type") == "STANDARD"
- assert result.get("error") is None
-
- def test_get_endpoint_not_found(self):
- """Should return NOT_FOUND for non-existent endpoint."""
- result = get_vs_endpoint("nonexistent_endpoint_xyz_99999")
-
- assert result["state"] == "NOT_FOUND"
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestListEndpoints:
- """Tests for listing endpoints."""
-
- def test_list_endpoints(self, vs_endpoint_name: str):
- """Should list endpoints including the test endpoint."""
- endpoints = list_vs_endpoints()
-
- logger.info(f"Found {len(endpoints)} endpoints")
-
- assert isinstance(endpoints, list)
- assert len(endpoints) > 0
-
- # Find our test endpoint
- names = [ep["name"] for ep in endpoints]
- assert vs_endpoint_name in names, f"Test endpoint '{vs_endpoint_name}' not found in: {names}"
-
-
-@pytest.mark.integration
-class TestCreateEndpoint:
- """Tests for creating endpoints."""
-
- def test_create_endpoint(self, cleanup_endpoints):
- """Should create a new endpoint and return creation info."""
- name = f"vs_test_create_{uuid.uuid4().hex[:8]}"
- cleanup_endpoints(name)
-
- result = create_vs_endpoint(name=name, endpoint_type="STANDARD")
-
- logger.info(f"Create result: {result}")
-
- assert result["name"] == name
- assert result["endpoint_type"] == "STANDARD"
- assert result["status"] in ("CREATING", "ALREADY_EXISTS")
-
- def test_create_duplicate_endpoint(self, vs_endpoint_name: str):
- """Should return ALREADY_EXISTS for duplicate endpoint."""
- result = create_vs_endpoint(name=vs_endpoint_name, endpoint_type="STANDARD")
-
- assert result["status"] == "ALREADY_EXISTS"
-
-
-@pytest.mark.integration
-class TestDeleteEndpoint:
- """Tests for deleting endpoints."""
-
- def test_delete_endpoint_not_found(self):
- """Should return NOT_FOUND when deleting non-existent endpoint."""
- result = delete_vs_endpoint("nonexistent_endpoint_xyz_99999")
-
- assert result["status"] == "NOT_FOUND"
- assert "error" in result
diff --git a/databricks-tools-core/tests/integration/vector_search/test_indexes.py b/databricks-tools-core/tests/integration/vector_search/test_indexes.py
deleted file mode 100644
index b3eb0a24..00000000
--- a/databricks-tools-core/tests/integration/vector_search/test_indexes.py
+++ /dev/null
@@ -1,157 +0,0 @@
-"""
-Integration tests for Vector Search index operations.
-
-Tests:
-- create_vs_index
-- get_vs_index
-- list_vs_indexes
-- delete_vs_index
-- sync_vs_index
-"""
-
-import json
-import logging
-import uuid
-
-import pytest
-
-from databricks_tools_core.vector_search import (
- create_vs_index,
- delete_vs_index,
- get_vs_index,
- list_vs_indexes,
- sync_vs_index,
-)
-
-logger = logging.getLogger(__name__)
-
-EMBEDDING_DIM = 8
-
-
-@pytest.mark.integration
-class TestCreateIndex:
- """Tests for creating indexes."""
-
- def test_create_direct_access_index(self, vs_endpoint_name: str, test_catalog, test_schema, cleanup_indexes):
- """Should create a Direct Access index."""
- suffix = uuid.uuid4().hex[:8]
- index_name = f"{test_catalog}.{test_schema}.vs_idx_create_{suffix}"
- cleanup_indexes(index_name)
-
- result = create_vs_index(
- name=index_name,
- endpoint_name=vs_endpoint_name,
- primary_key="id",
- index_type="DIRECT_ACCESS",
- direct_access_index_spec={
- "embedding_vector_columns": [
- {
- "name": "embedding",
- "embedding_dimension": EMBEDDING_DIM,
- }
- ],
- "schema_json": json.dumps(
- {
- "id": "int",
- "text": "string",
- "embedding": "array",
- }
- ),
- },
- )
-
- logger.info(f"Create index result: {result}")
-
- assert result["name"] == index_name
- assert result["endpoint_name"] == vs_endpoint_name
- assert result["index_type"] == "DIRECT_ACCESS"
- assert result["status"] in ("CREATING", "ALREADY_EXISTS")
-
- def test_create_duplicate_index(self, vs_endpoint_name: str, vs_direct_index_name: str):
- """Should return ALREADY_EXISTS for duplicate index."""
- result = create_vs_index(
- name=vs_direct_index_name,
- endpoint_name=vs_endpoint_name,
- primary_key="id",
- index_type="DIRECT_ACCESS",
- direct_access_index_spec={
- "embedding_vector_columns": [
- {
- "name": "embedding",
- "embedding_dimension": EMBEDDING_DIM,
- }
- ],
- "schema_json": json.dumps(
- {
- "id": "int",
- "text": "string",
- "embedding": "array",
- }
- ),
- },
- )
-
- assert result["status"] == "ALREADY_EXISTS"
-
-
-@pytest.mark.integration
-class TestGetIndex:
- """Tests for getting index details."""
-
- def test_get_index(self, vs_direct_index_name: str):
- """Should return details for an existing index."""
- result = get_vs_index(vs_direct_index_name)
-
- logger.info(f"Index details: {result}")
-
- assert result["name"] == vs_direct_index_name
- assert result.get("index_type") == "DIRECT_ACCESS"
- assert result.get("primary_key") == "id"
- assert result.get("state") == "ONLINE"
-
- def test_get_index_not_found(self, test_catalog, test_schema):
- """Should return NOT_FOUND for non-existent index."""
- result = get_vs_index(f"{test_catalog}.{test_schema}.nonexistent_idx_99999")
-
- assert result["state"] == "NOT_FOUND"
- assert "error" in result
-
-
-@pytest.mark.integration
-class TestListIndexes:
- """Tests for listing indexes."""
-
- def test_list_indexes(self, vs_endpoint_name: str, vs_direct_index_name: str):
- """Should list indexes on the endpoint including our test index."""
- indexes = list_vs_indexes(vs_endpoint_name)
-
- logger.info(f"Found {len(indexes)} indexes on {vs_endpoint_name}")
-
- assert isinstance(indexes, list)
- assert len(indexes) > 0
-
- names = [idx["name"] for idx in indexes]
- assert vs_direct_index_name in names, f"Test index '{vs_direct_index_name}' not in: {names}"
-
-
-@pytest.mark.integration
-class TestSyncIndex:
- """Tests for syncing indexes."""
-
- def test_sync_direct_access_index(self, vs_direct_index_name: str):
- """Sync on Direct Access index should raise (not a Delta Sync index)."""
- # Direct Access indexes don't support sync - expect an error
- with pytest.raises(Exception):
- sync_vs_index(vs_direct_index_name)
-
-
-@pytest.mark.integration
-class TestDeleteIndex:
- """Tests for deleting indexes."""
-
- def test_delete_index_not_found(self, test_catalog, test_schema):
- """Should return NOT_FOUND when deleting non-existent index."""
- result = delete_vs_index(f"{test_catalog}.{test_schema}.nonexistent_idx_99999")
-
- assert result["status"] == "NOT_FOUND"
- assert "error" in result
diff --git a/databricks-tools-core/tests/integration/vector_search/test_queries.py b/databricks-tools-core/tests/integration/vector_search/test_queries.py
deleted file mode 100644
index 9a1e6205..00000000
--- a/databricks-tools-core/tests/integration/vector_search/test_queries.py
+++ /dev/null
@@ -1,192 +0,0 @@
-"""
-Integration tests for Vector Search query and data operations.
-
-Tests:
-- upsert_vs_data
-- query_vs_index
-- scan_vs_index
-- delete_vs_data
-
-These tests run in order against a shared Direct Access index.
-"""
-
-import json
-import logging
-import time
-
-import pytest
-
-from databricks_tools_core.vector_search import (
- delete_vs_data,
- query_vs_index,
- scan_vs_index,
- upsert_vs_data,
-)
-
-logger = logging.getLogger(__name__)
-
-EMBEDDING_DIM = 8
-
-# Sample data with small embeddings
-SAMPLE_RECORDS = [
- {
- "id": 1,
- "text": "machine learning basics",
- "embedding": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
- },
- {
- "id": 2,
- "text": "deep learning neural networks",
- "embedding": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
- },
- {
- "id": 3,
- "text": "natural language processing",
- "embedding": [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
- },
- {
- "id": 4,
- "text": "computer vision and images",
- "embedding": [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
- },
- {
- "id": 5,
- "text": "reinforcement learning agents",
- "embedding": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
- },
-]
-
-
-@pytest.mark.integration
-class TestUpsertData:
- """Tests for upserting data into a Direct Access index."""
-
- def test_upsert_records(self, vs_direct_index_name: str):
- """Should upsert 5 records successfully."""
- result = upsert_vs_data(
- index_name=vs_direct_index_name,
- inputs_json=json.dumps(SAMPLE_RECORDS),
- )
-
- logger.info(f"Upsert result: {result}")
-
- assert result["name"] == vs_direct_index_name
- assert result["num_records"] == 5
- # Status can be SUCCESS or the enum value
- assert "status" in result
-
- # Small delay for data to be indexed
- time.sleep(3)
-
- def test_upsert_single_record(self, vs_direct_index_name: str):
- """Should upsert a single record (update existing)."""
- updated_record = [
- {
- "id": 1,
- "text": "machine learning fundamentals updated",
- "embedding": [0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85],
- }
- ]
-
- result = upsert_vs_data(
- index_name=vs_direct_index_name,
- inputs_json=json.dumps(updated_record),
- )
-
- assert result["num_records"] == 1
- time.sleep(2)
-
-
-@pytest.mark.integration
-class TestQueryIndex:
- """Tests for querying a Vector Search index."""
-
- def test_query_with_vector(self, vs_direct_index_name: str):
- """Should return results when querying with a vector."""
- # Query with a vector similar to record 1
- query_vec = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
-
- result = query_vs_index(
- index_name=vs_direct_index_name,
- columns=["id", "text"],
- query_vector=query_vec,
- num_results=3,
- )
-
- logger.info(f"Query result: {result}")
-
- assert "columns" in result
- assert "data" in result
- assert result["num_results"] > 0
- assert result["num_results"] <= 3
-
- def test_query_with_fewer_results(self, vs_direct_index_name: str):
- """Should respect num_results parameter."""
- query_vec = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
-
- result = query_vs_index(
- index_name=vs_direct_index_name,
- columns=["id", "text"],
- query_vector=query_vec,
- num_results=1,
- )
-
- assert result["num_results"] == 1
-
- def test_query_missing_params_raises(self, vs_direct_index_name: str):
- """Should raise ValueError when neither query_text nor query_vector provided."""
- with pytest.raises(ValueError, match="query_text or query_vector"):
- query_vs_index(
- index_name=vs_direct_index_name,
- columns=["id", "text"],
- )
-
-
-@pytest.mark.integration
-class TestScanIndex:
- """Tests for scanning index contents."""
-
- def test_scan_index(self, vs_direct_index_name: str):
- """Should return all entries in the index."""
- result = scan_vs_index(
- index_name=vs_direct_index_name,
- num_results=100,
- )
-
- logger.info(f"Scan result: {result.get('num_results')} entries")
-
- assert "columns" in result
- assert "data" in result
- assert result["num_results"] >= 5 # we upserted 5 records
-
-
-@pytest.mark.integration
-class TestDeleteData:
- """Tests for deleting data from a Direct Access index."""
-
- def test_delete_records(self, vs_direct_index_name: str):
- """Should delete specified records by primary key."""
- result = delete_vs_data(
- index_name=vs_direct_index_name,
- primary_keys=["4", "5"],
- )
-
- logger.info(f"Delete result: {result}")
-
- assert result["name"] == vs_direct_index_name
- assert result["num_deleted"] == 2
-
- # Small delay for deletion to propagate
- time.sleep(3)
-
- def test_verify_deletion(self, vs_direct_index_name: str):
- """Should show fewer entries after deletion."""
- result = scan_vs_index(
- index_name=vs_direct_index_name,
- num_results=100,
- )
-
- logger.info(f"After delete, scan shows {result.get('num_results')} entries")
-
- # Should have 3 remaining (5 upserted - 2 deleted)
- assert result["num_results"] <= 5 # at most 5
diff --git a/databricks-tools-core/tests/integration/volume/__init__.py b/databricks-tools-core/tests/integration/volume/__init__.py
deleted file mode 100644
index 9ee4d2ec..00000000
--- a/databricks-tools-core/tests/integration/volume/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Integration tests for Unity Catalog volume file operations."""
diff --git a/databricks-tools-core/tests/integration/volume/test_upload.py b/databricks-tools-core/tests/integration/volume/test_upload.py
deleted file mode 100644
index 401b920b..00000000
--- a/databricks-tools-core/tests/integration/volume/test_upload.py
+++ /dev/null
@@ -1,206 +0,0 @@
-"""Integration tests for volume upload functions.
-
-These tests actually upload files to a Databricks Unity Catalog volume and verify they exist.
-Requires a valid Databricks connection and an existing volume at /Volumes/main/demo/raw_data.
-"""
-
-import logging
-import os
-import tempfile
-import uuid
-from pathlib import Path
-
-import pytest
-from databricks.sdk import WorkspaceClient
-
-from databricks_tools_core.unity_catalog import upload_to_volume, list_volume_files
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture(scope="module")
-def volume_path(workspace_client: WorkspaceClient) -> str:
- """Create a unique folder path in an existing volume for testing and clean up after."""
- test_id = str(uuid.uuid4())[:8]
- # Use an existing volume - main.demo.raw_data is commonly available
- path = f"/Volumes/main/demo/raw_data/test_upload_{test_id}"
-
- logger.info(f"Using test volume path: {path}")
-
- yield path
-
- # Cleanup: recursively delete test folder contents
- def delete_recursive(folder_path: str):
- try:
- items = list_volume_files(folder_path)
- # Delete files first
- for item in items:
- if not item.is_directory:
- try:
- workspace_client.files.delete(item.path)
- except Exception:
- pass
- # Then delete subdirectories
- for item in items:
- if item.is_directory:
- delete_recursive(item.path)
- try:
- workspace_client.files.delete_directory(item.path)
- except Exception:
- pass
- except Exception:
- pass
-
- try:
- delete_recursive(path)
- workspace_client.files.delete_directory(path)
- logger.info(f"Cleaned up test folder: {path}")
- except Exception as e:
- logger.warning(f"Failed to clean up test folder {path}: {e}")
-
-
-@pytest.mark.integration
-class TestUploadToVolumeIntegration:
- """Integration tests for upload_to_volume."""
-
- def test_upload_single_file(self, workspace_client: WorkspaceClient, volume_path: str):
- """Test uploading a single file to volume."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
- f.write("col1,col2\n1,2\n3,4\n")
- local_path = f.name
-
- try:
- remote_path = f"{volume_path}/single_file.csv"
-
- result = upload_to_volume(
- local_path=local_path,
- volume_path=remote_path,
- )
-
- assert result.success, f"Upload failed: {result.results}"
- assert result.total_files == 1
- assert result.successful == 1
- assert result.failed == 0
-
- # Verify file exists in volume
- metadata = workspace_client.files.get_metadata(remote_path)
- assert metadata is not None
- logger.info(f"Successfully uploaded single file to {remote_path}")
-
- finally:
- os.unlink(local_path)
-
- def test_upload_folder(self, workspace_client: WorkspaceClient, volume_path: str):
- """Test uploading a folder with multiple files."""
- with tempfile.TemporaryDirectory() as tmpdir:
- # Create test files
- (Path(tmpdir) / "file1.csv").write_text("col1,col2\n1,2")
- (Path(tmpdir) / "file2.json").write_text('{"key": "value"}')
- (Path(tmpdir) / "subdir").mkdir()
- (Path(tmpdir) / "subdir" / "nested.txt").write_text("nested content")
-
- remote_path = f"{volume_path}/folder_test"
-
- result = upload_to_volume(
- local_path=tmpdir,
- volume_path=remote_path,
- )
-
- assert result.success, f"Upload failed: {result.results}"
- assert result.total_files == 3
- assert result.successful == 3
-
- # Verify files exist
- metadata1 = workspace_client.files.get_metadata(f"{remote_path}/file1.csv")
- assert metadata1 is not None
-
- metadata2 = workspace_client.files.get_metadata(f"{remote_path}/subdir/nested.txt")
- assert metadata2 is not None
-
- logger.info(f"Successfully uploaded folder with {result.total_files} files")
-
- def test_upload_glob_pattern(self, workspace_client: WorkspaceClient, volume_path: str):
- """Test uploading files matching a glob pattern."""
- with tempfile.TemporaryDirectory() as tmpdir:
- # Create test files
- (Path(tmpdir) / "data1.csv").write_text("col1\n1")
- (Path(tmpdir) / "data2.csv").write_text("col1\n2")
- (Path(tmpdir) / "config.json").write_text('{"setting": true}')
- (Path(tmpdir) / "readme.md").write_text("# README")
-
- remote_path = f"{volume_path}/glob_test"
-
- # Upload only .csv files
- result = upload_to_volume(
- local_path=f"{tmpdir}/*.csv",
- volume_path=remote_path,
- )
-
- assert result.success, f"Upload failed: {result.results}"
- assert result.total_files == 2 # Only .csv files
- assert result.successful == 2
-
- # Verify .csv files exist
- metadata = workspace_client.files.get_metadata(f"{remote_path}/data1.csv")
- assert metadata is not None
-
- # Verify .json was NOT uploaded
- with pytest.raises(Exception):
- workspace_client.files.get_metadata(f"{remote_path}/config.json")
-
- logger.info("Successfully uploaded files matching glob pattern")
-
- def test_upload_star_glob(self, workspace_client: WorkspaceClient, volume_path: str):
- """Test uploading all contents with /* pattern."""
- with tempfile.TemporaryDirectory() as tmpdir:
- # Create test structure
- (Path(tmpdir) / "main.csv").write_text("col1\nmain")
- (Path(tmpdir) / "utils").mkdir()
- (Path(tmpdir) / "utils" / "helper.txt").write_text("helper content")
-
- remote_path = f"{volume_path}/star_test"
-
- # Upload all contents
- result = upload_to_volume(
- local_path=f"{tmpdir}/*",
- volume_path=remote_path,
- )
-
- assert result.success, f"Upload failed: {result.results}"
- assert result.total_files == 2 # main.csv + utils/helper.txt
-
- # Verify nested file exists
- metadata = workspace_client.files.get_metadata(f"{remote_path}/utils/helper.txt")
- assert metadata is not None
-
- logger.info("Successfully uploaded with /* glob pattern")
-
- def test_upload_overwrites_existing(self, workspace_client: WorkspaceClient, volume_path: str):
- """Test that overwrite=True replaces existing files."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
- f.write("version,1\n")
- local_path = f.name
-
- try:
- remote_path = f"{volume_path}/overwrite_test.csv"
-
- # First upload
- result1 = upload_to_volume(local_path=local_path, volume_path=remote_path)
- assert result1.success
-
- # Modify file
- with open(local_path, "w") as f:
- f.write("version,2\n")
-
- # Second upload with overwrite
- result2 = upload_to_volume(
- local_path=local_path,
- volume_path=remote_path,
- overwrite=True,
- )
- assert result2.success
-
- logger.info("Successfully tested overwrite functionality")
-
- finally:
- os.unlink(local_path)
diff --git a/databricks-tools-core/tests/integration/workspace/__init__.py b/databricks-tools-core/tests/integration/workspace/__init__.py
deleted file mode 100644
index 56961b39..00000000
--- a/databricks-tools-core/tests/integration/workspace/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Integration tests for workspace file operations."""
diff --git a/databricks-tools-core/tests/integration/workspace/test_upload.py b/databricks-tools-core/tests/integration/workspace/test_upload.py
deleted file mode 100644
index 84c4ceda..00000000
--- a/databricks-tools-core/tests/integration/workspace/test_upload.py
+++ /dev/null
@@ -1,185 +0,0 @@
-"""Integration tests for workspace upload functions.
-
-These tests actually upload files to Databricks workspace and verify they exist.
-Requires a valid Databricks connection.
-"""
-
-import logging
-import os
-import tempfile
-import uuid
-from pathlib import Path
-
-import pytest
-from databricks.sdk import WorkspaceClient
-
-from databricks_tools_core.file.workspace import upload_to_workspace
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture(scope="module")
-def workspace_path(workspace_client: WorkspaceClient) -> str:
- """Create a unique workspace path for testing and clean up after."""
- # Get current user's home folder
- user = workspace_client.current_user.me()
- test_id = str(uuid.uuid4())[:8]
- path = f"/Workspace/Users/{user.user_name}/test_upload_{test_id}"
-
- logger.info(f"Using test workspace path: {path}")
-
- yield path
-
- # Cleanup: delete the test folder
- try:
- workspace_client.workspace.delete(path, recursive=True)
- logger.info(f"Cleaned up test folder: {path}")
- except Exception as e:
- logger.warning(f"Failed to clean up test folder {path}: {e}")
-
-
-@pytest.mark.integration
-class TestUploadToWorkspaceIntegration:
- """Integration tests for upload_to_workspace."""
-
- def test_upload_single_file(self, workspace_client: WorkspaceClient, workspace_path: str):
- """Test uploading a single file to workspace."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write("# Test file\nprint('hello world')\n")
- local_path = f.name
-
- try:
- remote_path = f"{workspace_path}/single_file.py"
-
- result = upload_to_workspace(
- local_path=local_path,
- workspace_path=remote_path,
- )
-
- assert result.success, f"Upload failed: {result.results}"
- assert result.total_files == 1
- assert result.successful == 1
- assert result.failed == 0
-
- # Verify file exists in workspace
- status = workspace_client.workspace.get_status(remote_path)
- assert status is not None
- logger.info(f"Successfully uploaded single file to {remote_path}")
-
- finally:
- os.unlink(local_path)
-
- def test_upload_folder(self, workspace_client: WorkspaceClient, workspace_path: str):
- """Test uploading a folder with multiple files."""
- with tempfile.TemporaryDirectory() as tmpdir:
- # Create test files
- (Path(tmpdir) / "file1.py").write_text("# File 1\nprint(1)")
- (Path(tmpdir) / "file2.py").write_text("# File 2\nprint(2)")
- (Path(tmpdir) / "subdir").mkdir()
- (Path(tmpdir) / "subdir" / "nested.py").write_text("# Nested\nprint('nested')")
-
- remote_path = f"{workspace_path}/folder_test"
-
- result = upload_to_workspace(
- local_path=tmpdir,
- workspace_path=remote_path,
- )
-
- assert result.success, f"Upload failed: {result.results}"
- assert result.total_files == 3
- assert result.successful == 3
-
- # Verify files exist
- status1 = workspace_client.workspace.get_status(f"{remote_path}/file1.py")
- assert status1 is not None
-
- status2 = workspace_client.workspace.get_status(f"{remote_path}/subdir/nested.py")
- assert status2 is not None
-
- logger.info(f"Successfully uploaded folder with {result.total_files} files")
-
- def test_upload_glob_pattern(self, workspace_client: WorkspaceClient, workspace_path: str):
- """Test uploading files matching a glob pattern."""
- with tempfile.TemporaryDirectory() as tmpdir:
- # Create test files
- (Path(tmpdir) / "script1.py").write_text("# Script 1")
- (Path(tmpdir) / "script2.py").write_text("# Script 2")
- (Path(tmpdir) / "data.json").write_text('{"key": "value"}')
- (Path(tmpdir) / "readme.md").write_text("# README")
-
- remote_path = f"{workspace_path}/glob_test"
-
- # Upload only .py files
- result = upload_to_workspace(
- local_path=f"{tmpdir}/*.py",
- workspace_path=remote_path,
- )
-
- assert result.success, f"Upload failed: {result.results}"
- assert result.total_files == 2 # Only .py files
- assert result.successful == 2
-
- # Verify .py files exist
- status = workspace_client.workspace.get_status(f"{remote_path}/script1.py")
- assert status is not None
-
- # Verify .json was NOT uploaded
- with pytest.raises(Exception):
- workspace_client.workspace.get_status(f"{remote_path}/data.json")
-
- logger.info("Successfully uploaded files matching glob pattern")
-
- def test_upload_star_glob(self, workspace_client: WorkspaceClient, workspace_path: str):
- """Test uploading all contents with /* pattern."""
- with tempfile.TemporaryDirectory() as tmpdir:
- # Create test structure
- (Path(tmpdir) / "main.py").write_text("# Main")
- (Path(tmpdir) / "utils").mkdir()
- (Path(tmpdir) / "utils" / "helper.py").write_text("# Helper")
-
- remote_path = f"{workspace_path}/star_test"
-
- # Upload all contents
- result = upload_to_workspace(
- local_path=f"{tmpdir}/*",
- workspace_path=remote_path,
- )
-
- assert result.success, f"Upload failed: {result.results}"
- assert result.total_files == 2 # main.py + utils/helper.py
-
- # Verify nested file exists
- status = workspace_client.workspace.get_status(f"{remote_path}/utils/helper.py")
- assert status is not None
-
- logger.info("Successfully uploaded with /* glob pattern")
-
- def test_upload_overwrites_existing(self, workspace_client: WorkspaceClient, workspace_path: str):
- """Test that overwrite=True replaces existing files."""
- with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
- f.write("# Version 1\n")
- local_path = f.name
-
- try:
- remote_path = f"{workspace_path}/overwrite_test.py"
-
- # First upload
- result1 = upload_to_workspace(local_path=local_path, workspace_path=remote_path)
- assert result1.success
-
- # Modify file
- with open(local_path, "w") as f:
- f.write("# Version 2\n")
-
- # Second upload with overwrite
- result2 = upload_to_workspace(
- local_path=local_path,
- workspace_path=remote_path,
- overwrite=True,
- )
- assert result2.success
-
- logger.info("Successfully tested overwrite functionality")
-
- finally:
- os.unlink(local_path)
diff --git a/databricks-tools-core/tests/unit/__init__.py b/databricks-tools-core/tests/unit/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/databricks-tools-core/tests/unit/test_auth.py b/databricks-tools-core/tests/unit/test_auth.py
deleted file mode 100644
index 6a2a56e6..00000000
--- a/databricks-tools-core/tests/unit/test_auth.py
+++ /dev/null
@@ -1,170 +0,0 @@
-"""Unit tests for workspace switching auth state management."""
-
-import os
-from unittest import mock
-
-import pytest
-
-from databricks_tools_core.auth import (
- clear_active_workspace,
- clear_databricks_auth,
- get_active_workspace,
- get_current_username,
- get_workspace_client,
- set_active_workspace,
- set_databricks_auth,
-)
-
-_WS_CLIENT = "databricks_tools_core.auth.WorkspaceClient"
-_TAG_CLIENT = "databricks_tools_core.auth.tag_client"
-_HAS_OAUTH = "databricks_tools_core.auth._has_oauth_credentials"
-
-
-@pytest.fixture(autouse=True)
-def clean_state():
- """Reset auth state before and after every test."""
- clear_active_workspace()
- clear_databricks_auth()
- yield
- clear_active_workspace()
- clear_databricks_auth()
-
-
-# ---------------------------------------------------------------------------
-# set_active_workspace / get_active_workspace / clear_active_workspace
-# ---------------------------------------------------------------------------
-
-
-def test_default_no_active_workspace():
- """With no active workspace set, get_workspace_client falls through to the default SDK path."""
- with (
- mock.patch(_HAS_OAUTH, return_value=False),
- mock.patch(_TAG_CLIENT, side_effect=lambda c: c),
- mock.patch(_WS_CLIENT) as mock_ws,
- ):
- get_workspace_client()
- call_kwargs = mock_ws.call_args.kwargs
- assert "profile" not in call_kwargs
- assert "host" not in call_kwargs
-
-
-def test_set_active_profile():
- """After set_active_workspace(profile=...), WorkspaceClient is called with that profile."""
- set_active_workspace(profile="prod")
- with mock.patch(_TAG_CLIENT, side_effect=lambda c: c), mock.patch(_WS_CLIENT) as mock_ws:
- get_workspace_client()
- assert mock_ws.call_args.kwargs.get("profile") == "prod"
-
-
-def test_set_active_host():
- """After set_active_workspace(host=...), WorkspaceClient is called with that host."""
- set_active_workspace(host="https://adb-123.azuredatabricks.net")
- with (
- mock.patch(_HAS_OAUTH, return_value=False),
- mock.patch(_TAG_CLIENT, side_effect=lambda c: c),
- mock.patch(_WS_CLIENT) as mock_ws,
- ):
- get_workspace_client()
- assert mock_ws.call_args.kwargs.get("host") == "https://adb-123.azuredatabricks.net"
- assert "profile" not in mock_ws.call_args.kwargs
-
-
-def test_profile_takes_precedence_over_host():
- """When both profile and host are set, profile wins."""
- set_active_workspace(profile="myprofile", host="https://ignored.azuredatabricks.net")
- with mock.patch(_TAG_CLIENT, side_effect=lambda c: c), mock.patch(_WS_CLIENT) as mock_ws:
- get_workspace_client()
- assert mock_ws.call_args.kwargs.get("profile") == "myprofile"
- assert "host" not in mock_ws.call_args.kwargs
-
-
-def test_clear_resets_to_default():
- """After clear_active_workspace(), falls through to the default SDK path."""
- set_active_workspace(profile="prod")
- clear_active_workspace()
- with (
- mock.patch(_HAS_OAUTH, return_value=False),
- mock.patch(_TAG_CLIENT, side_effect=lambda c: c),
- mock.patch(_WS_CLIENT) as mock_ws,
- ):
- get_workspace_client()
- call_kwargs = mock_ws.call_args.kwargs
- assert "profile" not in call_kwargs
- assert "host" not in call_kwargs
-
-
-def test_get_active_workspace_reflects_state():
- """get_active_workspace() returns the current module-level state."""
- assert get_active_workspace() == {"profile": None, "host": None}
- set_active_workspace(profile="staging")
- assert get_active_workspace() == {"profile": "staging", "host": None}
- set_active_workspace(host="https://adb-456.net")
- assert get_active_workspace() == {"profile": None, "host": "https://adb-456.net"}
- clear_active_workspace()
- assert get_active_workspace() == {"profile": None, "host": None}
-
-
-def test_set_active_workspace_is_idempotent():
- """Calling set_active_workspace twice replaces the previous value."""
- set_active_workspace(profile="first")
- set_active_workspace(profile="second")
- assert get_active_workspace()["profile"] == "second"
-
-
-# ---------------------------------------------------------------------------
-# Priority: force_token and OAuth M2M override active workspace
-# ---------------------------------------------------------------------------
-
-
-def test_force_token_overrides_active_workspace():
- """set_databricks_auth(force_token=True) bypasses the active workspace override."""
- set_active_workspace(profile="should-be-ignored")
- set_databricks_auth("https://force-host.net", "force-token", force_token=True)
- with mock.patch(_TAG_CLIENT, side_effect=lambda c: c), mock.patch(_WS_CLIENT) as mock_ws:
- get_workspace_client()
- assert mock_ws.call_args.kwargs.get("host") == "https://force-host.net"
- assert mock_ws.call_args.kwargs.get("token") == "force-token"
- assert "profile" not in mock_ws.call_args.kwargs
-
-
-def test_active_workspace_bypassed_when_oauth_m2m():
- """When OAuth M2M credentials are in env, they take precedence over active workspace."""
- set_active_workspace(profile="should-be-ignored")
- env = {
- "DATABRICKS_CLIENT_ID": "my-client-id",
- "DATABRICKS_CLIENT_SECRET": "my-client-secret",
- "DATABRICKS_HOST": "https://apps-host.azuredatabricks.net",
- }
- with (
- mock.patch.dict(os.environ, env, clear=False),
- mock.patch(_TAG_CLIENT, side_effect=lambda c: c),
- mock.patch(_WS_CLIENT) as mock_ws,
- ):
- get_workspace_client()
- assert mock_ws.call_args.kwargs.get("client_id") == "my-client-id"
- assert mock_ws.call_args.kwargs.get("client_secret") == "my-client-secret"
- assert "profile" not in mock_ws.call_args.kwargs
-
-
-# ---------------------------------------------------------------------------
-# Username cache reset on workspace switch
-# ---------------------------------------------------------------------------
-
-
-def test_username_cache_reset_on_switch():
- """set_active_workspace() resets the cached username so it's re-fetched for the new workspace."""
- mock_client_a = mock.MagicMock()
- mock_client_a.current_user.me.return_value = mock.MagicMock(user_name="user-a@example.com")
- mock_client_b = mock.MagicMock()
- mock_client_b.current_user.me.return_value = mock.MagicMock(user_name="user-b@example.com")
-
- with mock.patch(_TAG_CLIENT, side_effect=lambda c: c), mock.patch(_WS_CLIENT, return_value=mock_client_a):
- set_active_workspace(profile="workspace-a")
- username_a = get_current_username()
- assert username_a == "user-a@example.com"
-
- # Switch workspace — cache should be invalidated
- with mock.patch(_TAG_CLIENT, side_effect=lambda c: c), mock.patch(_WS_CLIENT, return_value=mock_client_b):
- set_active_workspace(profile="workspace-b")
- username_b = get_current_username()
- assert username_b == "user-b@example.com"
diff --git a/databricks-tools-core/tests/unit/test_identity.py b/databricks-tools-core/tests/unit/test_identity.py
deleted file mode 100644
index 1b7ac7ca..00000000
--- a/databricks-tools-core/tests/unit/test_identity.py
+++ /dev/null
@@ -1,91 +0,0 @@
-"""Unit tests for identity module helpers."""
-
-from unittest import mock
-
-from databricks_tools_core.identity import (
- DESCRIPTION_FOOTER,
- _load_version,
- with_description_footer,
-)
-
-
-# ── _load_version ──────────────────────────────────────────────────────
-
-
-def test_load_version_reads_version_file(tmp_path):
- """Finds a VERSION file by walking up from __file__."""
- (tmp_path / "VERSION").write_text("1.2.3\n")
- nested = tmp_path / "a" / "b"
- nested.mkdir(parents=True)
- mod = nested / "mod.py"
- mod.touch()
-
- with mock.patch("databricks_tools_core.identity.__file__", str(mod)):
- assert _load_version() == "1.2.3"
-
-
-def test_load_version_strips_whitespace(tmp_path):
- (tmp_path / "VERSION").write_text(" 0.5.0-beta \n\n")
- mod = tmp_path / "pkg" / "mod.py"
- mod.parent.mkdir()
- mod.touch()
-
- with mock.patch("databricks_tools_core.identity.__file__", str(mod)):
- assert _load_version() == "0.5.0-beta"
-
-
-def test_load_version_fallback_when_no_file(tmp_path):
- """Returns fallback when no VERSION file exists anywhere."""
- mod = tmp_path / "a" / "b" / "c" / "d" / "e" / "f" / "g" / "mod.py"
- mod.parent.mkdir(parents=True)
- mod.touch()
-
- with mock.patch("databricks_tools_core.identity.__file__", str(mod)):
- assert _load_version() == "0.0.0-unknown"
-
-
-def test_load_version_fallback_on_exception():
- """Returns fallback when __file__ resolution fails."""
- with mock.patch(
- "databricks_tools_core.identity.__file__",
- "/nonexistent/path/that/cannot/resolve/mod.py",
- ):
- # Still returns fallback (caught by the broad except)
- result = _load_version()
- assert isinstance(result, str)
-
-
-# ── with_description_footer ────────────────────────────────────────────
-
-
-def test_footer_appended_to_description():
- result = with_description_footer("My cool resource")
- assert result == f"My cool resource\n\n{DESCRIPTION_FOOTER}"
-
-
-def test_footer_only_when_none():
- assert with_description_footer(None) == DESCRIPTION_FOOTER
-
-
-def test_footer_only_when_empty_string():
- assert with_description_footer("") == DESCRIPTION_FOOTER
-
-
-def test_footer_only_when_whitespace():
- # Empty-ish strings are falsy only if literally empty; " " is truthy
- result = with_description_footer(" ")
- assert DESCRIPTION_FOOTER in result
-
-
-def test_footer_preserves_multiline_description():
- desc = "Line one\nLine two"
- result = with_description_footer(desc)
- assert result.startswith("Line one\nLine two\n\n")
- assert result.endswith(DESCRIPTION_FOOTER)
-
-
-def test_footer_idempotent_guard():
- """Calling twice appends twice (no dedup) — caller is responsible."""
- once = with_description_footer("desc")
- twice = with_description_footer(once)
- assert twice.count(DESCRIPTION_FOOTER) == 2
diff --git a/databricks-tools-core/tests/unit/test_llm.py b/databricks-tools-core/tests/unit/test_llm.py
deleted file mode 100644
index 7393c00a..00000000
--- a/databricks-tools-core/tests/unit/test_llm.py
+++ /dev/null
@@ -1,126 +0,0 @@
-"""Unit tests for LLM endpoint discovery."""
-
-import pytest
-from unittest.mock import patch
-
-from databricks_tools_core.pdf.llm import _discover_databricks_gpt_endpoints, _get_model_name, LLMConfigurationError
-
-
-class TestEndpointDiscovery:
- """Test dynamic databricks-gpt endpoint discovery."""
-
- def setup_method(self):
- """Clear the lru_cache before each test."""
- _discover_databricks_gpt_endpoints.cache_clear()
-
- @patch("databricks_tools_core.pdf.llm.list_serving_endpoints")
- def test_discover_latest_gpt_endpoints(self, mock_list):
- """Test discovering the latest databricks-gpt endpoints."""
- mock_list.return_value = [
- {"name": "databricks-gpt-5-2", "state": "READY"},
- {"name": "databricks-gpt-5-4", "state": "READY"},
- {"name": "databricks-gpt-5-3", "state": "READY"},
- {"name": "databricks-gpt-5-4-nano", "state": "READY"},
- {"name": "databricks-gpt-5-2-nano", "state": "READY"},
- {"name": "other-endpoint", "state": "READY"},
- ]
-
- main_model, nano_model = _discover_databricks_gpt_endpoints()
-
- assert main_model == "databricks-gpt-5-4"
- assert nano_model == "databricks-gpt-5-4-nano"
-
- @patch("databricks_tools_core.pdf.llm.list_serving_endpoints")
- def test_discover_no_nano_falls_back_to_main(self, mock_list):
- """Test that nano falls back to main model if no nano available."""
- mock_list.return_value = [
- {"name": "databricks-gpt-5-4", "state": "READY"},
- {"name": "databricks-gpt-5-3", "state": "READY"},
- ]
-
- main_model, nano_model = _discover_databricks_gpt_endpoints()
-
- assert main_model == "databricks-gpt-5-4"
- assert nano_model == "databricks-gpt-5-4" # Falls back to main
-
- @patch("databricks_tools_core.pdf.llm.list_serving_endpoints")
- def test_discover_ignores_not_ready_endpoints(self, mock_list):
- """Test that NOT_READY endpoints are ignored."""
- mock_list.return_value = [
- {"name": "databricks-gpt-5-5", "state": "NOT_READY"},
- {"name": "databricks-gpt-5-4", "state": "READY"},
- ]
-
- main_model, nano_model = _discover_databricks_gpt_endpoints()
-
- assert main_model == "databricks-gpt-5-4"
-
- @patch("databricks_tools_core.pdf.llm.list_serving_endpoints")
- def test_discover_no_gpt_endpoints(self, mock_list):
- """Test when no databricks-gpt endpoints exist."""
- mock_list.return_value = [
- {"name": "my-custom-model", "state": "READY"},
- ]
-
- main_model, nano_model = _discover_databricks_gpt_endpoints()
-
- assert main_model is None
- assert nano_model is None
-
- @patch("databricks_tools_core.pdf.llm.list_serving_endpoints")
- def test_discover_handles_api_error(self, mock_list):
- """Test graceful handling of API errors."""
- mock_list.side_effect = Exception("API error")
-
- main_model, nano_model = _discover_databricks_gpt_endpoints()
-
- assert main_model is None
- assert nano_model is None
-
-
-class TestGetModelName:
- """Test model name resolution with priority order."""
-
- def setup_method(self):
- """Clear the lru_cache before each test."""
- _discover_databricks_gpt_endpoints.cache_clear()
-
- def test_explicit_model_name_takes_priority(self):
- """Test that explicit model_name parameter wins."""
- result = _get_model_name(mini=False, model_name="my-custom-model")
- assert result == "my-custom-model"
-
- @patch.dict("os.environ", {"DATABRICKS_MODEL": "env-model"})
- def test_env_var_takes_priority_over_discovery(self):
- """Test that env var is used before auto-discovery."""
- result = _get_model_name(mini=False)
- assert result == "env-model"
-
- @patch.dict("os.environ", {"DATABRICKS_MODEL_NANO": "env-nano-model"})
- def test_nano_env_var_for_mini(self):
- """Test that DATABRICKS_MODEL_NANO is used for mini=True."""
- result = _get_model_name(mini=True)
- assert result == "env-nano-model"
-
- @patch.dict("os.environ", {}, clear=True)
- @patch("databricks_tools_core.pdf.llm._discover_databricks_gpt_endpoints")
- def test_auto_discovery_when_no_env(self, mock_discover):
- """Test auto-discovery when no env vars set."""
- mock_discover.return_value = ("databricks-gpt-5-4", "databricks-gpt-5-4-nano")
-
- result = _get_model_name(mini=False)
- assert result == "databricks-gpt-5-4"
-
- result = _get_model_name(mini=True)
- assert result == "databricks-gpt-5-4-nano"
-
- @patch.dict("os.environ", {}, clear=True)
- @patch("databricks_tools_core.pdf.llm._discover_databricks_gpt_endpoints")
- def test_raises_error_when_no_model_found(self, mock_discover):
- """Test that LLMConfigurationError is raised when no model available."""
- mock_discover.return_value = (None, None)
-
- with pytest.raises(LLMConfigurationError) as exc_info:
- _get_model_name(mini=False)
-
- assert "No LLM model configured" in str(exc_info.value)
diff --git a/databricks-tools-core/tests/unit/test_sql.py b/databricks-tools-core/tests/unit/test_sql.py
deleted file mode 100644
index fb7a4bc7..00000000
--- a/databricks-tools-core/tests/unit/test_sql.py
+++ /dev/null
@@ -1,224 +0,0 @@
-"""Unit tests for SQL execution functions."""
-
-from unittest import mock
-
-from databricks.sdk.service.sql import QueryTag, State, StatementState
-
-from databricks_tools_core.sql import execute_sql, execute_sql_multi
-from databricks_tools_core.sql.sql_utils import SQLExecutor
-from databricks_tools_core.sql.warehouse import _sort_within_tier, get_best_warehouse
-
-
-class TestExecuteSQLQueryTags:
- """Tests for query_tags parameter passthrough."""
-
- @mock.patch("databricks_tools_core.sql.sql.get_best_warehouse", return_value="wh-123")
- @mock.patch("databricks_tools_core.sql.sql.SQLExecutor")
- def test_execute_sql_passes_query_tags_to_executor(self, mock_executor_cls, mock_warehouse):
- """query_tags should be passed through to SQLExecutor.execute()."""
- mock_executor = mock.Mock()
- mock_executor.execute.return_value = [{"num": 1}]
- mock_executor_cls.return_value = mock_executor
-
- execute_sql(
- sql_query="SELECT 1",
- warehouse_id="wh-123",
- query_tags="team:eng,cost_center:701",
- )
-
- mock_executor.execute.assert_called_once()
- call_kwargs = mock_executor.execute.call_args.kwargs
- assert call_kwargs["query_tags"] == "team:eng,cost_center:701"
-
- @mock.patch("databricks_tools_core.sql.sql.get_best_warehouse", return_value="wh-123")
- @mock.patch("databricks_tools_core.sql.sql.SQLExecutor")
- def test_execute_sql_without_query_tags(self, mock_executor_cls, mock_warehouse):
- """When query_tags not provided, executor should not receive it (or receive None)."""
- mock_executor = mock.Mock()
- mock_executor.execute.return_value = [{"num": 1}]
- mock_executor_cls.return_value = mock_executor
-
- execute_sql(sql_query="SELECT 1", warehouse_id="wh-123")
-
- mock_executor.execute.assert_called_once()
- call_kwargs = mock_executor.execute.call_args.kwargs
- assert call_kwargs.get("query_tags") is None
-
- @mock.patch("databricks_tools_core.sql.sql.get_best_warehouse", return_value="wh-123")
- @mock.patch("databricks_tools_core.sql.sql.SQLParallelExecutor")
- def test_execute_sql_multi_passes_query_tags(self, mock_parallel_cls, mock_warehouse):
- """query_tags should be passed through to SQLParallelExecutor.execute()."""
- mock_executor = mock.Mock()
- mock_executor.execute.return_value = {
- "results": {0: {"status": "success", "query_index": 0}},
- "execution_summary": {"total_queries": 1, "total_groups": 1},
- }
- mock_parallel_cls.return_value = mock_executor
-
- execute_sql_multi(
- sql_content="SELECT 1;",
- warehouse_id="wh-123",
- query_tags="app:agent,env:dev",
- )
-
- mock_executor.execute.assert_called_once()
- call_kwargs = mock_executor.execute.call_args.kwargs
- assert call_kwargs["query_tags"] == "app:agent,env:dev"
-
-
-class TestSQLExecutorQueryTags:
- """Tests for SQLExecutor passing query_tags to the API."""
-
- @mock.patch("databricks_tools_core.sql.sql_utils.executor.get_workspace_client")
- def test_executor_passes_query_tags_to_api(self, mock_get_client):
- """SQLExecutor.execute() should include query_tags in execute_statement call."""
- mock_client = mock.Mock()
- mock_response = mock.Mock()
- mock_response.statement_id = "stmt-1"
- mock_client.statement_execution.execute_statement.return_value = mock_response
-
- # Simulate SUCCEEDED state on get_statement
- mock_status = mock.Mock()
- mock_status.status.state = StatementState.SUCCEEDED
- mock_status.result = mock.Mock()
- mock_status.result.data_array = []
- mock_status.manifest = None
- mock_client.statement_execution.get_statement.return_value = mock_status
-
- mock_get_client.return_value = mock_client
-
- executor = SQLExecutor(warehouse_id="wh-123", client=mock_client)
- executor.execute(
- sql_query="SELECT 1",
- query_tags="team:eng,cost_center:701",
- )
-
- call_kwargs = mock_client.statement_execution.execute_statement.call_args.kwargs
- # query_tags string is parsed into List[QueryTag] objects
- expected_tags = [
- QueryTag(key="team", value="eng"),
- QueryTag(key="cost_center", value="701"),
- ]
- assert call_kwargs.get("query_tags") == expected_tags
-
- @mock.patch("databricks_tools_core.sql.sql_utils.executor.get_workspace_client")
- def test_executor_without_query_tags_omits_from_api(self, mock_get_client):
- """When query_tags not provided, it should not be in the API call."""
- mock_client = mock.Mock()
- mock_response = mock.Mock()
- mock_response.statement_id = "stmt-1"
- mock_client.statement_execution.execute_statement.return_value = mock_response
-
- mock_status = mock.Mock()
- mock_status.status.state = StatementState.SUCCEEDED
- mock_status.result = mock.Mock()
- mock_status.result.data_array = []
- mock_status.manifest = None
- mock_client.statement_execution.get_statement.return_value = mock_status
-
- mock_get_client.return_value = mock_client
-
- executor = SQLExecutor(warehouse_id="wh-123", client=mock_client)
- executor.execute(sql_query="SELECT 1")
-
- call_kwargs = mock_client.statement_execution.execute_statement.call_args.kwargs
- assert "query_tags" not in call_kwargs
-
-
-def _make_warehouse(id, name, state, creator_name="other@example.com", enable_serverless_compute=False):
- """Helper to create a mock warehouse object."""
- w = mock.Mock()
- w.id = id
- w.name = name
- w.state = state
- w.creator_name = creator_name
- w.enable_serverless_compute = enable_serverless_compute
- w.cluster_size = "Small"
- w.auto_stop_mins = 10
- return w
-
-
-class TestSortWithinTier:
- """Tests for _sort_within_tier serverless and user-owned preference."""
-
- def test_serverless_first(self):
- """Serverless warehouses should come before classic ones."""
- classic = _make_warehouse("c1", "Classic WH", State.RUNNING)
- serverless = _make_warehouse("s1", "Serverless WH", State.RUNNING, enable_serverless_compute=True)
- result = _sort_within_tier([classic, serverless], current_user=None)
- assert result[0].id == "s1"
- assert result[1].id == "c1"
-
- def test_serverless_before_user_owned(self):
- """Serverless should be preferred over user-owned classic."""
- classic_owned = _make_warehouse("c1", "My WH", State.RUNNING, creator_name="me@example.com")
- serverless_other = _make_warehouse(
- "s1", "Other WH", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True
- )
- result = _sort_within_tier([classic_owned, serverless_other], current_user="me@example.com")
- assert result[0].id == "s1"
-
- def test_serverless_user_owned_first(self):
- """Among serverless, user-owned should come first."""
- serverless_other = _make_warehouse(
- "s1", "Other Serverless", State.RUNNING, creator_name="other@example.com", enable_serverless_compute=True
- )
- serverless_owned = _make_warehouse(
- "s2", "My Serverless", State.RUNNING, creator_name="me@example.com", enable_serverless_compute=True
- )
- result = _sort_within_tier([serverless_other, serverless_owned], current_user="me@example.com")
- assert result[0].id == "s2"
- assert result[1].id == "s1"
-
- def test_empty_list(self):
- assert _sort_within_tier([], current_user="me@example.com") == []
-
- def test_no_current_user(self):
- """Without a current user, only serverless preference applies."""
- classic = _make_warehouse("c1", "Classic", State.RUNNING)
- serverless = _make_warehouse("s1", "Serverless", State.RUNNING, enable_serverless_compute=True)
- result = _sort_within_tier([classic, serverless], current_user=None)
- assert result[0].id == "s1"
-
-
-class TestGetBestWarehouseServerless:
- """Tests for serverless preference in get_best_warehouse."""
-
- @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com")
- @mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client")
- def test_prefers_serverless_within_running_shared(self, mock_client_fn, mock_user):
- """Among running shared warehouses, serverless should be picked."""
- classic_shared = _make_warehouse("c1", "Shared WH", State.RUNNING)
- serverless_shared = _make_warehouse("s1", "Shared Serverless", State.RUNNING, enable_serverless_compute=True)
- mock_client = mock.Mock()
- mock_client.warehouses.list.return_value = [classic_shared, serverless_shared]
- mock_client_fn.return_value = mock_client
-
- result = get_best_warehouse()
- assert result == "s1"
-
- @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com")
- @mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client")
- def test_prefers_serverless_within_running_other(self, mock_client_fn, mock_user):
- """Among running non-shared warehouses, serverless should be picked."""
- classic = _make_warehouse("c1", "My WH", State.RUNNING)
- serverless = _make_warehouse("s1", "Fast WH", State.RUNNING, enable_serverless_compute=True)
- mock_client = mock.Mock()
- mock_client.warehouses.list.return_value = [classic, serverless]
- mock_client_fn.return_value = mock_client
-
- result = get_best_warehouse()
- assert result == "s1"
-
- @mock.patch("databricks_tools_core.sql.warehouse.get_current_username", return_value="me@example.com")
- @mock.patch("databricks_tools_core.sql.warehouse.get_workspace_client")
- def test_tier_order_preserved_over_serverless(self, mock_client_fn, mock_user):
- """A running shared classic should still beat a stopped serverless."""
- running_shared_classic = _make_warehouse("c1", "Shared WH", State.RUNNING)
- stopped_serverless = _make_warehouse("s1", "Fast WH", State.STOPPED, enable_serverless_compute=True)
- mock_client = mock.Mock()
- mock_client.warehouses.list.return_value = [stopped_serverless, running_shared_classic]
- mock_client_fn.return_value = mock_client
-
- result = get_best_warehouse()
- assert result == "c1"
diff --git a/databricks-tools-core/tests/unit/test_volume_files.py b/databricks-tools-core/tests/unit/test_volume_files.py
deleted file mode 100644
index c907b33e..00000000
--- a/databricks-tools-core/tests/unit/test_volume_files.py
+++ /dev/null
@@ -1,401 +0,0 @@
-"""Unit tests for volume file upload functions."""
-
-import os
-from pathlib import Path
-from unittest import mock
-
-import pytest
-
-from databricks_tools_core.unity_catalog.volume_files import (
- upload_to_volume,
- delete_from_volume,
- _collect_local_files,
- _collect_local_directories,
- VolumeUploadResult,
- VolumeFolderUploadResult,
- VolumeDeleteResult,
-)
-
-
-class TestCollectLocalFiles:
- """Tests for _collect_local_files helper function."""
-
- def test_collects_files_recursively(self, tmp_path):
- """Should collect all files in nested directories."""
- (tmp_path / "file1.csv").write_text("content1")
- (tmp_path / "subdir").mkdir()
- (tmp_path / "subdir" / "file2.csv").write_text("content2")
- (tmp_path / "subdir" / "nested").mkdir()
- (tmp_path / "subdir" / "nested" / "file3.csv").write_text("content3")
-
- files = _collect_local_files(str(tmp_path))
-
- assert len(files) == 3
- rel_paths = {f[1] for f in files}
- assert "file1.csv" in rel_paths
- assert os.path.join("subdir", "file2.csv") in rel_paths
- assert os.path.join("subdir", "nested", "file3.csv") in rel_paths
-
- def test_skips_hidden_files(self, tmp_path):
- """Should skip files starting with dot."""
- (tmp_path / "visible.csv").write_text("content")
- (tmp_path / ".hidden").write_text("hidden")
-
- files = _collect_local_files(str(tmp_path))
-
- assert len(files) == 1
- assert files[0][1] == "visible.csv"
-
- def test_skips_pycache(self, tmp_path):
- """Should skip __pycache__ directories."""
- (tmp_path / "file.csv").write_text("content")
- (tmp_path / "__pycache__").mkdir()
- (tmp_path / "__pycache__" / "cached.pyc").write_text("cached")
-
- files = _collect_local_files(str(tmp_path))
-
- assert len(files) == 1
- assert files[0][1] == "file.csv"
-
-
-class TestCollectLocalDirectories:
- """Tests for _collect_local_directories helper function."""
-
- def test_collects_directories_recursively(self, tmp_path):
- """Should collect all directories."""
- (tmp_path / "dir1").mkdir()
- (tmp_path / "dir1" / "subdir").mkdir()
- (tmp_path / "dir2").mkdir()
-
- dirs = _collect_local_directories(str(tmp_path))
-
- assert "dir1" in dirs
- assert "dir2" in dirs
- assert os.path.join("dir1", "subdir") in dirs
-
- def test_skips_hidden_directories(self, tmp_path):
- """Should skip directories starting with dot."""
- (tmp_path / "visible").mkdir()
- (tmp_path / ".hidden").mkdir()
-
- dirs = _collect_local_directories(str(tmp_path))
-
- assert "visible" in dirs
- assert ".hidden" not in dirs
-
-
-class TestUploadToVolume:
- """Tests for upload_to_volume function."""
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_single_file_upload(self, mock_get_client, tmp_path):
- """Should upload a single file correctly."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- test_file = tmp_path / "test.csv"
- test_file.write_text("col1,col2\n1,2")
-
- result = upload_to_volume(
- local_path=str(test_file),
- volume_path="/Volumes/catalog/schema/volume/test.csv",
- )
-
- assert result.success
- assert result.total_files == 1
- assert result.successful == 1
- assert result.failed == 0
- mock_client.files.upload_from.assert_called_once()
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_folder_upload(self, mock_get_client, tmp_path):
- """Should upload a folder with all files."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- (tmp_path / "file1.csv").write_text("content1")
- (tmp_path / "file2.json").write_text("{}")
-
- result = upload_to_volume(
- local_path=str(tmp_path),
- volume_path="/Volumes/catalog/schema/volume/data",
- )
-
- assert result.success
- assert result.total_files == 2
- assert result.successful == 2
- assert result.failed == 0
- assert mock_client.files.upload_from.call_count == 2
- mock_client.files.create_directory.assert_called()
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_glob_pattern_files(self, mock_get_client, tmp_path):
- """Should upload files matching glob pattern."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- (tmp_path / "file1.csv").write_text("content1")
- (tmp_path / "file2.csv").write_text("content2")
- (tmp_path / "data.json").write_text("{}")
-
- result = upload_to_volume(
- local_path=str(tmp_path / "*.csv"),
- volume_path="/Volumes/catalog/schema/volume/data",
- )
-
- assert result.success
- assert result.total_files == 2 # Only .csv files
- assert result.successful == 2
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_glob_pattern_star(self, mock_get_client, tmp_path):
- """Should upload all contents with /* pattern."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- (tmp_path / "file.csv").write_text("content")
- (tmp_path / "subdir").mkdir()
- (tmp_path / "subdir" / "nested.json").write_text("{}")
-
- result = upload_to_volume(
- local_path=str(tmp_path / "*"),
- volume_path="/Volumes/catalog/schema/volume/data",
- )
-
- assert result.success
- # file.csv + subdir/nested.json
- assert result.total_files == 2
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_nonexistent_path_returns_error(self, mock_get_client):
- """Should return error result for nonexistent path."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- result = upload_to_volume(
- local_path="/nonexistent/path/file.csv",
- volume_path="/Volumes/catalog/schema/volume/file.csv",
- )
-
- assert not result.success
- assert result.failed == 1
- assert "not found" in result.results[0].error.lower()
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_empty_glob_returns_error(self, mock_get_client, tmp_path):
- """Should return error when glob matches nothing."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- result = upload_to_volume(
- local_path=str(tmp_path / "*.nonexistent"),
- volume_path="/Volumes/catalog/schema/volume/data",
- )
-
- assert not result.success
- assert "no files match" in result.results[0].error.lower()
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_creates_parent_directories(self, mock_get_client, tmp_path):
- """Should create parent directories for single file upload."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- test_file = tmp_path / "test.csv"
- test_file.write_text("content")
-
- upload_to_volume(
- local_path=str(test_file),
- volume_path="/Volumes/catalog/schema/volume/deep/nested/test.csv",
- )
-
- # Should call create_directory for parent
- mock_client.files.create_directory.assert_called()
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_handles_upload_failure(self, mock_get_client, tmp_path):
- """Should handle upload failures gracefully."""
- mock_client = mock.Mock()
- mock_client.files.upload_from.side_effect = Exception("Upload failed")
- mock_get_client.return_value = mock_client
-
- test_file = tmp_path / "test.csv"
- test_file.write_text("content")
-
- result = upload_to_volume(
- local_path=str(test_file),
- volume_path="/Volumes/catalog/schema/volume/test.csv",
- )
-
- assert not result.success
- assert result.failed == 1
- assert "Upload failed" in result.results[0].error
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_expands_tilde_in_path(self, mock_get_client, tmp_path):
- """Should expand ~ in local path."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- # Create a file in a known location we can reference with ~
- home = Path.home()
- test_dir = home / ".test_upload_volume"
- test_dir.mkdir(exist_ok=True)
- test_file = test_dir / "test.csv"
- test_file.write_text("content")
-
- try:
- result = upload_to_volume(
- local_path="~/.test_upload_volume/test.csv",
- volume_path="/Volumes/catalog/schema/volume/test.csv",
- )
-
- assert result.total_files == 1
- mock_client.files.upload_from.assert_called_once()
- finally:
- test_file.unlink()
- test_dir.rmdir()
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_max_workers_parameter(self, mock_get_client, tmp_path):
- """Should respect max_workers parameter."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- # Create multiple files
- for i in range(10):
- (tmp_path / f"file{i}.csv").write_text(f"content{i}")
-
- result = upload_to_volume(
- local_path=str(tmp_path),
- volume_path="/Volumes/catalog/schema/volume/data",
- max_workers=2,
- )
-
- assert result.success
- assert result.total_files == 10
- assert mock_client.files.upload_from.call_count == 10
-
-
-class TestDeleteFromVolume:
- """Tests for delete_from_volume function."""
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_delete_file_succeeds(self, mock_get_client):
- """Should delete a file successfully."""
- mock_client = mock.Mock()
- # get_metadata succeeds means it's a file
- mock_client.files.get_metadata.return_value = mock.Mock(content_length=100)
- mock_get_client.return_value = mock_client
-
- result = delete_from_volume("/Volumes/catalog/schema/volume/file.csv")
-
- assert result.success
- assert result.files_deleted == 1
- assert result.directories_deleted == 0
- mock_client.files.delete.assert_called_once_with("/Volumes/catalog/schema/volume/file.csv")
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_delete_empty_directory_without_recursive(self, mock_get_client):
- """Should delete an empty directory without recursive flag."""
- mock_client = mock.Mock()
- # get_metadata fails for directories
- mock_client.files.get_metadata.side_effect = Exception("Not a file")
- # list_directory_contents returns empty
- mock_client.files.list_directory_contents.return_value = iter([])
- mock_get_client.return_value = mock_client
-
- result = delete_from_volume("/Volumes/catalog/schema/volume/empty_dir")
-
- assert result.success
- assert result.directories_deleted == 1
- mock_client.files.delete_directory.assert_called_once()
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_delete_nonempty_directory_without_recursive_fails(self, mock_get_client):
- """Should fail when deleting non-empty directory without recursive."""
- mock_client = mock.Mock()
- mock_client.files.get_metadata.side_effect = Exception("Not a file")
- # list succeeds with content
- mock_client.files.list_directory_contents.return_value = iter([
- mock.Mock(name="file.csv", path="/Volumes/.../file.csv", is_directory=False)
- ])
- # delete_directory fails because not empty
- mock_client.files.delete_directory.side_effect = Exception("Directory not empty")
- mock_get_client.return_value = mock_client
-
- result = delete_from_volume("/Volumes/catalog/schema/volume/nonempty_dir")
-
- assert not result.success
- assert "recursive=True" in result.error
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_delete_directory_recursive(self, mock_get_client):
- """Should delete directory and contents with recursive=True."""
- mock_client = mock.Mock()
- mock_client.files.get_metadata.side_effect = Exception("Not a file")
-
- # Simulate directory with nested content
- # The function calls list_directory_contents multiple times:
- # 1. Initial check to determine it's a directory
- # 2. _collect_volume_contents for root dir
- # 3. _collect_volume_contents for subdir
- mock_client.files.list_directory_contents.side_effect = [
- # Initial check (called with trailing /)
- iter([
- mock.Mock(name="file1.csv", path="/Volumes/c/s/v/dir/file1.csv", is_directory=False),
- ]),
- # _collect_volume_contents: root dir
- iter([
- mock.Mock(name="file1.csv", path="/Volumes/c/s/v/dir/file1.csv", is_directory=False),
- mock.Mock(name="subdir", path="/Volumes/c/s/v/dir/subdir", is_directory=True),
- ]),
- # _collect_volume_contents: subdir
- iter([
- mock.Mock(name="file2.csv", path="/Volumes/c/s/v/dir/subdir/file2.csv", is_directory=False),
- ]),
- ]
- mock_get_client.return_value = mock_client
-
- result = delete_from_volume("/Volumes/c/s/v/dir", recursive=True)
-
- assert result.success
- assert result.files_deleted == 2
- assert result.directories_deleted == 2 # subdir + root dir
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_delete_nonexistent_path_fails(self, mock_get_client):
- """Should fail for nonexistent path."""
- mock_client = mock.Mock()
- mock_client.files.get_metadata.side_effect = Exception("Not found")
- mock_client.files.list_directory_contents.side_effect = Exception("Not found")
- mock_get_client.return_value = mock_client
-
- result = delete_from_volume("/Volumes/catalog/schema/volume/nonexistent")
-
- assert not result.success
- assert "not found" in result.error.lower()
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_delete_file_handles_api_error(self, mock_get_client):
- """Should handle API errors gracefully."""
- mock_client = mock.Mock()
- mock_client.files.get_metadata.return_value = mock.Mock()
- mock_client.files.delete.side_effect = Exception("Permission denied")
- mock_get_client.return_value = mock_client
-
- result = delete_from_volume("/Volumes/catalog/schema/volume/file.csv")
-
- assert not result.success
- assert "Permission denied" in result.error
-
- @mock.patch("databricks_tools_core.unity_catalog.volume_files.get_workspace_client")
- def test_delete_strips_trailing_slash(self, mock_get_client):
- """Should strip trailing slash from path."""
- mock_client = mock.Mock()
- mock_client.files.get_metadata.return_value = mock.Mock()
- mock_get_client.return_value = mock_client
-
- result = delete_from_volume("/Volumes/catalog/schema/volume/file.csv/")
-
- assert result.volume_path == "/Volumes/catalog/schema/volume/file.csv"
diff --git a/databricks-tools-core/tests/unit/test_workspace.py b/databricks-tools-core/tests/unit/test_workspace.py
deleted file mode 100644
index 35adede5..00000000
--- a/databricks-tools-core/tests/unit/test_workspace.py
+++ /dev/null
@@ -1,387 +0,0 @@
-"""Unit tests for workspace file upload and delete functions."""
-
-import os
-import tempfile
-from pathlib import Path
-from unittest import mock
-
-import pytest
-
-from databricks_tools_core.file.workspace import (
- upload_to_workspace,
- delete_from_workspace,
- _collect_files,
- _collect_directories,
- _is_protected_path,
- UploadResult,
- FolderUploadResult,
- DeleteResult,
-)
-
-
-class TestCollectFiles:
- """Tests for _collect_files helper function."""
-
- def test_collects_files_recursively(self, tmp_path):
- """Should collect all files in nested directories."""
- # Create nested structure
- (tmp_path / "file1.py").write_text("content1")
- (tmp_path / "subdir").mkdir()
- (tmp_path / "subdir" / "file2.py").write_text("content2")
- (tmp_path / "subdir" / "nested").mkdir()
- (tmp_path / "subdir" / "nested" / "file3.py").write_text("content3")
-
- files = _collect_files(str(tmp_path))
-
- assert len(files) == 3
- rel_paths = {f[1] for f in files}
- assert "file1.py" in rel_paths
- assert os.path.join("subdir", "file2.py") in rel_paths
- assert os.path.join("subdir", "nested", "file3.py") in rel_paths
-
- def test_skips_hidden_files(self, tmp_path):
- """Should skip files starting with dot."""
- (tmp_path / "visible.py").write_text("content")
- (tmp_path / ".hidden").write_text("hidden")
-
- files = _collect_files(str(tmp_path))
-
- assert len(files) == 1
- assert files[0][1] == "visible.py"
-
- def test_skips_pycache(self, tmp_path):
- """Should skip __pycache__ directories."""
- (tmp_path / "file.py").write_text("content")
- (tmp_path / "__pycache__").mkdir()
- (tmp_path / "__pycache__" / "cached.pyc").write_text("cached")
-
- files = _collect_files(str(tmp_path))
-
- assert len(files) == 1
- assert files[0][1] == "file.py"
-
-
-class TestCollectDirectories:
- """Tests for _collect_directories helper function."""
-
- def test_collects_directories_recursively(self, tmp_path):
- """Should collect all directories."""
- (tmp_path / "dir1").mkdir()
- (tmp_path / "dir1" / "subdir").mkdir()
- (tmp_path / "dir2").mkdir()
-
- dirs = _collect_directories(str(tmp_path))
-
- assert "dir1" in dirs
- assert "dir2" in dirs
- assert os.path.join("dir1", "subdir") in dirs
-
- def test_skips_hidden_directories(self, tmp_path):
- """Should skip directories starting with dot."""
- (tmp_path / "visible").mkdir()
- (tmp_path / ".hidden").mkdir()
-
- dirs = _collect_directories(str(tmp_path))
-
- assert "visible" in dirs
- assert ".hidden" not in dirs
-
-
-class TestUploadToWorkspace:
- """Tests for upload_to_workspace function."""
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_single_file_upload(self, mock_get_client, tmp_path):
- """Should upload a single file correctly."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- # Create a test file
- test_file = tmp_path / "test.py"
- test_file.write_text("print('hello')")
-
- result = upload_to_workspace(
- local_path=str(test_file),
- workspace_path="/Workspace/Users/test@example.com/test.py",
- )
-
- assert result.success
- assert result.total_files == 1
- assert result.successful == 1
- assert result.failed == 0
- mock_client.workspace.upload.assert_called_once()
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_folder_upload(self, mock_get_client, tmp_path):
- """Should upload a folder with all files."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- # Create test folder with files
- (tmp_path / "file1.py").write_text("content1")
- (tmp_path / "file2.py").write_text("content2")
-
- result = upload_to_workspace(
- local_path=str(tmp_path),
- workspace_path="/Workspace/Users/test@example.com/project",
- )
-
- assert result.success
- assert result.total_files == 2
- assert result.successful == 2
- assert result.failed == 0
- assert mock_client.workspace.upload.call_count == 2
- mock_client.workspace.mkdirs.assert_called()
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_glob_pattern_files(self, mock_get_client, tmp_path):
- """Should upload files matching glob pattern."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- # Create test files
- (tmp_path / "file1.py").write_text("content1")
- (tmp_path / "file2.py").write_text("content2")
- (tmp_path / "data.json").write_text("{}")
-
- result = upload_to_workspace(
- local_path=str(tmp_path / "*.py"),
- workspace_path="/Workspace/Users/test@example.com/project",
- )
-
- assert result.success
- assert result.total_files == 2 # Only .py files
- assert result.successful == 2
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_glob_pattern_star(self, mock_get_client, tmp_path):
- """Should upload all contents with /* pattern."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- # Create test folder with subdir
- (tmp_path / "file.py").write_text("content")
- (tmp_path / "subdir").mkdir()
- (tmp_path / "subdir" / "nested.py").write_text("nested")
-
- result = upload_to_workspace(
- local_path=str(tmp_path / "*"),
- workspace_path="/Workspace/Users/test@example.com/project",
- )
-
- assert result.success
- # file.py + subdir/nested.py
- assert result.total_files == 2
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_nonexistent_path_returns_error(self, mock_get_client):
- """Should return error result for nonexistent path."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- result = upload_to_workspace(
- local_path="/nonexistent/path/file.py",
- workspace_path="/Workspace/Users/test@example.com/file.py",
- )
-
- assert not result.success
- assert result.failed == 1
- assert "not found" in result.results[0].error.lower()
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_empty_glob_returns_error(self, mock_get_client, tmp_path):
- """Should return error when glob matches nothing."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- result = upload_to_workspace(
- local_path=str(tmp_path / "*.nonexistent"),
- workspace_path="/Workspace/Users/test@example.com/project",
- )
-
- assert not result.success
- assert "no files match" in result.results[0].error.lower()
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_creates_parent_directories(self, mock_get_client, tmp_path):
- """Should create parent directories for single file upload."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- test_file = tmp_path / "test.py"
- test_file.write_text("content")
-
- upload_to_workspace(
- local_path=str(test_file),
- workspace_path="/Workspace/Users/test@example.com/deep/nested/test.py",
- )
-
- # Should call mkdirs for parent directory
- mock_client.workspace.mkdirs.assert_called_with(
- "/Workspace/Users/test@example.com/deep/nested"
- )
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_handles_upload_failure(self, mock_get_client, tmp_path):
- """Should handle upload failures gracefully."""
- mock_client = mock.Mock()
- mock_client.workspace.upload.side_effect = Exception("Upload failed")
- mock_get_client.return_value = mock_client
-
- test_file = tmp_path / "test.py"
- test_file.write_text("content")
-
- result = upload_to_workspace(
- local_path=str(test_file),
- workspace_path="/Workspace/Users/test@example.com/test.py",
- )
-
- assert not result.success
- assert result.failed == 1
- assert "Upload failed" in result.results[0].error
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_expands_tilde_in_path(self, mock_get_client, tmp_path):
- """Should expand ~ in local path."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- # Create a file in a known location we can reference with ~
- home = Path.home()
- test_dir = home / ".test_upload_workspace"
- test_dir.mkdir(exist_ok=True)
- test_file = test_dir / "test.py"
- test_file.write_text("content")
-
- try:
- result = upload_to_workspace(
- local_path="~/.test_upload_workspace/test.py",
- workspace_path="/Workspace/Users/test@example.com/test.py",
- )
-
- assert result.total_files == 1
- mock_client.workspace.upload.assert_called_once()
- finally:
- # Cleanup
- test_file.unlink()
- test_dir.rmdir()
-
-
-class TestIsProtectedPath:
- """Tests for _is_protected_path helper function."""
-
- def test_user_home_folder_is_protected(self):
- """Should protect user home folders."""
- assert _is_protected_path("/Workspace/Users/user@example.com") is True
- assert _is_protected_path("/Workspace/Users/user@example.com/") is True
- assert _is_protected_path("/Users/user@example.com") is True
- assert _is_protected_path("/Users/user@example.com/") is True
-
- def test_user_subfolder_is_not_protected(self):
- """Should allow deletion of user subfolders."""
- assert _is_protected_path("/Workspace/Users/user@example.com/my_folder") is False
- assert _is_protected_path("/Workspace/Users/user@example.com/project/src") is False
- assert _is_protected_path("/Users/user@example.com/my_folder") is False
-
- def test_repos_root_is_protected(self):
- """Should protect repos root folders."""
- assert _is_protected_path("/Workspace/Repos/user@example.com") is True
- assert _is_protected_path("/Repos/user@example.com") is True
-
- def test_repos_subfolder_is_not_protected(self):
- """Should allow deletion of repos subfolders."""
- assert _is_protected_path("/Workspace/Repos/user@example.com/my_repo") is False
- assert _is_protected_path("/Repos/user@example.com/my_repo") is False
-
- def test_shared_root_is_protected(self):
- """Should protect shared folder root."""
- assert _is_protected_path("/Workspace/Shared") is True
-
- def test_shared_subfolder_is_not_protected(self):
- """Should allow deletion of shared subfolders."""
- assert _is_protected_path("/Workspace/Shared/team_folder") is False
-
- def test_root_paths_are_protected(self):
- """Should protect root-level paths."""
- assert _is_protected_path("/") is True
- assert _is_protected_path("/Workspace") is True
- assert _is_protected_path("/Workspace/Users") is True
- assert _is_protected_path("/Users") is True
- assert _is_protected_path("/Workspace/Repos") is True
- assert _is_protected_path("/Repos") is True
-
-
-class TestDeleteFromWorkspace:
- """Tests for delete_from_workspace function."""
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_delete_file_succeeds(self, mock_get_client):
- """Should delete a file successfully."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- result = delete_from_workspace(
- workspace_path="/Workspace/Users/test@example.com/my_folder/file.py",
- )
-
- assert result.success
- mock_client.workspace.delete.assert_called_once_with(
- "/Workspace/Users/test@example.com/my_folder/file.py", recursive=False
- )
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_delete_folder_recursive(self, mock_get_client):
- """Should delete a folder recursively."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- result = delete_from_workspace(
- workspace_path="/Workspace/Users/test@example.com/my_folder",
- recursive=True,
- )
-
- assert result.success
- mock_client.workspace.delete.assert_called_once_with(
- "/Workspace/Users/test@example.com/my_folder", recursive=True
- )
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_delete_protected_path_fails(self, mock_get_client):
- """Should fail when trying to delete protected paths."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- result = delete_from_workspace(
- workspace_path="/Workspace/Users/test@example.com",
- )
-
- assert not result.success
- assert "protected path" in result.error.lower()
- mock_client.workspace.delete.assert_not_called()
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_delete_protected_path_with_trailing_slash_fails(self, mock_get_client):
- """Should fail even with trailing slash."""
- mock_client = mock.Mock()
- mock_get_client.return_value = mock_client
-
- result = delete_from_workspace(
- workspace_path="/Workspace/Users/test@example.com/",
- )
-
- assert not result.success
- assert "protected path" in result.error.lower()
-
- @mock.patch("databricks_tools_core.file.workspace.get_workspace_client")
- def test_delete_handles_api_error(self, mock_get_client):
- """Should handle API errors gracefully."""
- mock_client = mock.Mock()
- mock_client.workspace.delete.side_effect = Exception("Not found")
- mock_get_client.return_value = mock_client
-
- result = delete_from_workspace(
- workspace_path="/Workspace/Users/test@example.com/my_folder",
- )
-
- assert not result.success
- assert "Not found" in result.error
diff --git a/install.ps1 b/install.ps1
index e4b4ab67..854e112a 100644
--- a/install.ps1
+++ b/install.ps1
@@ -1,7 +1,7 @@
#
# Databricks AI Dev Kit - Unified Installer (Windows)
#
-# Installs skills, MCP server, and configuration for Claude Code, Cursor, OpenAI Codex, GitHub Copilot, Gemini CLI, and Antigravity.
+# Installs skills and configuration for Claude Code, Cursor, OpenAI Codex, GitHub Copilot, Gemini CLI, and Antigravity.
#
# Usage: irm https://raw.githubusercontent.com/databricks-solutions/ai-dev-kit/main/install.ps1 -OutFile install.ps1
# .\install.ps1 [OPTIONS]
@@ -22,8 +22,6 @@
# # Install for specific tools only
# .\install.ps1 -Tools cursor
#
-# # Skills only (skip MCP server)
-# .\install.ps1 -SkillsOnly
#
# # Install specific branch or tag
# $env:AIDEVKIT_BRANCH = '0.1.0'; .\install.ps1
@@ -52,9 +50,6 @@ $RepoUrl = "https://github.com/$Owner/$Repo.git"
$RawUrl = "https://raw.githubusercontent.com/$Owner/$Repo/$Branch"
$InstallDir = if ($env:AIDEVKIT_HOME) { $env:AIDEVKIT_HOME } else { Join-Path $env:USERPROFILE ".ai-dev-kit" }
$RepoDir = Join-Path $InstallDir "repo"
-$VenvDir = Join-Path $InstallDir ".venv"
-$VenvPython = Join-Path $VenvDir "Scripts\python.exe"
-$McpEntry = Join-Path $RepoDir "databricks-mcp-server\run_server.py"
# Minimum required versions
$MinCliVersion = "0.278.0"
@@ -64,14 +59,11 @@ $MinSdkVersion = "0.85.0"
$script:Profile_ = "DEFAULT"
$script:Scope = "project"
$script:ScopeExplicit = $false # Track if --global was explicitly passed
-$script:InstallMcp = $true
$script:InstallSkills = $true
$script:Force = $false
$script:Silent = $false
$script:UserTools = ""
$script:Tools = ""
-$script:UserMcpPath = ""
-$script:Pkg = ""
$script:ProfileProvided = $false
$script:SkillsProfile = ""
$script:UserSkills = ""
@@ -211,9 +203,6 @@ while ($i -lt $args.Count) {
switch ($args[$i]) {
{ $_ -in "-p", "--profile" } { $script:Profile_ = $args[$i + 1]; $script:ProfileProvided = $true; $i += 2 }
{ $_ -in "-g", "--global", "-Global" } { $script:Scope = "global"; $script:ScopeExplicit = $true; $i++ }
- { $_ -in "--skills-only", "-SkillsOnly" } { $script:InstallMcp = $false; $i++ }
- { $_ -in "--mcp-only", "-McpOnly" } { $script:InstallSkills = $false; $i++ }
- { $_ -in "--mcp-path", "-McpPath" } { $script:UserMcpPath = $args[$i + 1]; $i += 2 }
{ $_ -in "--silent", "-Silent" } { $script:Silent = $true; $i++ }
{ $_ -in "--tools", "-Tools" } { $script:UserTools = $args[$i + 1]; $i += 2 }
{ $_ -in "--skills-profile", "-SkillsProfile" } { $script:SkillsProfile = $args[$i + 1]; $i += 2 }
@@ -229,9 +218,6 @@ while ($i -lt $args.Count) {
Write-Host "Options:"
Write-Host " -p, --profile NAME Databricks profile (default: DEFAULT)"
Write-Host " -g, --global Install globally for all projects"
- Write-Host " --skills-only Skip MCP server setup"
- Write-Host " --mcp-only Skip skills installation"
- Write-Host " --mcp-path PATH Path to MCP server installation"
Write-Host " --silent Silent mode (no output except errors)"
Write-Host " --tools LIST Comma-separated: claude,cursor,copilot,codex,gemini,antigravity"
Write-Host " --skills-profile LIST Comma-separated profiles: all,data-engineer,analyst,ai-ml-engineer,app-developer"
@@ -660,28 +646,6 @@ function Invoke-PromptProfile {
}
}
-# ─── MCP path selection ──────────────────────────────────────
-function Invoke-PromptMcpPath {
- if (-not [string]::IsNullOrWhiteSpace($script:UserMcpPath)) {
- $script:InstallDir = $script:UserMcpPath
- } elseif (-not $script:Silent) {
- Write-Host ""
- Write-Host " MCP server location" -ForegroundColor White
- Write-Host " The MCP server runtime (Python venv + source) will be installed here." -ForegroundColor DarkGray
- Write-Host " Shared across all your projects -- only the config files are per-project." -ForegroundColor DarkGray
- Write-Host ""
-
- $selected = Read-Prompt -PromptText "Install path" -Default $InstallDir
- $script:InstallDir = $selected
- }
-
- # Update derived paths
- $script:RepoDir = Join-Path $script:InstallDir "repo"
- $script:VenvDir = Join-Path $script:InstallDir ".venv"
- $script:VenvPython = Join-Path $script:VenvDir "Scripts\python.exe"
- $script:McpEntry = Join-Path $script:RepoDir "databricks-mcp-server\run_server.py"
-}
-
# ─── Check prerequisites ─────────────────────────────────────
function Test-Dependencies {
# Git
@@ -713,19 +677,6 @@ function Test-Dependencies {
Write-Msg "You can still install, but authentication will require the CLI later."
}
- # Python package manager
- if ($script:InstallMcp) {
- if (Get-Command uv -ErrorAction SilentlyContinue) {
- $script:Pkg = "uv"
- } elseif (Get-Command pip3 -ErrorAction SilentlyContinue) {
- $script:Pkg = "pip3"
- } elseif (Get-Command pip -ErrorAction SilentlyContinue) {
- $script:Pkg = "pip"
- } else {
- Write-Err "Python package manager required. Install Python: choco install python -y"
- }
- Write-Ok $script:Pkg
- }
}
# ─── Check version ───────────────────────────────────────────
@@ -768,11 +719,9 @@ function Test-Version {
}
}
-# ─── Setup MCP server ────────────────────────────────────────
-function Install-McpServer {
- Write-Step "Setting up MCP server"
-
- # Native commands (git, pip) write informational messages to stderr.
+# ─── Clone repository ────────────────────────────────────────
+function Clone-Repo {
+ # Native commands (git) write informational messages to stderr.
# Temporarily relax error handling so these don't terminate the script.
$prevEAP = $ErrorActionPreference
$ErrorActionPreference = "Continue"
@@ -795,50 +744,8 @@ function Install-McpServer {
$ErrorActionPreference = $prevEAP
Write-Err "Failed to clone repository"
}
- Write-Ok "Repository cloned ($Branch)"
-
- # Create venv and install
- Write-Msg "Installing Python dependencies..."
- if ($script:Pkg -eq "uv") {
- & uv venv --python 3.11 --allow-existing $script:VenvDir -q 2>&1 | Out-Null
- if ($LASTEXITCODE -ne 0) {
- & uv venv --allow-existing $script:VenvDir -q 2>&1 | Out-Null
- }
- & uv pip install --python $script:VenvPython -e "$($script:RepoDir)\databricks-tools-core" -e "$($script:RepoDir)\databricks-mcp-server" -q 2>&1 | Out-Null
- } else {
- if (-not (Test-Path $script:VenvDir)) {
- & python -m venv $script:VenvDir 2>&1 | Out-Null
- }
- & $script:VenvPython -m pip install -q -e "$($script:RepoDir)\databricks-tools-core" -e "$($script:RepoDir)\databricks-mcp-server" 2>&1 | Out-Null
- }
-
- # Verify
- & $script:VenvPython -c "import databricks_mcp_server" 2>&1 | Out-Null
- if ($LASTEXITCODE -ne 0) {
- $ErrorActionPreference = $prevEAP
- Write-Err "MCP server install failed"
- }
-
$ErrorActionPreference = $prevEAP
- Write-Ok "MCP server ready"
-
- # Check Databricks SDK version
- try {
- $sdkOutput = & $script:VenvPython -c "from databricks.sdk.version import __version__; print(__version__)" 2>&1
- if ($sdkOutput -match '(\d+\.\d+\.\d+)') {
- $sdkVersion = $Matches[1]
- if ([version]$sdkVersion -ge [version]$MinSdkVersion) {
- Write-Ok "Databricks SDK v$sdkVersion"
- } else {
- Write-Warn "Databricks SDK v$sdkVersion is outdated (minimum: v$MinSdkVersion)"
- Write-Msg " Upgrade: $($script:VenvPython) -m pip install --upgrade databricks-sdk"
- }
- } else {
- Write-Warn "Could not determine Databricks SDK version"
- }
- } catch {
- Write-Warn "Could not determine Databricks SDK version"
- }
+ Write-Ok "Repository cloned ($Branch)"
}
# ─── Skill profile selection ──────────────────────────────────
@@ -1292,192 +1199,6 @@ function Install-Skills {
}
}
-# ─── Write MCP configs ───────────────────────────────────────
-function Write-McpJson {
- param([string]$Path)
-
- $dir = Split-Path $Path -Parent
- if (-not (Test-Path $dir)) {
- New-Item -ItemType Directory -Path $dir -Force | Out-Null
- }
-
- # Backup existing
- if (Test-Path $Path) {
- Copy-Item $Path "$Path.bak" -Force
- Write-Msg "Backed up $(Split-Path $Path -Leaf) -> $(Split-Path $Path -Leaf).bak"
- }
-
- # Try to merge with existing config
- if ((Test-Path $Path) -and (Test-Path $script:VenvPython)) {
- try {
- $existing = Get-Content $Path -Raw | ConvertFrom-Json
- } catch {
- $existing = $null
- }
- }
-
- if ($existing) {
- # Merge into existing config — use forward slashes for JSON compatibility
- if (-not $existing.mcpServers) {
- $existing | Add-Member -NotePropertyName "mcpServers" -NotePropertyValue ([PSCustomObject]@{}) -Force
- }
- $dbEntry = [PSCustomObject]@{
- command = $script:VenvPython -replace '\\', '/'
- args = @($script:McpEntry -replace '\\', '/')
- env = [PSCustomObject]@{ DATABRICKS_CONFIG_PROFILE = $script:Profile_ }
- }
- $existing.mcpServers | Add-Member -NotePropertyName "databricks" -NotePropertyValue $dbEntry -Force
- $existing | ConvertTo-Json -Depth 10 | Set-Content $Path -Encoding UTF8
- } else {
- # Write fresh config — use forward slashes for cross-platform JSON compatibility
- $pythonPath = $script:VenvPython -replace '\\', '/'
- $entryPath = $script:McpEntry -replace '\\', '/'
- $json = @"
-{
- "mcpServers": {
- "databricks": {
- "command": "$pythonPath",
- "args": ["$entryPath"],
- "env": {"DATABRICKS_CONFIG_PROFILE": "$($script:Profile_)"}
- }
- }
-}
-"@
- Set-Content -Path $Path -Value $json -Encoding UTF8
- }
-}
-
-function Write-CopilotMcpJson {
- param([string]$Path)
-
- $dir = Split-Path $Path -Parent
- if (-not (Test-Path $dir)) {
- New-Item -ItemType Directory -Path $dir -Force | Out-Null
- }
-
- # Backup existing
- if (Test-Path $Path) {
- Copy-Item $Path "$Path.bak" -Force
- Write-Msg "Backed up $(Split-Path $Path -Leaf) -> $(Split-Path $Path -Leaf).bak"
- }
-
- # Try to merge with existing config
- if ((Test-Path $Path) -and (Test-Path $script:VenvPython)) {
- try {
- $existing = Get-Content $Path -Raw | ConvertFrom-Json
- } catch {
- $existing = $null
- }
- }
-
- if ($existing) {
- if (-not $existing.servers) {
- $existing | Add-Member -NotePropertyName "servers" -NotePropertyValue ([PSCustomObject]@{}) -Force
- }
- $dbEntry = [PSCustomObject]@{
- command = $script:VenvPython -replace '\\', '/'
- args = @($script:McpEntry -replace '\\', '/')
- env = [PSCustomObject]@{ DATABRICKS_CONFIG_PROFILE = $script:Profile_ }
- }
- $existing.servers | Add-Member -NotePropertyName "databricks" -NotePropertyValue $dbEntry -Force
- $existing | ConvertTo-Json -Depth 10 | Set-Content $Path -Encoding UTF8
- } else {
- $pythonPath = $script:VenvPython -replace '\\', '/'
- $entryPath = $script:McpEntry -replace '\\', '/'
- $json = @"
-{
- "servers": {
- "databricks": {
- "command": "$pythonPath",
- "args": ["$entryPath"],
- "env": {"DATABRICKS_CONFIG_PROFILE": "$($script:Profile_)"}
- }
- }
-}
-"@
- Set-Content -Path $Path -Value $json -Encoding UTF8
- }
-}
-
-function Write-McpToml {
- param([string]$Path)
-
- $dir = Split-Path $Path -Parent
- if (-not (Test-Path $dir)) {
- New-Item -ItemType Directory -Path $dir -Force | Out-Null
- }
-
- # Check if already configured
- if (Test-Path $Path) {
- $content = Get-Content $Path -Raw
- if ($content -match 'mcp_servers\.databricks') { return }
- Copy-Item $Path "$Path.bak" -Force
- Write-Msg "Backed up $(Split-Path $Path -Leaf) -> $(Split-Path $Path -Leaf).bak"
- }
-
- $pythonPath = $script:VenvPython -replace '\\', '/'
- $entryPath = $script:McpEntry -replace '\\', '/'
- $tomlBlock = @"
-
-[mcp_servers.databricks]
-command = "$pythonPath"
-args = ["$entryPath"]
-"@
- Add-Content -Path $Path -Value $tomlBlock -Encoding UTF8
-}
-
-function Write-GeminiMcpJson {
- param([string]$Path)
-
- $dir = Split-Path $Path -Parent
- if (-not (Test-Path $dir)) {
- New-Item -ItemType Directory -Path $dir -Force | Out-Null
- }
-
- # Backup existing
- if (Test-Path $Path) {
- Copy-Item $Path "$Path.bak" -Force
- Write-Msg "Backed up $(Split-Path $Path -Leaf) -> $(Split-Path $Path -Leaf).bak"
- }
-
- # Try to merge with existing config
- if ((Test-Path $Path) -and (Test-Path $script:VenvPython)) {
- try {
- $existing = Get-Content $Path -Raw | ConvertFrom-Json
- } catch {
- $existing = $null
- }
- }
-
- if ($existing) {
- if (-not $existing.mcpServers) {
- $existing | Add-Member -NotePropertyName "mcpServers" -NotePropertyValue ([PSCustomObject]@{}) -Force
- }
- $dbEntry = [PSCustomObject]@{
- command = $script:VenvPython -replace '\\', '/'
- args = @($script:McpEntry -replace '\\', '/')
- env = [PSCustomObject]@{ DATABRICKS_CONFIG_PROFILE = $script:Profile_ }
- }
- $existing.mcpServers | Add-Member -NotePropertyName "databricks" -NotePropertyValue $dbEntry -Force
- $existing | ConvertTo-Json -Depth 10 | Set-Content $Path -Encoding UTF8
- } else {
- $pythonPath = $script:VenvPython -replace '\\', '/'
- $entryPath = $script:McpEntry -replace '\\', '/'
- $json = @"
-{
- "mcpServers": {
- "databricks": {
- "command": "$pythonPath",
- "args": ["$entryPath"],
- "env": {"DATABRICKS_CONFIG_PROFILE": "$($script:Profile_)"}
- }
- }
-}
-"@
- Set-Content -Path $Path -Value $json -Encoding UTF8
- }
-}
-
function Write-GeminiMd {
param([string]$Path)
@@ -1486,17 +1207,7 @@ function Write-GeminiMd {
$content = @"
# Databricks AI Dev Kit
-You have access to Databricks skills and MCP tools installed by the Databricks AI Dev Kit.
-
-## Available MCP Tools
-
-The ``databricks`` MCP server provides 50+ tools for interacting with Databricks, including:
-- SQL execution and warehouse management
-- Unity Catalog operations (tables, volumes, schemas)
-- Jobs and workflow management
-- Model serving endpoints
-- Genie spaces and AI/BI dashboards
-- Databricks Apps deployment
+You have access to Databricks skills installed by the Databricks AI Dev Kit.
## Available Skills
@@ -1511,88 +1222,12 @@ Skills are installed in ``.gemini/skills/`` and provide patterns and best practi
## Getting Started
-Try asking: "List my SQL warehouses" or "Show my Unity Catalog schemas"
+Try asking: "Help me create a Spark pipeline" or "How do I use Unity Catalog?"
"@
Set-Content -Path $Path -Value $content -Encoding UTF8
Write-Ok "GEMINI.md"
}
-function Write-McpConfigs {
- param([string]$BaseDir)
-
- Write-Step "Configuring MCP"
-
- foreach ($tool in ($script:Tools -split ' ')) {
- switch ($tool) {
- "claude" {
- if ($script:Scope -eq "global") {
- Write-McpJson (Join-Path $env:USERPROFILE ".claude\mcp.json")
- } else {
- Write-McpJson (Join-Path $BaseDir ".mcp.json")
- }
- Write-Ok "Claude MCP config"
- }
- "cursor" {
- if ($script:Scope -eq "global") {
- Write-Warn "Cursor global: manual MCP configuration required"
- Write-Msg " 1. Open Cursor -> Settings -> Cursor Settings -> Tools & MCP"
- Write-Msg " 2. Click New MCP Server"
- Write-Msg " 3. Add the following JSON config:"
- Write-Msg " {"
- Write-Msg " `"mcpServers`": {"
- Write-Msg " `"databricks`": {"
- Write-Msg " `"command`": `"$($script:VenvPython)`","
- Write-Msg " `"args`": [`"$($script:McpEntry)`"],"
- Write-Msg " `"env`": {`"DATABRICKS_CONFIG_PROFILE`": `"$($script:Profile)`"}"
- Write-Msg " }"
- Write-Msg " }"
- Write-Msg " }"
- } else {
- Write-McpJson (Join-Path $BaseDir ".cursor\mcp.json")
- Write-Ok "Cursor MCP config"
- }
- Write-Warn "Cursor: MCP servers are disabled by default."
- Write-Msg " Enable in: Cursor -> Settings -> Cursor Settings -> Tools & MCP -> Toggle 'databricks'"
- }
- "copilot" {
- if ($script:Scope -eq "global") {
- Write-Warn "Copilot global: configure MCP in VS Code settings (Ctrl+Shift+P -> 'MCP: Open User Configuration')"
- Write-Msg " Command: $($script:VenvPython) | Args: $($script:McpEntry)"
- } else {
- Write-CopilotMcpJson (Join-Path $BaseDir ".vscode\mcp.json")
- Write-Ok "Copilot MCP config (.vscode/mcp.json)"
- }
- Write-Warn "Copilot: MCP servers must be enabled manually."
- Write-Msg " In Copilot Chat, click 'Configure Tools' (tool icon, bottom-right) and enable 'databricks'"
- }
- "codex" {
- if ($script:Scope -eq "global") {
- Write-McpToml (Join-Path $env:USERPROFILE ".codex\config.toml")
- } else {
- Write-McpToml (Join-Path $BaseDir ".codex\config.toml")
- }
- Write-Ok "Codex MCP config"
- }
- "gemini" {
- if ($script:Scope -eq "global") {
- Write-GeminiMcpJson (Join-Path $env:USERPROFILE ".gemini\settings.json")
- } else {
- Write-GeminiMcpJson (Join-Path $BaseDir ".gemini\settings.json")
- }
- Write-Ok "Gemini CLI MCP config"
- }
- "antigravity" {
- if ($script:Scope -eq "project") {
- Write-Warn "Antigravity only supports global MCP configuration."
- Write-Msg " Config written to ~/.gemini/antigravity/mcp_config.json"
- }
- Write-GeminiMcpJson (Join-Path $env:USERPROFILE ".gemini\antigravity\mcp_config.json")
- Write-Ok "Antigravity MCP config"
- }
- }
- }
-}
-
# ─── Save version ────────────────────────────────────────────
function Save-Version {
try {
@@ -1627,13 +1262,11 @@ function Show-Summary {
Write-Msg "Next steps:"
$step = 1
if ($script:Tools -match 'cursor') {
- Write-Msg "$step. Enable MCP in Cursor: Cursor -> Settings -> Cursor Settings -> Tools & MCP -> Toggle 'databricks'"
+ Write-Msg "$step. Cursor skills are installed in .cursor/skills/"
$step++
}
if ($script:Tools -match 'copilot') {
- Write-Msg "$step. In Copilot Chat, click 'Configure Tools' (tool icon, bottom-right) and enable 'databricks'"
- $step++
- Write-Msg "$step. Use Copilot in Agent mode to access Databricks skills and MCP tools"
+ Write-Msg "$step. Copilot instructions are installed in .github/copilot-instructions.md"
$step++
}
if ($script:Tools -match 'gemini') {
@@ -1641,12 +1274,12 @@ function Show-Summary {
$step++
}
if ($script:Tools -match 'antigravity') {
- Write-Msg "$step. Open your project in Antigravity to use Databricks skills and MCP tools"
+ Write-Msg "$step. Open your project in Antigravity to use Databricks skills"
$step++
}
Write-Msg "$step. Open your project in your tool of choice"
$step++
- Write-Msg "$step. Try: `"List my SQL warehouses`""
+ Write-Msg "$step. Try: `"Help me create a Spark pipeline`""
Write-Host ""
}
@@ -1838,12 +1471,6 @@ function Invoke-Main {
}
}
- # MCP path
- if ($script:InstallMcp) {
- Invoke-PromptMcpPath
- Write-Ok "MCP path: $($script:InstallDir)"
- }
-
# Confirmation summary
if (-not $script:Silent) {
Write-Host ""
@@ -1852,9 +1479,6 @@ function Invoke-Main {
Write-Host " Tools: " -NoNewline; Write-Host "$(($script:Tools -split ' ') -join ', ')" -ForegroundColor Green
Write-Host " Profile: " -NoNewline; Write-Host $script:Profile_ -ForegroundColor Green
Write-Host " Scope: " -NoNewline; Write-Host $script:Scope -ForegroundColor Green
- if ($script:InstallMcp) {
- Write-Host " MCP server: " -NoNewline; Write-Host $script:InstallDir -ForegroundColor Green
- }
if ($script:InstallSkills) {
$skTotal = $script:SelectedSkills.Count + $script:SelectedMlflowSkills.Count + $script:SelectedApxSkills.Count
if (-not [string]::IsNullOrWhiteSpace($script:UserSkills)) {
@@ -1864,9 +1488,6 @@ function Invoke-Main {
Write-Host " Skills: " -NoNewline; Write-Host "$profileDisplay ($skTotal skills)" -ForegroundColor Green
}
}
- if ($script:InstallMcp) {
- Write-Host " MCP config: " -NoNewline; Write-Host "yes" -ForegroundColor Green
- }
Write-Host ""
}
@@ -1889,19 +1510,9 @@ function Invoke-Main {
$baseDir = (Get-Location).Path
}
- # Setup MCP server
- if ($script:InstallMcp) {
- Install-McpServer
- } elseif (-not (Test-Path $script:RepoDir)) {
- Write-Step "Downloading sources"
- if (-not (Test-Path $script:InstallDir)) {
- New-Item -ItemType Directory -Path $script:InstallDir -Force | Out-Null
- }
- $prevEAP = $ErrorActionPreference; $ErrorActionPreference = "Continue"
- & git -c advice.detachedHead=false clone -q --depth 1 --branch $Branch $RepoUrl $script:RepoDir 2>&1 | Out-Null
- $ErrorActionPreference = $prevEAP
- Write-Ok "Repository cloned ($Branch)"
- }
+ # Clone repository for skills
+ Write-Step "Downloading sources"
+ Clone-Repo
# Install skills
if ($script:InstallSkills) {
@@ -1917,11 +1528,6 @@ function Invoke-Main {
}
}
- # Write MCP configs
- if ($script:InstallMcp) {
- Write-McpConfigs -BaseDir $baseDir
- }
-
# Save version
Save-Version
diff --git a/install.sh b/install.sh
index 114cb2c4..4dcc0134 100644
--- a/install.sh
+++ b/install.sh
@@ -2,7 +2,7 @@
#
# Databricks AI Dev Kit - Unified Installer
#
-# Installs skills, MCP server, and configuration for Claude Code, Cursor, OpenAI Codex, GitHub Copilot, Gemini CLI, and Antigravity.
+# Installs skills and configuration for Claude Code, Cursor, OpenAI Codex, GitHub Copilot, Gemini CLI, and Antigravity.
#
# Usage: bash <(curl -sL https://raw.githubusercontent.com/databricks-solutions/ai-dev-kit/main/install.sh) [OPTIONS]
#
@@ -19,8 +19,6 @@
# # Install for specific tools only
# bash <(curl -sL https://raw.githubusercontent.com/databricks-solutions/ai-dev-kit/main/install.sh) --tools cursor,codex,copilot,gemini
#
-# # Skills only (skip MCP server)
-# bash <(curl -sL https://raw.githubusercontent.com/databricks-solutions/ai-dev-kit/main/install.sh) --skills-only
#
# # Install skills for a specific profile
# bash <(curl -sL https://raw.githubusercontent.com/databricks-solutions/ai-dev-kit/main/install.sh) --skills-profile data-engineer
@@ -50,7 +48,6 @@ IS_UPDATE=false
SILENT="${DEVKIT_SILENT:-false}"
TOOLS="${DEVKIT_TOOLS:-}"
USER_TOOLS=""
-USER_MCP_PATH="${DEVKIT_MCP_PATH:-}"
SKILLS_PROFILE="${DEVKIT_SKILLS_PROFILE:-}"
USER_SKILLS="${DEVKIT_SKILLS:-}"
@@ -77,7 +74,6 @@ else
fi
# Installation mode defaults
-INSTALL_MCP=true
INSTALL_SKILLS=true
# Minimum required versions
@@ -127,9 +123,6 @@ while [ $# -gt 0 ]; do
-p|--profile) PROFILE="$2"; shift 2 ;;
-g|--global) SCOPE="global"; SCOPE_EXPLICIT=true; shift ;;
-b|--branch) BRANCH="$2"; shift 2 ;;
- --skills-only) INSTALL_MCP=false; shift ;;
- --mcp-only) INSTALL_SKILLS=false; shift ;;
- --mcp-path) USER_MCP_PATH="$2"; shift 2 ;;
--skills-profile) SKILLS_PROFILE="$2"; shift 2 ;;
--skills) USER_SKILLS="$2"; shift 2 ;;
--list-skills) LIST_SKILLS=true; shift ;;
@@ -145,9 +138,6 @@ while [ $# -gt 0 ]; do
echo " -p, --profile NAME Databricks profile (default: DEFAULT)"
echo " -b, --branch NAME Git branch/tag to install (default: latest release)"
echo " -g, --global Install globally for all projects"
- echo " --skills-only Skip MCP server setup"
- echo " --mcp-only Skip skills installation"
- echo " --mcp-path PATH Path to MCP server installation (default: ~/.ai-dev-kit)"
echo " --silent Silent mode (no output except errors)"
echo " --tools LIST Comma-separated: claude,cursor,copilot,codex,gemini,antigravity"
echo " --skills-profile LIST Comma-separated profiles: all,data-engineer,analyst,ai-ml-engineer,app-developer"
@@ -162,7 +152,6 @@ while [ $# -gt 0 ]; do
echo " DEVKIT_SCOPE 'project' or 'global'"
echo " DEVKIT_TOOLS Comma-separated list of tools"
echo " DEVKIT_FORCE Set to 'true' to force reinstall"
- echo " DEVKIT_MCP_PATH Path to MCP server installation"
echo " DEVKIT_SKILLS_PROFILE Comma-separated skill profiles"
echo " DEVKIT_SKILLS Comma-separated skill names"
echo " DEVKIT_SILENT Set to 'true' for silent mode"
@@ -246,9 +235,6 @@ REPO_URL="https://github.com/databricks-solutions/ai-dev-kit.git"
RAW_URL="https://raw.githubusercontent.com/databricks-solutions/ai-dev-kit/${BRANCH}"
INSTALL_DIR="${AIDEVKIT_HOME:-$HOME/.ai-dev-kit}"
REPO_DIR="$INSTALL_DIR/repo"
-VENV_DIR="$INSTALL_DIR/.venv"
-VENV_PYTHON="$VENV_DIR/bin/python"
-MCP_ENTRY="$REPO_DIR/databricks-mcp-server/run_server.py"
# ─── Interactive helpers ────────────────────────────────────────
# Reads from /dev/tty so prompts work even when piped via curl | bash
@@ -631,32 +617,6 @@ prompt_profile() {
fi
}
-# ─── MCP path selection ────────────────────────────────────────
-prompt_mcp_path() {
- # If provided via --mcp-path flag, skip prompt
- if [ -n "$USER_MCP_PATH" ]; then
- INSTALL_DIR="$USER_MCP_PATH"
- elif [ "$SILENT" = false ] && [ -e /dev/tty ]; then
- [ "$SILENT" = false ] && echo ""
- [ "$SILENT" = false ] && echo -e " ${B}MCP server location${N}"
- [ "$SILENT" = false ] && echo -e " ${D}The MCP server runtime (Python venv + source) will be installed here.${N}"
- [ "$SILENT" = false ] && echo -e " ${D}Shared across all your projects — only the config files are per-project.${N}"
- [ "$SILENT" = false ] && echo ""
-
- local selected
- selected=$(prompt "Install path" "$INSTALL_DIR")
-
- # Expand ~ to $HOME
- INSTALL_DIR="${selected/#\~/$HOME}"
- fi
-
- # Update derived paths
- REPO_DIR="$INSTALL_DIR/repo"
- VENV_DIR="$INSTALL_DIR/.venv"
- VENV_PYTHON="$VENV_DIR/bin/python"
- MCP_ENTRY="$REPO_DIR/databricks-mcp-server/run_server.py"
-}
-
# ─── Skill profile selection ──────────────────────────────────
# Resolve selected skills from profile names or explicit skill list
resolve_skills() {
@@ -961,24 +921,6 @@ check_cli_version() {
fi
}
-# Check Databricks SDK version in the MCP venv
-check_sdk_version() {
- local sdk_version
- sdk_version=$("$VENV_PYTHON" -c "from databricks.sdk.version import __version__; print(__version__)" 2>/dev/null)
-
- if [ -z "$sdk_version" ]; then
- warn "Could not determine Databricks SDK version"
- return
- fi
-
- if version_gte "$sdk_version" "$MIN_SDK_VERSION"; then
- ok "Databricks SDK v${sdk_version}"
- else
- warn "Databricks SDK v${sdk_version} is outdated (minimum: v${MIN_SDK_VERSION})"
- msg " ${B}Upgrade:${N} $VENV_PYTHON -m pip install --upgrade databricks-sdk"
- fi
-}
-
# Check prerequisites
check_deps() {
command -v git >/dev/null 2>&1 || die "git required"
@@ -991,16 +933,6 @@ check_deps() {
msg "${D}You can still install, but authentication will require the CLI later.${N}"
fi
- if [ "$INSTALL_MCP" = true ]; then
- if command -v uv >/dev/null 2>&1; then
- PKG="uv"
- ok "$PKG ($(uv --version 2>/dev/null || echo 'unknown version'))"
- else
- die "uv is required but not found on your PATH.
- Install it with: ${B}curl -LsSf https://astral.sh/uv/install.sh | sh${N}
- Then re-run this installer."
- fi
- fi
}
# Check if update needed
@@ -1038,11 +970,8 @@ check_version() {
fi
}
-# Setup MCP server
-setup_mcp() {
- step "Setting up MCP server"
-
- # Clone or update repo
+# Clone or update repo
+clone_repo() {
if [ -d "$REPO_DIR/.git" ]; then
git -C "$REPO_DIR" fetch -q --depth 1 origin "$BRANCH" 2>/dev/null || true
git -C "$REPO_DIR" reset --hard FETCH_HEAD 2>/dev/null || {
@@ -1054,24 +983,6 @@ setup_mcp() {
git -c advice.detachedHead=false clone -q --depth 1 --branch "$BRANCH" "$REPO_URL" "$REPO_DIR"
fi
ok "Repository cloned ($BRANCH)"
-
- # Create venv and install
- # On Apple Silicon under Rosetta, force arm64 to avoid architecture mismatch
- # with universal2 Python binaries (see: github.com/databricks-solutions/ai-dev-kit/issues/115)
- local arch_prefix=""
- if [ "$(sysctl -n hw.optional.arm64 2>/dev/null)" = "1" ] && [ "$(uname -m)" = "x86_64" ]; then
- if arch -arm64 python3 -c "pass" 2>/dev/null; then
- arch_prefix="arch -arm64"
- warn "Rosetta detected on Apple Silicon — forcing arm64 for Python"
- fi
- fi
-
- msg "Installing Python dependencies..."
- $arch_prefix uv venv --python 3.11 --allow-existing "$VENV_DIR" -q 2>/dev/null || $arch_prefix uv venv --allow-existing "$VENV_DIR" -q
- $arch_prefix uv pip install --python "$VENV_PYTHON" -e "$REPO_DIR/databricks-tools-core" -e "$REPO_DIR/databricks-mcp-server" -q
-
- "$VENV_PYTHON" -c "import databricks_mcp_server" 2>/dev/null || die "MCP server install failed"
- ok "MCP server ready"
}
# Install skills
@@ -1203,150 +1114,13 @@ install_skills() {
fi
}
-# Write MCP configs
-write_mcp_json() {
- local path=$1
- mkdir -p "$(dirname "$path")"
-
- # Backup existing file before any modifications
- if [ -f "$path" ]; then
- cp "$path" "${path}.bak"
- msg "${D}Backed up ${path##*/} → ${path##*/}.bak${N}"
- fi
-
- if [ -f "$VENV_PYTHON" ]; then
- "$VENV_PYTHON" -c "
-import json, sys
-try:
- with open('$path') as f: cfg = json.load(f)
-except: cfg = {}
-cfg.setdefault('mcpServers', {})['databricks'] = {'command': '$VENV_PYTHON', 'args': ['$MCP_ENTRY'], 'defer_loading': True, 'env': {'DATABRICKS_CONFIG_PROFILE': '$PROFILE'}}
-with open('$path', 'w') as f: json.dump(cfg, f, indent=2); f.write('\n')
-" 2>/dev/null && return
- fi
-
- # Fallback: only safe for new files — refuse to overwrite existing files
- # that may contain other settings (e.g. ~/.claude.json)
- if [ -f "$path" ]; then
- warn "Cannot merge MCP config into $path without Python. Add manually."
- return
- fi
-
- cat > "$path" << EOF
-{
- "mcpServers": {
- "databricks": {
- "command": "$VENV_PYTHON",
- "args": ["$MCP_ENTRY"],
- "defer_loading": true,
- "env": {"DATABRICKS_CONFIG_PROFILE": "$PROFILE"}
- }
- }
-}
-EOF
-}
-
-write_copilot_mcp_json() {
- local path=$1
- mkdir -p "$(dirname "$path")"
-
- # Backup existing file before any modifications
- if [ -f "$path" ]; then
- cp "$path" "${path}.bak"
- msg "${D}Backed up ${path##*/} → ${path##*/}.bak${N}"
- fi
-
- if [ -f "$path" ] && [ -f "$VENV_PYTHON" ]; then
- "$VENV_PYTHON" -c "
-import json, sys
-try:
- with open('$path') as f: cfg = json.load(f)
-except: cfg = {}
-cfg.setdefault('servers', {})['databricks'] = {'command': '$VENV_PYTHON', 'args': ['$MCP_ENTRY'], 'env': {'DATABRICKS_CONFIG_PROFILE': '$PROFILE'}}
-with open('$path', 'w') as f: json.dump(cfg, f, indent=2); f.write('\n')
-" 2>/dev/null && return
- fi
-
- cat > "$path" << EOF
-{
- "servers": {
- "databricks": {
- "command": "$VENV_PYTHON",
- "args": ["$MCP_ENTRY"],
- "env": {"DATABRICKS_CONFIG_PROFILE": "$PROFILE"}
- }
- }
-}
-EOF
-}
-
-write_mcp_toml() {
- local path=$1
- mkdir -p "$(dirname "$path")"
- grep -q "mcp_servers.databricks" "$path" 2>/dev/null && return
- if [ -f "$path" ]; then
- cp "$path" "${path}.bak"
- msg "${D}Backed up ${path##*/} → ${path##*/}.bak${N}"
- fi
- cat >> "$path" << EOF
-
-[mcp_servers.databricks]
-command = "$VENV_PYTHON"
-args = ["$MCP_ENTRY"]
-EOF
-}
-
-write_gemini_mcp_json() {
- local path=$1
- mkdir -p "$(dirname "$path")"
-
- # Backup existing file before any modifications
- if [ -f "$path" ]; then
- cp "$path" "${path}.bak"
- msg "${D}Backed up ${path##*/} → ${path##*/}.bak${N}"
- fi
-
- if [ -f "$path" ] && [ -f "$VENV_PYTHON" ]; then
- "$VENV_PYTHON" -c "
-import json, sys
-try:
- with open('$path') as f: cfg = json.load(f)
-except: cfg = {}
-cfg.setdefault('mcpServers', {})['databricks'] = {'command': '$VENV_PYTHON', 'args': ['$MCP_ENTRY'], 'env': {'DATABRICKS_CONFIG_PROFILE': '$PROFILE'}}
-with open('$path', 'w') as f: json.dump(cfg, f, indent=2); f.write('\n')
-" 2>/dev/null && return
- fi
-
- cat > "$path" << EOF
-{
- "mcpServers": {
- "databricks": {
- "command": "$VENV_PYTHON",
- "args": ["$MCP_ENTRY"],
- "env": {"DATABRICKS_CONFIG_PROFILE": "$PROFILE"}
- }
- }
-}
-EOF
-}
-
write_gemini_md() {
local path=$1
[ -f "$path" ] && return # Don't overwrite existing file
cat > "$path" << 'GEMINIEOF'
# Databricks AI Dev Kit
-You have access to Databricks skills and MCP tools installed by the Databricks AI Dev Kit.
-
-## Available MCP Tools
-
-The `databricks` MCP server provides 50+ tools for interacting with Databricks, including:
-- SQL execution and warehouse management
-- Unity Catalog operations (tables, volumes, schemas)
-- Jobs and workflow management
-- Model serving endpoints
-- Genie spaces and AI/BI dashboards
-- Databricks Apps deployment
+You have access to Databricks skills installed by the Databricks AI Dev Kit.
## Available Skills
@@ -1361,135 +1135,11 @@ Skills are installed in `.gemini/skills/` and provide patterns and best practice
## Getting Started
-Try asking: "List my SQL warehouses" or "Show my Unity Catalog schemas"
+Try asking: "Help me create a Spark pipeline" or "How do I use Unity Catalog?"
GEMINIEOF
ok "GEMINI.md"
}
-write_claude_hook() {
- local path=$1
- local script=$2
- mkdir -p "$(dirname "$path")"
-
- # Merge into existing settings.json if present, using Python for safe JSON handling
- if [ -f "$path" ] && [ -f "$VENV_PYTHON" ]; then
- "$VENV_PYTHON" -c "
-import json
-path = '$path'
-script = '$script'
-hook_entry = {'type': 'command', 'command': 'bash ' + script, 'timeout': 5}
-try:
- with open(path) as f: cfg = json.load(f)
-except: cfg = {}
-hooks = cfg.setdefault('hooks', {})
-session_hooks = hooks.setdefault('SessionStart', [])
-# Check if hook already exists
-for group in session_hooks:
- for h in group.get('hooks', []):
- if 'check_update.sh' in h.get('command', ''):
- exit(0) # Already configured
-# Append new hook group
-session_hooks.append({'hooks': [hook_entry]})
-with open(path, 'w') as f: json.dump(cfg, f, indent=2); f.write('\n')
-" 2>/dev/null && return
- fi
-
- # Fallback: write new file (only if no existing file)
- [ -f "$path" ] && return # Don't overwrite existing settings without Python
- cat > "$path" << EOF
-{
- "hooks": {
- "SessionStart": [
- {
- "hooks": [
- {
- "type": "command",
- "command": "bash $script",
- "timeout": 5
- }
- ]
- }
- ]
- }
-}
-EOF
-}
-
-write_mcp_configs() {
- step "Configuring MCP"
-
- local base_dir=$1
- for tool in $TOOLS; do
- case $tool in
- claude)
- [ "$SCOPE" = "global" ] && write_mcp_json "$HOME/.claude.json" || write_mcp_json "$base_dir/.mcp.json"
- ok "Claude MCP config"
- # Add version check hook to Claude settings
- local check_script="$REPO_DIR/.claude-plugin/check_update.sh"
- if [ "$SCOPE" = "global" ]; then
- write_claude_hook "$HOME/.claude/settings.json" "$check_script"
- else
- write_claude_hook "$base_dir/.claude/settings.json" "$check_script"
- fi
- ok "Claude update check hook"
- ;;
- cursor)
- if [ "$SCOPE" = "global" ]; then
- warn "Cursor global: manual MCP configuration required"
- msg " 1. Open ${B}Cursor → Settings → Cursor Settings → Tools & MCP${N}"
- msg " 2. Click ${B}New MCP Server${N}"
- msg " 3. Add the following JSON config:"
- msg " {"
- msg " \"mcpServers\": {"
- msg " \"databricks\": {"
- msg " \"command\": \"$VENV_PYTHON\","
- msg " \"args\": [\"$MCP_ENTRY\"],"
- msg " \"env\": {\"DATABRICKS_CONFIG_PROFILE\": \"$PROFILE\"}"
- msg " }"
- msg " }"
- msg " }"
- else
- write_mcp_json "$base_dir/.cursor/mcp.json"
- ok "Cursor MCP config"
- fi
- warn "Cursor: MCP servers are disabled by default."
- msg " Enable in: ${B}Cursor → Settings → Cursor Settings → Tools & MCP → Toggle 'databricks'${N}"
- ;;
- copilot)
- if [ "$SCOPE" = "global" ]; then
- warn "Copilot global: configure MCP in VS Code settings (Ctrl+Shift+P → 'MCP: Open User Configuration')"
- msg " Command: $VENV_PYTHON | Args: $MCP_ENTRY"
- else
- write_copilot_mcp_json "$base_dir/.vscode/mcp.json"
- ok "Copilot MCP config (.vscode/mcp.json)"
- fi
- warn "Copilot: MCP servers must be enabled manually."
- msg " In Copilot Chat, click ${B}Configure Tools${N} (tool icon, bottom-right) and enable ${B}databricks${N}"
- ;;
- codex)
- [ "$SCOPE" = "global" ] && write_mcp_toml "$HOME/.codex/config.toml" || write_mcp_toml "$base_dir/.codex/config.toml"
- ok "Codex MCP config"
- ;;
- gemini)
- if [ "$SCOPE" = "global" ]; then
- write_gemini_mcp_json "$HOME/.gemini/settings.json"
- else
- write_gemini_mcp_json "$base_dir/.gemini/settings.json"
- fi
- ok "Gemini CLI MCP config"
- ;;
- antigravity)
- if [ "$SCOPE" = "project" ]; then
- warn "Antigravity only supports global MCP configuration."
- msg " Config written to ${B}~/.gemini/antigravity/mcp_config.json${N}"
- fi
- write_gemini_mcp_json "$HOME/.gemini/antigravity/mcp_config.json"
- ok "Antigravity MCP config"
- ;;
- esac
- done
-}
-
# Save version
save_version() {
# Use -f to fail on HTTP errors (like 404)
@@ -1515,14 +1165,8 @@ summary() {
echo ""
msg "${B}Next steps:${N}"
local step=1
- if echo "$TOOLS" | grep -q cursor; then
- msg "${R}${step}. Enable MCP in Cursor: ${B}Cursor → Settings → Cursor Settings → Tools & MCP → Toggle 'databricks'${N}"
- step=$((step + 1))
- fi
if echo "$TOOLS" | grep -q copilot; then
- msg "${step}. In Copilot Chat, click ${B}Configure Tools${N} (tool icon, bottom-right) and enable ${B}databricks${N}"
- step=$((step + 1))
- msg "${step}. Use Copilot in ${B}Agent mode${N} to access Databricks skills and MCP tools"
+ msg "${step}. Use Copilot in ${B}Agent mode${N} to access Databricks skills"
step=$((step + 1))
fi
if echo "$TOOLS" | grep -q gemini; then
@@ -1530,12 +1174,12 @@ summary() {
step=$((step + 1))
fi
if echo "$TOOLS" | grep -q antigravity; then
- msg "${step}. Open your project in Antigravity to use Databricks skills and MCP tools"
+ msg "${step}. Open your project in Antigravity to use Databricks skills"
step=$((step + 1))
fi
msg "${step}. Open your project in your tool of choice"
step=$((step + 1))
- msg "${step}. Try: \"List my SQL warehouses\""
+ msg "${step}. Try: \"Help me create a Spark pipeline\""
echo ""
fi
}
@@ -1706,13 +1350,7 @@ main() {
fi
fi
- # ── Step 5: Interactive MCP path ──
- if [ "$INSTALL_MCP" = true ]; then
- prompt_mcp_path
- ok "MCP path: $INSTALL_DIR"
- fi
-
- # ── Step 6: Confirm before proceeding ──
+ # ── Step 5: Confirm before proceeding ──
if [ "$SILENT" = false ]; then
echo ""
echo -e " ${B}Summary${N}"
@@ -1720,7 +1358,6 @@ main() {
echo -e " Tools: ${G}$(echo "$TOOLS" | tr ' ' ', ')${N}"
echo -e " Profile: ${G}${PROFILE}${N}"
echo -e " Scope: ${G}${SCOPE}${N}"
- [ "$INSTALL_MCP" = true ] && echo -e " MCP server: ${G}${INSTALL_DIR}${N}"
if [ "$INSTALL_SKILLS" = true ]; then
if [ -n "$USER_SKILLS" ]; then
echo -e " Skills: ${G}custom selection${N}"
@@ -1730,7 +1367,6 @@ main() {
echo -e " Skills: ${G}${SKILLS_PROFILE:-all} ($sk_total skills)${N}"
fi
fi
- [ "$INSTALL_MCP" = true ] && echo -e " MCP config: ${G}yes${N}"
echo ""
fi
@@ -1744,23 +1380,19 @@ main() {
fi
fi
- # ── Step 7: Version check (may exit early if up to date) ──
+ # ── Step 6: Version check (may exit early if up to date) ──
check_version
-
+
# Determine base directory
local base_dir
[ "$SCOPE" = "global" ] && base_dir="$HOME" || base_dir="$(pwd)"
-
- # Setup MCP server
- if [ "$INSTALL_MCP" = true ]; then
- setup_mcp
- elif [ ! -d "$REPO_DIR" ]; then
+
+ # Clone repo (for skills)
+ if [ ! -d "$REPO_DIR" ]; then
step "Downloading sources"
- mkdir -p "$INSTALL_DIR"
- git -c advice.detachedHead=false clone -q --depth 1 --branch "$BRANCH" "$REPO_URL" "$REPO_DIR"
- ok "Repository cloned ($BRANCH)"
+ clone_repo
fi
-
+
# Install skills
[ "$INSTALL_SKILLS" = true ] && install_skills "$base_dir"
@@ -1773,9 +1405,6 @@ main() {
fi
fi
- # Write MCP configs
- [ "$INSTALL_MCP" = true ] && write_mcp_configs "$base_dir"
-
# Save version
save_version