diff --git a/CHANGELOG.md b/CHANGELOG.md index efd96a01..a57d153d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,7 +34,7 @@ * **docs:** change idleTimeout from minutes to seconds ([#205](https://github.com/runpod/flash/issues/205)) ([51693c7](https://github.com/runpod/flash/commit/51693c7e2dd0c9d803f3c49de1d0009ded285d5d)) * prevent false deployment attempts in Flash environments ([#192](https://github.com/runpod/flash/issues/192)) ([f07c9fb](https://github.com/runpod/flash/commit/f07c9fb92003d4603fbf8cdc17b956c368009353)) -* **runtime:** restore on-demand provisioning for flash run ([#206](https://github.com/runpod/flash/issues/206)) ([5859f4b](https://github.com/runpod/flash/commit/5859f4b78476a070db2100b689dfd94caf5fc93f)) +* **runtime:** restore on-demand provisioning for flash dev ([#206](https://github.com/runpod/flash/issues/206)) ([5859f4b](https://github.com/runpod/flash/commit/5859f4b78476a070db2100b689dfd94caf5fc93f)) ### Code Refactoring @@ -189,7 +189,7 @@ ### Features * AE-1512: deploy() and undeploy() deployable resources directly ([#126](https://github.com/runpod/runpod-flash/issues/126)) ([3deac3a](https://github.com/runpod/runpod-flash/commit/3deac3a91b84fa4cf07cf553c46431907290a61c)) -* **cli:** Add --auto-provision flag to flash run command ([#125](https://github.com/runpod/runpod-flash/issues/125)) ([ee5793c](https://github.com/runpod/runpod-flash/commit/ee5793c33537acc15e26b680e3bac5aedb3c0735)) +* **cli:** Add --auto-provision flag to flash dev command ([#125](https://github.com/runpod/runpod-flash/issues/125)) ([ee5793c](https://github.com/runpod/runpod-flash/commit/ee5793c33537acc15e26b680e3bac5aedb3c0735)) ### Code Refactoring diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e9d9080a..fc454cdb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -35,7 +35,7 @@ Get your API key from: https://docs.runpod.io/get-started/api-keys - Integration tests that interact with Runpod API **When is the API key NOT needed?** -- Local development with `flash run` (local server only) +- Local development with `flash dev` (local server only) - `flash init` command (project scaffolding) - Unit tests (mocked API calls) - Code formatting, linting, type checking diff --git a/PRD.md b/PRD.md new file mode 100644 index 00000000..ad936c0a --- /dev/null +++ b/PRD.md @@ -0,0 +1,302 @@ +# Flash SDK: Zero-Boilerplate Experience — Product Requirements Document + +## 1. Problem Statement + +Flash currently forces every project into a FastAPI-first model: + +- Users must create `main.py` with a `FastAPI()` instance +- HTTP routing boilerplate adds no semantic value — the routes simply call `@remote` functions +- No straightforward path for deploying a standalone QB function without wrapping it in a FastAPI app +- The "mothership" concept introduces an implicit coordinator with no clear ownership model +- `flash dev` fails unless `main.py` exists with a FastAPI app, blocking the simplest use cases + +## 2. Goals + +- **Zero boilerplate**: a `@remote`-decorated function in any `.py` file is sufficient for `flash dev` and `flash deploy` +- **File-system-as-namespace**: the project directory structure maps 1:1 to URL paths on the local dev server +- **Single command**: `flash dev` works for all project topologies (one QB function, many files, mixed QB+LB) without any configuration +- **`flash deploy` requires no additional configuration** beyond the `@remote` declarations themselves +- **Peer endpoints**: every `@resource_config` is a first-class endpoint; no implicit coordinator + +## 3. Non-Goals + +- No backward compatibility with `main.py`/FastAPI-first style +- No implicit "mothership" concept; all endpoints are peers +- No changes to the QB runtime (`generic_handler.py`) or QB stub behavior +- No changes to deployed endpoint behavior (RunPod QB/LB APIs are unchanged) + +## 4. Developer Experience Specification + +### 4.1 Minimum viable QB project + +```python +# gpu_worker.py +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def process(input_data: dict) -> dict: + return {"result": "processed", "input": input_data} +``` + +`flash dev` → `POST /gpu_worker/run_sync` +`flash deploy` → standalone QB endpoint at `api.runpod.ai/v2/{id}/run` + +### 4.2 LB endpoint + +```python +# api/routes.py +from runpod_flash import CpuLiveLoadBalancer, remote + +lb_config = CpuLiveLoadBalancer(name="api_routes") + +@remote(lb_config, method="POST", path="/compute") +async def compute(input_data: dict) -> dict: + return {"result": input_data} +``` + +`flash dev` → `POST /api/routes/compute` +`flash deploy` → LB endpoint at `{id}.api.runpod.ai/compute` + +### 4.3 Mixed QB + LB (LB calling QB) + +```python +# api/routes.py (LB) +from runpod_flash import CpuLiveLoadBalancer, remote +from workers.gpu import heavy_compute # QB stub + +lb_config = CpuLiveLoadBalancer(name="api_routes") + +@remote(lb_config, method="POST", path="/process") +async def process_route(data: dict): + return await heavy_compute(data) # dispatches to QB endpoint + +# workers/gpu.py (QB) +from runpod_flash import LiveServerless, GpuGroup, remote + +gpu_config = LiveServerless(name="gpu_worker", gpus=[GpuGroup.ANY]) + +@remote(gpu_config) +async def heavy_compute(data: dict) -> dict: ... +``` + +## 5. URL Path Specification + +### 5.1 File prefix derivation + +The local dev server uses the project directory structure as a URL namespace. Each file's URL prefix is its path relative to the project root with `.py` stripped: + +``` +File Local URL prefix +────────────────────────────── ──────────────────────────── +gpu_worker.py → /gpu_worker +longruns/stage1.py → /longruns/stage1 +preprocess/first_pass.py → /preprocess/first_pass +workers/gpu/inference.py → /workers/gpu/inference +``` + +### 5.2 QB route generation + +| Condition | Routes | +|---|---| +| One `@remote` function in file | `POST {file_prefix}/run` and `POST {file_prefix}/run_sync` | +| Multiple `@remote` functions in file | `POST {file_prefix}/{fn_name}/run` and `POST {file_prefix}/{fn_name}/run_sync` | + +### 5.3 LB route generation + +| Condition | Route | +|---|---| +| `@remote(lb_config, method="POST", path="/compute")` | `POST {file_prefix}/compute` | + +The declared `path=` is appended to the file prefix. The `method=` determines the HTTP verb. + +### 5.4 QB request/response envelope + +Mirrors RunPod's API for consistency: + +``` +POST /gpu_worker/run_sync +Body: {"input": {"key": "value"}} +Response: {"id": "uuid", "status": "COMPLETED", "output": {...}} +``` + +## 6. Deployed Topology Specification + +Each unique resource config gets its own RunPod endpoint: + +| Type | Deployed URL | Example | +|---|---|---| +| QB | `https://api.runpod.ai/v2/{endpoint_id}/run` | `https://api.runpod.ai/v2/uoy3n7hkyb052a/run` | +| QB sync | `https://api.runpod.ai/v2/{endpoint_id}/run_sync` | | +| LB | `https://{endpoint_id}.api.runpod.ai/{declared_path}` | `https://rzlk6lph6gw7dk.api.runpod.ai/compute` | + +## 7. `.flash/` Folder Specification + +The `.flash/` directory is used by `flash build` for build artifacts (e.g. `manifest.json`). The `flash dev` command does not create or use `.flash/` at all. + +``` +my_project/ +├── gpu_worker.py +├── longruns/ +│ └── stage1.py +└── .flash/ + └── manifest.json ← generated by flash build +``` + +### 7.1 Dev server launch + +Uvicorn is invoked with the `--factory` flag pointing to the app factory function, and the project root is passed via the `FLASH_PROJECT_ROOT` environment variable: + +```bash +uvicorn --factory runpod_flash.cli.commands._dev_server:create_app \ + --reload \ + --reload-dir . +``` + +## 8. `flash dev` Behavior + +1. Scan project for all `@remote` functions (QB and LB) in any `.py` file + - Skip: `.flash/`, `__pycache__`, `*.pyc`, `__init__.py` +2. If none found: print error with usage instructions, exit 1 +3. Build FastAPI app programmatically via `_dev_server.create_app()` +4. Start uvicorn with `--factory` and `--reload` watching the project root +5. Print startup table: local paths, resource names, types +6. Swagger UI available at `http://localhost:{port}/docs` +7. On exit (Ctrl+C or SIGTERM): deprovision all Live Serverless endpoints provisioned during this session + +### 8.1 Startup table format + +``` +Flash Dev Server http://localhost:8888 + + Local path Resource Type + ────────────────────────────────── ─────────────────── ──── + POST /gpu_worker/run gpu_worker QB + POST /gpu_worker/run_sync gpu_worker QB + POST /longruns/stage1/run longruns_stage1 QB + POST /preprocess/first_pass/compute preprocess_first_pass LB + + Visit http://localhost:8888/docs for Swagger UI +``` + +## 9. `flash build` Behavior + +1. Scan project for all `@remote` functions (QB and LB) +2. Build `.flash/manifest.json` with flat resource structure (see §10) +3. For LB resources: generate deployed handler files using `module_path` +4. Package build artifact + +## 10. Manifest Structure + +Resource names are derived from file paths (slashes → underscores): + +```json +{ + "version": "1.0", + "project_name": "my_project", + "resources": { + "gpu_worker": { + "resource_type": "LiveServerless", + "file_path": "gpu_worker.py", + "local_path_prefix": "/gpu_worker", + "module_path": "gpu_worker", + "functions": ["gpu_hello"], + "is_load_balanced": false, + "makes_remote_calls": false + }, + "longruns_stage1": { + "resource_type": "LiveServerless", + "file_path": "longruns/stage1.py", + "local_path_prefix": "/longruns/stage1", + "module_path": "longruns.stage1", + "functions": ["stage1_process"], + "is_load_balanced": false, + "makes_remote_calls": false + }, + "preprocess_first_pass": { + "resource_type": "CpuLiveLoadBalancer", + "file_path": "preprocess/first_pass.py", + "local_path_prefix": "/preprocess/first_pass", + "module_path": "preprocess.first_pass", + "functions": [ + {"name": "first_pass_fn", "http_method": "POST", "http_path": "/compute"} + ], + "is_load_balanced": true, + "makes_remote_calls": true + } + } +} +``` + +## 11. Dev Server App Structure + +The dev server is built programmatically by `_dev_server.create_app()`. User modules are imported via `importlib.import_module()` and routes are registered with `app.add_api_route()`. Tracebacks point directly to the original source files. + +Conceptual equivalent of the generated app: + +```python +app = FastAPI(title="Flash Dev Server") + +# QB: gpu_worker.py - imported via importlib, route added via add_api_route +# POST /gpu_worker/run_sync -> calls gpu_hello(body["input"]) + +# QB: longruns/stage1.py +# POST /longruns/stage1/run_sync -> calls stage1_process(body["input"]) + +# LB: preprocess/first_pass.py +# POST /preprocess/first_pass/compute -> calls lb_execute(config, first_pass_fn, body) + +# Health +# GET / -> {"message": "Flash Dev Server", "docs": "/docs"} +# GET /ping -> {"status": "healthy"} +``` + +Subdirectory imports use dotted module paths: `longruns/stage1.py` -> `longruns.stage1`. + +Multi-function QB files (2+ `@remote` functions) get sub-prefixed routes: +``` +longruns/stage1.py has: stage1_preprocess, stage1_infer +→ POST /longruns/stage1/stage1_preprocess/run +→ POST /longruns/stage1/stage1_preprocess/run_sync +→ POST /longruns/stage1/stage1_infer/run +→ POST /longruns/stage1/stage1_infer/run_sync +``` + +## 12. Acceptance Criteria + +- [ ] A file with one `@remote(QB_config)` function and nothing else is a valid Flash project +- [ ] `flash dev` produces a Swagger UI showing all routes grouped by source file +- [ ] QB routes accept `{"input": {...}}` and return `{"id": ..., "status": "COMPLETED", "output": {...}}` +- [ ] Subdirectory files produce URL prefixes matching their relative path +- [ ] Multiple `@remote` functions in one file each get their own sub-prefixed routes +- [ ] LB route handler body executes directly (not dispatched remotely) +- [ ] QB calls inside LB route handler body route to the remote QB endpoint +- [ ] `flash deploy` creates a RunPod endpoint for each resource config +- [ ] `flash build` produces `.flash/manifest.json` with `file_path`, `local_path_prefix`, `module_path` per resource +- [ ] When `flash dev` exits, all Live Serverless endpoints provisioned during that session are automatically undeployed + +## 13. Edge Cases + +- **No `@remote` functions found**: Error with clear message and usage instructions +- **Multiple `@remote` functions per file (QB)**: Sub-prefixed routes `/{file_prefix}/{fn_name}/run_sync` +- **`__init__.py` files**: Skipped -- not treated as worker files +- **File path with hyphens** (e.g., `my-worker.py`): Resource name sanitized to `my_worker`, URL prefix `/my-worker` (hyphens valid in URLs, underscores in Python identifiers) +- **LB function calling another LB function**: Not supported via `@remote` -- emit a warning at build time +- **`flash deploy` with no LB endpoints**: QB-only deploy +- **Subdirectory `__init__.py`** imports needed: Generator checks and warns if missing +- **Numeric-prefix directories** (e.g., `01_hello/`): Handled via `importlib.import_module()` with scoped `sys.path` + +## 14. Implementation Files + +| File | Change | +|------|--------| +| `flash/main/PRD.md` | This document | +| `src/runpod_flash/client.py` | Passthrough for LB route handlers (`__is_lb_route_handler__`) | +| `cli/commands/run.py` | Worker scanning, startup table, uvicorn subprocess management | +| `cli/commands/_dev_server.py` | Programmatic FastAPI app factory (`create_app`), route registration | +| `cli/commands/build_utils/scanner.py` | Path utilities; `is_lb_route_handler` field; file-based resource identity | +| `cli/commands/build_utils/manifest.py` | Flat resource structure; `file_path`/`local_path_prefix`/`module_path` fields | +| `cli/commands/build_utils/lb_handler_generator.py` | Import module by `module_path`, walk `__is_lb_route_handler__`, register routes | +| `cli/commands/build.py` | Remove main.py requirement from `validate_project_structure` | +| `core/resources/serverless.py` | Inject `FLASH_MODULE_PATH` env var | diff --git a/README.md b/README.md index cce420ce..136fff22 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ Computed on: NVIDIA GeForce RTX 4090 ## Create Flash API endpoints -You can use Flash to deploy and serve API endpoints that compute responses using GPU and CPU Serverless workers. Use `flash run` for local development of `@remote` functions, then `flash deploy` to deploy your full application to Runpod Serverless for production. +You can use Flash to deploy and serve API endpoints that compute responses using GPU and CPU Serverless workers. Use `flash dev` for local development of `@remote` functions, then `flash deploy` to deploy your full application to Runpod Serverless for production. These endpoints use the same Python `@remote` decorators [demonstrated above](#get-started) @@ -179,7 +179,7 @@ This template includes: - Pre-configured worker scaling limits using the `LiveServerless()` object. - A `@remote` decorated function that returns a response from a worker. -When you run `flash run`, it auto-discovers all `@remote` functions and generates a local development server at `.flash/server.py`. Queue-based workers are exposed at `/{file_prefix}/run_sync` (e.g., `/gpu_worker/run_sync`). +When you run `flash dev`, it auto-discovers all `@remote` functions and builds a local development server programmatically. Queue-based workers are exposed at `/{file_prefix}/run_sync` (e.g., `/gpu_worker/run_sync`). ### Step 3: Install Python dependencies @@ -219,10 +219,10 @@ Save the file and close it. ### Step 5: Start the local API server -Use `flash run` to start the API server: +Use `flash dev` to start the API server: ```bash -flash run +flash dev ``` Open a new terminal tab or window and test your GPU API using cURL: @@ -233,23 +233,23 @@ curl -X POST http://localhost:8888/gpu_worker/run_sync \ -d '{"message": "Hello from the GPU!"}' ``` -If you switch back to the terminal tab where you used `flash run`, you'll see the details of the job's progress. +If you switch back to the terminal tab where you used `flash dev`, you'll see the details of the job's progress. -For more `flash run` options and configuration, see the [flash run documentation](src/runpod_flash/cli/docs/flash-run.md). +For more `flash dev` options and configuration, see the [flash dev documentation](src/runpod_flash/cli/docs/flash-dev.md). ### Faster testing with auto-provisioning For development with multiple endpoints, use `--auto-provision` to deploy all resources before testing: ```bash -flash run --auto-provision +flash dev --auto-provision ``` This eliminates cold-start delays by provisioning all serverless endpoints upfront. Endpoints are cached and reused across server restarts, making subsequent runs much faster. Resources are identified by name, so the same endpoint won't be re-deployed if configuration hasn't changed. ### Step 6: Open the API explorer -Besides starting the API server, `flash run` also starts an interactive API explorer. Point your web browser at [http://localhost:8888/docs](http://localhost:8888/docs) to explore the API. +Besides starting the API server, `flash dev` also starts an interactive API explorer. Point your web browser at [http://localhost:8888/docs](http://localhost:8888/docs) to explore the API. To run remote functions in the explorer: @@ -264,7 +264,7 @@ To customize your API: 1. Create new `.py` files with `@remote` decorated functions. 2. Test the scripts individually by running `python your_worker.py`. -3. Run `flash run` to auto-discover all `@remote` functions and serve them. +3. Run `flash dev` to auto-discover all `@remote` functions and serve them. ## CLI Reference @@ -273,7 +273,7 @@ Flash provides a command-line interface for project management, development, and ### Main Commands - **`flash init`** - Initialize a new Flash project with template structure -- **`flash run`** - Start local development server to test your `@remote` functions with auto-reload +- **`flash dev`** - Start local development server to test your `@remote` functions with auto-reload - **`flash build`** - Build deployment artifact with all dependencies - **`flash deploy`** - Build and deploy your application to Runpod Serverless in one step @@ -291,7 +291,7 @@ Flash provides a command-line interface for project management, development, and # Initialize and run locally flash init my-project cd my-project -flash run --auto-provision +flash dev --auto-provision # Build and deploy to production flash build @@ -315,7 +315,7 @@ For complete CLI documentation including all options, examples, and troubleshoot Individual command references: - [flash init](src/runpod_flash/cli/docs/flash-init.md) - Project initialization -- [flash run](src/runpod_flash/cli/docs/flash-run.md) - Development server +- [flash dev](src/runpod_flash/cli/docs/flash-dev.md) - Development server - [flash build](src/runpod_flash/cli/docs/flash-build.md) - Build artifacts - [flash deploy](src/runpod_flash/cli/docs/flash-deploy.md) - Deployment - [flash env](src/runpod_flash/cli/docs/flash-env.md) - Environment management diff --git a/docs/Using_Remote_With_LoadBalancer.md b/docs/Using_Remote_With_LoadBalancer.md index 038204b2..b0f297a4 100644 --- a/docs/Using_Remote_With_LoadBalancer.md +++ b/docs/Using_Remote_With_LoadBalancer.md @@ -528,4 +528,4 @@ A test/development variant of LoadBalancerSlsResource: - Review `docs/LoadBalancer_Runtime_Architecture.md` for runtime execution and request flows - Check examples in `flash-examples/` repository for more patterns - Use `flash build --help` to see build options -- Use `flash run --help` to see local testing options +- Use `flash dev --help` to see local testing options diff --git a/src/runpod_flash/cli/commands/_dev_server.py b/src/runpod_flash/cli/commands/_dev_server.py new file mode 100644 index 00000000..6133b220 --- /dev/null +++ b/src/runpod_flash/cli/commands/_dev_server.py @@ -0,0 +1,210 @@ +"""Programmatic FastAPI dev server for flash dev. + +Builds the FastAPI app by scanning for @remote functions and registering +routes via add_api_route(). User modules are imported directly, so +tracebacks point to the original source files. +""" + +import importlib +import os +import sys +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Callable, List, Optional + +from fastapi import FastAPI, Request + +if TYPE_CHECKING: + from .run import WorkerInfo + + +def create_app( + project_root: Optional[Path] = None, + workers: Optional[List["WorkerInfo"]] = None, +) -> FastAPI: + """Factory function for the Flash dev server. + + When called by uvicorn via ``--factory``, both parameters are None and + the function reads ``FLASH_PROJECT_ROOT`` from the environment and + scans for workers itself. Tests can pass both directly. + """ + if project_root is None: + project_root = Path(os.environ.get("FLASH_PROJECT_ROOT", os.getcwd())) + + root_str = str(project_root) + if root_str not in sys.path: + sys.path.insert(0, root_str) + + if workers is None: + from .run import _scan_project_workers + + workers = _scan_project_workers(project_root) + + app = FastAPI( + title="Flash Dev Server", + description="Built by `flash dev`. Visit /docs for interactive testing.", + ) + + for worker in workers: + _register_worker_routes(app, worker, project_root) + + @app.get("/", tags=["health"]) + def home(): + return {"message": "Flash Dev Server", "docs": "/docs"} + + @app.get("/ping", tags=["health"]) + def ping(): + return {"status": "healthy"} + + return app + + +def _import_from_module(module_path: str, name: str, project_root: Path): + """Import *name* from *module_path*, handling numeric-prefix directories. + + When a dotted module path contains segments starting with a digit + (e.g. ``01_hello.gpu_worker``), Python's ``from`` syntax cannot be + used. This function uses ``importlib.import_module`` and temporarily + scopes ``sys.path`` so that sibling imports within the target module + resolve to the correct subdirectory. + """ + has_numeric = any(seg and seg[0].isdigit() for seg in module_path.split(".")) + + if has_numeric: + parts = module_path.rsplit(".", 1) + if len(parts) > 1: + subdir = str(project_root / parts[0].replace(".", os.sep)) + sys.path.insert(0, subdir) + try: + mod = importlib.import_module(module_path) + finally: + try: + sys.path.remove(subdir) + except ValueError: + pass + else: + mod = importlib.import_module(module_path) + else: + mod = importlib.import_module(module_path) + + return getattr(mod, name) + + +def _register_worker_routes( + app: FastAPI, worker: "WorkerInfo", project_root: Path +) -> None: + """Register FastAPI routes for a single discovered worker.""" + tag = f"{worker.url_prefix.lstrip('/')} [{worker.worker_type}]" + + if worker.worker_type == "QB": + _register_qb_routes(app, worker, project_root, tag) + elif worker.worker_type == "LB": + _register_lb_routes(app, worker, project_root, tag) + + +def _register_qb_routes( + app: FastAPI, worker: "WorkerInfo", project_root: Path, tag: str +) -> None: + """Register queue-based (QB) routes. + + Single-function workers get one ``/run_sync`` endpoint. + Multi-function workers get ``//run_sync`` for each function. + """ + if len(worker.functions) == 1: + fn_name = worker.functions[0] + fn = _import_from_module(worker.module_path, fn_name, project_root) + path = f"{worker.url_prefix}/run_sync" + + async def qb_handler(body: dict, _fn=fn): + result = await _fn(body.get("input", body)) + return { + "id": str(uuid.uuid4()), + "status": "COMPLETED", + "output": result, + } + + qb_handler.__name__ = f"{worker.resource_name}_run_sync" + app.add_api_route(path, qb_handler, methods=["POST"], tags=[tag]) + else: + for fn_name in worker.functions: + fn = _import_from_module(worker.module_path, fn_name, project_root) + path = f"{worker.url_prefix}/{fn_name}/run_sync" + + async def qb_handler(body: dict, _fn=fn): + result = await _fn(body.get("input", body)) + return { + "id": str(uuid.uuid4()), + "status": "COMPLETED", + "output": result, + } + + qb_handler.__name__ = f"{worker.resource_name}_{fn_name}_run_sync" + app.add_api_route(path, qb_handler, methods=["POST"], tags=[tag]) + + +def _register_lb_routes( + app: FastAPI, + worker: "WorkerInfo", + project_root: Path, + tag: str, + executor: Optional[Callable] = None, +) -> None: + """Register load-balanced (LB) routes. + + Each LB route is dispatched through *executor* (defaults to + ``lb_execute`` from ``_run_server_helpers``). Tests can pass a + substitute to avoid hitting real infrastructure. + """ + if executor is None: + from ._run_server_helpers import lb_execute + + executor = lb_execute + + # import config variables (deduplicated) + config_vars: dict = {} + for route in worker.lb_routes: + var_name = route.get("config_variable") + if var_name and var_name not in config_vars: + config_vars[var_name] = _import_from_module( + worker.module_path, var_name, project_root + ) + + for route in worker.lb_routes: + method = route["method"] + sub_path = route["path"].lstrip("/") + fn_name = route["fn_name"] + config_var_name = route["config_variable"] + full_path = f"{worker.url_prefix}/{sub_path}" + + fn = _import_from_module(worker.module_path, fn_name, project_root) + config = config_vars.get(config_var_name) + + has_body = method.upper() in ("POST", "PUT", "PATCH", "DELETE") + if has_body: + + async def lb_body_handler( + body: dict, _config=config, _fn=fn, _exec=executor + ): + return await _exec(_config, _fn, body) + + lb_body_handler.__name__ = f"_route_{worker.resource_name}_{fn_name}" + app.add_api_route( + full_path, + lb_body_handler, + methods=[method.upper()], + tags=[tag], + ) + else: + + async def lb_query_handler( + request: Request, _config=config, _fn=fn, _exec=executor + ): + return await _exec(_config, _fn, dict(request.query_params)) + + lb_query_handler.__name__ = f"_route_{worker.resource_name}_{fn_name}" + app.add_api_route( + full_path, + lb_query_handler, + methods=[method.upper()], + tags=[tag], + ) diff --git a/src/runpod_flash/cli/commands/_run_server_helpers.py b/src/runpod_flash/cli/commands/_run_server_helpers.py index abf48f06..df370902 100644 --- a/src/runpod_flash/cli/commands/_run_server_helpers.py +++ b/src/runpod_flash/cli/commands/_run_server_helpers.py @@ -1,4 +1,4 @@ -"""Helpers for the flash run dev server — loaded inside the generated server.py.""" +"""Helpers for the flash dev server.""" import inspect from typing import Any, get_type_hints diff --git a/src/runpod_flash/cli/commands/init.py b/src/runpod_flash/cli/commands/init.py index eabd7583..89d7a781 100644 --- a/src/runpod_flash/cli/commands/init.py +++ b/src/runpod_flash/cli/commands/init.py @@ -6,7 +6,6 @@ import typer from rich.console import Console from rich.panel import Panel -from rich.table import Table from ..utils.skeleton import create_project_skeleton, detect_file_conflicts @@ -94,26 +93,17 @@ def init_command( # Next steps console.print("\n[bold]Next steps:[/bold]") - steps_table = Table(show_header=False, box=None, padding=(0, 1)) - steps_table.add_column("Step", style="bold cyan") - steps_table.add_column("Description") - step_num = 1 if not is_current_dir: - steps_table.add_row(f"{step_num}.", f"cd {actual_project_name}") + console.print(f" {step_num}. cd {actual_project_name}") step_num += 1 - - steps_table.add_row(f"{step_num}.", "pip install -r requirements.txt") - step_num += 1 - steps_table.add_row(f"{step_num}.", "cp .env.example .env") + console.print(f" {step_num}. pip install -r requirements.txt") step_num += 1 - steps_table.add_row(f"{step_num}.", "Add your RUNPOD_API_KEY to .env") + console.print(f" {step_num}. cp .env.example .env && add RUNPOD_API_KEY") step_num += 1 - steps_table.add_row(f"{step_num}.", "flash run") + console.print(f" {step_num}. flash dev") - console.print(steps_table) - - console.print("\n[bold]Get your API key:[/bold]") - console.print(" https://docs.runpod.io/get-started/api-keys") - console.print("\nVisit http://localhost:8888/docs after running") - console.print("\nCheck out the README.md for more") + console.print( + "\n [dim]API keys: https://docs.runpod.io/get-started/api-keys[/dim]" + ) + console.print(" [dim]Docs: http://localhost:8888/docs (after running)[/dim]") diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index db2d354c..740286c6 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -1,12 +1,11 @@ -"""Run Flash development server.""" +"""Flash dev server command.""" import logging import os -import re import signal +import socket import subprocess import sys -import threading from dataclasses import dataclass, field from pathlib import Path from typing import List @@ -15,22 +14,6 @@ from rich.console import Console from rich.table import Table -try: - from watchfiles import DefaultFilter as _WatchfilesDefaultFilter - from watchfiles import watch as _watchfiles_watch -except ModuleNotFoundError: - - def _watchfiles_watch(*_a, **_kw): # type: ignore[misc] - raise ModuleNotFoundError( - "watchfiles is required for flash run --reload. " - "Install it with: pip install watchfiles" - ) - - class _WatchfilesDefaultFilter: # type: ignore[no-redef] - def __init__(self, **_kw): - pass - - from .build_utils.scanner import ( RemoteDecoratorScanner, file_to_module_path, @@ -41,9 +24,33 @@ def __init__(self, **_kw): logger = logging.getLogger(__name__) console = Console() -# Resource state file written by ResourceManager in the uvicorn subprocess. +# resource state file written by ResourceManager in the uvicorn subprocess _RESOURCE_STATE_FILE = Path(".runpod") / "resources.pkl" +_MAX_PORT_ATTEMPTS = 20 + + +def _find_available_port(host: str, start_port: int) -> int: + """Find the first available port starting from start_port. + + Tries up to _MAX_PORT_ATTEMPTS consecutive ports. Raises typer.Exit + if no port is available. + """ + for offset in range(_MAX_PORT_ATTEMPTS): + port = start_port + offset + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((host, port)) + return port + except OSError: + continue + + console.print( + f"[red]Error:[/red] No available port found in range " + f"{start_port}-{start_port + _MAX_PORT_ATTEMPTS - 1}." + ) + raise typer.Exit(1) + @dataclass class WorkerInfo: @@ -55,13 +62,7 @@ class WorkerInfo: resource_name: str # e.g. longruns_stage1 worker_type: str # "QB" or "LB" functions: List[str] # function names - class_remotes: List[dict] = field( - default_factory=list - ) # [{name, methods, method_params}] lb_routes: List[dict] = field(default_factory=list) # [{method, path, fn_name}] - function_params: dict[str, list[str]] = field( - default_factory=dict - ) # fn_name -> param_names def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: @@ -83,7 +84,7 @@ def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: scanner = RemoteDecoratorScanner(project_root) remote_functions = scanner.discover_remote_functions() - # Group by file path + # group by file path by_file: dict[Path, List] = {} for func in remote_functions: by_file.setdefault(func.file_path, []).append(func) @@ -94,11 +95,10 @@ def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: module_path = file_to_module_path(file_path, project_root) resource_name = file_to_resource_name(file_path, project_root) - qb_funcs = [f for f in funcs if not f.is_load_balanced and not f.is_class] - qb_classes = [f for f in funcs if not f.is_load_balanced and f.is_class] + qb_funcs = [f for f in funcs if not f.is_load_balanced] lb_funcs = [f for f in funcs if f.is_load_balanced and f.is_lb_route_handler] - if qb_funcs or qb_classes: + if qb_funcs: workers.append( WorkerInfo( file_path=file_path, @@ -107,15 +107,6 @@ def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: resource_name=resource_name, worker_type="QB", functions=[f.function_name for f in qb_funcs], - class_remotes=[ - { - "name": c.function_name, - "methods": c.class_methods, - "method_params": c.class_method_params, - } - for c in qb_classes - ], - function_params={f.function_name: f.param_names for f in qb_funcs}, ) ) @@ -144,404 +135,6 @@ def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: return workers -def _ensure_gitignore(project_root: Path) -> None: - """Add .flash/ to .gitignore if not already present.""" - gitignore = project_root / ".gitignore" - entry = ".flash/" - - if gitignore.exists(): - content = gitignore.read_text(encoding="utf-8") - if entry in content: - return - # Append with a newline - if not content.endswith("\n"): - content += "\n" - gitignore.write_text(content + entry + "\n", encoding="utf-8") - else: - gitignore.write_text(entry + "\n", encoding="utf-8") - - -def _sanitize_fn_name(name: str) -> str: - """Sanitize a string for use as a Python function name. - - Replaces non-identifier characters with underscores and prepends '_' - if the result starts with a digit (Python identifiers cannot start - with digits). - """ - result = name.replace("/", "_").replace(".", "_").replace("-", "_") - if result and result[0].isdigit(): - result = "_" + result - return result - - -def _has_numeric_module_segments(module_path: str) -> bool: - """Check if any segment in a dotted module path starts with a digit. - - Python identifiers cannot start with digits, so ``from 01_foo import bar`` - is a SyntaxError. Callers should use ``importlib.import_module()`` instead. - """ - return any(seg and seg[0].isdigit() for seg in module_path.split(".")) - - -def _module_parent_subdir(module_path: str) -> str | None: - """Return the parent sub-directory for a dotted module path, or None for top-level. - - Example: ``01_getting_started.03_mixed.pipeline`` → ``01_getting_started/03_mixed`` - """ - parts = module_path.rsplit(".", 1) - if len(parts) == 1: - return None - return parts[0].replace(".", "/") - - -def _make_import_line(module_path: str, name: str) -> str: - """Build an import statement for *name* from *module_path*. - - Uses a regular ``from … import …`` when the module path is a valid - Python identifier chain. Falls back to ``_flash_import()`` (a generated - helper in server.py) when any segment starts with a digit. The helper - temporarily scopes ``sys.path`` so sibling imports in the target module - resolve to the correct directory. - """ - if _has_numeric_module_segments(module_path): - subdir = _module_parent_subdir(module_path) - if subdir: - return f'{name} = _flash_import("{module_path}", "{name}", "{subdir}")' - return f'{name} = _flash_import("{module_path}", "{name}")' - return f"from {module_path} import {name}" - - -_PATH_PARAM_RE = re.compile(r"\{(\w+)\}") - - -def _extract_path_params(path: str) -> list[str]: - """Extract path parameter names from a FastAPI-style route path. - - Example: "/images/{file_id}" -> ["file_id"] - """ - return _PATH_PARAM_RE.findall(path) - - -def _build_call_expr(callable_name: str, params: list[str] | None) -> tuple[str, bool]: - """Build an async call expression based on parameter count. - - Args: - callable_name: Fully qualified callable (e.g. "fn" or "instance.method") - params: List of param names, or None if unknown (backward compat) - - Returns: - Tuple of (call_expression, needs_body). needs_body is False when the - handler signature should omit the body parameter. - """ - if params is not None and len(params) == 0: - return f"await {callable_name}()", False - return f"await _call_with_body({callable_name}, body)", True - - -def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Path: - """Generate .flash/server.py from the discovered workers. - - Args: - project_root: Root of the Flash project - workers: List of discovered worker infos - - Returns: - Path to the generated server.py - """ - flash_dir = project_root / ".flash" - flash_dir.mkdir(exist_ok=True) - - _ensure_gitignore(project_root) - - has_lb_workers = any(w.worker_type == "LB" for w in workers) - - lines = [ - '"""Auto-generated Flash dev server. Do not edit — regenerated on each flash run."""', - "import sys", - "import uuid", - "from pathlib import Path", - "_project_root = Path(__file__).parent.parent", - "sys.path.insert(0, str(_project_root))", - "", - ] - - # When modules live in directories with numeric prefixes (e.g. 01_hello/), - # we cannot use ``from … import …`` — Python identifiers cannot start with - # digits. Instead we emit a small ``_flash_import`` helper that uses - # ``importlib.import_module()`` *and* temporarily scopes ``sys.path`` so - # that sibling imports inside the loaded module (e.g. ``from cpu_worker - # import …``) resolve to the correct directory rather than a same-named - # file from a different example subdirectory. - needs_importlib = any(_has_numeric_module_segments(w.module_path) for w in workers) - - if needs_importlib: - lines += [ - "import importlib as _importlib", - "", - "", - "def _flash_import(module_path, name, subdir=None):", - ' """Import *name* from *module_path* with scoped sys.path for sibling imports."""', - " _path = str(_project_root / subdir) if subdir else None", - " if _path:", - " sys.path.insert(0, _path)", - " try:", - " return getattr(_importlib.import_module(module_path), name)", - " finally:", - " if _path is not None:", - " try:", - " if sys.path and sys.path[0] == _path:", - " sys.path.pop(0)", - " else:", - " sys.path.remove(_path)", - " except ValueError:", - " pass", - "", - ] - - lines += [ - "from runpod_flash.cli.commands._run_server_helpers import make_input_model as _make_input_model", - "from runpod_flash.cli.commands._run_server_helpers import call_with_body as _call_with_body", - ] - - if has_lb_workers: - lines += [ - "from fastapi import FastAPI, Request", - "from runpod_flash.cli.commands._run_server_helpers import lb_execute as _lb_execute", - "from runpod_flash.cli.commands._run_server_helpers import to_dict as _to_dict", - "", - ] - else: - lines += [ - "from fastapi import FastAPI", - "", - ] - - # Collect imports — QB functions are called directly, LB config variables and - # functions are passed to lb_execute for dispatch via LoadBalancerSlsStub. - all_imports: List[str] = [] - for worker in workers: - if worker.worker_type == "QB": - for fn_name in worker.functions: - all_imports.append(_make_import_line(worker.module_path, fn_name)) - for cls_info in worker.class_remotes: - all_imports.append( - _make_import_line(worker.module_path, cls_info["name"]) - ) - elif worker.worker_type == "LB": - # Import the resource config variable (e.g. "api" from api = LiveLoadBalancer(...)) - config_vars = { - r["config_variable"] - for r in worker.lb_routes - if r.get("config_variable") - } - for var in sorted(config_vars): - all_imports.append(_make_import_line(worker.module_path, var)) - for fn_name in worker.functions: - all_imports.append(_make_import_line(worker.module_path, fn_name)) - - if all_imports: - lines.extend(all_imports) - lines.append("") - - lines += [ - "app = FastAPI(", - ' title="Flash Dev Server",', - ' description="Auto-generated by `flash run`. Visit /docs for interactive testing.",', - ")", - "", - ] - - # Module-level instance creation for @remote classes - for worker in workers: - for cls_info in worker.class_remotes: - cls_name = cls_info["name"] - lines.append(f"_instance_{cls_name} = {cls_name}()") - # Add blank line if any instances were created - if any(worker.class_remotes for worker in workers): - lines.append("") - - # Module-level Pydantic model creation for typed Swagger UI - model_lines: list[str] = [] - for worker in workers: - if worker.worker_type == "QB": - for fn in worker.functions: - params = worker.function_params.get(fn) - if params is None or len(params) > 0: - model_var = f"_{worker.resource_name}_{fn}_Input" - model_lines.append( - f'{model_var} = _make_input_model("{model_var}", {fn}) or dict' - ) - for cls_info in worker.class_remotes: - cls_name = cls_info["name"] - method_params = cls_info.get("method_params", {}) - instance_var = f"_instance_{cls_name}" - for method in cls_info["methods"]: - params = method_params.get(method) - if params is None or len(params) > 0: - model_var = f"_{worker.resource_name}_{cls_name}_{method}_Input" - # Use _class_type to get the original unwrapped method - # (RemoteClassWrapper.__getattr__ returns proxies with (*args, **kwargs)) - class_ref = f"getattr({instance_var}, '_class_type', type({instance_var}))" - model_lines.append( - f'{model_var} = _make_input_model("{model_var}", {class_ref}.{method}) or dict' - ) - elif worker.worker_type == "LB": - for route in worker.lb_routes: - method = route["method"].lower() - if method in ("post", "put", "patch", "delete"): - fn_name = route["fn_name"] - model_var = f"_{worker.resource_name}_{fn_name}_Input" - model_lines.append( - f'{model_var} = _make_input_model("{model_var}", {fn_name}) or dict' - ) - if model_lines: - lines.extend(model_lines) - lines.append("") - - for worker in workers: - # Group routes by project directory in Swagger UI. - # Nested: /03_mixed_workers/cpu_worker -> "03_mixed_workers/" - # Root: /worker -> "worker" - prefix = worker.url_prefix.lstrip("/") - tag = f"{prefix.rsplit('/', 1)[0]}/" if "/" in prefix else prefix - lines.append(f"# {'─' * 60}") - lines.append(f"# {worker.worker_type}: {worker.file_path.name}") - lines.append(f"# {'─' * 60}") - - if worker.worker_type == "QB": - # Total callable count: functions + sum of class methods - total_class_methods = sum(len(c["methods"]) for c in worker.class_remotes) - total_callables = len(worker.functions) + total_class_methods - use_multi = total_callables > 1 - - # Function-based routes - for fn in worker.functions: - if use_multi: - handler_name = _sanitize_fn_name( - f"{worker.resource_name}_{fn}_run_sync" - ) - sync_path = f"{worker.url_prefix}/{fn}/run_sync" - else: - handler_name = _sanitize_fn_name(f"{worker.resource_name}_run_sync") - sync_path = f"{worker.url_prefix}/run_sync" - params = worker.function_params.get(fn) - call_expr, needs_body = _build_call_expr(fn, params) - if needs_body: - model_var = f"_{worker.resource_name}_{fn}_Input" - handler_sig = f"async def {handler_name}(body: {model_var}):" - else: - handler_sig = f"async def {handler_name}():" - lines += [ - f'@app.post("{sync_path}", tags=["{tag}"])', - handler_sig, - f" result = {call_expr}", - ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', - "", - ] - - # Class-based routes - for cls_info in worker.class_remotes: - cls_name = cls_info["name"] - methods = cls_info["methods"] - method_params = cls_info.get("method_params", {}) - instance_var = f"_instance_{cls_name}" - - for method in methods: - if use_multi: - handler_name = _sanitize_fn_name( - f"{worker.resource_name}_{cls_name}_{method}_run_sync" - ) - sync_path = f"{worker.url_prefix}/{method}/run_sync" - else: - handler_name = _sanitize_fn_name( - f"{worker.resource_name}_{cls_name}_run_sync" - ) - sync_path = f"{worker.url_prefix}/run_sync" - params = method_params.get(method) - call_expr, needs_body = _build_call_expr( - f"{instance_var}.{method}", params - ) - if needs_body: - model_var = f"_{worker.resource_name}_{cls_name}_{method}_Input" - handler_sig = f"async def {handler_name}(body: {model_var}):" - else: - handler_sig = f"async def {handler_name}():" - lines += [ - f'@app.post("{sync_path}", tags=["{tag}"])', - handler_sig, - f" result = {call_expr}", - ' return {"id": str(uuid.uuid4()), "status": "COMPLETED", "output": result}', - "", - ] - - elif worker.worker_type == "LB": - for route in worker.lb_routes: - method = route["method"].lower() - sub_path = route["path"].lstrip("/") - fn_name = route["fn_name"] - config_var = route["config_variable"] - full_path = f"{worker.url_prefix}/{sub_path}" - handler_name = _sanitize_fn_name( - f"_route_{worker.resource_name}_{fn_name}" - ) - path_params = _extract_path_params(full_path) - has_body = method in ("post", "put", "patch", "delete") - if has_body: - model_var = f"_{worker.resource_name}_{fn_name}_Input" - # POST/PUT/PATCH/DELETE: typed body + optional path params - if path_params: - param_sig = ", ".join(f"{p}: str" for p in path_params) - param_dict = ", ".join(f'"{p}": {p}' for p in path_params) - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(body: {model_var}, {param_sig}):", - f" return await _lb_execute({config_var}, {fn_name}, {{**_to_dict(body), {param_dict}}})", - "", - ] - else: - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(body: {model_var}):", - f" return await _lb_execute({config_var}, {fn_name}, _to_dict(body))", - "", - ] - else: - # GET/etc: path params + query params (unchanged) - if path_params: - param_sig = ", ".join(f"{p}: str" for p in path_params) - param_dict = ", ".join(f'"{p}": {p}' for p in path_params) - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}({param_sig}, request: Request):", - f" return await _lb_execute({config_var}, {fn_name}, {{**dict(request.query_params), {param_dict}}})", - "", - ] - else: - lines += [ - f'@app.{method}("{full_path}", tags=["{tag}"])', - f"async def {handler_name}(request: Request):", - f" return await _lb_execute({config_var}, {fn_name}, dict(request.query_params))", - "", - ] - - # Health endpoints - lines += [ - "# Health", - '@app.get("/", tags=["health"])', - "def home():", - ' return {"message": "Flash Dev Server", "docs": "/docs"}', - "", - '@app.get("/ping", tags=["health"])', - "def ping():", - ' return {"status": "healthy"}', - "", - ] - - server_path = flash_dir / "server.py" - server_path.write_text("\n".join(lines), encoding="utf-8") - return server_path - - def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> None: """Print the startup table showing local paths, resource names, and types.""" console.print(f"\n[bold green]Flash Dev Server[/bold green] http://{host}:{port}") @@ -554,39 +147,19 @@ def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> Non for worker in workers: if worker.worker_type == "QB": - total_class_methods = sum(len(c["methods"]) for c in worker.class_remotes) - total_callables = len(worker.functions) + total_class_methods - use_multi = total_callables > 1 - - for fn in worker.functions: - if use_multi: + if len(worker.functions) == 1: + table.add_row( + f"POST {worker.url_prefix}/run_sync", + worker.resource_name, + "QB", + ) + else: + for fn in worker.functions: table.add_row( f"POST {worker.url_prefix}/{fn}/run_sync", worker.resource_name, "QB", ) - else: - table.add_row( - f"POST {worker.url_prefix}/run_sync", - worker.resource_name, - "QB", - ) - - for cls_info in worker.class_remotes: - methods = cls_info["methods"] - for method in methods: - if use_multi: - table.add_row( - f"POST {worker.url_prefix}/{method}/run_sync", - worker.resource_name, - "QB", - ) - else: - table.add_row( - f"POST {worker.url_prefix}/run_sync", - worker.resource_name, - "QB", - ) elif worker.worker_type == "LB": for route in worker.lb_routes: sub_path = route["path"].lstrip("/") @@ -600,7 +173,7 @@ def _print_startup_table(workers: List[WorkerInfo], host: str, port: int) -> Non console.print(table) console.print(f"\n Visit [bold]http://{host}:{port}/docs[/bold] for Swagger UI") console.print( - " Press [bold]Ctrl+C[/bold] to stop — provisioned endpoints are cleaned up automatically\n" + " Press [bold]Ctrl+C[/bold] to stop. Provisioned endpoints are cleaned up automatically.\n" ) @@ -615,74 +188,68 @@ def _cleanup_live_endpoints() -> None: if not _RESOURCE_STATE_FILE.exists(): return - import asyncio - import cloudpickle - from ...core.utils.file_lock import file_lock - - # Load persisted resource state. If this fails (lock error, corruption), - # log and return — don't let it prevent the rest of shutdown. try: + import asyncio + import cloudpickle + from ...core.utils.file_lock import file_lock + with open(_RESOURCE_STATE_FILE, "rb") as f: with file_lock(f, exclusive=False): data = cloudpickle.load(f) - except Exception as e: - logger.warning(f"Could not read resource state for cleanup: {e}") - return - if isinstance(data, tuple): - resources, configs = data - else: - resources, configs = data, {} + if isinstance(data, tuple): + resources, configs = data + else: + resources, configs = data, {} - live_items = { - key: resource - for key, resource in resources.items() - if hasattr(resource, "name") - and resource.name - and resource.name.startswith("live-") - } + live_items = { + key: resource + for key, resource in resources.items() + if hasattr(resource, "name") + and resource.name + and resource.name.startswith("live-") + } - if not live_items: - return + if not live_items: + return - import time + import time - async def _do_cleanup(): - undeployed = 0 - for key, resource in live_items.items(): - name = getattr(resource, "name", key) - try: - success = await resource._do_undeploy() - if success: - console.print(f" Deprovisioned: {name}") - undeployed += 1 - else: - logger.warning(f"Failed to deprovision: {name}") - except Exception as e: - logger.warning(f"Error deprovisioning {name}: {e}") - return undeployed + async def _do_cleanup(): + undeployed = 0 + for key, resource in live_items.items(): + name = getattr(resource, "name", key) + try: + success = await resource._do_undeploy() + if success: + console.print(f" Deprovisioned: {name}") + undeployed += 1 + else: + logger.warning(f"Failed to deprovision: {name}") + except Exception as e: + logger.warning(f"Error deprovisioning {name}: {e}") + return undeployed + + t0 = time.monotonic() + undeployed = asyncio.run(_do_cleanup()) + elapsed = time.monotonic() - t0 + console.print( + f" Cleanup completed: {undeployed}/{len(live_items)} " + f"resource(s) undeployed in {elapsed:.1f}s" + ) - t0 = time.monotonic() - loop = asyncio.new_event_loop() - try: - undeployed = loop.run_until_complete(_do_cleanup()) - finally: - loop.close() - elapsed = time.monotonic() - t0 - console.print( - f" Cleanup completed: {undeployed}/{len(live_items)} " - f"resource(s) undeployed in {elapsed:.1f}s" - ) + # remove live- entries from persisted state so they don't linger + remaining = {k: v for k, v in resources.items() if k not in live_items} + remaining_configs = {k: v for k, v in configs.items() if k not in live_items} + try: + with open(_RESOURCE_STATE_FILE, "wb") as f: + with file_lock(f, exclusive=True): + cloudpickle.dump((remaining, remaining_configs), f) + except Exception as e: + logger.warning(f"Could not update resource state after cleanup: {e}") - # Remove live- entries from persisted state so they don't linger. - remaining = {k: v for k, v in resources.items() if k not in live_items} - remaining_configs = {k: v for k, v in configs.items() if k not in live_items} - try: - with open(_RESOURCE_STATE_FILE, "wb") as f: - with file_lock(f, exclusive=True): - cloudpickle.dump((remaining, remaining_configs), f) except Exception as e: - logger.warning(f"Could not update resource state after cleanup: {e}") + logger.warning(f"Live endpoint cleanup failed: {e}") def _is_reload() -> bool: @@ -690,39 +257,6 @@ def _is_reload() -> bool: return "UVICORN_RELOADER_PID" in os.environ -def _watch_and_regenerate(project_root: Path, stop_event: threading.Event) -> None: - """Watch project .py files and regenerate server.py when they change. - - Ignores .flash/ to avoid reacting to our own writes. Runs until - stop_event is set. - """ - # Suppress watchfiles' internal debug chatter (filter hits, rust timeouts). - logging.getLogger("watchfiles").setLevel(logging.WARNING) - - watch_filter = _WatchfilesDefaultFilter(ignore_paths=[str(project_root / ".flash")]) - - try: - for changes in _watchfiles_watch( - project_root, - watch_filter=watch_filter, - stop_event=stop_event, - ): - py_changed = [p for _, p in changes if p.endswith(".py")] - if not py_changed: - continue - try: - workers = _scan_project_workers(project_root) - _generate_flash_server(project_root, workers) - logger.debug("server.py regenerated (%d changed)", len(py_changed)) - except Exception as e: - logger.warning("Failed to regenerate server.py: %s", e) - except ModuleNotFoundError as e: - logger.warning("File watching disabled: %s", e) - except Exception as e: - if not stop_event.is_set(): - logger.exception("Unexpected error in file watcher: %s", e) - - def _discover_resources(project_root: Path): """Discover deployable resources in project files. @@ -746,8 +280,7 @@ def _discover_resources(project_root: Path): ) ) - # Add project root to sys.path so cross-module imports resolve - # (e.g. api/routes.py doing "from longruns.stage1 import stage1_process"). + # add project root to sys.path so cross-module imports resolve root_str = str(project_root) added_to_path = root_str not in sys.path if added_to_path: @@ -824,21 +357,21 @@ def run_command( help="Auto-provision all endpoints on startup (eliminates cold-start on first request)", ), ): - """Run Flash development server. + """Start Flash development server. - Scans the project for @remote decorated functions, generates a dev server - at .flash/server.py, and starts uvicorn with hot-reload. + Scans the project for @remote decorated functions and starts a FastAPI + dev server via uvicorn. Tracebacks point directly to your source files. No main.py or FastAPI boilerplate required. Any .py file with @remote decorated functions is a valid Flash project. """ project_root = Path.cwd() - # Set flag for live provisioning so stubs get the live- prefix + # set flag for live provisioning so stubs get the live- prefix if not _is_reload(): os.environ["FLASH_IS_LIVE_PROVISIONING"] = "true" - # Auto-provision all endpoints upfront (eliminates cold-start) + # auto-provision all endpoints upfront (eliminates cold-start) if auto_provision and not _is_reload(): try: resources = _discover_resources(project_root) @@ -851,14 +384,14 @@ def run_command( "[dim]Resources will be provisioned on-demand at first request.[/dim]" ) - # Discover @remote functions + # discover @remote functions for the startup table workers = _scan_project_workers(project_root) if not workers: console.print("[red]Error:[/red] No @remote functions found.") console.print("Add @remote decorators to your functions to get started.") - console.print("\nExample:") console.print( + "\nExample:\n" " from runpod_flash import LiveServerless, remote\n" " gpu_config = LiveServerless(name='my_worker')\n" "\n" @@ -868,19 +401,25 @@ def run_command( ) raise typer.Exit(1) - # Generate .flash/server.py - _generate_flash_server(project_root, workers) + # find a free port, counting up from the requested one + actual_port = _find_available_port(host, port) + if actual_port != port: + console.print( + f"[yellow]Port {port} is in use, using {actual_port} instead.[/yellow]" + ) + port = actual_port _print_startup_table(workers, host, port) - # Build uvicorn command using --app-dir so server:app is importable + # tell the factory function where the project lives + os.environ["FLASH_PROJECT_ROOT"] = str(project_root) + cmd = [ sys.executable, "-m", "uvicorn", - "server:app", - "--app-dir", - ".flash", + "--factory", + "runpod_flash.cli.commands._dev_server:create_app", "--host", host, "--port", @@ -893,21 +432,9 @@ def run_command( cmd += [ "--reload", "--reload-dir", - ".flash", - "--reload-include", - "server.py", + str(project_root), ] - stop_event = threading.Event() - watcher_thread = None - if reload: - watcher_thread = threading.Thread( - target=_watch_and_regenerate, - args=(project_root, stop_event), - daemon=True, - name="flash-watcher", - ) - process = None try: if sys.platform == "win32": @@ -917,18 +444,11 @@ def run_command( else: process = subprocess.Popen(cmd, preexec_fn=os.setsid) - if watcher_thread is not None: - watcher_thread.start() - process.wait() except KeyboardInterrupt: console.print("\n[yellow]Stopping server and cleaning up...[/yellow]") - stop_event.set() - if watcher_thread is not None and watcher_thread.is_alive(): - watcher_thread.join(timeout=2) - if process: try: if sys.platform == "win32": @@ -955,10 +475,6 @@ def run_command( except Exception as e: console.print(f"[red]Error:[/red] {e}") - stop_event.set() - if watcher_thread is not None and watcher_thread.is_alive(): - watcher_thread.join(timeout=2) - if process: try: if sys.platform == "win32": diff --git a/src/runpod_flash/cli/commands/test_mothership.py b/src/runpod_flash/cli/commands/test_mothership.py index 4e7b66bd..246401e5 100644 --- a/src/runpod_flash/cli/commands/test_mothership.py +++ b/src/runpod_flash/cli/commands/test_mothership.py @@ -278,7 +278,7 @@ def _create_entrypoint_script(build_dir: str) -> None: """Create entrypoint.sh script for Docker container. This script handles signal trapping and cleanup on shutdown. - It runs manifest-based provisioning then flash run (without --auto-provision + It runs manifest-based provisioning then flash dev (without --auto-provision to avoid duplicate discovery from bundled dependencies). """ build_path = Path(build_dir) @@ -369,7 +369,7 @@ def _display_test_objectives() -> None: def _display_config(build_dir: str, image: str, port: int, endpoint_id: str) -> None: """Display test configuration.""" config_text = f"""[bold]Build directory:[/bold] {build_dir} -[bold]Command:[/bold] flash run +[bold]Command:[/bold] flash dev [bold]Docker image:[/bold] {image} [bold]Endpoint ID:[/bold] {endpoint_id} [bold]Port:[/bold] http://localhost:{port}""" diff --git a/src/runpod_flash/cli/docs/README.md b/src/runpod_flash/cli/docs/README.md index 1a1b4dfe..00c7baef 100644 --- a/src/runpod_flash/cli/docs/README.md +++ b/src/runpod_flash/cli/docs/README.md @@ -26,7 +26,7 @@ echo "RUNPOD_API_KEY=your_api_key_here" > .env Start the development server to test your `@remote` functions: ```bash -flash run +flash dev ``` When you're ready to deploy your application to Runpod, use: @@ -128,12 +128,12 @@ flash deploy --preview --- -### flash run +### flash dev Start a Flash development server for testing/debugging/development. ```bash -flash run [OPTIONS] +flash dev [OPTIONS] ``` **Options:** @@ -144,11 +144,11 @@ flash run [OPTIONS] **Example:** ```bash -flash run -flash run --port 3000 +flash dev +flash dev --port 3000 ``` -[Full documentation](./flash-run.md) +[Full documentation](./flash-dev.md) --- @@ -331,5 +331,5 @@ curl -X POST http://localhost:8888/cpu_worker/run_sync \ ```bash flash --help flash init --help -flash run --help +flash dev --help ``` diff --git a/src/runpod_flash/cli/docs/flash-build.md b/src/runpod_flash/cli/docs/flash-build.md index deb0e633..91f010f9 100644 --- a/src/runpod_flash/cli/docs/flash-build.md +++ b/src/runpod_flash/cli/docs/flash-build.md @@ -287,7 +287,7 @@ Check the [worker-flash repository](https://github.com/runpod-workers/worker-fla After building: -1. **Test Locally**: Run `flash run` to test the application +1. **Test Locally**: Run `flash dev` to test the application 2. **Deploy**: Use `flash deploy` to deploy to Runpod Serverless 3. **Preview**: Test with `flash build --preview` before production deployment 4. **Monitor**: Use `flash env get` to check deployment status @@ -295,6 +295,6 @@ After building: ## Related Commands - [flash deploy](./flash-deploy.md) - Build and deploy in one step -- [flash run](./flash-run.md) - Start development server +- [flash dev](./flash-dev.md) - Start development server - [flash env](./flash-env.md) - Manage deployment environments - [flash undeploy](./flash-undeploy.md) - Manage deployed endpoints diff --git a/src/runpod_flash/cli/docs/flash-deploy.md b/src/runpod_flash/cli/docs/flash-deploy.md index d0fcb6a7..c8acd3b1 100644 --- a/src/runpod_flash/cli/docs/flash-deploy.md +++ b/src/runpod_flash/cli/docs/flash-deploy.md @@ -61,11 +61,11 @@ With `flash deploy`, your **entire application** runs on Runpod Serverless—all - **No `live-` prefix** on endpoint names (these are production endpoints) - **No hot reload:** code changes require a new deployment -This is different from `flash run`, where your FastAPI app runs locally on your machine. See [flash run](./flash-run.md) for the hybrid development architecture. +This is different from `flash dev`, where your FastAPI app runs locally on your machine. See [flash dev](./flash-dev.md) for the hybrid development architecture. -### flash run vs flash deploy +### flash dev vs flash deploy -| Aspect | `flash run` | `flash deploy` | +| Aspect | `flash dev` | `flash deploy` | |--------|-------------|----------------| | **App runs on** | Your machine (localhost) | Runpod Serverless | | **`@remote` functions run on** | Runpod Serverless | Runpod Serverless | @@ -454,4 +454,4 @@ After deploying: - [flash env](./flash-env.md) - Manage deployment environments - [flash app](./flash-app.md) - Manage Flash applications - [flash undeploy](./flash-undeploy.md) - Remove deployed endpoints -- [flash run](./flash-run.md) - Local development server +- [flash dev](./flash-dev.md) - Local development server diff --git a/src/runpod_flash/cli/docs/flash-run.md b/src/runpod_flash/cli/docs/flash-dev.md similarity index 82% rename from src/runpod_flash/cli/docs/flash-run.md rename to src/runpod_flash/cli/docs/flash-dev.md index 70976d6c..39cdfba2 100644 --- a/src/runpod_flash/cli/docs/flash-run.md +++ b/src/runpod_flash/cli/docs/flash-dev.md @@ -1,25 +1,25 @@ -# flash run +# flash dev Start the Flash development server for testing/debugging/development. ## Overview -The `flash run` command starts a local development server that auto-discovers your `@remote` functions and serves them on your machine while deploying them to Runpod Serverless. This hybrid architecture lets you rapidly iterate on your application with hot-reload while testing real GPU/CPU workloads in the cloud. +The `flash dev` command starts a local development server that auto-discovers your `@remote` functions and serves them on your machine while deploying them to Runpod Serverless. This hybrid architecture lets you rapidly iterate on your application with hot-reload while testing real GPU/CPU workloads in the cloud. -Use `flash run` when you want to skip the build step and test/develop/debug your remote functions rapidly before deploying your full application with `flash deploy`. (See [Flash Deploy](./flash-deploy.md) for details.) +Use `flash dev` when you want to skip the build step and test/develop/debug your remote functions rapidly before deploying your full application with `flash deploy`. (See [Flash Deploy](./flash-deploy.md) for details.) ## Architecture: Local App + Remote Workers -With `flash run`, your system runs in a **hybrid architecture**: +With `flash dev`, your system runs in a **hybrid architecture**: ``` ┌─────────────────────────────────────────────────────────────────┐ │ YOUR MACHINE (localhost:8888) │ │ ┌─────────────────────────────────────┐ │ -│ │ Auto-generated server │ │ -│ │ (.flash/server.py) │ │ +│ │ Programmatic FastAPI server │ │ +│ │ (programmatic FastAPI app) │ │ │ │ - Discovers @remote functions │─────────┐ │ -│ │ - Hot-reload via watchfiles │ │ │ +│ │ - Hot-reload via uvicorn │ │ │ │ └─────────────────────────────────────┘ │ │ └──────────────────────────────────────────────────│──────────────┘ │ HTTPS @@ -34,25 +34,26 @@ With `flash run`, your system runs in a **hybrid architecture**: ``` **Key points:** -- **`flash run` auto-discovers `@remote` functions** and generates `.flash/server.py` +- **`flash dev` auto-discovers `@remote` functions** and builds routes programmatically - **Queue-based (QB) routes execute locally** at `/{file_prefix}/run_sync` - **Load-balanced (LB) routes dispatch remotely** via `LoadBalancerSlsStub` - **`@remote` functions run on Runpod** as serverless endpoints -- **Hot reload** watches for `.py` file changes via watchfiles +- **Hot reload** watches your project directory via uvicorn's built-in reloader - **Endpoints are prefixed with `live-`** to distinguish development endpoints from production (e.g., `gpu-worker` becomes `live-gpu-worker`) +- **Direct tracebacks** - errors point to your original source files This is different from `flash deploy`, where **everything** (including your FastAPI app) runs on Runpod. See [flash deploy](./flash-deploy.md) for the fully-deployed architecture. ## Usage ```bash -flash run [OPTIONS] +flash dev [OPTIONS] ``` ## Options -- `--host`: Host to bind to (default: localhost) -- `--port, -p`: Port to bind to (default: 8888) +- `--host`: Host to bind to (default: localhost, env: FLASH_HOST) +- `--port, -p`: Port to bind to (default: 8888, env: FLASH_PORT) - `--reload/--no-reload`: Enable auto-reload (default: enabled) - `--auto-provision`: Auto-provision Serverless endpoints on startup (default: disabled) @@ -60,33 +61,34 @@ flash run [OPTIONS] ```bash # Start server with defaults -flash run +flash dev # Custom port -flash run --port 3000 +flash dev --port 3000 # Disable auto-reload -flash run --no-reload +flash dev --no-reload # Custom host and port -flash run --host 0.0.0.0 --port 8000 +flash dev --host 0.0.0.0 --port 8000 ``` ## What It Does 1. Scans project files for `@remote` decorated functions -2. Generates `.flash/server.py` with QB and LB routes -3. Starts uvicorn server with hot-reload via watchfiles +2. Builds FastAPI routes programmatically via `create_app()` +3. Starts uvicorn server with hot-reload watching your project directory 4. GPU workers use LiveServerless (no packaging needed) + ### How It Works -When you call a `@remote` function using `flash run`, Flash deploys a **Serverless endpoint** to Runpod. (These are actual cloud resources that incur costs.) +When you call a `@remote` function using `flash dev`, Flash deploys a **Serverless endpoint** to Runpod. (These are actual cloud resources that incur costs.) ``` -flash run +flash dev │ ├── Scans project for @remote functions - ├── Generates .flash/server.py + ├── Builds FastAPI app programmatically ├── Starts local server (e.g. localhost:8888) │ ├── QB routes: /{file_prefix}/run_sync (local execution) │ └── LB routes: /{file_prefix}/{path} (remote dispatch) @@ -121,13 +123,13 @@ Auto-provisioning discovers and deploys Serverless endpoints before the Flash de Enable it with the `--auto-provision` flag: ```bash -flash run --auto-provision +flash dev --auto-provision ``` Example with custom host and port: ```bash -flash run --auto-provision --host 0.0.0.0 --port 8000 +flash dev --auto-provision --host 0.0.0.0 --port 8000 ``` ### Benefits @@ -149,10 +151,10 @@ Resources are cached by name and automatically reused: ```bash # First run: deploys endpoints -flash run --auto-provision +flash dev --auto-provision # Subsequent runs: reuses cached endpoints (faster) -flash run --auto-provision +flash dev --auto-provision ``` Resources persist in `.runpod/resources.pkl` and survive server restarts. Configuration changes are detected automatically and trigger re-deployment only when needed. diff --git a/src/runpod_flash/cli/docs/flash-env.md b/src/runpod_flash/cli/docs/flash-env.md index 81ce7993..1688f673 100644 --- a/src/runpod_flash/cli/docs/flash-env.md +++ b/src/runpod_flash/cli/docs/flash-env.md @@ -380,7 +380,7 @@ production ### Deployment Workflow -1. **Develop locally**: Test with `flash run` or `flash deploy --preview` +1. **Develop locally**: Test with `flash dev` or `flash deploy --preview` 2. **Deploy to dev**: `flash deploy --env dev` for initial testing 3. **Deploy to staging**: `flash deploy --env staging` for QA validation 4. **Deploy to production**: `flash deploy --env production` after approval diff --git a/src/runpod_flash/cli/docs/flash-init.md b/src/runpod_flash/cli/docs/flash-init.md index 19c32f13..bbfbd86e 100644 --- a/src/runpod_flash/cli/docs/flash-init.md +++ b/src/runpod_flash/cli/docs/flash-init.md @@ -15,7 +15,7 @@ The `flash init` command scaffolds a new Flash project with everything you need **After initialization:** 1. Copy `.env.example` to `.env` and add your `RUNPOD_API_KEY` -2. Run `flash run` to start the local development server +2. Run `flash dev` to start the local development server 3. Customize the workers for your use case 4. Deploy with `flash deploy` when ready @@ -64,7 +64,7 @@ my-project/ cd my-project uv sync # or: pip install -r requirements.txt # Add RUNPOD_API_KEY to .env -flash run +flash dev ``` Visit http://localhost:8888/docs for interactive API documentation. diff --git a/src/runpod_flash/cli/docs/flash-logging.md b/src/runpod_flash/cli/docs/flash-logging.md index 0b586578..bae55ebe 100644 --- a/src/runpod_flash/cli/docs/flash-logging.md +++ b/src/runpod_flash/cli/docs/flash-logging.md @@ -8,12 +8,12 @@ Flash automatically logs all CLI operations to local files during development. T ### How it works -File-based logging is enabled by default in local development mode ([flash run](./flash-run.md)) and automatically disabled in deployed containers ([flash deploy](./flash-deploy.md)). +File-based logging is enabled by default in local development mode ([flash dev](./flash-dev.md)) and automatically disabled in deployed containers ([flash deploy](./flash-deploy.md)). When you run a `@remote` function, Flash logs the activity to a file: ``` -flash run +flash dev │ ├── Console output (what you see) └── .flash/logs/activity.log (persistent record) @@ -95,7 +95,7 @@ Custom directory for log files. ```bash # Use custom log directory export FLASH_LOG_DIR=/var/log/flash -flash run +flash dev ``` **Note:** The directory will be created automatically if it doesn't exist. @@ -118,7 +118,7 @@ Keep only recent logs during active development: ```bash export FLASH_LOG_RETENTION_DAYS=3 -flash run +flash dev ``` ### Custom Log Directory @@ -241,7 +241,7 @@ find .flash/logs -name "activity.log.*" -mtime +30 -delete ## Related Configuration - `LOG_LEVEL`: Controls console and file log verbosity (DEBUG, INFO, WARNING, ERROR) -- See [flash-run.md](./flash-run.md) for environment variable usage in local development +- See [flash-dev.md](./flash-dev.md) for environment variable usage in local development - See [flash-build.md](./flash-build.md) for build-time logging behavior ## Summary diff --git a/src/runpod_flash/cli/docs/flash-undeploy.md b/src/runpod_flash/cli/docs/flash-undeploy.md index 2c03a313..bc2cf132 100644 --- a/src/runpod_flash/cli/docs/flash-undeploy.md +++ b/src/runpod_flash/cli/docs/flash-undeploy.md @@ -4,7 +4,7 @@ Manage and delete Runpod serverless endpoints deployed via Flash. ## Overview -The `flash undeploy` command helps you clean up Serverless endpoints that Flash has created when you ran/deployed a `@remote` function using `flash run` or `flash deploy`. It manages endpoints recorded in `.runpod/resources.pkl` and ensures both the cloud resources and local tracking state stay in sync. +The `flash undeploy` command helps you clean up Serverless endpoints that Flash has created when you ran/deployed a `@remote` function using `flash dev` or `flash deploy`. It manages endpoints recorded in `.runpod/resources.pkl` and ensures both the cloud resources and local tracking state stay in sync. ### When To Use This Command @@ -23,7 +23,7 @@ For production deployments, use `flash env delete` to remove the entire environm ### How Endpoint Tracking Works Flash tracks deployed endpoints in `.runpod/resources.pkl`. Endpoints get added to this file when you: -- Run `flash run --auto-provision` (local development) +- Run `flash dev --auto-provision` (local development) - Run `flash deploy` (production deployment) ## Synopsis @@ -253,7 +253,7 @@ flash undeploy list ## Related Commands - `flash init` - Initialize new project -- `flash run` - Run development server +- `flash dev` - Run development server - `flash build` - Build deployment packages - `flash deploy` - Deploy to Runpod diff --git a/src/runpod_flash/cli/main.py b/src/runpod_flash/cli/main.py index 405a8f61..49d604c4 100644 --- a/src/runpod_flash/cli/main.py +++ b/src/runpod_flash/cli/main.py @@ -36,7 +36,8 @@ def get_version() -> str: # command: flash app.command("init")(init.init_command) -app.command("run")(run.run_command) +app.command("dev")(run.run_command) +app.command("run", hidden=True)(run.run_command) # legacy alias for flash dev app.command("build")(build.build_command) app.command("deploy")(deploy.deploy_command) # app.command("report")(resource.report_command) diff --git a/src/runpod_flash/cli/utils/skeleton_template/README.md b/src/runpod_flash/cli/utils/skeleton_template/README.md index 328a8ab3..5d549da3 100644 --- a/src/runpod_flash/cli/utils/skeleton_template/README.md +++ b/src/runpod_flash/cli/utils/skeleton_template/README.md @@ -16,7 +16,7 @@ Set up the project: uv venv && source .venv/bin/activate uv sync cp .env.example .env # Add your RUNPOD_API_KEY -flash run +flash dev ``` Or with pip: @@ -25,12 +25,12 @@ Or with pip: python -m venv .venv && source .venv/bin/activate pip install -r requirements.txt cp .env.example .env # Add your RUNPOD_API_KEY -flash run +flash dev ``` Server starts at **http://localhost:8888**. Visit **http://localhost:8888/docs** for interactive Swagger UI. -Use `flash run --auto-provision` to pre-deploy all endpoints on startup, eliminating cold-start delays on first request. Provisioned endpoints are cached and reused across restarts. +Use `flash dev --auto-provision` to pre-deploy all endpoints on startup, eliminating cold-start delays on first request. Provisioned endpoints are cached and reused across restarts. When you stop the server with Ctrl+C, all endpoints provisioned during the session are automatically cleaned up. @@ -131,7 +131,7 @@ async def health() -> dict: ## Adding New Workers -Create a new `.py` file with a `@remote` function. `flash run` auto-discovers all +Create a new `.py` file with a `@remote` function. `flash dev` auto-discovers all `@remote` functions in the project. ```python @@ -147,7 +147,7 @@ async def predict(input_data: dict) -> dict: return pipe(input_data["text"])[0] ``` -Then run `flash run` — the new worker appears automatically. +Then run `flash dev` — the new worker appears automatically. ## GPU Types diff --git a/src/runpod_flash/core/discovery.py b/src/runpod_flash/core/discovery.py index 06c5d57e..88fa0b86 100644 --- a/src/runpod_flash/core/discovery.py +++ b/src/runpod_flash/core/discovery.py @@ -1,4 +1,4 @@ -"""Resource discovery for auto-provisioning during flash run startup.""" +"""Resource discovery for auto-provisioning during flash dev startup.""" import ast import importlib.util diff --git a/src/runpod_flash/core/resources/live_serverless.py b/src/runpod_flash/core/resources/live_serverless.py index 8ae0b3a5..228e4673 100644 --- a/src/runpod_flash/core/resources/live_serverless.py +++ b/src/runpod_flash/core/resources/live_serverless.py @@ -74,7 +74,7 @@ class LiveLoadBalancer(LiveServerlessMixin, LoadBalancerSlsResource): Features: - Locks to Flash LB image (flash-lb) - Direct HTTP execution (not queue-based) - - Local development with flash run + - Local development with flash dev - Same @remote decorator pattern as LoadBalancerSlsResource Usage: @@ -92,7 +92,7 @@ async def process_data(x: int, y: int): Local Development Flow: 1. Create LiveLoadBalancer with routing 2. Decorate functions with @remote(lb_resource, method=..., path=...) - 3. Run with `flash run` to start local endpoint + 3. Run with `flash dev` to start local endpoint 4. Call functions directly in tests or scripts 5. Deploy to production with `flash build` and `flash deploy` @@ -124,7 +124,7 @@ class CpuLiveLoadBalancer(LiveServerlessMixin, CpuLoadBalancerSlsResource): - Locks to CPU Flash LB image (flash-lb-cpu) - CPU instance support with automatic disk sizing - Direct HTTP execution (not queue-based) - - Local development with flash run + - Local development with flash dev - Same @remote decorator pattern as CpuLoadBalancerSlsResource Usage: @@ -142,7 +142,7 @@ async def process_data(x: int, y: int): Local Development Flow: 1. Create CpuLiveLoadBalancer with routing 2. Decorate functions with @remote(lb_resource, method=..., path=...) - 3. Run with `flash run` to start local endpoint + 3. Run with `flash dev` to start local endpoint 4. Call functions directly in tests or scripts 5. Deploy to production with `flash build` and `flash deploy` """ diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index 869b85a8..50b0ce82 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -474,8 +474,8 @@ def is_deployed(self) -> bool: if not self.id: return False - # During flash run, skip the health check. Newly-created endpoints - # can fail health checks due to RunPod propagation delay — the + # During flash dev, skip the health check. Newly-created endpoints + # can fail health checks due to RunPod propagation delay -- the # endpoint exists but the health API hasn't registered it yet. # Trusting the cached ID is correct here; actual failures surface # on the first real run/run_sync call. @@ -496,7 +496,7 @@ def _payload_exclude(self) -> Set[str]: # When templateId is already set, exclude template from the payload. # RunPod rejects requests that contain both fields simultaneously. # Both can coexist after deploy mutates config (sets templateId while - # template remains from initialization) — templateId takes precedence. + # template remains from initialization) -- templateId takes precedence. if self.templateId: exclude_fields.add("template") return exclude_fields diff --git a/tests/conftest.py b/tests/conftest.py index 7641d521..85766c5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -283,7 +283,7 @@ def isolate_resource_state_file( def clear_live_provisioning_env(monkeypatch: pytest.MonkeyPatch): """Clear FLASH_IS_LIVE_PROVISIONING env var between tests. - This fixture ensures that the flag set by `flash run` command + This fixture ensures that the flag set by `flash dev` command doesn't leak into unit tests. It's autouse so it runs for all tests. Args: diff --git a/tests/unit/cli/commands/test_run.py b/tests/unit/cli/commands/test_run.py deleted file mode 100644 index c8b3311e..00000000 --- a/tests/unit/cli/commands/test_run.py +++ /dev/null @@ -1,589 +0,0 @@ -"""Tests for flash run dev server generation.""" - -import tempfile -from pathlib import Path - -from runpod_flash.cli.commands.run import ( - WorkerInfo, - _generate_flash_server, - _scan_project_workers, -) - - -def test_scan_separates_classes_from_functions(): - """Test that _scan_project_workers puts classes in class_remotes, not functions.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - worker_file = project_root / "gpu_worker.py" - worker_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="gpu_worker") - -@remote(config) -async def process(data): - return data - -@remote(config) -class SimpleSD: - def generate_image(self, prompt): - return {"image": "data"} - - def upscale(self, image): - return {"image": "upscaled"} -""" - ) - - workers = _scan_project_workers(project_root) - - assert len(workers) == 1 - worker = workers[0] - assert worker.worker_type == "QB" - assert worker.functions == ["process"] - assert len(worker.class_remotes) == 1 - assert worker.class_remotes[0]["name"] == "SimpleSD" - assert worker.class_remotes[0]["methods"] == ["generate_image", "upscale"] - - -def test_scan_class_only_worker(): - """Test scanning a file with only a class-based @remote.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - worker_file = project_root / "sd_worker.py" - worker_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="sd_worker") - -@remote(config) -class StableDiffusion: - def __init__(self): - self.model = None - - def generate(self, prompt): - return {"image": "data"} -""" - ) - - workers = _scan_project_workers(project_root) - - assert len(workers) == 1 - worker = workers[0] - assert worker.worker_type == "QB" - assert worker.functions == [] - assert len(worker.class_remotes) == 1 - assert worker.class_remotes[0]["name"] == "StableDiffusion" - assert worker.class_remotes[0]["methods"] == ["generate"] - - -def test_codegen_class_single_method(): - """Test generated server.py for a class with a single method uses short URL.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("sd_worker.py"), - url_prefix="/sd_worker", - module_path="sd_worker", - resource_name="sd_worker", - worker_type="QB", - functions=[], - class_remotes=[ - { - "name": "StableDiffusion", - "methods": ["generate"], - "method_params": {"generate": ["prompt"]}, - }, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - assert "_instance_StableDiffusion = StableDiffusion()" in content - assert "_call_with_body(_instance_StableDiffusion.generate, body)" in content - assert "body: _sd_worker_StableDiffusion_generate_Input" in content - assert "_make_input_model" in content - assert '"/sd_worker/run_sync"' in content - # Single method: no method name in URL - assert '"/sd_worker/generate/run_sync"' not in content - - -def test_codegen_class_multiple_methods(): - """Test generated server.py for a class with multiple methods uses method URLs.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("gpu_worker.py"), - url_prefix="/gpu_worker", - module_path="gpu_worker", - resource_name="gpu_worker", - worker_type="QB", - functions=[], - class_remotes=[ - { - "name": "SimpleSD", - "methods": ["generate_image", "upscale"], - "method_params": { - "generate_image": ["prompt"], - "upscale": ["image"], - }, - }, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - assert "_instance_SimpleSD = SimpleSD()" in content - assert '"/gpu_worker/generate_image/run_sync"' in content - assert '"/gpu_worker/upscale/run_sync"' in content - assert "_call_with_body(_instance_SimpleSD.generate_image, body)" in content - assert "_call_with_body(_instance_SimpleSD.upscale, body)" in content - assert "body: _gpu_worker_SimpleSD_generate_image_Input" in content - assert "body: _gpu_worker_SimpleSD_upscale_Input" in content - - -def test_codegen_mixed_function_and_class(): - """Test codegen when a worker has both functions and class remotes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="QB", - functions=["process"], - class_remotes=[ - { - "name": "MyModel", - "methods": ["predict"], - "method_params": {"predict": ["data"]}, - }, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - # Both should use multi-callable URL pattern (total_callables = 2) - assert '"/worker/process/run_sync"' in content - assert '"/worker/predict/run_sync"' in content - assert "_instance_MyModel = MyModel()" in content - assert "_call_with_body(_instance_MyModel.predict, body)" in content - assert "_call_with_body(process, body)" in content - - -def test_codegen_function_only(): - """Test that function-only workers use Pydantic model and _call_with_body.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("simple.py"), - url_prefix="/simple", - module_path="simple", - resource_name="simple", - worker_type="QB", - functions=["process"], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - # Single function: short URL - assert '"/simple/run_sync"' in content - assert "_call_with_body(process, body)" in content - assert "_simple_process_Input = _make_input_model(" in content - assert "body: _simple_process_Input" in content - # No instance creation - assert "_instance_" not in content - - -def test_codegen_zero_param_function(): - """Test generated code uses await fn() for zero-param functions.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="QB", - functions=["list_images"], - function_params={"list_images": []}, - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - assert "await list_images()" in content - assert 'body.get("input"' not in content - # Handler should not accept body parameter - assert "async def worker_run_sync():" in content - - -def test_codegen_multi_param_function(): - """Test generated code uses _call_with_body for multi-param functions.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="QB", - functions=["transform"], - function_params={"transform": ["text", "operation"]}, - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - assert "_call_with_body(transform, body)" in content - assert "_worker_transform_Input = _make_input_model(" in content - assert "body: _worker_transform_Input" in content - - -def test_codegen_single_param_function(): - """Test generated code uses _call_with_body for single-param functions.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="QB", - functions=["process"], - function_params={"process": ["data"]}, - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - assert "_call_with_body(process, body)" in content - assert "body: _worker_process_Input" in content - - -def test_codegen_zero_param_class_method(): - """Test generated code uses await instance.method() for zero-param class methods.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="QB", - functions=[], - class_remotes=[ - { - "name": "ImageProcessor", - "methods": ["list_models"], - "method_params": {"list_models": []}, - }, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - assert "await _instance_ImageProcessor.list_models()" in content - # Handler should not accept body parameter - assert "worker_ImageProcessor_run_sync():" in content - - -def test_codegen_multi_param_class_method(): - """Test generated code uses _call_with_body for multi-param class methods.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="QB", - functions=[], - class_remotes=[ - { - "name": "ImageProcessor", - "methods": ["generate"], - "method_params": {"generate": ["prompt", "width"]}, - }, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - assert "_call_with_body(_instance_ImageProcessor.generate, body)" in content - assert "body: _worker_ImageProcessor_generate_Input" in content - # Model creation uses _class_type to get original method signature - assert "_class_type" in content - - -def test_codegen_backward_compat_no_method_params(): - """Test that missing method_params in class_remotes uses _call_with_body.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="QB", - functions=[], - class_remotes=[ - {"name": "OldStyle", "methods": ["process"]}, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - # Should use _call_with_body when method_params not provided (params=None) - assert "_call_with_body(_instance_OldStyle.process, body)" in content - assert "body: _worker_OldStyle_process_Input" in content - - -def test_scan_populates_function_params(): - """Test that _scan_project_workers populates function_params from scanner.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - worker_file = project_root / "worker.py" - worker_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -async def no_params() -> dict: - return {} - -@remote(config) -async def one_param(data: dict) -> dict: - return data - -@remote(config) -async def multi_params(text: str, mode: str = "default") -> dict: - return {"text": text} -""" - ) - - workers = _scan_project_workers(project_root) - - assert len(workers) == 1 - worker = workers[0] - assert worker.function_params == { - "no_params": [], - "one_param": ["data"], - "multi_params": ["text", "mode"], - } - - -def test_scan_populates_class_method_params(): - """Test that _scan_project_workers populates method_params in class_remotes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - worker_file = project_root / "worker.py" - worker_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -class Processor: - def run(self, data: dict): - return data - - def status(self): - return {"ok": True} -""" - ) - - workers = _scan_project_workers(project_root) - - assert len(workers) == 1 - worker = workers[0] - assert len(worker.class_remotes) == 1 - cls = worker.class_remotes[0] - assert cls["method_params"] == { - "run": ["data"], - "status": [], - } - - -def test_codegen_lb_get_with_path_params(): - """Test LB GET route with path params generates proper Swagger-compatible handler.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="LB", - functions=["get_image"], - lb_routes=[ - { - "method": "GET", - "path": "/images/{file_id}", - "fn_name": "get_image", - "config_variable": "cpu_config", - }, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - # Handler must declare file_id as a typed parameter for Swagger - assert "file_id: str" in content - # Path param must be forwarded in the body dict - assert '"file_id": file_id' in content - # Should NOT use bare request: Request as only param - assert ( - "async def _route_worker_get_image(file_id: str, request: Request):" - in content - ) - - -def test_codegen_lb_get_without_path_params(): - """Test LB GET route without path params uses request: Request.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="LB", - functions=["health"], - lb_routes=[ - { - "method": "GET", - "path": "/health", - "fn_name": "health", - "config_variable": "cpu_config", - }, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - assert "async def _route_worker_health(request: Request):" in content - assert "dict(request.query_params)" in content - - -def test_codegen_lb_post_with_path_params(): - """Test LB POST route with path params includes both body and path params.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="LB", - functions=["update_item"], - lb_routes=[ - { - "method": "POST", - "path": "/items/{item_id}", - "fn_name": "update_item", - "config_variable": "api_config", - }, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - # POST handler must have typed body and path param - assert ( - "async def _route_worker_update_item(body: _worker_update_item_Input, item_id: str):" - in content - ) - assert '"item_id": item_id' in content - assert "_to_dict(body)" in content - - -def test_codegen_lb_get_with_multiple_path_params(): - """Test LB GET route with multiple path params.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - workers = [ - WorkerInfo( - file_path=Path("worker.py"), - url_prefix="/worker", - module_path="worker", - resource_name="worker", - worker_type="LB", - functions=["get_version"], - lb_routes=[ - { - "method": "GET", - "path": "/items/{item_id}/versions/{version_id}", - "fn_name": "get_version", - "config_variable": "api_config", - }, - ], - ), - ] - - server_path = _generate_flash_server(project_root, workers) - content = server_path.read_text() - - assert "item_id: str" in content - assert "version_id: str" in content - assert '"item_id": item_id' in content - assert '"version_id": version_id' in content diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index 73bee3bd..ec930082 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -1,30 +1,32 @@ -"""Unit tests for run CLI command.""" +"""Unit tests for run CLI command and programmatic dev server.""" + +import os +import sys +from unittest.mock import MagicMock, patch import pytest -from pathlib import Path -from unittest.mock import patch, MagicMock +from fastapi import FastAPI +from fastapi.testclient import TestClient from typer.testing import CliRunner -from runpod_flash.cli.main import app -from runpod_flash.cli.commands.run import ( - WorkerInfo, - _generate_flash_server, - _has_numeric_module_segments, - _make_import_line, - _module_parent_subdir, - _sanitize_fn_name, +from runpod_flash.cli.commands._dev_server import ( + _import_from_module, + _register_lb_routes, + _register_qb_routes, + create_app, ) +from runpod_flash.cli.commands.run import WorkerInfo +from runpod_flash.cli.main import app @pytest.fixture def runner(): - """Create CLI test runner.""" return CliRunner() @pytest.fixture -def temp_fastapi_app(tmp_path): - """Create minimal Flash project with @remote function for testing.""" +def temp_project(tmp_path): + """Create a minimal Flash project with a @remote function.""" worker_file = tmp_path / "worker.py" worker_file.write_text( "from runpod_flash import LiveServerless, remote\n" @@ -36,771 +38,512 @@ def temp_fastapi_app(tmp_path): return tmp_path -class TestRunCommandEnvironmentVariables: - """Test flash run command environment variable support.""" - - @pytest.fixture(autouse=True) - def patch_watcher(self): - """Prevent the background watcher thread from blocking tests.""" - with patch("runpod_flash.cli.commands.run._watch_and_regenerate"): - yield - - def test_port_from_environment_variable( - self, runner, temp_fastapi_app, monkeypatch - ): - """Test that FLASH_PORT environment variable is respected.""" - monkeypatch.chdir(temp_fastapi_app) - monkeypatch.setenv("FLASH_PORT", "8080") - - # Mock subprocess to capture command and prevent actual server start - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level process group operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - with patch("runpod_flash.cli.commands.run.os.killpg"): - runner.invoke(app, ["run"]) - - # Verify port 8080 was used in uvicorn command - call_args = mock_popen.call_args[0][0] - assert "--port" in call_args - port_index = call_args.index("--port") - assert call_args[port_index + 1] == "8080" - - def test_host_from_environment_variable( - self, runner, temp_fastapi_app, monkeypatch - ): - """Test that FLASH_HOST environment variable is respected.""" - monkeypatch.chdir(temp_fastapi_app) - monkeypatch.setenv("FLASH_HOST", "0.0.0.0") - - # Mock subprocess to capture command - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - with patch("runpod_flash.cli.commands.run.os.killpg"): - runner.invoke(app, ["run"]) - - # Verify host 0.0.0.0 was used - call_args = mock_popen.call_args[0][0] - assert "--host" in call_args - host_index = call_args.index("--host") - assert call_args[host_index + 1] == "0.0.0.0" - - def test_cli_flag_overrides_environment_variable( - self, runner, temp_fastapi_app, monkeypatch - ): - """Test that --port flag overrides FLASH_PORT environment variable.""" - monkeypatch.chdir(temp_fastapi_app) - monkeypatch.setenv("FLASH_PORT", "8080") - - # Mock subprocess to capture command - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Use --port flag to override env var - runner.invoke(app, ["run", "--port", "9000"]) - - # Verify port 9000 was used (flag overrides env) - call_args = mock_popen.call_args[0][0] - assert "--port" in call_args - port_index = call_args.index("--port") - assert call_args[port_index + 1] == "9000" - - def test_default_port_when_no_env_or_flag( - self, runner, temp_fastapi_app, monkeypatch - ): - """Test that default port 8888 is used when no env var or flag.""" - monkeypatch.chdir(temp_fastapi_app) - # Ensure FLASH_PORT is not set - monkeypatch.delenv("FLASH_PORT", raising=False) - - # Mock subprocess to capture command - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - with patch("runpod_flash.cli.commands.run.os.killpg"): - runner.invoke(app, ["run"]) - - # Verify default port 8888 was used - call_args = mock_popen.call_args[0][0] - assert "--port" in call_args - port_index = call_args.index("--port") - assert call_args[port_index + 1] == "8888" - - def test_default_host_when_no_env_or_flag( - self, runner, temp_fastapi_app, monkeypatch - ): - """Test that default host localhost is used when no env var or flag.""" - monkeypatch.chdir(temp_fastapi_app) - # Ensure FLASH_HOST is not set - monkeypatch.delenv("FLASH_HOST", raising=False) - - # Mock subprocess to capture command - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - with patch("runpod_flash.cli.commands.run.os.killpg"): - runner.invoke(app, ["run"]) - - # Verify default host localhost was used - call_args = mock_popen.call_args[0][0] - assert "--host" in call_args - host_index = call_args.index("--host") - assert call_args[host_index + 1] == "localhost" - - def test_both_host_and_port_from_environment( - self, runner, temp_fastapi_app, monkeypatch - ): - """Test that both FLASH_HOST and FLASH_PORT environment variables work together.""" - monkeypatch.chdir(temp_fastapi_app) - monkeypatch.setenv("FLASH_HOST", "0.0.0.0") - monkeypatch.setenv("FLASH_PORT", "3000") - - # Mock subprocess to capture command - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - with patch("runpod_flash.cli.commands.run.os.killpg"): - runner.invoke(app, ["run"]) - - # Verify both host and port were used - call_args = mock_popen.call_args[0][0] - - assert "--host" in call_args - host_index = call_args.index("--host") - assert call_args[host_index + 1] == "0.0.0.0" - - assert "--port" in call_args - port_index = call_args.index("--port") - assert call_args[port_index + 1] == "3000" - - def test_short_port_flag_overrides_environment( - self, runner, temp_fastapi_app, monkeypatch - ): - """Test that -p short flag also overrides FLASH_PORT environment variable.""" - monkeypatch.chdir(temp_fastapi_app) - monkeypatch.setenv("FLASH_PORT", "8080") - - # Mock subprocess to capture command - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - # Mock OS-level operations - with patch("runpod_flash.cli.commands.run.os.getpgid") as mock_getpgid: - mock_getpgid.return_value = 12345 - with patch("runpod_flash.cli.commands.run.os.killpg"): - # Use -p short flag - runner.invoke(app, ["run", "-p", "7000"]) - - # Verify port 7000 was used (short flag overrides env) - call_args = mock_popen.call_args[0][0] - assert "--port" in call_args - port_index = call_args.index("--port") - assert call_args[port_index + 1] == "7000" - - -class TestRunCommandHotReload: - """Test flash run hot-reload behavior.""" - - @pytest.fixture(autouse=True) - def patch_watcher(self): - """Prevent the background watcher thread from blocking tests.""" - with patch("runpod_flash.cli.commands.run._watch_and_regenerate"): - yield - - def _invoke_run(self, runner, monkeypatch, temp_fastapi_app, extra_args=None): - """Helper: invoke flash run and return the Popen call args.""" - monkeypatch.chdir(temp_fastapi_app) - monkeypatch.delenv("FLASH_PORT", raising=False) - monkeypatch.delenv("FLASH_HOST", raising=False) - - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process +def _run_cli(runner, project_dir, extra_args=None): + """Invoke ``flash dev`` with subprocess mocked and return the Popen command.""" + saved_env = { + k: os.environ.get(k) + for k in ("FLASH_PROJECT_ROOT", "FLASH_IS_LIVE_PROVISIONING") + } + with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + mock_process = MagicMock() + mock_process.pid = 12345 + mock_process.wait.side_effect = KeyboardInterrupt() + mock_popen.return_value = mock_process + # bypass port probing so tests get the exact port they requested + with patch( + "runpod_flash.cli.commands.run._find_available_port", + side_effect=lambda host, port: port, + ): with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): with patch("runpod_flash.cli.commands.run.os.killpg"): - runner.invoke(app, ["run"] + (extra_args or [])) - - return mock_popen.call_args[0][0] - - def test_reload_watches_flash_server_py( - self, runner, temp_fastapi_app, monkeypatch - ): - """Uvicorn watches .flash/server.py, not the whole project.""" - cmd = self._invoke_run(runner, monkeypatch, temp_fastapi_app) - + old_cwd = os.getcwd() + try: + os.chdir(project_dir) + runner.invoke(app, ["dev"] + (extra_args or [])) + finally: + os.chdir(old_cwd) + for k, v in saved_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + return mock_popen.call_args[0][0] + + +# --------------------------------------------------------------------------- +# CLI: uvicorn command construction +# --------------------------------------------------------------------------- + + +class TestRunCommandFlags: + """Test that run_command builds the correct uvicorn command.""" + + def test_uses_factory_flag(self, runner, temp_project): + cmd = _run_cli(runner, temp_project) + assert "--factory" in cmd + idx = cmd.index("--factory") + assert cmd[idx + 1] == "runpod_flash.cli.commands._dev_server:create_app" + + def test_no_flash_dir_created(self, runner, temp_project): + _run_cli(runner, temp_project) + assert not (temp_project / ".flash").exists() + + def test_default_host_and_port(self, runner, temp_project): + cmd = _run_cli(runner, temp_project) + assert cmd[cmd.index("--host") + 1] == "localhost" + assert cmd[cmd.index("--port") + 1] == "8888" + + def test_custom_port_flag(self, runner, temp_project): + cmd = _run_cli(runner, temp_project, ["--port", "9000"]) + assert cmd[cmd.index("--port") + 1] == "9000" + + def test_custom_host_flag(self, runner, temp_project): + cmd = _run_cli(runner, temp_project, ["--host", "0.0.0.0"]) + assert cmd[cmd.index("--host") + 1] == "0.0.0.0" + + def test_short_port_flag(self, runner, temp_project): + cmd = _run_cli(runner, temp_project, ["-p", "7000"]) + assert cmd[cmd.index("--port") + 1] == "7000" + + def test_reload_watches_project_root(self, runner, temp_project): + cmd = _run_cli(runner, temp_project) assert "--reload" in cmd - assert "--reload-dir" in cmd - reload_dir_index = cmd.index("--reload-dir") - assert cmd[reload_dir_index + 1] == ".flash" - - assert "--reload-include" in cmd - reload_include_index = cmd.index("--reload-include") - assert cmd[reload_include_index + 1] == "server.py" - - def test_reload_does_not_watch_project_root( - self, runner, temp_fastapi_app, monkeypatch - ): - """Uvicorn reload-dir must not be '.' to prevent double-reload.""" - cmd = self._invoke_run(runner, monkeypatch, temp_fastapi_app) - - reload_dir_index = cmd.index("--reload-dir") - assert cmd[reload_dir_index + 1] != "." - - def test_no_reload_skips_watcher_thread( - self, runner, temp_fastapi_app, monkeypatch - ): - """--no-reload: neither uvicorn reload args nor watcher thread started.""" - monkeypatch.chdir(temp_fastapi_app) - - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): - with patch("runpod_flash.cli.commands.run.os.killpg"): - with patch( - "runpod_flash.cli.commands.run.threading.Thread" - ) as mock_thread_cls: - mock_thread = MagicMock() - mock_thread_cls.return_value = mock_thread + idx = cmd.index("--reload-dir") + assert cmd[idx + 1] == str(temp_project) - runner.invoke(app, ["run", "--no-reload"]) + def test_no_reload_flag(self, runner, temp_project): + cmd = _run_cli(runner, temp_project, ["--no-reload"]) + assert "--reload" not in cmd + assert "--reload-dir" not in cmd - cmd = mock_popen.call_args[0][0] - assert "--reload" not in cmd - mock_thread.start.assert_not_called() + def test_sets_project_root_env_var(self, runner, temp_project): + """FLASH_PROJECT_ROOT is set when Popen is called (inherited by child).""" + captured_env = {} - def test_watcher_thread_started_on_reload( - self, runner, temp_fastapi_app, monkeypatch, patch_watcher - ): - """When reload=True, the background watcher thread is started.""" - monkeypatch.chdir(temp_fastapi_app) - - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process - - with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): - with patch("runpod_flash.cli.commands.run.os.killpg"): - with patch( - "runpod_flash.cli.commands.run.threading.Thread" - ) as mock_thread_cls: - mock_thread = MagicMock() - mock_thread_cls.return_value = mock_thread - - runner.invoke(app, ["run"]) - - mock_thread.start.assert_called_once() - - def test_watcher_thread_stopped_on_keyboard_interrupt( - self, runner, temp_fastapi_app, monkeypatch - ): - """KeyboardInterrupt sets stop_event and joins the watcher thread.""" - monkeypatch.chdir(temp_fastapi_app) - - with patch("runpod_flash.cli.commands.run.subprocess.Popen") as mock_popen: + def capture_popen(cmd, **kwargs): + captured_env["FLASH_PROJECT_ROOT"] = os.environ.get("FLASH_PROJECT_ROOT") mock_process = MagicMock() mock_process.pid = 12345 mock_process.wait.side_effect = KeyboardInterrupt() - mock_popen.return_value = mock_process + return mock_process + with patch( + "runpod_flash.cli.commands.run.subprocess.Popen", side_effect=capture_popen + ): with patch("runpod_flash.cli.commands.run.os.getpgid", return_value=12345): with patch("runpod_flash.cli.commands.run.os.killpg"): - with patch( - "runpod_flash.cli.commands.run.threading.Thread" - ) as mock_thread_cls: - mock_thread = MagicMock() - mock_thread_cls.return_value = mock_thread - with patch( - "runpod_flash.cli.commands.run.threading.Event" - ) as mock_event_cls: - mock_stop = MagicMock() - mock_event_cls.return_value = mock_stop - - runner.invoke(app, ["run"]) + old_cwd = os.getcwd() + try: + os.chdir(temp_project) + runner.invoke(app, ["dev"]) + finally: + os.chdir(old_cwd) + os.environ.pop("FLASH_PROJECT_ROOT", None) + os.environ.pop("FLASH_IS_LIVE_PROVISIONING", None) - mock_stop.set.assert_called_once() - mock_thread.join.assert_called_once_with(timeout=2) + assert captured_env["FLASH_PROJECT_ROOT"] == str(temp_project) -class TestWatchAndRegenerate: - """Unit tests for the _watch_and_regenerate background function.""" +# --------------------------------------------------------------------------- +# create_app factory +# --------------------------------------------------------------------------- - def test_regenerates_server_py_on_py_file_change(self, tmp_path): - """When a .py file changes, server.py is regenerated.""" - import threading - from runpod_flash.cli.commands.run import _watch_and_regenerate - stop = threading.Event() - - with patch( - "runpod_flash.cli.commands.run._scan_project_workers", return_value=[] - ) as mock_scan: - with patch( - "runpod_flash.cli.commands.run._generate_flash_server" - ) as mock_gen: - with patch( - "runpod_flash.cli.commands.run._watchfiles_watch" - ) as mock_watch: - # Yield one batch of changes then stop - mock_watch.return_value = iter([{(1, "/path/to/worker.py")}]) - stop.set() # ensures the loop exits after one iteration - _watch_and_regenerate(tmp_path, stop) - - mock_scan.assert_called_once_with(tmp_path) - mock_gen.assert_called_once() - - def test_ignores_non_py_changes(self, tmp_path): - """Changes to non-.py files do not trigger regeneration.""" - import threading - from runpod_flash.cli.commands.run import _watch_and_regenerate - - stop = threading.Event() - - with patch("runpod_flash.cli.commands.run._scan_project_workers") as mock_scan: - with patch( - "runpod_flash.cli.commands.run._generate_flash_server" - ) as mock_gen: - with patch( - "runpod_flash.cli.commands.run._watchfiles_watch" - ) as mock_watch: - mock_watch.return_value = iter([{(1, "/path/to/README.md")}]) - _watch_and_regenerate(tmp_path, stop) - - mock_scan.assert_not_called() - mock_gen.assert_not_called() - - def test_scan_error_does_not_crash_watcher(self, tmp_path): - """If regeneration raises, the watcher logs a warning and continues.""" - import threading - from runpod_flash.cli.commands.run import _watch_and_regenerate - - stop = threading.Event() +class TestCreateApp: + """Test the programmatic create_app factory.""" - with patch( - "runpod_flash.cli.commands.run._scan_project_workers", - side_effect=RuntimeError("scan failed"), - ): - with patch("runpod_flash.cli.commands.run._watchfiles_watch") as mock_watch: - mock_watch.return_value = iter([{(1, "/path/to/worker.py")}]) - # Should not raise - _watch_and_regenerate(tmp_path, stop) + def test_returns_fastapi_instance(self, tmp_path): + result = create_app(project_root=tmp_path, workers=[]) + assert isinstance(result, FastAPI) + def test_health_endpoints(self, tmp_path): + test_app = create_app(project_root=tmp_path, workers=[]) + client = TestClient(test_app) -class TestGenerateFlashServer: - """Test _generate_flash_server() route code generation.""" + resp = client.get("/") + assert resp.status_code == 200 + assert resp.json()["docs"] == "/docs" - def _make_lb_worker(self, tmp_path: Path, method: str = "GET") -> WorkerInfo: - return WorkerInfo( - file_path=tmp_path / "api.py", - url_prefix="/api", - module_path="api", - resource_name="api", - worker_type="LB", - functions=["list_routes"], - lb_routes=[ - { - "method": method, - "path": "/routes/list", - "fn_name": "list_routes", - "config_variable": "api_config", - } - ], - ) + resp = client.get("/ping") + assert resp.status_code == 200 + assert resp.json()["status"] == "healthy" - def test_post_lb_route_generates_body_param(self, tmp_path): - """POST/PUT/PATCH/DELETE LB routes use typed body for OpenAPI docs.""" - for method in ("POST", "PUT", "PATCH", "DELETE"): - worker = self._make_lb_worker(tmp_path, method) - content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "body: _api_list_routes_Input" in content - assert "_lb_execute(api_config, list_routes, _to_dict(body))" in content - - def test_get_lb_route_uses_query_params(self, tmp_path): - """GET LB routes pass query params as a dict.""" - worker = self._make_lb_worker(tmp_path, "GET") - content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "async def _route_api_list_routes(request: Request):" in content - assert ( - "_lb_execute(api_config, list_routes, dict(request.query_params))" - in content - ) + def test_registers_qb_worker_routes(self, tmp_path): + mod = tmp_path / "worker.py" + mod.write_text("async def process(data):\n return {'echo': data}\n") - def test_lb_config_var_and_function_imported(self, tmp_path): - """LB config vars and functions are both imported for remote dispatch.""" - worker = self._make_lb_worker(tmp_path) - content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "from api import api_config" in content - assert "from api import list_routes" in content - - def test_lb_execute_import_present_when_lb_routes_exist(self, tmp_path): - """server.py imports _lb_execute when there are LB workers.""" - worker = self._make_lb_worker(tmp_path) - content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "_lb_execute" in content - assert "lb_execute" in content - - def test_qb_function_still_imported_directly(self, tmp_path): - """QB workers still import and call functions directly.""" worker = WorkerInfo( - file_path=tmp_path / "worker.py", + file_path=mod, url_prefix="/worker", module_path="worker", resource_name="worker", worker_type="QB", functions=["process"], ) - content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "from worker import process" in content - assert "_call_with_body(process, body)" in content - - -class TestSanitizeFnName: - """Test _sanitize_fn_name handles leading-digit identifiers.""" - - def test_normal_name_unchanged(self): - assert _sanitize_fn_name("worker_run_sync") == "worker_run_sync" - - def test_leading_digit_gets_underscore_prefix(self): - assert _sanitize_fn_name("01_hello_run_sync") == "_01_hello_run_sync" - - def test_slashes_replaced(self): - assert _sanitize_fn_name("a/b/c") == "a_b_c" - - def test_dots_and_hyphens_replaced(self): - assert _sanitize_fn_name("a.b-c") == "a_b_c" - - def test_numeric_after_slash(self): - assert _sanitize_fn_name("01_foo/02_bar") == "_01_foo_02_bar" - - -class TestHasNumericModuleSegments: - """Test _has_numeric_module_segments detects digit-prefixed segments.""" - - def test_normal_module_path(self): - assert _has_numeric_module_segments("worker") is False - - def test_dotted_normal(self): - assert _has_numeric_module_segments("longruns.stage1") is False - - def test_leading_digit_first_segment(self): - assert _has_numeric_module_segments("01_hello.worker") is True - - def test_leading_digit_nested_segment(self): - assert _has_numeric_module_segments("getting_started.01_hello.worker") is True - - def test_digit_in_middle_not_leading(self): - assert _has_numeric_module_segments("stage1.worker") is False - - -class TestModuleParentSubdir: - """Test _module_parent_subdir extracts parent directory from dotted path.""" - - def test_top_level_returns_none(self): - assert _module_parent_subdir("worker") is None - - def test_single_parent(self): - assert _module_parent_subdir("01_hello.gpu_worker") == "01_hello" - - def test_nested_parent(self): - assert ( - _module_parent_subdir("01_getting_started.03_mixed.pipeline") - == "01_getting_started/03_mixed" - ) - - -class TestMakeImportLine: - """Test _make_import_line generates correct import syntax.""" - - def test_normal_module_uses_from_import(self): - result = _make_import_line("worker", "process") - assert result == "from worker import process" - - def test_numeric_module_uses_flash_import(self): - result = _make_import_line("01_hello.gpu_worker", "gpu_hello") - assert ( - result - == 'gpu_hello = _flash_import("01_hello.gpu_worker", "gpu_hello", "01_hello")' - ) - - def test_nested_numeric_includes_full_subdir(self): - result = _make_import_line( - "01_getting_started.01_hello.gpu_worker", "gpu_hello" - ) - assert '"01_getting_started/01_hello"' in result + sys.path.insert(0, str(tmp_path)) + try: + test_app = create_app(project_root=tmp_path, workers=[worker]) + client = TestClient(test_app) + resp = client.post("/worker/run_sync", json={"input": "hello"}) + assert resp.status_code == 200 + assert resp.json()["output"] == {"echo": "hello"} + finally: + sys.path.remove(str(tmp_path)) + sys.modules.pop("worker", None) - def test_top_level_numeric_module_no_subdir(self): - result = _make_import_line("01_worker", "process") - assert result == 'process = _flash_import("01_worker", "process")' +# --------------------------------------------------------------------------- +# QB routes +# --------------------------------------------------------------------------- -class TestGenerateFlashServerNumericDirs: - """Test _generate_flash_server with numeric-prefixed directory names.""" - def test_qb_numeric_dir_uses_flash_import(self, tmp_path): - """QB workers in numeric dirs use _flash_import with scoped sys.path.""" - worker = WorkerInfo( - file_path=tmp_path / "01_hello" / "gpu_worker.py", - url_prefix="/01_hello/gpu_worker", - module_path="01_hello.gpu_worker", - resource_name="01_hello_gpu_worker", - worker_type="QB", - functions=["gpu_hello"], - ) - content = _generate_flash_server(tmp_path, [worker]).read_text() - - # Must NOT contain invalid 'from 01_hello...' import - assert "from 01_hello" not in content - # Must have _flash_import helper and importlib - assert "import importlib as _importlib" in content - assert "def _flash_import(" in content - assert ( - '_flash_import("01_hello.gpu_worker", "gpu_hello", "01_hello")' in content - ) +class TestRegisterQBRoutes: + """Test QB route registration and invocation.""" - def test_qb_numeric_dir_function_name_prefixed(self, tmp_path): - """QB handler function names starting with digits get '_' prefix.""" - worker = WorkerInfo( - file_path=tmp_path / "01_hello" / "gpu_worker.py", - url_prefix="/01_hello/gpu_worker", - module_path="01_hello.gpu_worker", - resource_name="01_hello_gpu_worker", - worker_type="QB", - functions=["gpu_hello"], - ) - content = _generate_flash_server(tmp_path, [worker]).read_text() + def test_single_function_run_sync(self, tmp_path): + mod = tmp_path / "worker.py" + mod.write_text("async def process(data):\n return {'echo': data}\n") - # Function name must start with '_', not a digit - assert ( - "async def _01_hello_gpu_worker_run_sync(body: _01_hello_gpu_worker_gpu_hello_Input):" - in content - ) - - def test_lb_numeric_dir_uses_flash_import(self, tmp_path): - """LB workers in numeric dirs use _flash_import for config and function imports.""" worker = WorkerInfo( - file_path=tmp_path / "03_advanced" / "05_lb" / "cpu_lb.py", - url_prefix="/03_advanced/05_lb/cpu_lb", - module_path="03_advanced.05_lb.cpu_lb", - resource_name="03_advanced_05_lb_cpu_lb", - worker_type="LB", - functions=["validate_data"], - lb_routes=[ - { - "method": "POST", - "path": "/validate", - "fn_name": "validate_data", - "config_variable": "cpu_config", - } - ], - ) - content = _generate_flash_server(tmp_path, [worker]).read_text() - - assert "from 03_advanced" not in content - assert ( - '_flash_import("03_advanced.05_lb.cpu_lb", "cpu_config", "03_advanced/05_lb")' - in content - ) - assert ( - '_flash_import("03_advanced.05_lb.cpu_lb", "validate_data", "03_advanced/05_lb")' - in content - ) - - def test_mixed_numeric_and_normal_dirs(self, tmp_path): - """Normal modules use 'from' imports, numeric modules use _flash_import.""" - normal_worker = WorkerInfo( - file_path=tmp_path / "worker.py", + file_path=mod, url_prefix="/worker", module_path="worker", resource_name="worker", worker_type="QB", functions=["process"], ) - numeric_worker = WorkerInfo( - file_path=tmp_path / "01_hello" / "gpu_worker.py", - url_prefix="/01_hello/gpu_worker", - module_path="01_hello.gpu_worker", - resource_name="01_hello_gpu_worker", - worker_type="QB", - functions=["gpu_hello"], - ) - content = _generate_flash_server( - tmp_path, [normal_worker, numeric_worker] - ).read_text() - - # Normal worker uses standard import - assert "from worker import process" in content - # Numeric worker uses scoped _flash_import - assert ( - '_flash_import("01_hello.gpu_worker", "gpu_hello", "01_hello")' in content + sys.path.insert(0, str(tmp_path)) + try: + test_app = FastAPI() + _register_qb_routes(test_app, worker, tmp_path, "test [QB]") + client = TestClient(test_app) + resp = client.post("/worker/run_sync", json={"input": {"k": "v"}}) + body = resp.json() + assert resp.status_code == 200 + assert body["status"] == "COMPLETED" + assert body["output"] == {"echo": {"k": "v"}} + assert "id" in body + finally: + sys.path.remove(str(tmp_path)) + sys.modules.pop("worker", None) + + def test_multi_function_routes(self, tmp_path): + mod = tmp_path / "multi.py" + mod.write_text( + "async def alpha(d):\n return 'a'\nasync def beta(d):\n return 'b'\n" ) - - def test_no_importlib_when_all_normal_dirs(self, tmp_path): - """importlib and _flash_import are not emitted when no numeric dirs exist.""" worker = WorkerInfo( - file_path=tmp_path / "worker.py", - url_prefix="/worker", - module_path="worker", - resource_name="worker", + file_path=mod, + url_prefix="/multi", + module_path="multi", + resource_name="multi", worker_type="QB", - functions=["process"], + functions=["alpha", "beta"], + ) + sys.path.insert(0, str(tmp_path)) + try: + test_app = FastAPI() + _register_qb_routes(test_app, worker, tmp_path, "test [QB]") + client = TestClient(test_app) + assert ( + client.post("/multi/alpha/run_sync", json={"input": {}}).json()[ + "output" + ] + == "a" + ) + assert ( + client.post("/multi/beta/run_sync", json={"input": {}}).json()["output"] + == "b" + ) + finally: + sys.path.remove(str(tmp_path)) + sys.modules.pop("multi", None) + + +# --------------------------------------------------------------------------- +# LB routes +# --------------------------------------------------------------------------- + + +class TestRegisterLBRoutes: + """Test LB route registration using an injected executor.""" + + def _write_lb_module(self, tmp_path, name, config_var, fn_name): + mod = tmp_path / f"{name}.py" + mod.write_text( + f"{config_var} = 'fake_config'\nasync def {fn_name}(d):\n return d\n" ) - content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "importlib" not in content - assert "_flash_import" not in content - def test_scoped_import_includes_subdir(self, tmp_path): - """_flash_import calls pass the subdirectory for sibling import scoping.""" - worker = WorkerInfo( - file_path=tmp_path / "01_getting_started" / "03_mixed" / "pipeline.py", - url_prefix="/01_getting_started/03_mixed/pipeline", - module_path="01_getting_started.03_mixed.pipeline", - resource_name="01_getting_started_03_mixed_pipeline", + def _make_lb_worker(self, tmp_path, name, config_var, fn_name, method, path): + return WorkerInfo( + file_path=tmp_path / f"{name}.py", + url_prefix=f"/{name}", + module_path=name, + resource_name=name, worker_type="LB", - functions=["classify"], + functions=[fn_name], lb_routes=[ { - "method": "POST", - "path": "/classify", - "fn_name": "classify", - "config_variable": "pipeline_config", + "method": method, + "path": path, + "fn_name": fn_name, + "config_variable": config_var, } ], ) - content = _generate_flash_server(tmp_path, [worker]).read_text() - - # Must scope to correct subdirectory, not add all dirs to sys.path - assert '"01_getting_started/03_mixed"' in content - # No global sys.path additions for subdirs — only the project root - # line at the top and the one inside _flash_import helper body - lines = content.split("\n") - global_sys_path_lines = [ - line - for line in lines - if "sys.path.insert" in line and not line.startswith(" ") - ] - assert len(global_sys_path_lines) == 1 - - def test_generated_server_is_valid_python(self, tmp_path): - """Generated server.py with numeric dirs must be parseable Python.""" - worker = WorkerInfo( - file_path=tmp_path / "01_getting_started" / "01_hello" / "gpu_worker.py", - url_prefix="/01_getting_started/01_hello/gpu_worker", - module_path="01_getting_started.01_hello.gpu_worker", - resource_name="01_getting_started_01_hello_gpu_worker", - worker_type="QB", - functions=["gpu_hello"], - ) - server_path = _generate_flash_server(tmp_path, [worker]) - content = server_path.read_text() - - # Must parse without SyntaxError - import ast - ast.parse(content) + def test_post_route_passes_body(self, tmp_path): + """POST LB routes forward the request body to the executor.""" + self._write_lb_module(tmp_path, "api", "api_config", "handle") + worker = self._make_lb_worker( + tmp_path, "api", "api_config", "handle", "POST", "/do" + ) + captured = {} + + async def fake_executor(config, fn, body): + captured["config"] = config + captured["body"] = body + return {"ok": True} + + sys.path.insert(0, str(tmp_path)) + try: + test_app = FastAPI() + _register_lb_routes( + test_app, worker, tmp_path, "lb", executor=fake_executor + ) + client = TestClient(test_app) + resp = client.post("/api/do", json={"key": "val"}) + assert resp.status_code == 200 + assert captured["config"] == "fake_config" + assert captured["body"] == {"key": "val"} + finally: + sys.path.remove(str(tmp_path)) + sys.modules.pop("api", None) + + def test_get_route_passes_query_params(self, tmp_path): + """GET LB routes forward query params as a dict.""" + self._write_lb_module(tmp_path, "search", "search_cfg", "find") + worker = self._make_lb_worker( + tmp_path, "search", "search_cfg", "find", "GET", "/query" + ) + captured = {} + + async def fake_executor(config, fn, body): + captured["body"] = body + return {"ok": True} + + sys.path.insert(0, str(tmp_path)) + try: + test_app = FastAPI() + _register_lb_routes( + test_app, worker, tmp_path, "lb", executor=fake_executor + ) + client = TestClient(test_app) + resp = client.get("/search/query?q=test&limit=10") + assert resp.status_code == 200 + assert captured["body"] == {"q": "test", "limit": "10"} + finally: + sys.path.remove(str(tmp_path)) + sys.modules.pop("search", None) + + def test_all_body_methods(self, tmp_path): + """POST/PUT/PATCH/DELETE all register as body-accepting routes.""" + for method in ("POST", "PUT", "PATCH", "DELETE"): + mod_name = f"mod_{method.lower()}" + self._write_lb_module(tmp_path, mod_name, "cfg", "handler") + worker = self._make_lb_worker( + tmp_path, mod_name, "cfg", "handler", method, "/ep" + ) + + async def noop_executor(config, fn, body): + return {"ok": True} + + sys.path.insert(0, str(tmp_path)) + try: + test_app = FastAPI() + _register_lb_routes( + test_app, worker, tmp_path, "lb", executor=noop_executor + ) + route = next( + r + for r in test_app.routes + if hasattr(r, "path") and r.path == f"/{mod_name}/ep" + ) + assert method in route.methods + finally: + sys.path.remove(str(tmp_path)) + sys.modules.pop(mod_name, None) + + +# --------------------------------------------------------------------------- +# _import_from_module +# --------------------------------------------------------------------------- + + +class TestImportFromModule: + """Test module importing with standard and numeric-prefix paths.""" + + def test_standard_module(self, tmp_path): + (tmp_path / "mymod.py").write_text("MY_VAR = 42\n") + sys.path.insert(0, str(tmp_path)) + try: + assert _import_from_module("mymod", "MY_VAR", tmp_path) == 42 + finally: + sys.path.remove(str(tmp_path)) + sys.modules.pop("mymod", None) + + def test_numeric_prefix_module(self, tmp_path): + subdir = tmp_path / "01_hello" + subdir.mkdir() + (subdir / "__init__.py").write_text("") + (subdir / "gpu_worker.py").write_text("VALUE = 'hello'\n") + sys.path.insert(0, str(tmp_path)) + try: + assert ( + _import_from_module("01_hello.gpu_worker", "VALUE", tmp_path) == "hello" + ) + finally: + sys.path.remove(str(tmp_path)) + sys.modules.pop("01_hello.gpu_worker", None) + sys.modules.pop("01_hello", None) + + def test_top_level_numeric_module(self, tmp_path): + (tmp_path / "01_worker.py").write_text("RESULT = 'ok'\n") + sys.path.insert(0, str(tmp_path)) + try: + assert _import_from_module("01_worker", "RESULT", tmp_path) == "ok" + finally: + sys.path.remove(str(tmp_path)) + sys.modules.pop("01_worker", None) + + +# --------------------------------------------------------------------------- +# _map_body_to_params +# --------------------------------------------------------------------------- class TestMapBodyToParams: - """Tests for _map_body_to_params — maps HTTP body to function arguments.""" + """Tests for _map_body_to_params.""" - def test_body_keys_match_params_spreads_as_kwargs(self): + def test_matching_keys_spread_as_kwargs(self): from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params def process(name: str, value: int): pass - result = _map_body_to_params(process, {"name": "test", "value": 42}) - assert result == {"name": "test", "value": 42} + assert _map_body_to_params(process, {"name": "t", "value": 1}) == { + "name": "t", + "value": 1, + } - def test_body_keys_mismatch_wraps_in_first_param(self): + def test_mismatched_keys_wrap_in_first_param(self): from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params - def run_pipeline(input_data: dict): + def run(input_data: dict): pass - body = {"text": "hello", "mode": "fast"} - result = _map_body_to_params(run_pipeline, body) - assert result == {"input_data": {"text": "hello", "mode": "fast"}} + assert _map_body_to_params(run, {"a": 1}) == {"input_data": {"a": 1}} - def test_non_dict_body_wraps_in_first_param(self): + def test_non_dict_wraps_in_first_param(self): from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params - def run_pipeline(input_data): + def run(input_data): pass - result = _map_body_to_params(run_pipeline, [1, 2, 3]) - assert result == {"input_data": [1, 2, 3]} + assert _map_body_to_params(run, [1, 2]) == {"input_data": [1, 2]} def test_no_params_returns_empty(self): from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params - def no_args(): + def noop(): pass - result = _map_body_to_params(no_args, {"key": "val"}) - assert result == {} + assert _map_body_to_params(noop, {"k": "v"}) == {} - def test_partial_key_match_wraps_in_first_param(self): + def test_partial_match_wraps_in_first_param(self): from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params def process(name: str, value: int): pass - result = _map_body_to_params(process, {"name": "test", "extra": "bad"}) - assert result == {"name": {"name": "test", "extra": "bad"}} + assert _map_body_to_params(process, {"name": "t", "extra": "x"}) == { + "name": {"name": "t", "extra": "x"} + } - def test_empty_dict_body_spreads_as_empty_kwargs(self): + def test_empty_dict_spreads_as_empty(self): from runpod_flash.cli.commands._run_server_helpers import _map_body_to_params - def run_pipeline(input_data: dict): + def run(input_data: dict): pass - result = _map_body_to_params(run_pipeline, {}) - assert result == {} + assert _map_body_to_params(run, {}) == {} + + +class TestFindAvailablePort: + def test_returns_start_port_when_free(self): + from runpod_flash.cli.commands.run import _find_available_port + + # use port 0 trick: bind to 0 to get a free port, then test near it + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + free_port = s.getsockname()[1] + + # free_port is now unbound, so _find_available_port should return it + assert _find_available_port("localhost", free_port) == free_port + + def test_skips_occupied_port(self): + from runpod_flash.cli.commands.run import _find_available_port + + import socket + + # occupy a port + blocker = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + blocker.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + blocker.bind(("localhost", 0)) + occupied_port = blocker.getsockname()[1] + blocker.listen(1) + + try: + result = _find_available_port("localhost", occupied_port) + assert result > occupied_port + finally: + blocker.close() + + def test_exits_when_no_port_available(self): + from runpod_flash.cli.commands.run import _find_available_port + + import socket + + from runpod_flash.cli.commands.run import _MAX_PORT_ATTEMPTS + + blockers = [] + # find a free starting port first + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + start = s.getsockname()[1] + + # bind all ports in the range + for i in range(_MAX_PORT_ATTEMPTS): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + s.bind(("127.0.0.1", start + i)) + s.listen(1) + blockers.append(s) + except OSError: + s.close() + blockers.append(None) + + try: + from click.exceptions import Exit as ClickExit + + with pytest.raises((SystemExit, ClickExit)): + _find_available_port("127.0.0.1", start) + finally: + for s in blockers: + if s: + s.close() diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index 25899904..1450b65d 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -468,7 +468,7 @@ def test_is_deployed_false_when_no_id(self): assert serverless.is_deployed() is False def test_is_deployed_skips_health_check_during_live_provisioning(self, monkeypatch): - """During flash run, is_deployed returns True based on ID alone.""" + """During flash dev, is_deployed returns True based on ID alone.""" monkeypatch.setenv("FLASH_IS_LIVE_PROVISIONING", "true") serverless = ServerlessResource(name="test") serverless.id = "ep-live-123" @@ -477,7 +477,7 @@ def test_is_deployed_skips_health_check_during_live_provisioning(self, monkeypat assert serverless.is_deployed() is True def test_is_deployed_uses_health_check_outside_live_provisioning(self, monkeypatch): - """Outside flash run, is_deployed falls back to health check.""" + """Outside flash dev, is_deployed falls back to health check.""" monkeypatch.delenv("FLASH_IS_LIVE_PROVISIONING", raising=False) serverless = ServerlessResource(name="test") serverless.id = "ep-123"