diff --git a/flowfile_core/CLAUDE.md b/flowfile_core/CLAUDE.md index f9a407b40..d0a70c7b1 100644 --- a/flowfile_core/CLAUDE.md +++ b/flowfile_core/CLAUDE.md @@ -11,9 +11,11 @@ Central FastAPI backend and DAG execution engine for Flowfile: manages flows as - `flowfile_core/main.py` — FastAPI app, lifespan (scheduler/kernel/local-model shutdown), CORS (Tauri origin regex + explicit dev/Docker origins), all router mounts, `--run-flow` CLI. - `flowfile_core/routes/` — REST routers: `routes.py` (editor/transform, JWT-gated), `flow_api.py` (`data_router` API-key data + `management_router` JWT), `auth.py`, `secrets.py`, `catalog.py`, `cloud_connections.py`, `ga_connections.py`, `kafka.py`, `file_manager.py`, `api_consumers.py`, `user_defined_components.py`, `logs.py`, `public.py`. (More routers live under `ai/`, `kernel/`, `artifacts/`, `ml/`.) - `flowfile_core/flowfile/flow_graph.py` — DAG execution engine (`FlowGraph`, node add/run, worker offload). `flowfile/handler.py` — `FlowfileHandler` in-memory flow registry. +- `flowfile_core/flowfile/execution/` — compute-location seam: `transport.py` (`WorkerTransport`, sole owner of worker URLs/HTTP/WS), `exceptions.py` (typed `WorkerConnectionError` etc.), `handles.py` (`TaskHandle` protocol), `backends/` (`ExecutionBackend` ABC + `LocalBackend`/`RemoteWorkerBackend`, `resolve_backend(location)`). Route new local-vs-remote variance through a backend method, never an inline `execution_location` branch (ratchet test enforces this). +- `flowfile_core/flowfile/node_registry/` — **single source of truth for built-in node types**: one `NodeSpec` per type (`builtin/*.py`) bundling the `NodeTemplate`, settings class, defaults flag, AI classification, and (for simple nodes) the compute factory used by `FlowGraph._add_from_spec`. The legacy catalogs (`get_all_standard_nodes`, `NODE_TYPE_TO_SETTINGS_CLASS`, `nodes_with_defaults`, `_NODE_CLASS_MAP`) are derived views; contract tests in `tests/flowfile/execution/test_node_registry.py` pin them. - `flowfile_core/flowfile/flow_data_engine/flow_data_engine.py` — per-node Polars compute wrapper (lazy frames, previews; `join/`, `fuzzy_matching/`, `subprocess_operations/` subdirs). - `flowfile_core/flowfile/sources/external_sources/` — SQL / REST API / Google Analytics / custom source connectors (`factory.py`). -- `flowfile_core/configs/node_store/nodes.py` — node template/default registry (`get_all_standard_nodes`). +- `flowfile_core/configs/node_store/nodes.py` — legacy node template/default accessors (`get_all_standard_nodes`), now derived from `flowfile/node_registry`. - `flowfile_core/schemas/input_schema.py` — Pydantic node-config models (~90 classes); other request/response schemas alongside. - `flowfile_core/ai/` — AI subsystem (see patterns); routers under `ai/*_routes.py`, plus `agents/`, `providers/`, `tools/` (incl. `tools/executor/`), `local_model/`, `context/`. - `flowfile_core/auth/` — JWT (`jwt.py`), API keys (`api_key.py`), passwords (`password.py`). diff --git a/flowfile_core/flowfile_core/ai/tools/classification.py b/flowfile_core/flowfile_core/ai/tools/classification.py index 7f5b663cf..61a251383 100644 --- a/flowfile_core/flowfile_core/ai/tools/classification.py +++ b/flowfile_core/flowfile_core/ai/tools/classification.py @@ -30,56 +30,28 @@ from __future__ import annotations -from typing import Final, Literal +from typing import Literal NodeClass = Literal["static", "dynamic", "source", "passthrough"] -_NODE_CLASS_MAP: Final[dict[str, NodeClass]] = { - "manual_input": "source", - "filter": "static", - "formula": "static", - "select": "static", - "dynamic_rename": "dynamic", - "sort": "static", - "record_id": "static", - "sample": "static", - "random_split": "static", - "unique": "static", - "group_by": "static", - "window_functions": "static", - "pivot": "dynamic", - "unpivot": "dynamic", - "text_to_rows": "dynamic", - "graph_solver": "dynamic", - "python_script": "dynamic", - "polars_code": "dynamic", - "sql_query": "dynamic", - "join": "static", - "cross_join": "static", - "fuzzy_match": "static", - "record_count": "static", - "explore_data": "static", - "union": "static", - "output": "static", - "api_response": "static", - "read": "source", - "database_reader": "source", - "database_writer": "static", - "cloud_storage_reader": "source", - "cloud_storage_writer": "static", - "catalog_reader": "source", - "catalog_writer": "static", - "kafka_source": "source", - "google_analytics_reader": "source", - "rest_api_reader": "source", - "external_source": "source", - "promise": "passthrough", - "user_defined": "dynamic", - "train_model": "static", - "apply_model": "static", - "evaluate_model": "static", - "wait_for": "static", -} +# Derived from the node registry (each NodeSpec carries its ai_classification); +# built lazily to keep this module import-light. +_node_class_map: dict[str, NodeClass] | None = None + + +def _get_node_class_map() -> dict[str, NodeClass]: + global _node_class_map + if _node_class_map is None: + from flowfile_core.flowfile.node_registry import BUILTIN_REGISTRY + + _node_class_map = BUILTIN_REGISTRY.ai_classification_map() + return _node_class_map + + +def __getattr__(name: str): + if name == "_NODE_CLASS_MAP": + return _get_node_class_map() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") def classify_node_type(node_type: str) -> NodeClass: @@ -89,7 +61,7 @@ def classify_node_type(node_type: str) -> NodeClass: through the kernel dry-run path which fails closed (refusal on missing upstream sample) rather than producing a wrong schema. """ - return _NODE_CLASS_MAP.get(node_type, "dynamic") + return _get_node_class_map().get(node_type, "dynamic") def is_predictable_via_mirror(node_type: str) -> bool: diff --git a/flowfile_core/flowfile_core/configs/node_store/__init__.py b/flowfile_core/flowfile_core/configs/node_store/__init__.py index c5ff7cb11..1f755e379 100644 --- a/flowfile_core/flowfile_core/configs/node_store/__init__.py +++ b/flowfile_core/flowfile_core/configs/node_store/__init__.py @@ -1,6 +1,6 @@ import logging -from flowfile_core.configs.node_store.nodes import get_all_standard_nodes +from flowfile_core.configs.node_store.nodes import get_all_standard_nodes, get_nodes_with_default_settings from flowfile_core.configs.node_store.user_defined_node_registry import ( get_all_nodes_from_standard_location, load_single_node_from_file, @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) -nodes_with_defaults = {"sample", "sort", "union", "select", "record_count"} +nodes_with_defaults = get_nodes_with_default_settings() def register_custom_node(node: NodeTemplate): diff --git a/flowfile_core/flowfile_core/configs/node_store/nodes.py b/flowfile_core/flowfile_core/configs/node_store/nodes.py index 1a73c49f8..365e2b586 100644 --- a/flowfile_core/flowfile_core/configs/node_store/nodes.py +++ b/flowfile_core/flowfile_core/configs/node_store/nodes.py @@ -1,735 +1,32 @@ +"""Legacy node-catalog views, derived from the node registry. + +The NodeTemplate literals moved to +``flowfile_core/flowfile/node_registry/builtin/`` (the single source of +truth); this module keeps the historical accessors for existing importers. +""" + from functools import lru_cache -from flowfile_core.schemas.schemas import NodeDefault, NodeTag, NodeTemplate +from flowfile_core.schemas.schemas import NodeDefault, NodeTemplate def get_all_standard_nodes() -> tuple[list[NodeTemplate], dict[str, NodeTemplate], dict[str, NodeDefault]]: """ Initializes and returns the complete list, dict, and defaults for all nodes. """ - nodes_list: list[NodeTemplate] = [ - NodeTemplate( - name="External source", - item="external_source", - input=0, - output=1, - image="external_source.svg", - node_type="input", - transform_type="other", - node_group="input", - prod_ready=False, - drawer_title="External Source", - drawer_intro="Connect to external data sources and APIs", - laziness="eager", - tags=[NodeTag.API, NodeTag.REST, NodeTag.HTTP, NodeTag.EXTERNAL], - ), - NodeTemplate( - name="Manual input", - item="manual_input", - input=0, - output=1, - transform_type="other", - node_type="input", - image="manual_input.svg", - node_group="input", - drawer_title="Manual Input", - drawer_intro="Create data directly", - laziness="lazy", - tags=[NodeTag.MANUAL, NodeTag.PASTE, NodeTag.INPUT], - ), - NodeTemplate( - name="Read data", - item="read", - input=0, - output=1, - transform_type="other", - node_type="input", - image="input_data.svg", - node_group="input", - drawer_title="Read Data", - drawer_intro="Load data from CSV, Excel, or Parquet files", - # TODO: resolve laziness in check_upstream_laziness via isinstance check (like catalog_reader), - # then change to "lazy" - laziness="conditional", - tags=[ - NodeTag.CSV, - NodeTag.EXCEL, - NodeTag.PARQUET, - NodeTag.JSON, - NodeTag.FILE, - NodeTag.IMPORT, - NodeTag.READ, - ], - ), - NodeTemplate( - name="Join", - item="join", - input=2, - output=1, - transform_type="wide", - node_type="process", - image="join.svg", - node_group="combine", - drawer_title="Join Datasets", - drawer_intro="Merge two datasets based on matching column values", - laziness="lazy", - tags=[NodeTag.JOIN, NodeTag.MERGE, NodeTag.LOOKUP, NodeTag.VLOOKUP, NodeTag.INNER, NodeTag.OUTER], - ), - NodeTemplate( - name="Formula", - item="formula", - input=1, - output=1, - transform_type="narrow", - node_type="process", - image="formula.svg", - node_group="transform", - drawer_title="Formula Editor", - drawer_intro="Create or modify columns using custom expressions", - laziness="lazy", - tags=[NodeTag.FORMULA, NodeTag.EXPRESSION, NodeTag.TRANSFORM, - NodeTag.CALCULATE, NodeTag.MATH, NodeTag.CONCAT, NodeTag.SUM], - ), - NodeTemplate( - name="Write data", - item="output", - input=1, - output=0, - transform_type="other", - image="output.svg", - node_type="output", - node_group="output", - drawer_title="Write Data", - drawer_intro="Save your data as CSV, Excel, or Parquet files", - laziness="eager", - tags=[ - NodeTag.CSV, - NodeTag.EXCEL, - NodeTag.PARQUET, - NodeTag.JSON, - NodeTag.FILE, - NodeTag.EXPORT, - NodeTag.SAVE, - NodeTag.WRITE, - ], - ), - NodeTemplate( - name="API response", - item="api_response", - input=1, - output=0, - transform_type="other", - image="api_response.svg", - node_type="output", - node_group="output", - drawer_title="API Response", - drawer_intro="Return this dataset as the body of an HTTP API endpoint", - laziness="eager", - tags=[NodeTag.API, NodeTag.REST, NodeTag.HTTP, NodeTag.RESPONSE], - ), - NodeTemplate( - name="Select data", - item="select", - input=1, - output=1, - transform_type="narrow", - node_type="process", - image="select.svg", - node_group="transform", - drawer_title="Select Columns", - drawer_intro="Choose, rename, and reorder columns to keep", - laziness="lazy", - tags=[NodeTag.SELECT, NodeTag.COLUMNS, NodeTag.RENAME, NodeTag.REORDER, NodeTag.PROJECTION], - ), - NodeTemplate( - name="Rename columns", - item="dynamic_rename", - input=1, - output=1, - transform_type="narrow", - node_type="process", - image="dynamic_rename.svg", - node_group="transform", - drawer_title="Rename Columns", - drawer_intro="Bulk-rename columns by prefix, suffix, or a formula", - laziness="lazy", - tags=[NodeTag.RENAME, NodeTag.COLUMNS], - ), - NodeTemplate( - name="Filter data", - item="filter", - input=1, - output=1, - transform_type="narrow", - node_type="process", - image="filter.svg", - node_group="transform", - drawer_title="Filter Rows", - drawer_intro="Keep only rows that match your conditions", - laziness="lazy", - tags=[NodeTag.FILTER, NodeTag.WHERE, NodeTag.SUBSET], - ), - NodeTemplate( - name="Group by", - item="group_by", - input=1, - output=1, - transform_type="wide", - node_type="process", - image="group_by.svg", - node_group="aggregate", - drawer_title="Group By", - drawer_intro="Aggregate data by grouping and calculating statistics", - laziness="lazy", - tags=[ - NodeTag.GROUP_BY, - NodeTag.AGGREGATE, - NodeTag.SUM, - NodeTag.MEAN, - NodeTag.AVERAGE, - NodeTag.COUNT, - NodeTag.MIN, - NodeTag.MAX, - NodeTag.MEDIAN, - NodeTag.SUMMARIZE, - ], - ), - NodeTemplate( - name="Window functions", - item="window_functions", - input=1, - output=1, - transform_type="wide", - node_type="process", - image="window_functions.svg", - node_group="aggregate", - drawer_title="Window Functions", - drawer_intro="Rolling, cumulative, rank and tile calculations (optionally per partition)", - laziness="lazy", - tags=[ - NodeTag.WINDOW, - NodeTag.ROLLING, - NodeTag.CUMULATIVE, - NodeTag.RANK, - NodeTag.PARTITION, - NodeTag.LAG, - NodeTag.LEAD, - ], - ), - NodeTemplate( - name="Fuzzy match", - item="fuzzy_match", - input=2, - output=1, - transform_type="wide", - image="fuzzy_match.svg", - node_type="process", - node_group="combine", - drawer_title="Fuzzy Match", - drawer_intro="Join datasets based on similar values instead of exact matches", - laziness="eager", - tags=[NodeTag.FUZZY, NodeTag.SIMILARITY, NodeTag.LEVENSHTEIN, NodeTag.JOIN, NodeTag.LOOKUP], - ), - NodeTemplate( - name="Sort data", - item="sort", - input=1, - output=1, - transform_type="wide", - node_type="process", - image="sort.svg", - node_group="transform", - drawer_title="Sort Data", - drawer_intro="Order your data by one or more columns", - laziness="lazy", - tags=[NodeTag.SORT, NodeTag.ORDER, NodeTag.RANK, NodeTag.ASCENDING, NodeTag.DESCENDING], - ), - NodeTemplate( - name="Add record Id", - item="record_id", - input=1, - output=1, - transform_type="wide", - node_type="process", - image="record_id.svg", - node_group="transform", - drawer_title="Add Record ID", - drawer_intro="Generate unique identifiers for each row", - laziness="lazy", - tags=[NodeTag.RECORD_ID, NodeTag.ROW_NUMBER, NodeTag.INDEX], - ), - NodeTemplate( - name="Take Sample", - item="sample", - input=1, - output=1, - transform_type="narrow", - node_type="process", - image="sample.svg", - node_group="transform", - drawer_title="Take Sample", - drawer_intro="Work with a subset of your data", - laziness="lazy", - tags=[NodeTag.SAMPLE, NodeTag.SUBSET, NodeTag.LIMIT, NodeTag.HEAD], - ), - NodeTemplate( - name="Random Split", - item="random_split", - input=1, - output=2, - output_names=["train", "test"], - transform_type="narrow", - node_type="process", - image="random_split.svg", - node_group="ml", - drawer_title="Random Split", - drawer_intro="Randomly partition rows into named groups (e.g. train/test)", - laziness="lazy", - tags=[NodeTag.SPLIT, NodeTag.TRAIN, NodeTag.TEST, NodeTag.ML, NodeTag.PARTITION], - ), - NodeTemplate( - name="Explore data", - item="explore_data", - input=1, - output=0, - transform_type="other", - node_type="output", - image="explore_data.svg", - node_group="output", - drawer_title="Explore Data", - drawer_intro="Interactive data exploration and analysis", - laziness="eager", - tags=[NodeTag.EXPLORE, NodeTag.PROFILE, NodeTag.PREVIEW, NodeTag.EDA, NodeTag.STATISTICS, NodeTag.VISUALIZE, - NodeTag.INSIGHT, NodeTag.BAR_CHART, NodeTag.GRAPHS], - ), - NodeTemplate( - name="Pivot data", - item="pivot", - input=1, - output=1, - transform_type="wide", - image="pivot.svg", - node_type="process", - node_group="aggregate", - drawer_title="Pivot Data", - drawer_intro="Convert data from long format to wide format", - laziness="eager", - tags=[NodeTag.PIVOT, NodeTag.CROSSTAB, NodeTag.RESHAPE], - ), - NodeTemplate( - name="Unpivot data", - item="unpivot", - input=1, - output=1, - transform_type="wide", - node_type="process", - image="unpivot.svg", - node_group="aggregate", - drawer_title="Unpivot Data", - drawer_intro="Transform data from wide format to long format", - laziness="lazy", - tags=[NodeTag.UNPIVOT, NodeTag.MELT, NodeTag.RESHAPE], - ), - NodeTemplate( - name="Union data", - item="union", - input=10, - output=1, - transform_type="narrow", - node_type="process", - image="union.svg", - multi=True, - node_group="combine", - drawer_title="Union Data", - drawer_intro="Stack multiple datasets by combining rows", - laziness="lazy", - tags=[NodeTag.UNION, NodeTag.CONCAT, NodeTag.APPEND], - ), - NodeTemplate( - name="Drop duplicates", - item="unique", - input=1, - output=1, - transform_type="wide", - node_type="process", - image="unique.svg", - node_group="transform", - drawer_title="Drop Duplicates", - drawer_intro="Remove duplicate rows based on selected columns", - laziness="lazy", - tags=[NodeTag.UNIQUE, NodeTag.DEDUPE, NodeTag.DISTINCT, NodeTag.DROP_DUPLICATES], - ), - NodeTemplate( - name="Graph solver", - item="graph_solver", - input=1, - output=1, - transform_type="other", - node_type="process", - image="graph_solver.svg", - node_group="combine", - drawer_title="Graph Solver", - drawer_intro="Group related records in graph-structured data", - laziness="lazy", - tags=[NodeTag.GRAPH, NodeTag.NETWORK, NodeTag.CLUSTER, NodeTag.CONNECTED_COMPONENTS], - ), - NodeTemplate( - name="Count records", - item="record_count", - input=1, - output=1, - transform_type="wide", - node_type="process", - image="record_count.svg", - node_group="aggregate", - drawer_title="Count Records", - drawer_intro="Calculate the total number of rows", - laziness="lazy", - tags=[NodeTag.RECORD_COUNT, NodeTag.COUNT, NodeTag.ROWS], - ), - NodeTemplate( - name="Cross join", - item="cross_join", - input=2, - output=1, - transform_type="wide", - node_type="process", - image="cross_join.svg", - node_group="combine", - drawer_title="Cross Join", - drawer_intro="Create all possible combinations between two datasets", - laziness="lazy", - tags=[NodeTag.CROSS_JOIN, NodeTag.CARTESIAN, NodeTag.JOIN], - ), - NodeTemplate( - name="Text to rows", - item="text_to_rows", - input=1, - output=1, - transform_type="wide", - node_type="process", - image="text_to_rows.svg", - node_group="transform", - drawer_title="Text to Rows", - drawer_intro="Split text into multiple rows based on a delimiter", - laziness="lazy", - tags=[NodeTag.TEXT_TO_ROWS, NodeTag.SPLIT, NodeTag.EXPLODE], - ), - NodeTemplate( - name="Polars code", - item="polars_code", - input=10, - output=1, - transform_type="narrow", - image="polars_code.svg", - node_group="transform", - node_type="process", - multi=True, - can_be_start=True, - drawer_title="Polars Code", - drawer_intro="Write custom Polars DataFrame transformations", - # TODO: resolve laziness in check_upstream_laziness via isinstance check (like catalog_reader), - # then change to "lazy" - laziness="conditional", - tags=[NodeTag.POLARS, NodeTag.CODE, NodeTag.PYTHON, NodeTag.SCRIPT, NodeTag.CUSTOM, NodeTag.DATAFRAME, - NodeTag.TRANSFORM], - ), - NodeTemplate( - name="SQL Query", - item="sql_query", - input=10, - output=1, - transform_type="narrow", - image="sql_query.svg", - node_group="transform", - node_type="process", - multi=True, - can_be_start=True, - drawer_title="SQL Query", - drawer_intro="Write SQL queries against connected data sources", - laziness="lazy", - tags=[NodeTag.SQL, NodeTag.QUERY, NodeTag.DUCKDB], - ), - NodeTemplate( - name="Python Script", - item="python_script", - input=10, - output=1, - transform_type="narrow", - image="python_code.svg", - node_group="transform", - multi=True, - can_be_start=True, - node_type="process", - drawer_title="Python Script", - drawer_intro="Execute Python code on an isolated kernel container", - laziness="eager", - tags=[NodeTag.PYTHON, NodeTag.CODE, NodeTag.SCRIPT, NodeTag.KERNEL, NodeTag.CUSTOM, NodeTag.TRANSFORM], - ), - NodeTemplate( - name="Read from Database", - item="database_reader", - input=0, - output=1, - node_type="input", - transform_type="other", - image="database_reader.svg", - node_group="input", - drawer_title="Database Reader", - drawer_intro="Load data from database tables or queries", - laziness="eager", - tags=[ - NodeTag.DATABASE, - NodeTag.SQL, - NodeTag.POSTGRES, - NodeTag.MYSQL, - NodeTag.SQL_SERVER, - NodeTag.SNOWFLAKE, - NodeTag.ORACLE, - NodeTag.SQLITE, - NodeTag.REDSHIFT, - NodeTag.BIGQUERY, - NodeTag.QUERY, - NodeTag.TABLE, - ], - ), - NodeTemplate( - name="Write to Database", - item="database_writer", - input=1, - output=0, - transform_type="other", - node_type="output", - image="database_writer.svg", - node_group="output", - drawer_title="Database Writer", - drawer_intro="Save data to database tables", - laziness="eager", - tags=[ - NodeTag.DATABASE, - NodeTag.SQL, - NodeTag.POSTGRES, - NodeTag.MYSQL, - NodeTag.SNOWFLAKE, - NodeTag.REDSHIFT, - NodeTag.BIGQUERY, - NodeTag.TABLE, - ], - ), - NodeTemplate( - name="Read from cloud provider", - item="cloud_storage_reader", - input=0, - output=1, - transform_type="other", - node_type="input", - image="cloud_storage_reader.svg", - node_group="input", - drawer_title="Cloud Storage Reader", - drawer_intro="Read data from AWS S3 and other cloud storage", - # TODO: resolve laziness in check_upstream_laziness via isinstance check (like catalog_reader), - # then change to "lazy" - laziness="conditional", - tags=[ - NodeTag.S3, - NodeTag.AWS, - NodeTag.AZURE, - NodeTag.ADLS, - NodeTag.GCS, - NodeTag.BLOB, - NodeTag.BUCKET, - NodeTag.CLOUD, - NodeTag.DELTA, - ], - ), - NodeTemplate( - name="Read from Catalog", - item="catalog_reader", - input=0, - output=1, - transform_type="other", - node_type="input", - image="catalog_reader.svg", - node_group="input", - drawer_title="Catalog Reader", - drawer_intro="Read a table from the data catalog", - laziness="lazy", - tags=[NodeTag.CATALOG, NodeTag.DELTA, NodeTag.TABLE, NodeTag.LAKEHOUSE, NodeTag.TIME_TRAVEL], - ), - NodeTemplate( - name="Write to Catalog", - item="catalog_writer", - input=1, - output=0, - transform_type="other", - node_type="output", - image="catalog_writer.svg", - node_group="output", - drawer_title="Catalog Writer", - drawer_intro="Save data as a table in the data catalog", - laziness="eager", - tags=[NodeTag.CATALOG, NodeTag.DELTA, NodeTag.TABLE, NodeTag.LAKEHOUSE], - ), - NodeTemplate( - name="Write to cloud provider", - item="cloud_storage_writer", - input=1, - output=0, - transform_type="other", - node_type="output", - image="cloud_storage_writer.svg", - node_group="output", - drawer_title="Cloud Storage Writer", - drawer_intro="Save data to AWS S3 and other cloud storage", - laziness="eager", - tags=[ - NodeTag.S3, - NodeTag.AWS, - NodeTag.AZURE, - NodeTag.ADLS, - NodeTag.GCS, - NodeTag.BLOB, - NodeTag.BUCKET, - NodeTag.CLOUD, - NodeTag.DELTA, - ], - ), - NodeTemplate( - name="Kafka Source", - item="kafka_source", - input=0, - output=1, - node_type="input", - transform_type="other", - image="kafka_source.svg", - node_group="input", - drawer_title="Kafka Source", - drawer_intro="Read data from a Kafka or Redpanda topic", - laziness="eager", - tags=[NodeTag.KAFKA, NodeTag.REDPANDA, NodeTag.STREAMING, NodeTag.TOPIC], - ), - NodeTemplate( - name="Google Analytics", - item="google_analytics_reader", - input=0, - output=1, - node_type="input", - transform_type="other", - image="google_analytics.svg", - node_group="input", - drawer_title="Google Analytics", - drawer_intro="Load reports from a Google Analytics 4 property", - laziness="eager", - tags=[NodeTag.GOOGLE_ANALYTICS, NodeTag.GA4, NodeTag.ANALYTICS], - ), - NodeTemplate( - name="REST API", - item="rest_api_reader", - input=0, - output=1, - node_type="input", - transform_type="other", - image="rest_api_reader.svg", - node_group="input", - drawer_title="REST API", - drawer_intro="Read JSON data from a REST API with auth and pagination", - laziness="eager", - tags=[NodeTag.REST, NodeTag.API, NodeTag.HTTP, NodeTag.JSON, NodeTag.PAGINATION], - ), - NodeTemplate( - name="Train Model", - item="train_model", - input=1, - output=1, - transform_type="other", - node_type="process", - image="train_model.svg", - node_group="ml", - drawer_title="Train ML Model", - drawer_intro="Fit a regression or classification model; optionally save it to the catalog", - laziness="eager", - tags=[ - NodeTag.ML, - NodeTag.MACHINE_LEARNING, - NodeTag.TRAIN, - NodeTag.MODEL, - NodeTag.REGRESSION, - NodeTag.CLASSIFICATION, - ], - ), - NodeTemplate( - name="Apply Model", - item="apply_model", - input=1, - output=1, - transform_type="wide", - node_type="process", - image="apply_model.svg", - node_group="ml", - drawer_title="Apply ML Model", - drawer_intro="Score data with an upstream Train Model node, or with a trained model from the catalog", - laziness="eager", - tags=[NodeTag.ML, NodeTag.MACHINE_LEARNING, NodeTag.PREDICT, NodeTag.SCORE, NodeTag.MODEL], - ), - NodeTemplate( - name="Evaluate Model", - item="evaluate_model", - input=1, - output=1, - transform_type="narrow", - node_type="process", - image="evaluate_model.svg", - node_group="ml", - drawer_title="Evaluate Model", - drawer_intro="Compare actual vs predicted columns and compute quality metrics", - laziness="eager", - tags=[NodeTag.ML, NodeTag.MACHINE_LEARNING, NodeTag.EVALUATE, NodeTag.METRICS, NodeTag.MODEL], - ), - NodeTemplate( - name="Wait For", - item="wait_for", - input=2, - output=1, - transform_type="other", - node_type="process", - image="wait_for.svg", - node_group="combine", - drawer_title="Wait For", - drawer_intro="Pass the left input through; the right input only enforces ordering", - laziness="eager", - tags=[NodeTag.WAIT, NodeTag.DEPENDENCY], - ), - ] - nodes_list.sort(key=lambda x: x.name) - nodes_with_defaults = {"sample", "sort", "union", "select", "record_count"} + from flowfile_core.flowfile.node_registry import BUILTIN_REGISTRY - def check_if_has_default_setting(node_item: str): - return node_item in nodes_with_defaults + nodes_list = BUILTIN_REGISTRY.drawer_templates() + node_dict = BUILTIN_REGISTRY.template_dict() + node_defaults = BUILTIN_REGISTRY.node_defaults() + return nodes_list, node_dict, node_defaults - node_defaults = { - node.item: NodeDefault( - node_name=node.name, - node_type=node.node_type, - transform_type=node.transform_type, - has_default_settings=check_if_has_default_setting(node.item), - ) - for node in nodes_list - } - node_dict = {n.item: n for n in nodes_list} - node_dict["polars_lazy_frame"] = NodeTemplate( - name="LazyFrame node", - item="polars_lazy_frame", - input=0, - output=1, - node_group="special", - image="", - node_type="input", - transform_type="other", - laziness="lazy", - ) +def get_nodes_with_default_settings() -> set[str]: + """Node types that can be added with auto-generated default settings.""" + from flowfile_core.flowfile.node_registry import BUILTIN_REGISTRY - return nodes_list, node_dict, node_defaults + return BUILTIN_REGISTRY.node_types_with_default_settings() @lru_cache(maxsize=1) diff --git a/flowfile_core/flowfile_core/flowfile/execution/__init__.py b/flowfile_core/flowfile_core/flowfile/execution/__init__.py new file mode 100644 index 000000000..de0b35378 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/execution/__init__.py @@ -0,0 +1 @@ +"""Execution-layer abstractions: worker transport, typed errors, and compute backends.""" diff --git a/flowfile_core/flowfile_core/flowfile/execution/backends/__init__.py b/flowfile_core/flowfile_core/flowfile/execution/backends/__init__.py new file mode 100644 index 000000000..7814340bb --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/execution/backends/__init__.py @@ -0,0 +1,26 @@ +"""Backend registry: map an execution location to an ExecutionBackend.""" + +from __future__ import annotations + +from flowfile_core.flowfile.execution.backends.base import ExecutionBackend +from flowfile_core.flowfile.execution.backends.local import LocalBackend +from flowfile_core.flowfile.execution.backends.worker import RemoteWorkerBackend +from flowfile_core.flowfile.execution.transport import WorkerTransport + +_local_backend = LocalBackend() +_worker_backend: RemoteWorkerBackend | None = None + + +def resolve_backend(location: str, transport: WorkerTransport | None = None) -> ExecutionBackend: + """Return the backend for an execution location ("local" or "remote").""" + if location == "local": + return _local_backend + if transport is not None: + return RemoteWorkerBackend(transport=transport) + global _worker_backend + if _worker_backend is None: + _worker_backend = RemoteWorkerBackend() + return _worker_backend + + +__all__ = ["ExecutionBackend", "LocalBackend", "RemoteWorkerBackend", "resolve_backend"] diff --git a/flowfile_core/flowfile_core/flowfile/execution/backends/base.py b/flowfile_core/flowfile_core/flowfile/execution/backends/base.py new file mode 100644 index 000000000..b59a0a25d --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/execution/backends/base.py @@ -0,0 +1,109 @@ +"""ExecutionBackend: the seam deciding where node compute runs.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import TYPE_CHECKING, ClassVar + +import polars as pl + +from flowfile_core.flowfile.execution.handles import TaskHandle + +if TYPE_CHECKING: + from flowfile_core.flowfile.flow_data_engine.flow_data_engine import FlowDataEngine + from flowfile_core.flowfile.flow_node.multi_output import NamedOutputs + from flowfile_core.flowfile.sources.external_sources.sql_source.models import DatabaseExternalReadSettings + from flowfile_core.schemas.input_schema import OutputSettings + + +class ExecutionBackend(ABC): + """Typed compute operations, implemented per execution location. + + ``LocalBackend`` runs in-process; ``RemoteWorkerBackend`` ships work to the + flowfile_worker service. Node code calls these methods instead of branching + on ``execution_location`` and constructing fetchers inline, so new backends + (worker pools, remote runners) can be added without touching node logic. + """ + + location: ClassVar[str] + + # -- frame operations ------------------------------------------------------ + + @abstractmethod + def run_lazyframe( + self, + lf: pl.LazyFrame, + *, + flow_id: int, + node_id: int | str, + file_ref: str, + wait_on_completion: bool = False, + operation_type: str = "store", + ) -> TaskHandle: + """Materialise a LazyFrame plan; the result stays owned by the backend.""" + + @abstractmethod + def sample( + self, + lf: pl.LazyFrame, + *, + file_ref: str, + flow_id: int, + node_id: int | str, + sample_size: int = 100, + wait_on_completion: bool = True, + ) -> TaskHandle: + """Produce preview rows; ``handle.status.file_ref`` points at an Arrow IPC file.""" + + @abstractmethod + def count_records(self, lf: pl.LazyFrame, *, flow_id: int, node_id: int | str) -> int | None: + """Count the rows a LazyFrame plan produces.""" + + @abstractmethod + def random_split( + self, + df: FlowDataEngine, + splits: list[tuple[str, float]], + seed: int | None, + *, + flow_id: int, + node_id: int | str, + ) -> NamedOutputs: + """Partition rows into labeled outputs by the given percentages.""" + + # -- sources and sinks ------------------------------------------------------- + + @abstractmethod + def read_database( + self, + settings: DatabaseExternalReadSettings, + *, + cancel_check: Callable[[], bool] | None = None, + ) -> TaskHandle: + """Run the rendered SQL read; the result is a LazyFrame.""" + + @abstractmethod + def write_output( + self, + df: FlowDataEngine, + settings: OutputSettings, + *, + flow_id: int, + node_id: int | str, + ) -> TaskHandle: + """Write the frame to its file destination.""" + + # -- result cache ------------------------------------------------------------ + + @abstractmethod + def results_exist(self, file_ref: str) -> bool: + """Whether a completed cached result exists for ``file_ref``.""" + + @abstractmethod + def get_cached_lazyframe(self, file_ref: str) -> pl.LazyFrame: + """Load a completed cached result; raises when unavailable.""" + + @abstractmethod + def clear_result(self, file_ref: str) -> bool: + """Drop the cached result for ``file_ref``.""" diff --git a/flowfile_core/flowfile_core/flowfile/execution/backends/local.py b/flowfile_core/flowfile_core/flowfile/execution/backends/local.py new file mode 100644 index 000000000..dd47a70b8 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/execution/backends/local.py @@ -0,0 +1,130 @@ +"""Backend that runs compute in the core process. + +Only selected when ``execution_location == "local"`` (flowfile_frame, the +scheduler/CLI path, WASM-style single-process runs). Bounded preview collects +are allowed here; the core-never-collects contract applies to the remote path. +""" + +from __future__ import annotations + +from collections.abc import Callable +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar + +import polars as pl + +from flowfile_core.flowfile.execution.backends.base import ExecutionBackend +from flowfile_core.flowfile.execution.exceptions import WorkerTaskError +from flowfile_core.flowfile.execution.handles import LocalResultHandle, TaskHandle +from shared.storage_config import storage + +if TYPE_CHECKING: + from flowfile_core.flowfile.flow_data_engine.flow_data_engine import FlowDataEngine + from flowfile_core.flowfile.flow_node.multi_output import NamedOutputs + from flowfile_core.flowfile.sources.external_sources.sql_source.models import DatabaseExternalReadSettings + from flowfile_core.schemas.input_schema import OutputSettings + + +class LocalBackend(ExecutionBackend): + """Runs operations in-process; there is no shared result cache.""" + + location: ClassVar[str] = "local" + + def run_lazyframe( + self, + lf: pl.LazyFrame, + *, + flow_id: int, + node_id: int | str, + file_ref: str, + wait_on_completion: bool = False, + operation_type: str = "store", + ) -> TaskHandle: + return LocalResultHandle(result=lf, file_ref=file_ref) + + def sample( + self, + lf: pl.LazyFrame, + *, + file_ref: str, + flow_id: int, + node_id: int | str, + sample_size: int = 100, + wait_on_completion: bool = True, + ) -> TaskHandle: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations.models import Status + + df = lf.head(sample_size).collect() + path = Path(storage.cache_directory) / str(flow_id) / f"{file_ref}.arrow" + path.parent.mkdir(parents=True, exist_ok=True) + df.write_ipc(path) + status = Status( + background_task_id=file_ref, + status="Completed", + file_ref=str(path), + progress=100, + results=None, + result_type="other", + ) + return LocalResultHandle(result=df.lazy(), file_ref=str(path), status=status) + + def count_records(self, lf: pl.LazyFrame, *, flow_id: int, node_id: int | str) -> int | None: + return int(lf.select(pl.len()).collect().item()) + + def random_split( + self, + df: FlowDataEngine, + splits: list[tuple[str, float]], + seed: int | None, + *, + flow_id: int, + node_id: int | str, + ) -> NamedOutputs: + return df.random_split(splits, seed) + + def read_database( + self, + settings: DatabaseExternalReadSettings, + *, + cancel_check: Callable[[], bool] | None = None, + ) -> TaskHandle: + from flowfile_core.flowfile.sources.external_sources.sql_source import utils as sql_utils + from flowfile_core.flowfile.sources.external_sources.sql_source.sql_source import SqlSource + from flowfile_core.secret_manager.secret_manager import decrypt_secret + + connection = settings.connection + source = SqlSource( + connection_string=sql_utils.construct_sql_uri( + database_type=connection.database_type, + host=connection.host, + port=connection.port, + database=connection.database, + username=connection.username, + password=decrypt_secret(connection.password) if connection.password else None, + ssl_enabled=bool(getattr(connection, "ssl_enabled", False)), + connect_timeout=10, + ), + query=settings.query, + cancel_check=cancel_check, + ) + return LocalResultHandle(result=source.get_pl_df().lazy(), file_ref=str(settings.flowfile_node_id)) + + def write_output( + self, + df: FlowDataEngine, + settings: OutputSettings, + *, + flow_id: int, + node_id: int | str, + ) -> TaskHandle: + df.output(output_fs=settings, flow_id=flow_id, node_id=node_id, execute_remote=False) + return LocalResultHandle(result=None, file_ref=str(node_id)) + + def results_exist(self, file_ref: str) -> bool: + return False + + def get_cached_lazyframe(self, file_ref: str) -> pl.LazyFrame: + raise WorkerTaskError("Local execution keeps no shared result cache") + + def clear_result(self, file_ref: str) -> bool: + return False diff --git a/flowfile_core/flowfile_core/flowfile/execution/backends/worker.py b/flowfile_core/flowfile_core/flowfile/execution/backends/worker.py new file mode 100644 index 000000000..efa21687e --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/execution/backends/worker.py @@ -0,0 +1,149 @@ +"""Backend that ships compute to the flowfile_worker service.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, ClassVar + +import polars as pl + +from flowfile_core.configs.settings import OFFLOAD_TO_WORKER +from flowfile_core.flowfile.execution.backends.base import ExecutionBackend +from flowfile_core.flowfile.execution.exceptions import WorkerTaskError +from flowfile_core.flowfile.execution.handles import TaskHandle +from flowfile_core.flowfile.execution.transport import WorkerTransport, get_default_transport + +if TYPE_CHECKING: + from flowfile_core.flowfile.flow_data_engine.flow_data_engine import FlowDataEngine + from flowfile_core.flowfile.flow_node.multi_output import NamedOutputs + from flowfile_core.flowfile.sources.external_sources.sql_source.models import DatabaseExternalReadSettings + from flowfile_core.schemas.input_schema import OutputSettings + + +class RemoteWorkerBackend(ExecutionBackend): + """Runs operations on the worker; results live in the worker's cache.""" + + location: ClassVar[str] = "remote" + + def __init__(self, transport: WorkerTransport | None = None): + self.transport = transport or get_default_transport() + + def run_lazyframe( + self, + lf: pl.LazyFrame, + *, + flow_id: int, + node_id: int | str, + file_ref: str, + wait_on_completion: bool = False, + operation_type: str = "store", + ) -> TaskHandle: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations import ExternalDfFetcher + + return ExternalDfFetcher( + lf=lf, + file_ref=file_ref, + wait_on_completion=wait_on_completion, + flow_id=flow_id, + node_id=node_id, + operation_type=operation_type, + transport=self.transport, + ) + + def sample( + self, + lf: pl.LazyFrame, + *, + file_ref: str, + flow_id: int, + node_id: int | str, + sample_size: int = 100, + wait_on_completion: bool = True, + ) -> TaskHandle: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations import ExternalSampler + + return ExternalSampler( + lf=lf, + file_ref=file_ref, + wait_on_completion=wait_on_completion, + node_id=node_id, + flow_id=flow_id, + sample_size=sample_size, + transport=self.transport, + ) + + def count_records(self, lf: pl.LazyFrame, *, flow_id: int, node_id: int | str) -> int | None: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations import ExternalDfFetcher + + return ExternalDfFetcher( + lf=lf, + operation_type="calculate_number_of_records", + flow_id=flow_id, + node_id=node_id, + transport=self.transport, + ).result + + def random_split( + self, + df: FlowDataEngine, + splits: list[tuple[str, float]], + seed: int | None, + *, + flow_id: int, + node_id: int | str, + ) -> NamedOutputs: + return df.random_split_external(splits, seed, flow_id=flow_id, node_id=node_id) + + def read_database( + self, + settings: DatabaseExternalReadSettings, + *, + cancel_check: Callable[[], bool] | None = None, + ) -> TaskHandle: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations import ExternalDatabaseFetcher + + return ExternalDatabaseFetcher(settings, wait_on_completion=False, transport=self.transport) + + def write_output( + self, + df: FlowDataEngine, + settings: OutputSettings, + *, + flow_id: int, + node_id: int | str, + ) -> TaskHandle: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations import ExternalOutputWriter + + return ExternalOutputWriter( + lf=df.data_frame, + data_type=settings.file_type, + path=settings.abs_file_path, + write_mode=settings.write_mode, + sheet_name=settings.sheet_name, + delimiter=settings.delimiter, + compression=settings.compression, + flow_id=flow_id, + node_id=node_id, + wait_on_completion=False, + transport=self.transport, + ) + + def results_exist(self, file_ref: str) -> bool: + if not OFFLOAD_TO_WORKER: + return False + return self.transport.results_exist(file_ref) + + def get_cached_lazyframe(self, file_ref: str) -> pl.LazyFrame: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations import get_df_result + + status = self.transport.get_status(file_ref) + if status.status != "Completed": + raise WorkerTaskError(f"Status is not completed, {status.status}") + if status.result_type != "polars": + raise WorkerTaskError(f"Result type is not polars, {status.result_type}") + return get_df_result(status.results) + + def clear_result(self, file_ref: str) -> bool: + if not OFFLOAD_TO_WORKER: + return False + return self.transport.clear_task(file_ref) diff --git a/flowfile_core/flowfile_core/flowfile/execution/exceptions.py b/flowfile_core/flowfile_core/flowfile/execution/exceptions.py new file mode 100644 index 000000000..4399d6ddb --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/execution/exceptions.py @@ -0,0 +1,23 @@ +"""Typed errors for core -> worker communication.""" + +import requests + + +class WorkerError(Exception): + """Base class for worker-communication failures.""" + + +class WorkerConnectionError(WorkerError, requests.exceptions.ConnectionError): + """The worker service is unreachable. + + Also a ``requests.ConnectionError`` so pre-existing + ``except requests.RequestException`` call sites keep catching it. + """ + + +class WorkerTaskError(WorkerError): + """The worker was reachable but the task request failed.""" + + +class TaskCancelledError(WorkerError): + """The task was cancelled before it completed.""" diff --git a/flowfile_core/flowfile_core/flowfile/execution/handles.py b/flowfile_core/flowfile_core/flowfile/execution/handles.py new file mode 100644 index 000000000..43bea7ff7 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/execution/handles.py @@ -0,0 +1,57 @@ +"""Task handles: a uniform view over in-process and worker-backed results.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations.models import Status + + +@runtime_checkable +class TaskHandle(Protocol): + """What node execution needs from a submitted compute task. + + ``BaseFetcher`` (worker-backed) satisfies this protocol; + ``LocalResultHandle`` is the in-process equivalent. + """ + + file_ref: str + status: Status | None + + def get_result(self) -> Any: ... + + def cancel(self) -> None: ... + + @property + def error_code(self) -> int: ... + + @property + def error_description(self) -> str | None: ... + + +class LocalResultHandle: + """Immediately-available in-process result satisfying ``TaskHandle``.""" + + def __init__(self, result: Any, file_ref: str = "", status: Status | None = None): + self._result = result + self.file_ref = file_ref + self.status = status + + def get_result(self) -> Any: + return self._result + + @property + def result(self) -> Any: + return self._result + + def cancel(self) -> None: + return None + + @property + def error_code(self) -> int: + return 0 + + @property + def error_description(self) -> str | None: + return None diff --git a/flowfile_core/flowfile_core/flowfile/execution/transport.py b/flowfile_core/flowfile_core/flowfile/execution/transport.py new file mode 100644 index 000000000..04e80c869 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/execution/transport.py @@ -0,0 +1,146 @@ +"""HTTP/WebSocket transport to the flowfile_worker service. + +Single owner of worker URLs and request plumbing. Fetchers and trigger +functions talk to a ``WorkerTransport`` instead of formatting ``WORKER_URL`` +themselves, so an alternative worker (or a future pool) can be injected. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +import requests + +from flowfile_core.flowfile.execution.exceptions import WorkerConnectionError, WorkerTaskError + +if TYPE_CHECKING: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations.models import Status + + +class WorkerTransport: + """Client for one worker service, addressed by base URL. + + ``base_url`` may be a string, a zero-arg callable (resolved per request), + or ``None`` to follow ``settings.WORKER_URL`` (which already accounts for + single-file mode's ``/worker`` prefix). + """ + + def __init__(self, base_url: str | Callable[[], str] | None = None): + self._base_url = base_url + + @property + def base_url(self) -> str: + if callable(self._base_url): + return self._base_url() + if self._base_url is not None: + return self._base_url + from flowfile_core.configs.settings import WORKER_URL + + return WORKER_URL + + @property + def ws_base_url(self) -> str: + return self.base_url.replace("http://", "ws://").replace("https://", "wss://") + + # -- raw verbs ----------------------------------------------------------- + + def post(self, path: str, **kwargs: Any) -> requests.Response: + return self._request("post", path, **kwargs) + + def get(self, path: str, **kwargs: Any) -> requests.Response: + return self._request("get", path, **kwargs) + + def delete(self, path: str, **kwargs: Any) -> requests.Response: + return self._request("delete", path, **kwargs) + + def _request(self, method: str, path: str, **kwargs: Any) -> requests.Response: + url = f"{self.base_url}{path}" + try: + return requests.request(method, url, **kwargs) + except requests.exceptions.ConnectionError as e: + raise WorkerConnectionError(f"Could not connect to the worker at {url}: {e}") from e + + # -- task lifecycle -------------------------------------------------------- + + def get_status(self, task_id: str, timeout: float | None = None) -> Status: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations.models import Status + + response = self.get(f"/status/{task_id}", timeout=timeout) + if response.status_code != 200: + raise WorkerTaskError(f"Could not fetch the status, {response.text}") + return Status(**response.json()) + + def results_exist(self, task_id: str) -> bool: + try: + response = self.get(f"/status/{task_id}") + return response.status_code == 200 and response.json()["status"] == "Completed" + except requests.RequestException: + return False + + def cancel_task(self, task_id: str) -> bool: + try: + return self.post(f"/cancel_task/{task_id}").ok + except requests.RequestException as e: + raise WorkerTaskError(f"Failed to cancel task: {e}") from e + + def clear_task(self, task_id: str) -> bool: + try: + return self.delete(f"/clear_task/{task_id}").status_code == 200 + except requests.RequestException: + return False + + # -- WebSocket streaming --------------------------------------------------- + + def streaming_submit( + self, + task_id: str, + operation_type: str, + flow_id: int, + node_id: int | str, + lf_bytes: bytes, + kwargs: dict | None = None, + ) -> tuple[Any, Status]: + from flowfile_core.flowfile.flow_data_engine.subprocess_operations.streaming import streaming_submit + + return streaming_submit( + task_id=task_id, + operation_type=operation_type, + flow_id=flow_id, + node_id=node_id, + lf_bytes=lf_bytes, + kwargs=kwargs, + ws_base_url=self.ws_base_url, + ) + + def streaming_start( + self, + task_id: str, + operation_type: str, + flow_id: int, + node_id: int | str, + lf_bytes: bytes, + kwargs: dict | None = None, + ): + from flowfile_core.flowfile.flow_data_engine.subprocess_operations.streaming import streaming_start + + return streaming_start( + task_id=task_id, + operation_type=operation_type, + flow_id=flow_id, + node_id=node_id, + lf_bytes=lf_bytes, + kwargs=kwargs, + ws_base_url=self.ws_base_url, + ) + + +_default_transport: WorkerTransport | None = None + + +def get_default_transport() -> WorkerTransport: + """The process-wide transport pointing at ``settings.WORKER_URL``.""" + global _default_transport + if _default_transport is None: + _default_transport = WorkerTransport() + return _default_transport diff --git a/flowfile_core/flowfile_core/flowfile/flow_data_engine/subprocess_operations/streaming.py b/flowfile_core/flowfile_core/flowfile/flow_data_engine/subprocess_operations/streaming.py index 6687c45b1..290578bf1 100644 --- a/flowfile_core/flowfile_core/flowfile/flow_data_engine/subprocess_operations/streaming.py +++ b/flowfile_core/flowfile_core/flowfile/flow_data_engine/subprocess_operations/streaming.py @@ -17,12 +17,15 @@ import polars as pl from websockets.sync.client import connect -from flowfile_core.configs.settings import WORKER_URL from flowfile_core.flowfile.flow_data_engine.subprocess_operations.models import Status -def _get_ws_url() -> str: +def _get_ws_url(ws_base_url: str | None = None) -> str: """Convert HTTP worker URL to WebSocket URL.""" + if ws_base_url is not None: + return ws_base_url + from flowfile_core.configs.settings import WORKER_URL + return WORKER_URL.replace("http://", "ws://").replace("https://", "wss://") @@ -139,6 +142,7 @@ def streaming_start( node_id: int | str, lf_bytes: bytes, kwargs: dict | None = None, + ws_base_url: str | None = None, ): """Open a WebSocket connection and send the task. @@ -148,7 +152,7 @@ def streaming_start( Raises immediately on connection failure or send error. """ - ws_url = _get_ws_url() + "/ws/submit" + ws_url = _get_ws_url(ws_base_url) + "/ws/submit" metadata = _build_metadata(task_id, operation_type, flow_id, node_id, kwargs) ws = connect(ws_url) @@ -196,11 +200,12 @@ def streaming_submit( node_id: int | str, lf_bytes: bytes, kwargs: dict | None = None, + ws_base_url: str | None = None, ) -> tuple[Any, Status]: """Submit a task via WebSocket and block until the result arrives. Convenience wrapper around :func:`streaming_start` + :func:`streaming_receive`. """ - ws = streaming_start(task_id, operation_type, flow_id, node_id, lf_bytes, kwargs) + ws = streaming_start(task_id, operation_type, flow_id, node_id, lf_bytes, kwargs, ws_base_url) return streaming_receive(ws, task_id) diff --git a/flowfile_core/flowfile_core/flowfile/flow_data_engine/subprocess_operations/subprocess_operations.py b/flowfile_core/flowfile_core/flowfile/flow_data_engine/subprocess_operations/subprocess_operations.py index c7be0ed4f..e8ba144ce 100644 --- a/flowfile_core/flowfile_core/flowfile/flow_data_engine/subprocess_operations/subprocess_operations.py +++ b/flowfile_core/flowfile_core/flowfile/flow_data_engine/subprocess_operations/subprocess_operations.py @@ -11,7 +11,9 @@ from pl_fuzzy_frame_match.models import FuzzyMapping from flowfile_core.configs import logger -from flowfile_core.configs.settings import OFFLOAD_TO_WORKER, WORKER_URL +from flowfile_core.configs.settings import OFFLOAD_TO_WORKER +from flowfile_core.flowfile.execution.exceptions import WorkerTaskError +from flowfile_core.flowfile.execution.transport import WorkerTransport, get_default_transport from flowfile_core.flowfile.flow_data_engine.subprocess_operations.models import ( ApplyModelInput, FuzzyJoinInput, @@ -22,8 +24,6 @@ ) from flowfile_core.flowfile.flow_data_engine.subprocess_operations.streaming import ( streaming_receive, - streaming_start, - streaming_submit, ) from flowfile_core.flowfile.sources.external_sources.sql_source.models import ( DatabaseExternalReadSettings, @@ -43,6 +43,7 @@ def trigger_df_operation( file_ref: str, operation_type: OperationType = "store", kwargs: dict | None = None, + transport: WorkerTransport | None = None, ) -> Status: # Send raw bytes directly - no base64 encoding overhead headers = { @@ -54,14 +55,19 @@ def trigger_df_operation( } if kwargs: headers["X-Kwargs"] = json.dumps(kwargs) - v = requests.post(url=f"{WORKER_URL}/submit_query/", data=lf.serialize(), headers=headers) + v = (transport or get_default_transport()).post("/submit_query/", data=lf.serialize(), headers=headers) if not v.ok: - raise Exception(f"trigger_df_operation: Could not cache the data, {v.text}") + raise WorkerTaskError(f"trigger_df_operation: Could not cache the data, {v.text}") return Status(**v.json()) def trigger_sample_operation( - lf: pl.LazyFrame, file_ref: str, flow_id: int, node_id: str | int, sample_size: int = 100 + lf: pl.LazyFrame, + file_ref: str, + flow_id: int, + node_id: str | int, + sample_size: int = 100, + transport: WorkerTransport | None = None, ) -> Status: # Send raw bytes directly - no base64 encoding overhead headers = { @@ -72,9 +78,9 @@ def trigger_sample_operation( "X-Flow-Id": str(flow_id), "X-Node-Id": str(node_id), } - v = requests.post(url=f"{WORKER_URL}/store_sample/", data=lf.serialize(), headers=headers) + v = (transport or get_default_transport()).post("/store_sample/", data=lf.serialize(), headers=headers) if not v.ok: - raise Exception(f"trigger_sample_operation: Could not cache the data, {v.text}") + raise WorkerTaskError(f"trigger_sample_operation: Could not cache the data, {v.text}") return Status(**v.json()) @@ -85,6 +91,7 @@ def trigger_fuzzy_match_operation( file_ref: str, flow_id: int, node_id: int | str, + transport: WorkerTransport | None = None, ) -> Status: # Use raw bytes - Pydantic will handle single base64 encoding for JSON transport left_serializable_object = PolarsOperation(operation=left_df.serialize()) @@ -97,9 +104,9 @@ def trigger_fuzzy_match_operation( flowfile_flow_id=flow_id, flowfile_node_id=node_id, ) - v = requests.post(f"{WORKER_URL}/add_fuzzy_join", data=fuzzy_join_input.model_dump_json()) + v = (transport or get_default_transport()).post("/add_fuzzy_join", data=fuzzy_join_input.model_dump_json()) if not v.ok: - raise Exception(f"trigger_fuzzy_match_operation: Could not cache the data, {v.text}") + raise WorkerTaskError(f"trigger_fuzzy_match_operation: Could not cache the data, {v.text}") return Status(**v.json()) @@ -113,6 +120,7 @@ def trigger_train_model_operation( file_ref: str, flow_id: int, node_id: int | str, + transport: WorkerTransport | None = None, ) -> Status: """Submit a training job to the worker. @@ -130,9 +138,9 @@ def trigger_train_model_operation( flowfile_flow_id=flow_id, flowfile_node_id=node_id, ) - v = requests.post(f"{WORKER_URL}/train_ml_model", data=payload.model_dump_json()) + v = (transport or get_default_transport()).post("/train_ml_model", data=payload.model_dump_json()) if not v.ok: - raise Exception(f"trigger_train_model_operation: Could not start training, {v.text}") + raise WorkerTaskError(f"trigger_train_model_operation: Could not start training, {v.text}") return Status(**v.json()) @@ -143,6 +151,7 @@ def trigger_apply_model_operation( file_ref: str, flow_id: int, node_id: int | str, + transport: WorkerTransport | None = None, ) -> Status: """Submit an apply-model job to the worker.""" payload = ApplyModelInput( @@ -153,9 +162,9 @@ def trigger_apply_model_operation( flowfile_flow_id=flow_id, flowfile_node_id=node_id, ) - v = requests.post(f"{WORKER_URL}/apply_ml_model", data=payload.model_dump_json()) + v = (transport or get_default_transport()).post("/apply_ml_model", data=payload.model_dump_json()) if not v.ok: - raise Exception(f"trigger_apply_model_operation: Could not start scoring, {v.text}") + raise WorkerTaskError(f"trigger_apply_model_operation: Could not start scoring, {v.text}") return Status(**v.json()) @@ -164,42 +173,48 @@ def trigger_create_operation( node_id: int | str, received_table: ReceivedTable, file_type: str = Literal["csv", "parquet", "json", "excel", "ipc", "ndjson", "avro"], + transport: WorkerTransport | None = None, ): - f = requests.post( - url=f"{WORKER_URL}/create_table/{file_type}", + f = (transport or get_default_transport()).post( + f"/create_table/{file_type}", data=received_table.model_dump_json(), params={"flowfile_flow_id": flow_id, "flowfile_node_id": node_id}, ) if not f.ok: - raise Exception(f"trigger_create_operation: Could not cache the data, {f.text}") + raise WorkerTaskError(f"trigger_create_operation: Could not cache the data, {f.text}") return Status(**f.json()) -def trigger_database_read_collector(database_external_read_settings: DatabaseExternalReadSettings): - f = requests.post( - url=f"{WORKER_URL}/store_database_read_result", data=database_external_read_settings.model_dump_json() +def trigger_database_read_collector( + database_external_read_settings: DatabaseExternalReadSettings, + transport: WorkerTransport | None = None, +): + f = (transport or get_default_transport()).post( + "/store_database_read_result", data=database_external_read_settings.model_dump_json() ) if not f.ok: - raise Exception(f"trigger_database_read_collector: Could not cache the data, {f.text}") + raise WorkerTaskError(f"trigger_database_read_collector: Could not cache the data, {f.text}") return Status(**f.json()) -def trigger_kafka_read(kafka_read_settings) -> Status: +def trigger_kafka_read(kafka_read_settings, transport: WorkerTransport | None = None) -> Status: """Send a Kafka read request to the worker service.""" - f = requests.post(url=f"{WORKER_URL}/store_kafka_read_result", data=kafka_read_settings.model_dump_json()) + f = (transport or get_default_transport()).post( + "/store_kafka_read_result", data=kafka_read_settings.model_dump_json() + ) if not f.ok: - raise Exception(f"trigger_kafka_read: Could not read from Kafka, {f.text}") + raise WorkerTaskError(f"trigger_kafka_read: Could not read from Kafka, {f.text}") return Status(**f.json()) -def fetch_kafka_offsets(task_id: str) -> dict | None: +def fetch_kafka_offsets(task_id: str, transport: WorkerTransport | None = None) -> dict | None: """Fetch deferred Kafka offset data from the worker for a completed task. Returns a dict with ``new_offsets``, ``messages_consumed``, etc. from the KafkaReadResult that was saved as a sidecar file, or ``None`` if no offsets were recorded (e.g. empty topic). """ - f = requests.get(f"{WORKER_URL}/kafka_offsets/{task_id}") + f = (transport or get_default_transport()).get(f"/kafka_offsets/{task_id}") if not f.ok: logger.warning("Failed to fetch Kafka offsets for task %s: %s", task_id, f.text) return None @@ -207,35 +222,45 @@ def fetch_kafka_offsets(task_id: str) -> dict | None: return data -def trigger_google_analytics_read(ga_read_settings) -> Status: +def trigger_google_analytics_read(ga_read_settings, transport: WorkerTransport | None = None) -> Status: """Send a Google Analytics 4 read request to the worker service.""" - f = requests.post(url=f"{WORKER_URL}/store_google_analytics_read_result", data=ga_read_settings.model_dump_json()) + f = (transport or get_default_transport()).post( + "/store_google_analytics_read_result", data=ga_read_settings.model_dump_json() + ) if not f.ok: - raise Exception(f"trigger_google_analytics_read: Could not read from GA, {f.text}") + raise WorkerTaskError(f"trigger_google_analytics_read: Could not read from GA, {f.text}") return Status(**f.json()) -def trigger_rest_api_read(settings) -> Status: +def trigger_rest_api_read(settings, transport: WorkerTransport | None = None) -> Status: """Send a REST API read request to the worker service.""" - f = requests.post(url=f"{WORKER_URL}/store_rest_api_read_result", data=settings.model_dump_json()) + f = (transport or get_default_transport()).post("/store_rest_api_read_result", data=settings.model_dump_json()) if not f.ok: - raise Exception(f"trigger_rest_api_read: Could not read from the REST API, {f.text}") + raise WorkerTaskError(f"trigger_rest_api_read: Could not read from the REST API, {f.text}") return Status(**f.json()) -def trigger_database_write(database_external_write_settings: DatabaseExternalWriteSettings): - f = requests.post( - url=f"{WORKER_URL}/store_database_write_result", data=database_external_write_settings.model_dump_json() +def trigger_database_write( + database_external_write_settings: DatabaseExternalWriteSettings, + transport: WorkerTransport | None = None, +): + f = (transport or get_default_transport()).post( + "/store_database_write_result", data=database_external_write_settings.model_dump_json() ) if not f.ok: - raise Exception(f"trigger_database_write: Could not cache the data, {f.text}") + raise WorkerTaskError(f"trigger_database_write: Could not cache the data, {f.text}") return Status(**f.json()) -def trigger_cloud_storage_write(database_external_write_settings: CloudStorageWriteSettingsWorkerInterface): - f = requests.post(url=f"{WORKER_URL}/write_data_to_cloud", data=database_external_write_settings.model_dump_json()) +def trigger_cloud_storage_write( + database_external_write_settings: CloudStorageWriteSettingsWorkerInterface, + transport: WorkerTransport | None = None, +): + f = (transport or get_default_transport()).post( + "/write_data_to_cloud", data=database_external_write_settings.model_dump_json() + ) if not f.ok: - raise Exception(f"trigger_cloud_storage_write: Could not cache the data, {f.text}") + raise WorkerTaskError(f"trigger_cloud_storage_write: Could not cache the data, {f.text}") return Status(**f.json()) @@ -249,12 +274,13 @@ def trigger_write_output( sheet_name: str | None = None, delimiter: str | None = None, compression: str | None = None, + transport: WorkerTransport | None = None, ) -> Status: from base64 import encodebytes serializable_df = lf.serialize() - r = requests.post( - f"{WORKER_URL}/write_results/", + r = (transport or get_default_transport()).post( + "/write_results/", json={ "operation": encodebytes(serializable_df).decode(), "data_type": data_type, @@ -268,7 +294,7 @@ def trigger_write_output( }, ) if not r.ok: - raise Exception(f"trigger_write_output: Could not write the data, {r.text}") + raise WorkerTaskError(f"trigger_write_output: Could not write the data, {r.text}") return Status(**r.json()) @@ -283,7 +309,7 @@ def trigger_catalog_materialize( } if storage is not None: payload["storage"] = storage - response = requests.post(f"{WORKER_URL}/catalog/materialize", json=payload) + response = get_default_transport().post("/catalog/materialize", json=payload) return response @@ -305,7 +331,7 @@ def trigger_resolve_virtual_table( "plan_bytes": b64encode(plan_bytes).decode("ascii"), "source_versions_hash": source_versions_hash, } - response = requests.post(f"{WORKER_URL}/flow/resolve_virtual_table", json=payload, timeout=300) + response = get_default_transport().post("/flow/resolve_virtual_table", json=payload, timeout=300) if not response.ok: raise RuntimeError(f"Worker resolve_virtual_table failed: {response.text}") return response.json() @@ -331,7 +357,7 @@ def trigger_sql_query( payload["virtual_refs"] = virtual_refs if storage is not None: payload["storage"] = storage - response = requests.post(f"{WORKER_URL}/catalog/sql_query", json=payload) + response = get_default_transport().post("/catalog/sql_query", json=payload) if not response.ok: raise RuntimeError(f"Worker SQL query execution failed: {response.text}") return response.json() @@ -354,7 +380,7 @@ def trigger_visualize_query(worker_source: dict, payload: dict, max_rows: int) - max_rows, ) body = {"source": worker_source, "payload": payload, "max_rows": max_rows} - response = requests.post(f"{WORKER_URL}/catalog/visualize_query", json=body, timeout=HTTP_TIMEOUT_SECONDS) + response = get_default_transport().post("/catalog/visualize_query", json=body, timeout=HTTP_TIMEOUT_SECONDS) if not response.ok: logger.warning( "[viz] <- worker /catalog/visualize_query session_key=%s status=%d body=%s", @@ -384,7 +410,7 @@ def trigger_visualize_fields(worker_source: dict) -> dict: worker_source.get("kind"), ) body = {"source": worker_source} - response = requests.post(f"{WORKER_URL}/catalog/visualize_fields", json=body, timeout=30) + response = get_default_transport().post("/catalog/visualize_fields", json=body, timeout=30) if not response.ok: logger.warning( "[viz] <- worker /catalog/visualize_fields session_key=%s status=%d body=%s", @@ -415,7 +441,7 @@ def trigger_visualize_column_stats(worker_source: dict, column: str, limit: int) limit, ) body = {"source": worker_source, "column": column, "limit": limit} - response = requests.post(f"{WORKER_URL}/catalog/visualize_column_stats", json=body, timeout=HTTP_TIMEOUT_SECONDS) + response = get_default_transport().post("/catalog/visualize_column_stats", json=body, timeout=HTTP_TIMEOUT_SECONDS) if not response.ok: logger.warning( "[viz] <- worker /catalog/visualize_column_stats session_key=%s status=%d body=%s", @@ -426,7 +452,8 @@ def trigger_visualize_column_stats(worker_source: dict, column: str, limit: int) raise RuntimeError(f"Worker visualize_column_stats failed: {response.text}") data = response.json() logger.info( - "[viz] <- worker /catalog/visualize_column_stats session_key=%s status=%d cache_hit=%s value_count=%d truncated=%s", + "[viz] <- worker /catalog/visualize_column_stats session_key=%s status=%d cache_hit=%s " + "value_count=%d truncated=%s", session_key, response.status_code, data.get("cache_hit"), @@ -439,8 +466,8 @@ def trigger_visualize_column_stats(worker_source: dict, column: str, limit: int) def trigger_visualize_evict(session_key: str) -> None: """Ask the worker to drop a cached viz session (e.g. after a table update).""" logger.info("[viz] -> worker /catalog/visualize_evict session_key=%s", session_key) - response = requests.post( - f"{WORKER_URL}/catalog/visualize_evict", + response = get_default_transport().post( + "/catalog/visualize_evict", params={"session_key": session_key}, timeout=10, ) @@ -459,7 +486,7 @@ def trigger_read_table_metadata(table_name: str, storage: dict | None = None) -> payload = {"table_path": table_name} if storage is not None: payload["storage"] = storage - response = requests.post(f"{WORKER_URL}/catalog/table_metadata", json=payload) + response = get_default_transport().post("/catalog/table_metadata", json=payload) if not response.ok: raise RuntimeError(f"Worker table metadata read failed: {response.text}") return response.json() @@ -478,7 +505,7 @@ def trigger_delta_history( payload = {"table_path": table_name, "limit": limit} if storage is not None: payload["storage"] = storage - response = requests.post(f"{WORKER_URL}/catalog/delta_history", json=payload) + response = get_default_transport().post("/catalog/delta_history", json=payload) if not response.ok: raise RuntimeError(f"Worker delta history read failed: {response.text}") return DeltaTableHistory.model_validate(response.json()) @@ -498,7 +525,7 @@ def trigger_delta_version_preview( payload = {"table_path": table_name, "version": version, "n_rows": n_rows} if storage is not None: payload["storage"] = storage - response = requests.post(f"{WORKER_URL}/catalog/delta_version_preview", json=payload) + response = get_default_transport().post("/catalog/delta_version_preview", json=payload) if not response.ok: raise RuntimeError(f"Worker delta version preview failed: {response.text}") return CatalogTablePreview.model_validate(response.json()) @@ -517,7 +544,7 @@ def trigger_delta_preview( payload = {"table_path": table_name, "n_rows": n_rows} if storage is not None: payload["storage"] = storage - response = requests.post(f"{WORKER_URL}/catalog/delta_preview", json=payload) + response = get_default_transport().post("/catalog/delta_preview", json=payload) if not response.ok: raise RuntimeError(f"Worker delta preview failed: {response.text}") return CatalogTablePreview.model_validate(response.json()) @@ -536,7 +563,7 @@ def trigger_optimize_catalog_table( payload = {"table_path": table_name, "z_order_columns": z_order_columns} if storage is not None: payload["storage"] = storage - response = requests.post(f"{WORKER_URL}/catalog/optimize", json=payload, timeout=600) + response = get_default_transport().post("/catalog/optimize", json=payload, timeout=600) if not response.ok: raise RuntimeError(f"Worker optimize failed: {response.text}") return response.json() @@ -556,14 +583,14 @@ def trigger_vacuum_catalog_table( payload = {"table_path": table_name, "retention_hours": retention_hours, "dry_run": dry_run} if storage is not None: payload["storage"] = storage - response = requests.post(f"{WORKER_URL}/catalog/vacuum", json=payload, timeout=600) + response = get_default_transport().post("/catalog/vacuum", json=payload, timeout=600) if not response.ok: raise RuntimeError(f"Worker vacuum failed: {response.text}") return response.json() def get_results(file_ref: str) -> Status | None: - f = requests.get(f"{WORKER_URL}/status/{file_ref}") + f = get_default_transport().get(f"/status/{file_ref}") if f.status_code == 200: return Status(**f.json()) else: @@ -575,7 +602,7 @@ def results_exists(file_ref: str): return False try: - f = requests.get(f"{WORKER_URL}/status/{file_ref}") + f = get_default_transport().get(f"/status/{file_ref}") if f.status_code == 200: if f.json()["status"] == "Completed": return True @@ -602,7 +629,7 @@ def clear_task_from_worker(file_ref: str) -> bool: return False try: - f = requests.delete(f"{WORKER_URL}/clear_task/{file_ref}") + f = get_default_transport().delete(f"/clear_task/{file_ref}") if f.status_code == 200: return True return False @@ -627,7 +654,7 @@ def get_external_df_result(file_ref: str) -> pl.LazyFrame | None: def get_status(file_ref: str) -> Status: - status_response = requests.get(f"{WORKER_URL}/status/{file_ref}") + status_response = get_default_transport().get(f"/status/{file_ref}") if status_response.status_code == 200: return Status(**status_response.json()) else: @@ -648,7 +675,7 @@ def cancel_task(file_ref: str) -> bool: Exception: If there's an error communicating with the worker service """ try: - response = requests.post(f"{WORKER_URL}/cancel_task/{file_ref}") + response = get_default_transport().post(f"/cancel_task/{file_ref}") if response.ok: return True return False @@ -661,8 +688,11 @@ class BaseFetcher: Thread-safe fetcher for polling worker status and retrieving results. """ - def __init__(self, file_ref: str = None): + status: Status | None = None + + def __init__(self, file_ref: str = None, transport: WorkerTransport | None = None): self.file_ref = file_ref if file_ref else str(uuid4()) + self._transport = transport or get_default_transport() # Thread synchronization self._lock = threading.Lock() @@ -730,7 +760,7 @@ def _fetch_cached_df(self): try: while not self._stop_event.is_set(): try: - r = requests.get(f"{WORKER_URL}/status/{self.file_ref}", timeout=10) + r = self._transport.get(f"/status/{self.file_ref}", timeout=10) if r.status_code == 200: status = Status(**r.json()) @@ -833,7 +863,7 @@ def cancel(self): pass try: - cancel_task(self.file_ref) + self._transport.cancel_task(self.file_ref) except Exception as e: logger.error(f"Failed to cancel task on worker: {str(e)}") @@ -908,7 +938,7 @@ def _execute_streaming( Raises on connection or send error so the caller can fall back to REST. """ if blocking: - result, status = streaming_submit( + result, status = self._transport.streaming_submit( task_id=self.file_ref, operation_type=operation_type, flow_id=flow_id, @@ -922,7 +952,7 @@ def _execute_streaming( self._started = True self.status = status else: - ws = streaming_start( + ws = self._transport.streaming_start( task_id=self.file_ref, operation_type=operation_type, flow_id=flow_id, @@ -974,8 +1004,9 @@ def __init__( operation_type: OperationType = "store", offload_to_worker: bool = True, kwargs: dict | None = None, + transport: WorkerTransport | None = None, ): - super().__init__(file_ref=file_ref) + super().__init__(file_ref=file_ref, transport=transport) lf = lf.lazy() if isinstance(lf, pl.DataFrame) else lf try: @@ -999,11 +1030,12 @@ def __init__( node_id=node_id, flow_id=flow_id, kwargs=kwargs, + transport=self._transport, ) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() - self.status = get_status(self.file_ref) + self.status = self._transport.get_status(self.file_ref) class ExternalSampler(BaseFetcher): @@ -1017,8 +1049,9 @@ def __init__( file_ref: str = None, wait_on_completion: bool = True, sample_size: int = 100, + transport: WorkerTransport | None = None, ): - super().__init__(file_ref=file_ref) + super().__init__(file_ref=file_ref, transport=transport) lf = lf.lazy() if isinstance(lf, pl.DataFrame) else lf try: @@ -1036,12 +1069,17 @@ def __init__( # REST fallback (original behavior) r = trigger_sample_operation( - lf=lf, file_ref=file_ref, sample_size=sample_size, node_id=node_id, flow_id=flow_id + lf=lf, + file_ref=file_ref, + sample_size=sample_size, + node_id=node_id, + flow_id=flow_id, + transport=self._transport, ) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() - self.status = get_status(self.file_ref) + self.status = self._transport.get_status(self.file_ref) class ExternalFuzzyMatchFetcher(BaseFetcher): @@ -1054,8 +1092,9 @@ def __init__( node_id: int | str, file_ref: str = None, wait_on_completion: bool = True, + transport: WorkerTransport | None = None, ): - super().__init__(file_ref=file_ref) + super().__init__(file_ref=file_ref, transport=transport) r = trigger_fuzzy_match_operation( left_df=left_df, @@ -1064,6 +1103,7 @@ def __init__( file_ref=file_ref, flow_id=flow_id, node_id=node_id, + transport=self._transport, ) self.file_ref = r.background_task_id self.running = r.status == "Processing" @@ -1091,8 +1131,9 @@ def __init__( node_id: int | str, file_ref: str, wait_on_completion: bool = True, + transport: WorkerTransport | None = None, ): - super().__init__(file_ref=file_ref) + super().__init__(file_ref=file_ref, transport=transport) lf = lf.lazy() if isinstance(lf, pl.DataFrame) else lf r = trigger_train_model_operation( lf=lf, @@ -1104,6 +1145,7 @@ def __init__( file_ref=file_ref, flow_id=flow_id, node_id=node_id, + transport=self._transport, ) self.file_ref = r.background_task_id self.running = r.status == "Processing" @@ -1123,8 +1165,9 @@ def __init__( node_id: int | str, file_ref: str, wait_on_completion: bool = True, + transport: WorkerTransport | None = None, ): - super().__init__(file_ref=file_ref) + super().__init__(file_ref=file_ref, transport=transport) lf = lf.lazy() if isinstance(lf, pl.DataFrame) else lf r = trigger_apply_model_operation( lf=lf, @@ -1133,6 +1176,7 @@ def __init__( file_ref=file_ref, flow_id=flow_id, node_id=node_id, + transport=self._transport, ) self.file_ref = r.background_task_id self.running = r.status == "Processing" @@ -1148,20 +1192,30 @@ def __init__( flow_id: int, file_type: str = "csv", wait_on_completion: bool = True, + transport: WorkerTransport | None = None, ): + transport = transport or get_default_transport() r = trigger_create_operation( - received_table=received_table, file_type=file_type, node_id=node_id, flow_id=flow_id + received_table=received_table, file_type=file_type, node_id=node_id, flow_id=flow_id, transport=transport ) - super().__init__(file_ref=r.background_task_id) + super().__init__(file_ref=r.background_task_id, transport=transport) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() class ExternalDatabaseFetcher(BaseFetcher): - def __init__(self, database_external_read_settings: DatabaseExternalReadSettings, wait_on_completion: bool = True): - r = trigger_database_read_collector(database_external_read_settings=database_external_read_settings) - super().__init__(file_ref=r.background_task_id) + def __init__( + self, + database_external_read_settings: DatabaseExternalReadSettings, + wait_on_completion: bool = True, + transport: WorkerTransport | None = None, + ): + transport = transport or get_default_transport() + r = trigger_database_read_collector( + database_external_read_settings=database_external_read_settings, transport=transport + ) + super().__init__(file_ref=r.background_task_id, transport=transport) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() @@ -1170,9 +1224,10 @@ def __init__(self, database_external_read_settings: DatabaseExternalReadSettings class ExternalKafkaFetcher(BaseFetcher): """Fetches data from Kafka via the worker service. Same pattern as ExternalDatabaseFetcher.""" - def __init__(self, kafka_read_settings, wait_on_completion: bool = True): - r = trigger_kafka_read(kafka_read_settings=kafka_read_settings) - super().__init__(file_ref=r.background_task_id) + def __init__(self, kafka_read_settings, wait_on_completion: bool = True, transport: WorkerTransport | None = None): + transport = transport or get_default_transport() + r = trigger_kafka_read(kafka_read_settings=kafka_read_settings, transport=transport) + super().__init__(file_ref=r.background_task_id, transport=transport) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() @@ -1181,9 +1236,10 @@ def __init__(self, kafka_read_settings, wait_on_completion: bool = True): class ExternalGoogleAnalyticsFetcher(BaseFetcher): """Fetches GA4 data via the worker service. Same pattern as ExternalDatabaseFetcher.""" - def __init__(self, ga_read_settings, wait_on_completion: bool = True): - r = trigger_google_analytics_read(ga_read_settings=ga_read_settings) - super().__init__(file_ref=r.background_task_id) + def __init__(self, ga_read_settings, wait_on_completion: bool = True, transport: WorkerTransport | None = None): + transport = transport or get_default_transport() + r = trigger_google_analytics_read(ga_read_settings=ga_read_settings, transport=transport) + super().__init__(file_ref=r.background_task_id, transport=transport) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() @@ -1192,9 +1248,10 @@ def __init__(self, ga_read_settings, wait_on_completion: bool = True): class ExternalRestApiFetcher(BaseFetcher): """Fetches REST API data via the worker service. Same pattern as ExternalDatabaseFetcher.""" - def __init__(self, settings, wait_on_completion: bool = True): - r = trigger_rest_api_read(settings=settings) - super().__init__(file_ref=r.background_task_id) + def __init__(self, settings, wait_on_completion: bool = True, transport: WorkerTransport | None = None): + transport = transport or get_default_transport() + r = trigger_rest_api_read(settings=settings, transport=transport) + super().__init__(file_ref=r.background_task_id, transport=transport) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() @@ -1202,10 +1259,16 @@ def __init__(self, settings, wait_on_completion: bool = True): class ExternalDatabaseWriter(BaseFetcher): def __init__( - self, database_external_write_settings: DatabaseExternalWriteSettings, wait_on_completion: bool = True + self, + database_external_write_settings: DatabaseExternalWriteSettings, + wait_on_completion: bool = True, + transport: WorkerTransport | None = None, ): - r = trigger_database_write(database_external_write_settings=database_external_write_settings) - super().__init__(file_ref=r.background_task_id) + transport = transport or get_default_transport() + r = trigger_database_write( + database_external_write_settings=database_external_write_settings, transport=transport + ) + super().__init__(file_ref=r.background_task_id, transport=transport) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() @@ -1213,10 +1276,16 @@ def __init__( class ExternalCloudWriter(BaseFetcher): def __init__( - self, cloud_storage_write_settings: CloudStorageWriteSettingsWorkerInterface, wait_on_completion: bool = True + self, + cloud_storage_write_settings: CloudStorageWriteSettingsWorkerInterface, + wait_on_completion: bool = True, + transport: WorkerTransport | None = None, ): - r = trigger_cloud_storage_write(database_external_write_settings=cloud_storage_write_settings) - super().__init__(file_ref=r.background_task_id) + transport = transport or get_default_transport() + r = trigger_cloud_storage_write( + database_external_write_settings=cloud_storage_write_settings, transport=transport + ) + super().__init__(file_ref=r.background_task_id, transport=transport) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() @@ -1241,7 +1310,9 @@ def __init__( delimiter: str | None = None, compression: str | None = None, wait_on_completion: bool = True, + transport: WorkerTransport | None = None, ): + transport = transport or get_default_transport() lf = lf.lazy() if isinstance(lf, pl.DataFrame) else lf r = trigger_write_output( lf=lf, @@ -1253,8 +1324,9 @@ def __init__( compression=compression, flow_id=flow_id, node_id=node_id, + transport=transport, ) - super().__init__(file_ref=r.background_task_id) + super().__init__(file_ref=r.background_task_id, transport=transport) self.running = r.status == "Processing" if wait_on_completion: _ = self.get_result() diff --git a/flowfile_core/flowfile_core/flowfile/flow_graph.py b/flowfile_core/flowfile_core/flowfile/flow_graph.py index 847eec76d..3e65254ce 100644 --- a/flowfile_core/flowfile_core/flowfile/flow_graph.py +++ b/flowfile_core/flowfile_core/flowfile/flow_graph.py @@ -41,7 +41,7 @@ get_encrypted_credential, get_ga_connection, ) -from flowfile_core.flowfile.filter_expressions import build_filter_expression +from flowfile_core.flowfile.execution.backends import ExecutionBackend, resolve_backend from flowfile_core.flowfile.flow_data_engine.flow_data_engine import ( FlowDataEngine, execute_polars_code, @@ -55,12 +55,10 @@ ) from flowfile_core.flowfile.flow_data_engine.subprocess_operations.subprocess_operations import ( ExternalCloudWriter, - ExternalDatabaseFetcher, ExternalDatabaseWriter, ExternalDfFetcher, ExternalGoogleAnalyticsFetcher, ExternalKafkaFetcher, - ExternalOutputWriter, ExternalRestApiFetcher, MLApplyFetcher, MLTrainFetcher, @@ -79,6 +77,8 @@ group_nodes_by_depth, ) from flowfile_core.flowfile.node_designer.custom_node import CustomNodeBase +from flowfile_core.flowfile.node_registry import get_node_spec +from flowfile_core.flowfile.node_registry.spec import NodeBuildContext, NodeSpec from flowfile_core.flowfile.parameter_resolver import ( apply_parameters_in_place, find_unresolved_in_model, @@ -2080,18 +2080,7 @@ def add_union(self, union_settings: input_schema.NodeUnion): Args: union_settings: The settings for the union operation. """ - - def _func(*flowfile_tables: FlowDataEngine): - dfs: list[pl.LazyFrame] | list[pl.DataFrame] = [flt.data_frame for flt in flowfile_tables] - return FlowDataEngine(pl.concat(dfs, how="diagonal_relaxed")) - - self.add_node_step( - node_id=union_settings.node_id, - function=_func, - node_type="union", - setting_input=union_settings, - input_node_ids=union_settings.depending_on_ids, - ) + self._add_from_spec(get_node_spec("union"), union_settings) def add_initial_node_analysis(self, node_promise: input_schema.NodePromise, track_history: bool = True): """Adds a data exploration/analysis node based on a node promise. @@ -2212,38 +2201,7 @@ def add_filter(self, filter_settings: input_schema.NodeFilter): Args: filter_settings: The settings for the filter operation. """ - - def _func(fl: FlowDataEngine): - is_advanced = filter_settings.filter_input.is_advanced() - - if is_advanced: - expression = filter_settings.filter_input.advanced_filter - else: - basic_filter = filter_settings.filter_input.basic_filter - if basic_filter is None: - logger.warning("Basic filter is None, returning unfiltered data") - return fl - - try: - field_data_type = fl.get_schema_column(basic_filter.field).generic_datatype() - except Exception: - field_data_type = None - - expression = build_filter_expression(basic_filter, field_data_type) - filter_settings.filter_input.advanced_filter = expression - - if filter_settings.split_mode: - return fl.filter_split(expression) - return fl.do_filter(expression) - - self.add_node_step( - filter_settings.node_id, - _func, - node_type="filter", - renew_schema=False, - setting_input=filter_settings, - input_node_ids=[filter_settings.depending_on_id], - ) + self._add_from_spec(get_node_spec("filter"), filter_settings) @with_history_capture(HistoryActionType.UPDATE_SETTINGS) def add_record_count(self, node_number_of_records: input_schema.NodeRecordCount): @@ -2252,17 +2210,7 @@ def add_record_count(self, node_number_of_records: input_schema.NodeRecordCount) Args: node_number_of_records: The settings for the record count operation. """ - - def _func(fl: FlowDataEngine) -> FlowDataEngine: - return fl.get_record_count() - - self.add_node_step( - node_id=node_number_of_records.node_id, - function=_func, - node_type="record_count", - setting_input=node_number_of_records, - input_node_ids=[node_number_of_records.depending_on_id], - ) + self._add_from_spec(get_node_spec("record_count"), node_number_of_records) @with_history_capture(HistoryActionType.UPDATE_SETTINGS) def add_polars_code(self, node_polars_code: input_schema.NodePolarsCode): @@ -3134,18 +3082,7 @@ def add_sort(self, sort_settings: input_schema.NodeSort) -> "FlowGraph": Returns: The `FlowGraph` instance for method chaining. """ - - def _func(table: FlowDataEngine) -> FlowDataEngine: - return table.do_sort(sort_settings.sort_input) - - self.add_node_step( - node_id=sort_settings.node_id, - function=_func, - node_type="sort", - setting_input=sort_settings, - input_node_ids=[sort_settings.depending_on_id], - ) - return self + return self._add_from_spec(get_node_spec("sort"), sort_settings) @with_history_capture(HistoryActionType.UPDATE_SETTINGS) def add_sample(self, sample_settings: input_schema.NodeSample) -> "FlowGraph": @@ -3157,18 +3094,7 @@ def add_sample(self, sample_settings: input_schema.NodeSample) -> "FlowGraph": Returns: The `FlowGraph` instance for method chaining. """ - - def _func(table: FlowDataEngine) -> FlowDataEngine: - return table.get_sample(sample_settings.sample_size) - - self.add_node_step( - node_id=sample_settings.node_id, - function=_func, - node_type="sample", - setting_input=sample_settings, - input_node_ids=[sample_settings.depending_on_id], - ) - return self + return self._add_from_spec(get_node_spec("sample"), sample_settings) @with_history_capture(HistoryActionType.UPDATE_SETTINGS) def add_random_split(self, settings: input_schema.NodeRandomSplit) -> "FlowGraph": @@ -3187,9 +3113,8 @@ def add_random_split(self, settings: input_schema.NodeRandomSplit) -> "FlowGraph def _func(table: FlowDataEngine) -> NamedOutputs: split_pairs = [(s.name, s.percentage) for s in settings.splits] - if self.execution_location == "local": - return table.random_split(split_pairs, settings.seed) - return table.random_split_external( + return self.execution_backend.random_split( + table, split_pairs, settings.seed, flow_id=self.flow_id, @@ -3360,6 +3285,27 @@ def graph_has_input_data(self) -> bool: """Checks if the graph has an initial input data source.""" return self._input_data is not None + def _add_from_spec(self, spec: NodeSpec, settings: Any) -> "FlowGraph": + """Adds or updates a node declaratively from its NodeSpec. + + The spec's compute_factory builds the closure and the input node ids + are derived from the spec's arity; everything else goes through + add_node_step unchanged. + """ + if spec.compute_factory is None: + raise ValueError(f"Node type {spec.node_type!r} has no compute_factory; use its explicit add_* method") + ctx = NodeBuildContext(graph=self, node_id=settings.node_id) + function = spec.compute_factory(settings, ctx) + self.add_node_step( + node_id=settings.node_id, + function=function, + node_type=spec.node_type, + renew_schema=spec.renew_schema, + setting_input=settings, + input_node_ids=spec.derive_input_node_ids(settings), + ) + return self + def add_node_step( self, node_id: int | str, @@ -3501,30 +3447,15 @@ def add_output(self, output_file: input_schema.NodeOutput): """ def _func(df: FlowDataEngine): - if self.execution_location == "local": - df.output( - output_fs=output_file.output_settings, - flow_id=self.flow_id, - node_id=output_file.node_id, - execute_remote=False, - ) - return df - output_fs = output_file.output_settings node = self.get_node(output_file.node_id) - writer = ExternalOutputWriter( - lf=df.data_frame, - data_type=output_fs.file_type, - path=output_fs.abs_file_path, - write_mode=output_fs.write_mode, - sheet_name=output_fs.sheet_name, - delimiter=output_fs.delimiter, - compression=output_fs.compression, + writer_handle = self.execution_backend.write_output( + df, + output_file.output_settings, flow_id=self.flow_id, node_id=output_file.node_id, - wait_on_completion=False, ) - node._fetch_cached_df = writer - writer.get_result() + node._fetch_cached_df = writer_handle + writer_handle.get_result() return df def schema_callback(): @@ -3875,30 +3806,7 @@ def _func(): ) # Local and worker reads share shared.db_reader.read_sql_with_fallback - # (via SqlSource here, via read_sql_source in the worker). - if self.execution_location == "local": - local_source = SqlSource( - connection_string=sql_utils.construct_sql_uri( - database_type=database_connection.database_type, - host=database_connection.host, - port=database_connection.port, - database=database_connection.database, - username=database_connection.username, - password=decrypt_secret(encrypted_password) if encrypted_password else None, - ssl_enabled=bool(getattr(database_connection, "ssl_enabled", False)), - connect_timeout=10, - ), - query=None if database_settings.query_mode == "table" else database_settings.query, - table_name=database_settings.table_name, - schema_name=database_settings.schema_name, - fields=node_database_reader.fields, - cancel_check=lambda: self.flow_settings.is_canceled or node._execution_state.is_canceled, - ) - fl = FlowDataEngine(local_source.get_pl_df()) - fl.lazy = True - node_database_reader.fields = [c.get_minimal_field_info() for c in fl.schema] - return fl - + # (via SqlSource in the LocalBackend, via read_sql_source in the worker). database_external_read_settings = ( sql_models.DatabaseExternalReadSettings.create_from_from_node_database_reader( node_database_reader=node_database_reader, @@ -3910,11 +3818,13 @@ def _func(): ) ) - external_database_fetcher = ExternalDatabaseFetcher( - database_external_read_settings, wait_on_completion=False + reader_handle = self.execution_backend.read_database( + database_external_read_settings, + cancel_check=lambda: self.flow_settings.is_canceled or node._execution_state.is_canceled, ) - node._fetch_cached_df = external_database_fetcher - fl = FlowDataEngine(external_database_fetcher.get_result()) + node._fetch_cached_df = reader_handle + fl = FlowDataEngine(reader_handle.get_result()) + fl.lazy = True node_database_reader.fields = [c.get_minimal_field_info() for c in fl.schema] return fl @@ -4753,6 +4663,15 @@ def execution_location(self, execution_location: schemas.ExecutionLocationsLiter self.reset() self.flow_settings.execution_location = execution_location + @property + def execution_backend(self) -> ExecutionBackend: + """The compute backend for the flow's current execution location. + + Resolved per call so node closures built at add-time honor the + location at run time. + """ + return resolve_backend(self.execution_location) + def validate_if_node_can_be_fetched(self, node_id: int) -> None: flow_node = self._node_db.get(node_id) if not flow_node: diff --git a/flowfile_core/flowfile_core/flowfile/flow_node/executor.py b/flowfile_core/flowfile_core/flowfile/flow_node/executor.py index 390eea689..7c618f9af 100644 --- a/flowfile_core/flowfile_core/flowfile/flow_node/executor.py +++ b/flowfile_core/flowfile_core/flowfile/flow_node/executor.py @@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, Protocol +from flowfile_core.flowfile.execution.backends import ExecutionBackend, resolve_backend +from flowfile_core.flowfile.execution.exceptions import WorkerConnectionError from flowfile_core.flowfile.flow_data_engine.subprocess_operations import ( results_exists, ) @@ -93,6 +95,7 @@ def execute( retry: bool = True, node_logger: NodeLogger = None, optimize_for_downstream: bool = True, + backend: ExecutionBackend | None = None, ) -> None: """ Main execution entry point. @@ -107,10 +110,12 @@ def execute( retry: Allow retry on recoverable errors node_logger: Logger for this node's execution optimize_for_downstream: Cache wide transforms for downstream nodes + backend: Compute backend override; resolved from run_location when None """ if node_logger is None: raise ValueError("node_logger is required") + backend = backend or resolve_backend(run_location) state = self.state_provider.get_state(self.node.node_id, self.node.parent_uuid) if reset_cache: @@ -145,7 +150,7 @@ def execute( self.node.reset() try: - self._execute_with_strategy(state, decision.strategy, effective_performance_mode, node_logger) + self._execute_with_strategy(state, decision.strategy, effective_performance_mode, node_logger, backend) self._update_source_file_info(state) self._sync_state_to_legacy(state) self.state_provider.save_state(self.node.node_id, self.node.parent_uuid, state) @@ -240,6 +245,7 @@ def _execute_with_strategy( strategy: ExecutionStrategy, performance_mode: bool, node_logger: NodeLogger, + backend: ExecutionBackend, ) -> None: """Execute using the determined strategy.""" match strategy: @@ -248,9 +254,9 @@ def _execute_with_strategy( case ExecutionStrategy.FULL_LOCAL: self._do_full_local(state, performance_mode) case ExecutionStrategy.LOCAL_WITH_SAMPLING: - self._do_local_with_sampling(state, performance_mode, node_logger.flow_id) + self._do_local_with_sampling(state, performance_mode, node_logger.flow_id, backend) case ExecutionStrategy.REMOTE: - self._do_remote(state, performance_mode, node_logger) + self._do_remote(state, performance_mode, node_logger, backend) def _do_full_local(self, state: NodeExecutionState, performance_mode: bool) -> None: """ @@ -264,26 +270,38 @@ def _do_full_local(self, state: NodeExecutionState, performance_mode: bool) -> N if self.node.results.resulting_data is not None: state.result_schema = self.node.results.resulting_data.schema - def _do_local_with_sampling(self, state: NodeExecutionState, performance_mode: bool, flow_id: int) -> None: + def _do_local_with_sampling( + self, + state: NodeExecutionState, + performance_mode: bool, + flow_id: int, + backend: ExecutionBackend, + ) -> None: """ In-process execution with external sampler for preview data. The main computation runs locally, but sample data is generated via an external process for the UI preview. """ - self.node._do_execute_local_with_sampling(performance_mode, flow_id) + self.node._do_execute_local_with_sampling(performance_mode, flow_id, backend) if self.node.results.resulting_data is not None: state.result_schema = self.node.results.resulting_data.schema if self.node.results.errors is None and not self.node.node_stats.is_canceled: state.mark_successful() - def _do_remote(self, state: NodeExecutionState, performance_mode: bool, node_logger: NodeLogger) -> None: + def _do_remote( + self, + state: NodeExecutionState, + performance_mode: bool, + node_logger: NodeLogger, + backend: ExecutionBackend, + ) -> None: """ Full remote worker execution. Computation is offloaded to an external worker process. """ - self.node._do_execute_remote(performance_mode, node_logger) + self.node._do_execute_remote(performance_mode, node_logger, backend) if self.node.results.resulting_data is not None: state.result_schema = self.node.results.resulting_data.schema state.mark_successful() @@ -384,7 +402,7 @@ def _handle_error( ) return - if "Connection refused" in error_str and "/submit_query/" in error_str: + if isinstance(error, WorkerConnectionError): node_logger.warning( "Could not connect to remote worker. " "Ensure the worker process is running, or change settings to local execution." diff --git a/flowfile_core/flowfile_core/flowfile/flow_node/flow_node.py b/flowfile_core/flowfile_core/flowfile/flow_node/flow_node.py index fad3844b3..5d76b2290 100644 --- a/flowfile_core/flowfile_core/flowfile/flow_node/flow_node.py +++ b/flowfile_core/flowfile_core/flowfile/flow_node/flow_node.py @@ -7,17 +7,12 @@ from flowfile_core.configs import logger, node_store from flowfile_core.configs.flow_logger import NodeLogger +from flowfile_core.flowfile.execution.backends import ExecutionBackend, resolve_backend +from flowfile_core.flowfile.execution.handles import TaskHandle from flowfile_core.flowfile.flow_data_engine.flow_data_engine import FlowDataEngine from flowfile_core.flowfile.flow_data_engine.flow_file_column.main import FlowfileColumn from flowfile_core.flowfile.flow_data_engine.subprocess_operations import ( - ExternalCloudWriter, - ExternalDatabaseFetcher, - ExternalDatabaseWriter, - ExternalDfFetcher, - ExternalOutputWriter, - ExternalSampler, clear_task_from_worker, - get_external_df_result, results_exists, ) from flowfile_core.flowfile.flow_node.executor import NodeExecutor @@ -71,22 +66,8 @@ class FlowNode: _name: str = None # name of the node, used for display _schema_callback: SingleExecutionFuture | None = None # Function that calculates the schema without executing _state_needs_reset: bool = False - _fetch_cached_df: ( - ExternalDfFetcher - | ExternalDatabaseFetcher - | ExternalDatabaseWriter - | ExternalCloudWriter - | ExternalOutputWriter - | None - ) = None - _cache_progress: ( - ExternalDfFetcher - | ExternalDatabaseFetcher - | ExternalDatabaseWriter - | ExternalCloudWriter - | ExternalOutputWriter - | None - ) = None + _fetch_cached_df: TaskHandle | None = None + _cache_progress: TaskHandle | None = None _execution_state: NodeExecutionState = None _executor: NodeExecutor | None = None # Lazy-initialized # Callable that returns the parent flow's current parameters (name → value). @@ -1117,29 +1098,36 @@ def get_example_data(): self.node_schema.result_schema = self.results.resulting_data.schema self.node_stats.has_completed_last_run = True - def _do_execute_local_with_sampling(self, performance_mode: bool = False, flow_id: int = None): - """Executes the node's logic locally with external sampling. + def _do_execute_local_with_sampling( + self, + performance_mode: bool = False, + flow_id: int = None, + backend: ExecutionBackend | None = None, + ): + """Executes the node's logic locally, with the backend producing preview data. Internal method called by NodeExecutor. Args: performance_mode: If True, skips generating example data. flow_id: The ID of the parent flow. + backend: Compute backend for the sampling; the worker backend when None. Raises: Exception: Propagates exceptions from the execution. """ + backend = backend or resolve_backend("remote") try: resulting_data = self.get_resulting_data() if not performance_mode: - external_sampler = ExternalSampler( - lf=resulting_data.data_frame, + sampler_handle = backend.sample( + resulting_data.data_frame, file_ref=self.hash, wait_on_completion=True, node_id=self.node_id, flow_id=flow_id, ) - self.store_example_data_generator(external_sampler) + self.store_example_data_generator(sampler_handle) if self.results.errors is None and not self.node_stats.is_canceled: self.node_stats.has_run_with_current_setup = True self.node_schema.result_schema = resulting_data.schema @@ -1157,23 +1145,30 @@ def _do_execute_local_with_sampling(self, performance_mode: bool = False, flow_i if not self.node_settings.streamable: step.node_settings.streamable = self.node_settings.streamable - def _do_execute_remote(self, performance_mode: bool = False, node_logger: NodeLogger = None): - """Executes the node's logic remotely or handles cached results. + def _do_execute_remote( + self, + performance_mode: bool = False, + node_logger: NodeLogger = None, + backend: ExecutionBackend | None = None, + ): + """Executes the node's logic on the compute backend or handles cached results. Internal method called by NodeExecutor. Args: performance_mode: If True, skips generating example data. node_logger: The logger for this node execution. + backend: Compute backend; the worker backend when None. Raises: Exception: If the node_logger is not provided or if execution fails. """ if node_logger is None: raise Exception("Node logger is not defined") - if self.node_settings.cache_results and results_exists(self.hash): + backend = backend or resolve_backend("remote") + if self.node_settings.cache_results and backend.results_exist(self.hash): try: - self.results.resulting_data = FlowDataEngine(get_external_df_result(self.hash)) + self.results.resulting_data = FlowDataEngine(backend.get_cached_lazyframe(self.hash)) self._cache_progress = None return except Exception: @@ -1195,34 +1190,33 @@ def _do_execute_remote(self, performance_mode: bool = False, node_logger: NodeLo raise e if not performance_mode: - external_df_fetcher = ExternalDfFetcher( - lf=self.get_resulting_data().data_frame, + task_handle = backend.run_lazyframe( + self.get_resulting_data().data_frame, file_ref=self.hash, wait_on_completion=False, flow_id=node_logger.flow_id, node_id=self.node_id, ) - self._fetch_cached_df = external_df_fetcher + self._fetch_cached_df = task_handle try: - lf = external_df_fetcher.get_result() + lf = task_handle.get_result() self.results.resulting_data = FlowDataEngine( lf, - number_of_records=ExternalDfFetcher( - lf=lf, - operation_type="calculate_number_of_records", + number_of_records=backend.count_records( + lf, flow_id=node_logger.flow_id, node_id=self.node_id, - ).result, + ), ) if not performance_mode: - self.store_example_data_generator(external_df_fetcher) + self.store_example_data_generator(task_handle) self.node_stats.has_run_with_current_setup = True except Exception as e: node_logger.error("Error with external process") - if external_df_fetcher.error_code == -1: + if task_handle.error_code == -1: try: self.results.resulting_data = self.get_resulting_data() self.results.warnings = ( @@ -1234,12 +1228,12 @@ def _do_execute_remote(self, performance_mode: bool = False, node_logger: NodeLo except Exception as e: self.results.errors = str(e) raise e - elif external_df_fetcher.error_description is None: + elif task_handle.error_description is None: self.results.errors = str(e) raise e else: - self.results.errors = external_df_fetcher.error_description - raise Exception(external_df_fetcher.error_description) from e + self.results.errors = task_handle.error_description + raise Exception(task_handle.error_description) from e finally: self._fetch_cached_df = None @@ -1322,14 +1316,14 @@ def execute_node( optimize_for_downstream=optimize_for_downstream, ) - def store_example_data_generator(self, external_df_fetcher: ExternalDfFetcher | ExternalSampler): + def store_example_data_generator(self, task_handle: TaskHandle): """Stores a generator function for fetching a sample of the result data. Args: - external_df_fetcher: The process that generated the sample data. + task_handle: The task that generated the sample data. """ - if external_df_fetcher.status is not None: - file_ref = external_df_fetcher.status.file_ref + if task_handle.status is not None: + file_ref = task_handle.status.file_ref self.results.example_data_path = file_ref self.results.example_data_generator = get_read_top_n(file_path=file_ref, n=100) else: diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/__init__.py b/flowfile_core/flowfile_core/flowfile/node_registry/__init__.py new file mode 100644 index 000000000..4d78922ae --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/__init__.py @@ -0,0 +1,28 @@ +"""Single source of truth for built-in node types. + +``BUILTIN_REGISTRY`` holds one NodeSpec per node type; the legacy catalogs +(get_all_standard_nodes, NODE_TYPE_TO_SETTINGS_CLASS, nodes_with_defaults, +the AI classification map) are derived from it. +""" + +from flowfile_core.flowfile.node_registry.registry import NodeRegistry +from flowfile_core.flowfile.node_registry.spec import InputArity, NodeSpec + +BUILTIN_REGISTRY = NodeRegistry() + + +def _populate() -> None: + from flowfile_core.flowfile.node_registry.builtin import ALL_SPECS + + for spec in ALL_SPECS: + BUILTIN_REGISTRY.register(spec) + + +_populate() + + +def get_node_spec(node_type: str) -> NodeSpec | None: + return BUILTIN_REGISTRY.get(node_type) + + +__all__ = ["BUILTIN_REGISTRY", "InputArity", "NodeRegistry", "NodeSpec", "get_node_spec"] diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/builtin/__init__.py b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/__init__.py new file mode 100644 index 000000000..925e4fd0c --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/__init__.py @@ -0,0 +1,22 @@ +"""Built-in node specs, grouped by domain.""" + +from flowfile_core.flowfile.node_registry.builtin import ( + database, + io_nodes, + ml, + scripting, + simple, + special, + streaming_sources, +) +from flowfile_core.flowfile.node_registry.spec import NodeSpec + +ALL_SPECS: list[NodeSpec] = [ + *simple.SPECS, + *io_nodes.SPECS, + *database.SPECS, + *streaming_sources.SPECS, + *scripting.SPECS, + *ml.SPECS, + *special.SPECS, +] diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/builtin/database.py b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/database.py new file mode 100644 index 000000000..47e7ca2a6 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/database.py @@ -0,0 +1,70 @@ +"""Database source and sink nodes. + +Generated from the pre-registry catalogs (NodeTemplate list, settings map, +AI classification map); maintained by hand from here on. +""" + +from flowfile_core.flowfile.node_registry.spec import NodeSpec +from flowfile_core.schemas import input_schema +from flowfile_core.schemas.schemas import NodeTag, NodeTemplate + +SPECS: list[NodeSpec] = [ + NodeSpec( + node_type="database_reader", + settings_class=input_schema.NodeDatabaseReader, + template=NodeTemplate( + name="Read from Database", + item="database_reader", + input=0, + output=1, + image="database_reader.svg", + node_type="input", + transform_type="other", + node_group="input", + drawer_title="Database Reader", + drawer_intro="Load data from database tables or queries", + tags=[ + NodeTag.DATABASE, + NodeTag.SQL, + NodeTag.POSTGRES, + NodeTag.MYSQL, + NodeTag.SQL_SERVER, + NodeTag.SNOWFLAKE, + NodeTag.ORACLE, + NodeTag.SQLITE, + NodeTag.REDSHIFT, + NodeTag.BIGQUERY, + NodeTag.QUERY, + NodeTag.TABLE, + ], + ), + ai_classification="source", + ), + NodeSpec( + node_type="database_writer", + settings_class=input_schema.NodeDatabaseWriter, + template=NodeTemplate( + name="Write to Database", + item="database_writer", + input=1, + output=0, + image="database_writer.svg", + node_type="output", + transform_type="other", + node_group="output", + drawer_title="Database Writer", + drawer_intro="Save data to database tables", + tags=[ + NodeTag.DATABASE, + NodeTag.SQL, + NodeTag.POSTGRES, + NodeTag.MYSQL, + NodeTag.SNOWFLAKE, + NodeTag.REDSHIFT, + NodeTag.BIGQUERY, + NodeTag.TABLE, + ], + ), + ai_classification="static", + ), +] diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/builtin/io_nodes.py b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/io_nodes.py new file mode 100644 index 000000000..dc7c6bf76 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/io_nodes.py @@ -0,0 +1,225 @@ +"""File, cloud and catalog IO nodes. + +Generated from the pre-registry catalogs (NodeTemplate list, settings map, +AI classification map); maintained by hand from here on. +""" + +from flowfile_core.flowfile.node_registry.spec import NodeSpec +from flowfile_core.schemas import input_schema +from flowfile_core.schemas.schemas import NodeTag, NodeTemplate + +SPECS: list[NodeSpec] = [ + NodeSpec( + node_type="api_response", + settings_class=input_schema.NodeApiResponse, + template=NodeTemplate( + name="API response", + item="api_response", + input=1, + output=0, + image="api_response.svg", + node_type="output", + transform_type="other", + node_group="output", + drawer_title="API Response", + drawer_intro="Return this dataset as the body of an HTTP API endpoint", + tags=[NodeTag.API, NodeTag.REST, NodeTag.HTTP, NodeTag.RESPONSE], + ), + ai_classification="static", + ), + NodeSpec( + node_type="explore_data", + settings_class=input_schema.NodeExploreData, + template=NodeTemplate( + name="Explore data", + item="explore_data", + input=1, + output=0, + image="explore_data.svg", + node_type="output", + transform_type="other", + node_group="output", + drawer_title="Explore Data", + drawer_intro="Interactive data exploration and analysis", + tags=[ + NodeTag.EXPLORE, + NodeTag.PROFILE, + NodeTag.PREVIEW, + NodeTag.EDA, + NodeTag.STATISTICS, + NodeTag.VISUALIZE, + NodeTag.INSIGHT, + NodeTag.BAR_CHART, + NodeTag.GRAPHS, + ], + ), + ai_classification="static", + ), + NodeSpec( + node_type="manual_input", + settings_class=input_schema.NodeManualInput, + template=NodeTemplate( + name="Manual input", + item="manual_input", + input=0, + output=1, + image="manual_input.svg", + node_type="input", + transform_type="other", + node_group="input", + drawer_title="Manual Input", + drawer_intro="Create data directly", + laziness="lazy", + tags=[NodeTag.MANUAL, NodeTag.PASTE, NodeTag.INPUT], + ), + ai_classification="source", + ), + NodeSpec( + node_type="read", + settings_class=input_schema.NodeRead, + template=NodeTemplate( + name="Read data", + item="read", + input=0, + output=1, + image="input_data.svg", + node_type="input", + transform_type="other", + node_group="input", + drawer_title="Read Data", + drawer_intro="Load data from CSV, Excel, or Parquet files", + laziness="conditional", + tags=[ + NodeTag.CSV, + NodeTag.EXCEL, + NodeTag.PARQUET, + NodeTag.JSON, + NodeTag.FILE, + NodeTag.IMPORT, + NodeTag.READ, + ], + ), + ai_classification="source", + ), + NodeSpec( + node_type="catalog_reader", + settings_class=input_schema.NodeCatalogReader, + template=NodeTemplate( + name="Read from Catalog", + item="catalog_reader", + input=0, + output=1, + image="catalog_reader.svg", + node_type="input", + transform_type="other", + node_group="input", + drawer_title="Catalog Reader", + drawer_intro="Read a table from the data catalog", + laziness="lazy", + tags=[NodeTag.CATALOG, NodeTag.DELTA, NodeTag.TABLE, NodeTag.LAKEHOUSE, NodeTag.TIME_TRAVEL], + ), + ai_classification="source", + ), + NodeSpec( + node_type="cloud_storage_reader", + settings_class=input_schema.NodeCloudStorageReader, + template=NodeTemplate( + name="Read from cloud provider", + item="cloud_storage_reader", + input=0, + output=1, + image="cloud_storage_reader.svg", + node_type="input", + transform_type="other", + node_group="input", + drawer_title="Cloud Storage Reader", + drawer_intro="Read data from AWS S3 and other cloud storage", + laziness="conditional", + tags=[ + NodeTag.S3, + NodeTag.AWS, + NodeTag.AZURE, + NodeTag.ADLS, + NodeTag.GCS, + NodeTag.BLOB, + NodeTag.BUCKET, + NodeTag.CLOUD, + NodeTag.DELTA, + ], + ), + ai_classification="source", + ), + NodeSpec( + node_type="output", + settings_class=input_schema.NodeOutput, + template=NodeTemplate( + name="Write data", + item="output", + input=1, + output=0, + image="output.svg", + node_type="output", + transform_type="other", + node_group="output", + drawer_title="Write Data", + drawer_intro="Save your data as CSV, Excel, or Parquet files", + tags=[ + NodeTag.CSV, + NodeTag.EXCEL, + NodeTag.PARQUET, + NodeTag.JSON, + NodeTag.FILE, + NodeTag.EXPORT, + NodeTag.SAVE, + NodeTag.WRITE, + ], + ), + ai_classification="static", + ), + NodeSpec( + node_type="catalog_writer", + settings_class=input_schema.NodeCatalogWriter, + template=NodeTemplate( + name="Write to Catalog", + item="catalog_writer", + input=1, + output=0, + image="catalog_writer.svg", + node_type="output", + transform_type="other", + node_group="output", + drawer_title="Catalog Writer", + drawer_intro="Save data as a table in the data catalog", + tags=[NodeTag.CATALOG, NodeTag.DELTA, NodeTag.TABLE, NodeTag.LAKEHOUSE], + ), + ai_classification="static", + ), + NodeSpec( + node_type="cloud_storage_writer", + settings_class=input_schema.NodeCloudStorageWriter, + template=NodeTemplate( + name="Write to cloud provider", + item="cloud_storage_writer", + input=1, + output=0, + image="cloud_storage_writer.svg", + node_type="output", + transform_type="other", + node_group="output", + drawer_title="Cloud Storage Writer", + drawer_intro="Save data to AWS S3 and other cloud storage", + tags=[ + NodeTag.S3, + NodeTag.AWS, + NodeTag.AZURE, + NodeTag.ADLS, + NodeTag.GCS, + NodeTag.BLOB, + NodeTag.BUCKET, + NodeTag.CLOUD, + NodeTag.DELTA, + ], + ), + ai_classification="static", + ), +] diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/builtin/ml.py b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/ml.py new file mode 100644 index 000000000..062425e40 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/ml.py @@ -0,0 +1,73 @@ +"""Machine-learning nodes. + +Generated from the pre-registry catalogs (NodeTemplate list, settings map, +AI classification map); maintained by hand from here on. +""" + +from flowfile_core.flowfile.node_registry.spec import NodeSpec +from flowfile_core.schemas import input_schema +from flowfile_core.schemas.schemas import NodeTag, NodeTemplate + +SPECS: list[NodeSpec] = [ + NodeSpec( + node_type="apply_model", + settings_class=input_schema.NodeApplyModel, + template=NodeTemplate( + name="Apply Model", + item="apply_model", + input=1, + output=1, + image="apply_model.svg", + node_type="process", + transform_type="wide", + node_group="ml", + drawer_title="Apply ML Model", + drawer_intro="Score data with an upstream Train Model node, or with a trained model from the catalog", + tags=[NodeTag.ML, NodeTag.MACHINE_LEARNING, NodeTag.PREDICT, NodeTag.SCORE, NodeTag.MODEL], + ), + ai_classification="static", + ), + NodeSpec( + node_type="evaluate_model", + settings_class=input_schema.NodeEvaluateModel, + template=NodeTemplate( + name="Evaluate Model", + item="evaluate_model", + input=1, + output=1, + image="evaluate_model.svg", + node_type="process", + transform_type="narrow", + node_group="ml", + drawer_title="Evaluate Model", + drawer_intro="Compare actual vs predicted columns and compute quality metrics", + tags=[NodeTag.ML, NodeTag.MACHINE_LEARNING, NodeTag.EVALUATE, NodeTag.METRICS, NodeTag.MODEL], + ), + ai_classification="static", + ), + NodeSpec( + node_type="train_model", + settings_class=input_schema.NodeTrainModel, + template=NodeTemplate( + name="Train Model", + item="train_model", + input=1, + output=1, + image="train_model.svg", + node_type="process", + transform_type="other", + node_group="ml", + drawer_title="Train ML Model", + drawer_intro="Fit a regression or classification model; optionally save it to the catalog", + tags=[ + NodeTag.ML, + NodeTag.MACHINE_LEARNING, + NodeTag.TRAIN, + NodeTag.MODEL, + NodeTag.REGRESSION, + NodeTag.CLASSIFICATION, + ], + ), + ai_classification="static", + ), +] diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/builtin/scripting.py b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/scripting.py new file mode 100644 index 000000000..58ea9d977 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/scripting.py @@ -0,0 +1,82 @@ +"""Code-execution nodes (Python, Polars code, SQL). + +Generated from the pre-registry catalogs (NodeTemplate list, settings map, +AI classification map); maintained by hand from here on. +""" + +from flowfile_core.flowfile.node_registry.spec import NodeSpec +from flowfile_core.schemas import input_schema +from flowfile_core.schemas.schemas import NodeTag, NodeTemplate + +SPECS: list[NodeSpec] = [ + NodeSpec( + node_type="polars_code", + settings_class=input_schema.NodePolarsCode, + template=NodeTemplate( + name="Polars code", + item="polars_code", + input=10, + output=1, + image="polars_code.svg", + multi=True, + node_type="process", + transform_type="narrow", + node_group="transform", + can_be_start=True, + drawer_title="Polars Code", + drawer_intro="Write custom Polars DataFrame transformations", + laziness="conditional", + tags=[ + NodeTag.POLARS, + NodeTag.CODE, + NodeTag.PYTHON, + NodeTag.SCRIPT, + NodeTag.CUSTOM, + NodeTag.DATAFRAME, + NodeTag.TRANSFORM, + ], + ), + ai_classification="dynamic", + ), + NodeSpec( + node_type="python_script", + settings_class=input_schema.NodePythonScript, + template=NodeTemplate( + name="Python Script", + item="python_script", + input=10, + output=1, + image="python_code.svg", + multi=True, + node_type="process", + transform_type="narrow", + node_group="transform", + can_be_start=True, + drawer_title="Python Script", + drawer_intro="Execute Python code on an isolated kernel container", + tags=[NodeTag.PYTHON, NodeTag.CODE, NodeTag.SCRIPT, NodeTag.KERNEL, NodeTag.CUSTOM, NodeTag.TRANSFORM], + ), + ai_classification="dynamic", + ), + NodeSpec( + node_type="sql_query", + settings_class=input_schema.NodeSqlQuery, + template=NodeTemplate( + name="SQL Query", + item="sql_query", + input=10, + output=1, + image="sql_query.svg", + multi=True, + node_type="process", + transform_type="narrow", + node_group="transform", + can_be_start=True, + drawer_title="SQL Query", + drawer_intro="Write SQL queries against connected data sources", + laziness="lazy", + tags=[NodeTag.SQL, NodeTag.QUERY, NodeTag.DUCKDB], + ), + ai_classification="dynamic", + ), +] diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/builtin/simple.py b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/simple.py new file mode 100644 index 000000000..2f6a1fbc8 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/simple.py @@ -0,0 +1,527 @@ +"""Single-call transform nodes. + +Generated from the pre-registry catalogs (NodeTemplate list, settings map, +AI classification map); maintained by hand from here on. + +Compute factories build the closures FlowGraph._add_from_spec wires into +add_node_step. FlowDataEngine and friends are imported inside the factories: +this module is loaded while configs.node_store is still initializing, so a +module-level import would create a cycle. Every closure is named ``_func`` — +add_node_step uses function.__name__ for the node name. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import polars as pl + +from flowfile_core.flowfile.node_registry.spec import NodeBuildContext, NodeSpec +from flowfile_core.schemas import input_schema +from flowfile_core.schemas.schemas import NodeTag, NodeTemplate + +if TYPE_CHECKING: + from collections.abc import Callable + + from flowfile_core.flowfile.flow_data_engine.flow_data_engine import FlowDataEngine + + +def _sort_compute(settings: input_schema.NodeSort, ctx: NodeBuildContext) -> Callable: + def _func(table: FlowDataEngine) -> FlowDataEngine: + return table.do_sort(settings.sort_input) + + return _func + + +def _sample_compute(settings: input_schema.NodeSample, ctx: NodeBuildContext) -> Callable: + def _func(table: FlowDataEngine) -> FlowDataEngine: + return table.get_sample(settings.sample_size) + + return _func + + +def _record_count_compute(settings: input_schema.NodeRecordCount, ctx: NodeBuildContext) -> Callable: + def _func(fl: FlowDataEngine) -> FlowDataEngine: + return fl.get_record_count() + + return _func + + +def _filter_compute(settings: input_schema.NodeFilter, ctx: NodeBuildContext) -> Callable: + def _func(fl: FlowDataEngine) -> FlowDataEngine: + from flowfile_core.configs import logger + from flowfile_core.flowfile.filter_expressions import build_filter_expression + + is_advanced = settings.filter_input.is_advanced() + + if is_advanced: + expression = settings.filter_input.advanced_filter + else: + basic_filter = settings.filter_input.basic_filter + if basic_filter is None: + logger.warning("Basic filter is None, returning unfiltered data") + return fl + + try: + field_data_type = fl.get_schema_column(basic_filter.field).generic_datatype() + except Exception: + field_data_type = None + + expression = build_filter_expression(basic_filter, field_data_type) + settings.filter_input.advanced_filter = expression + + if settings.split_mode: + return fl.filter_split(expression) + return fl.do_filter(expression) + + return _func + + +def _union_compute(settings: input_schema.NodeUnion, ctx: NodeBuildContext) -> Callable: + def _func(*flowfile_tables: FlowDataEngine) -> FlowDataEngine: + from flowfile_core.flowfile.flow_data_engine.flow_data_engine import FlowDataEngine + + dfs: list[pl.LazyFrame] | list[pl.DataFrame] = [flt.data_frame for flt in flowfile_tables] + return FlowDataEngine(pl.concat(dfs, how="diagonal_relaxed")) + + return _func + + +SPECS: list[NodeSpec] = [ + NodeSpec( + node_type="record_id", + settings_class=input_schema.NodeRecordId, + template=NodeTemplate( + name="Add record Id", + item="record_id", + input=1, + output=1, + image="record_id.svg", + node_type="process", + transform_type="wide", + node_group="transform", + drawer_title="Add Record ID", + drawer_intro="Generate unique identifiers for each row", + laziness="lazy", + tags=[NodeTag.RECORD_ID, NodeTag.ROW_NUMBER, NodeTag.INDEX], + ), + ai_classification="static", + ), + NodeSpec( + node_type="record_count", + settings_class=input_schema.NodeRecordCount, + template=NodeTemplate( + name="Count records", + item="record_count", + input=1, + output=1, + image="record_count.svg", + node_type="process", + transform_type="wide", + node_group="aggregate", + drawer_title="Count Records", + drawer_intro="Calculate the total number of rows", + laziness="lazy", + tags=[NodeTag.RECORD_COUNT, NodeTag.COUNT, NodeTag.ROWS], + ), + has_default_settings=True, + ai_classification="static", + compute_factory=_record_count_compute, + ), + NodeSpec( + node_type="cross_join", + settings_class=input_schema.NodeCrossJoin, + template=NodeTemplate( + name="Cross join", + item="cross_join", + input=2, + output=1, + image="cross_join.svg", + node_type="process", + transform_type="wide", + node_group="combine", + drawer_title="Cross Join", + drawer_intro="Create all possible combinations between two datasets", + laziness="lazy", + tags=[NodeTag.CROSS_JOIN, NodeTag.CARTESIAN, NodeTag.JOIN], + ), + ai_classification="static", + ), + NodeSpec( + node_type="unique", + settings_class=input_schema.NodeUnique, + template=NodeTemplate( + name="Drop duplicates", + item="unique", + input=1, + output=1, + image="unique.svg", + node_type="process", + transform_type="wide", + node_group="transform", + drawer_title="Drop Duplicates", + drawer_intro="Remove duplicate rows based on selected columns", + laziness="lazy", + tags=[NodeTag.UNIQUE, NodeTag.DEDUPE, NodeTag.DISTINCT, NodeTag.DROP_DUPLICATES], + ), + ai_classification="static", + ), + NodeSpec( + node_type="filter", + settings_class=input_schema.NodeFilter, + template=NodeTemplate( + name="Filter data", + item="filter", + input=1, + output=1, + image="filter.svg", + node_type="process", + transform_type="narrow", + node_group="transform", + drawer_title="Filter Rows", + drawer_intro="Keep only rows that match your conditions", + laziness="lazy", + tags=[NodeTag.FILTER, NodeTag.WHERE, NodeTag.SUBSET], + ), + ai_classification="static", + compute_factory=_filter_compute, + renew_schema=False, + ), + NodeSpec( + node_type="formula", + settings_class=input_schema.NodeFormula, + template=NodeTemplate( + name="Formula", + item="formula", + input=1, + output=1, + image="formula.svg", + node_type="process", + transform_type="narrow", + node_group="transform", + drawer_title="Formula Editor", + drawer_intro="Create or modify columns using custom expressions", + laziness="lazy", + tags=[ + NodeTag.FORMULA, + NodeTag.EXPRESSION, + NodeTag.TRANSFORM, + NodeTag.CALCULATE, + NodeTag.MATH, + NodeTag.CONCAT, + NodeTag.SUM, + ], + ), + ai_classification="static", + ), + NodeSpec( + node_type="fuzzy_match", + settings_class=input_schema.NodeFuzzyMatch, + template=NodeTemplate( + name="Fuzzy match", + item="fuzzy_match", + input=2, + output=1, + image="fuzzy_match.svg", + node_type="process", + transform_type="wide", + node_group="combine", + drawer_title="Fuzzy Match", + drawer_intro="Join datasets based on similar values instead of exact matches", + tags=[NodeTag.FUZZY, NodeTag.SIMILARITY, NodeTag.LEVENSHTEIN, NodeTag.JOIN, NodeTag.LOOKUP], + ), + ai_classification="static", + ), + NodeSpec( + node_type="graph_solver", + settings_class=input_schema.NodeGraphSolver, + template=NodeTemplate( + name="Graph solver", + item="graph_solver", + input=1, + output=1, + image="graph_solver.svg", + node_type="process", + transform_type="other", + node_group="combine", + drawer_title="Graph Solver", + drawer_intro="Group related records in graph-structured data", + laziness="lazy", + tags=[NodeTag.GRAPH, NodeTag.NETWORK, NodeTag.CLUSTER, NodeTag.CONNECTED_COMPONENTS], + ), + ai_classification="dynamic", + ), + NodeSpec( + node_type="group_by", + settings_class=input_schema.NodeGroupBy, + template=NodeTemplate( + name="Group by", + item="group_by", + input=1, + output=1, + image="group_by.svg", + node_type="process", + transform_type="wide", + node_group="aggregate", + drawer_title="Group By", + drawer_intro="Aggregate data by grouping and calculating statistics", + laziness="lazy", + tags=[ + NodeTag.GROUP_BY, + NodeTag.AGGREGATE, + NodeTag.SUM, + NodeTag.MEAN, + NodeTag.AVERAGE, + NodeTag.COUNT, + NodeTag.MIN, + NodeTag.MAX, + NodeTag.MEDIAN, + NodeTag.SUMMARIZE, + ], + ), + ai_classification="static", + ), + NodeSpec( + node_type="join", + settings_class=input_schema.NodeJoin, + template=NodeTemplate( + name="Join", + item="join", + input=2, + output=1, + image="join.svg", + node_type="process", + transform_type="wide", + node_group="combine", + drawer_title="Join Datasets", + drawer_intro="Merge two datasets based on matching column values", + laziness="lazy", + tags=[NodeTag.JOIN, NodeTag.MERGE, NodeTag.LOOKUP, NodeTag.VLOOKUP, NodeTag.INNER, NodeTag.OUTER], + ), + ai_classification="static", + ), + NodeSpec( + node_type="pivot", + settings_class=input_schema.NodePivot, + template=NodeTemplate( + name="Pivot data", + item="pivot", + input=1, + output=1, + image="pivot.svg", + node_type="process", + transform_type="wide", + node_group="aggregate", + drawer_title="Pivot Data", + drawer_intro="Convert data from long format to wide format", + tags=[NodeTag.PIVOT, NodeTag.CROSSTAB, NodeTag.RESHAPE], + ), + ai_classification="dynamic", + ), + NodeSpec( + node_type="random_split", + settings_class=input_schema.NodeRandomSplit, + template=NodeTemplate( + name="Random Split", + item="random_split", + input=1, + output=2, + image="random_split.svg", + node_type="process", + transform_type="narrow", + node_group="ml", + drawer_title="Random Split", + drawer_intro="Randomly partition rows into named groups (e.g. train/test)", + laziness="lazy", + output_names=["train", "test"], + tags=[NodeTag.SPLIT, NodeTag.TRAIN, NodeTag.TEST, NodeTag.ML, NodeTag.PARTITION], + ), + ai_classification="static", + ), + NodeSpec( + node_type="dynamic_rename", + settings_class=input_schema.NodeDynamicRename, + template=NodeTemplate( + name="Rename columns", + item="dynamic_rename", + input=1, + output=1, + image="dynamic_rename.svg", + node_type="process", + transform_type="narrow", + node_group="transform", + drawer_title="Rename Columns", + drawer_intro="Bulk-rename columns by prefix, suffix, or a formula", + laziness="lazy", + tags=[NodeTag.RENAME, NodeTag.COLUMNS], + ), + ai_classification="dynamic", + ), + NodeSpec( + node_type="select", + settings_class=input_schema.NodeSelect, + template=NodeTemplate( + name="Select data", + item="select", + input=1, + output=1, + image="select.svg", + node_type="process", + transform_type="narrow", + node_group="transform", + drawer_title="Select Columns", + drawer_intro="Choose, rename, and reorder columns to keep", + laziness="lazy", + tags=[NodeTag.SELECT, NodeTag.COLUMNS, NodeTag.RENAME, NodeTag.REORDER, NodeTag.PROJECTION], + ), + has_default_settings=True, + ai_classification="static", + ), + NodeSpec( + node_type="sort", + settings_class=input_schema.NodeSort, + template=NodeTemplate( + name="Sort data", + item="sort", + input=1, + output=1, + image="sort.svg", + node_type="process", + transform_type="wide", + node_group="transform", + drawer_title="Sort Data", + drawer_intro="Order your data by one or more columns", + laziness="lazy", + tags=[NodeTag.SORT, NodeTag.ORDER, NodeTag.RANK, NodeTag.ASCENDING, NodeTag.DESCENDING], + ), + has_default_settings=True, + ai_classification="static", + compute_factory=_sort_compute, + ), + NodeSpec( + node_type="sample", + settings_class=input_schema.NodeSample, + template=NodeTemplate( + name="Take Sample", + item="sample", + input=1, + output=1, + image="sample.svg", + node_type="process", + transform_type="narrow", + node_group="transform", + drawer_title="Take Sample", + drawer_intro="Work with a subset of your data", + laziness="lazy", + tags=[NodeTag.SAMPLE, NodeTag.SUBSET, NodeTag.LIMIT, NodeTag.HEAD], + ), + has_default_settings=True, + ai_classification="static", + compute_factory=_sample_compute, + ), + NodeSpec( + node_type="text_to_rows", + settings_class=input_schema.NodeTextToRows, + template=NodeTemplate( + name="Text to rows", + item="text_to_rows", + input=1, + output=1, + image="text_to_rows.svg", + node_type="process", + transform_type="wide", + node_group="transform", + drawer_title="Text to Rows", + drawer_intro="Split text into multiple rows based on a delimiter", + laziness="lazy", + tags=[NodeTag.TEXT_TO_ROWS, NodeTag.SPLIT, NodeTag.EXPLODE], + ), + ai_classification="dynamic", + ), + NodeSpec( + node_type="union", + settings_class=input_schema.NodeUnion, + template=NodeTemplate( + name="Union data", + item="union", + input=10, + output=1, + image="union.svg", + multi=True, + node_type="process", + transform_type="narrow", + node_group="combine", + drawer_title="Union Data", + drawer_intro="Stack multiple datasets by combining rows", + laziness="lazy", + tags=[NodeTag.UNION, NodeTag.CONCAT, NodeTag.APPEND], + ), + has_default_settings=True, + ai_classification="static", + compute_factory=_union_compute, + ), + NodeSpec( + node_type="unpivot", + settings_class=input_schema.NodeUnpivot, + template=NodeTemplate( + name="Unpivot data", + item="unpivot", + input=1, + output=1, + image="unpivot.svg", + node_type="process", + transform_type="wide", + node_group="aggregate", + drawer_title="Unpivot Data", + drawer_intro="Transform data from wide format to long format", + laziness="lazy", + tags=[NodeTag.UNPIVOT, NodeTag.MELT, NodeTag.RESHAPE], + ), + ai_classification="dynamic", + ), + NodeSpec( + node_type="wait_for", + settings_class=input_schema.NodeWaitFor, + template=NodeTemplate( + name="Wait For", + item="wait_for", + input=2, + output=1, + image="wait_for.svg", + node_type="process", + transform_type="other", + node_group="combine", + drawer_title="Wait For", + drawer_intro="Pass the left input through; the right input only enforces ordering", + tags=[NodeTag.WAIT, NodeTag.DEPENDENCY], + ), + ai_classification="static", + ), + NodeSpec( + node_type="window_functions", + settings_class=input_schema.NodeWindowFunctions, + template=NodeTemplate( + name="Window functions", + item="window_functions", + input=1, + output=1, + image="window_functions.svg", + node_type="process", + transform_type="wide", + node_group="aggregate", + drawer_title="Window Functions", + drawer_intro="Rolling, cumulative, rank and tile calculations (optionally per partition)", + laziness="lazy", + tags=[ + NodeTag.WINDOW, + NodeTag.ROLLING, + NodeTag.CUMULATIVE, + NodeTag.RANK, + NodeTag.PARTITION, + NodeTag.LAG, + NodeTag.LEAD, + ], + ), + ai_classification="static", + ), +] diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/builtin/special.py b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/special.py new file mode 100644 index 000000000..2125f8341 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/special.py @@ -0,0 +1,45 @@ +"""Settings-only and dict-only node types (no drawer entry). + +Generated from the pre-registry catalogs (NodeTemplate list, settings map, +AI classification map); maintained by hand from here on. +""" + +from flowfile_core.flowfile.node_registry.spec import NodeSpec +from flowfile_core.schemas import input_schema +from flowfile_core.schemas.schemas import NodeTemplate + +SPECS: list[NodeSpec] = [ + NodeSpec( + node_type="promise", + settings_class=input_schema.NodePromise, + ai_classification="passthrough", + ), + NodeSpec( + node_type="user_defined", + settings_class=input_schema.UserDefinedNode, + ai_classification="dynamic", + ), + NodeSpec( + node_type="polars_lazy_frame", + settings_class=None, + template=NodeTemplate( + name="LazyFrame node", + item="polars_lazy_frame", + input=0, + output=1, + image="", + node_type="input", + transform_type="other", + node_group="special", + laziness="lazy", + ), + drawer_visible=False, + ), + # Legacy alias for the read node; its NodeDatasource settings class is + # resolved via the reflective fallback in routes.get_node_model and must + # stay out of the settings map (AI tools generate one tool per map entry). + NodeSpec( + node_type="datasource", + settings_class=None, + ), +] diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/builtin/streaming_sources.py b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/streaming_sources.py new file mode 100644 index 000000000..e9166b82d --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/builtin/streaming_sources.py @@ -0,0 +1,85 @@ +"""External streaming/API source nodes. + +Generated from the pre-registry catalogs (NodeTemplate list, settings map, +AI classification map); maintained by hand from here on. +""" + +from flowfile_core.flowfile.node_registry.spec import NodeSpec +from flowfile_core.schemas import input_schema +from flowfile_core.schemas.schemas import NodeTag, NodeTemplate + +SPECS: list[NodeSpec] = [ + NodeSpec( + node_type="external_source", + settings_class=input_schema.NodeExternalSource, + template=NodeTemplate( + name="External source", + item="external_source", + input=0, + output=1, + image="external_source.svg", + node_type="input", + transform_type="other", + node_group="input", + prod_ready=False, + drawer_title="External Source", + drawer_intro="Connect to external data sources and APIs", + tags=[NodeTag.API, NodeTag.REST, NodeTag.HTTP, NodeTag.EXTERNAL], + ), + ai_classification="source", + ), + NodeSpec( + node_type="google_analytics_reader", + settings_class=input_schema.NodeGoogleAnalyticsReader, + template=NodeTemplate( + name="Google Analytics", + item="google_analytics_reader", + input=0, + output=1, + image="google_analytics.svg", + node_type="input", + transform_type="other", + node_group="input", + drawer_title="Google Analytics", + drawer_intro="Load reports from a Google Analytics 4 property", + tags=[NodeTag.GOOGLE_ANALYTICS, NodeTag.GA4, NodeTag.ANALYTICS], + ), + ai_classification="source", + ), + NodeSpec( + node_type="kafka_source", + settings_class=input_schema.NodeKafkaSource, + template=NodeTemplate( + name="Kafka Source", + item="kafka_source", + input=0, + output=1, + image="kafka_source.svg", + node_type="input", + transform_type="other", + node_group="input", + drawer_title="Kafka Source", + drawer_intro="Read data from a Kafka or Redpanda topic", + tags=[NodeTag.KAFKA, NodeTag.REDPANDA, NodeTag.STREAMING, NodeTag.TOPIC], + ), + ai_classification="source", + ), + NodeSpec( + node_type="rest_api_reader", + settings_class=input_schema.NodeRestApiReader, + template=NodeTemplate( + name="REST API", + item="rest_api_reader", + input=0, + output=1, + image="rest_api_reader.svg", + node_type="input", + transform_type="other", + node_group="input", + drawer_title="REST API", + drawer_intro="Read JSON data from a REST API with auth and pagination", + tags=[NodeTag.REST, NodeTag.API, NodeTag.HTTP, NodeTag.JSON, NodeTag.PAGINATION], + ), + ai_classification="source", + ), +] diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/registry.py b/flowfile_core/flowfile_core/flowfile/node_registry/registry.py new file mode 100644 index 000000000..d194eccaa --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/registry.py @@ -0,0 +1,52 @@ +"""Registry of built-in NodeSpecs, with the derived catalog views.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from flowfile_core.flowfile.node_registry.spec import NodeSpec +from flowfile_core.schemas.schemas import NodeDefault, NodeTemplate + + +class NodeRegistry: + def __init__(self): + self._specs: dict[str, NodeSpec] = {} + + def register(self, spec: NodeSpec) -> None: + if spec.node_type in self._specs: + raise ValueError(f"Node type {spec.node_type!r} is already registered") + self._specs[spec.node_type] = spec + + def get(self, node_type: str) -> NodeSpec | None: + return self._specs.get(node_type) + + def __contains__(self, node_type: str) -> bool: + return node_type in self._specs + + def __iter__(self) -> Iterator[NodeSpec]: + return iter(self._specs.values()) + + def __len__(self) -> int: + return len(self._specs) + + # -- derived views over the legacy catalogs --------------------------------- + + def settings_class_map(self) -> dict[str, type]: + return {s.node_type: s.settings_class for s in self if s.settings_class is not None} + + def drawer_templates(self) -> list[NodeTemplate]: + templates = [s.template for s in self if s.template is not None and s.drawer_visible] + templates.sort(key=lambda t: t.name) + return templates + + def template_dict(self) -> dict[str, NodeTemplate]: + return {s.node_type: s.template for s in self if s.template is not None} + + def node_defaults(self) -> dict[str, NodeDefault]: + return {s.node_type: s.default for s in self if s.template is not None and s.drawer_visible} + + def node_types_with_default_settings(self) -> set[str]: + return {s.node_type for s in self if s.has_default_settings} + + def ai_classification_map(self) -> dict[str, str]: + return {s.node_type: s.ai_classification for s in self if s.ai_classification is not None} diff --git a/flowfile_core/flowfile_core/flowfile/node_registry/spec.py b/flowfile_core/flowfile_core/flowfile/node_registry/spec.py new file mode 100644 index 000000000..54fb2cf53 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/node_registry/spec.py @@ -0,0 +1,104 @@ +"""NodeSpec: single source of truth for a built-in node type. + +One spec bundles what used to live in four hand-maintained catalogs: +the NodeTemplate literal (configs/node_store/nodes.py), the settings-class +map (schemas.NODE_TYPE_TO_SETTINGS_CLASS), the nodes-with-defaults sets, +and the AI classification map (ai/tools/classification._NODE_CLASS_MAP). +Those catalogs are now derived views over the registry. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, Literal + +from flowfile_core.schemas.schemas import NodeDefault, NodeTemplate + +if TYPE_CHECKING: + from pydantic import BaseModel + + from flowfile_core.flowfile.execution.backends import ExecutionBackend + from flowfile_core.flowfile.flow_graph import FlowGraph + +NodeAiClassification = Literal["static", "dynamic", "source", "passthrough"] + + +class InputArity(Enum): + SOURCE = "source" # no inputs (readers, manual_input) + SINGLE = "single" # one main input + DOUBLE = "double" # main + right (join, cross_join, fuzzy_match) + MULTI = "multi" # any number of main inputs (union) + + +@dataclass +class NodeBuildContext: + """What a compute factory may need beyond the settings object.""" + + graph: FlowGraph + node_id: int | str + + def backend(self) -> ExecutionBackend: + # Resolved per call so closures honor the location at run time. + return self.graph.execution_backend + + +# (settings, ctx) -> the node's compute closure, as passed to add_node_step. +ComputeFactory = Callable[[Any, NodeBuildContext], Callable] + + +@dataclass(frozen=True) +class NodeSpec: + """Declarative description of one node type. + + ``template`` is None for settings-only types (promise, user_defined, + datasource); ``drawer_visible=False`` keeps a templated type out of the + UI drawer list (polars_lazy_frame). + """ + + node_type: str + settings_class: type[BaseModel] | None + template: NodeTemplate | None = None + has_default_settings: bool = False + ai_classification: NodeAiClassification | None = None + drawer_visible: bool = True + # Declarative path: FlowGraph._add_from_spec builds the compute closure via + # this factory. None for node types whose add_* method is still explicit. + compute_factory: ComputeFactory | None = None + renew_schema: bool = True + + @property + def input_arity(self) -> InputArity: + if self.template is None: + return InputArity.SINGLE + if self.template.multi: + return InputArity.MULTI + if self.template.input == 0: + return InputArity.SOURCE + if self.template.input == 2: + return InputArity.DOUBLE + return InputArity.SINGLE + + def derive_input_node_ids(self, settings: Any) -> list[int]: + """Input node ids for add_node_step, derived from arity + settings.""" + arity = self.input_arity + if arity is InputArity.SINGLE: + return [settings.depending_on_id] + if arity is InputArity.MULTI: + return settings.depending_on_ids + raise NotImplementedError( + f"Cannot derive input node ids for {self.node_type!r} ({arity}); use an explicit add_* body" + ) + + @property + def default(self) -> NodeDefault | None: + """The NodeDefault view previously built in get_all_standard_nodes.""" + if self.template is None: + return None + return NodeDefault( + node_name=self.template.name, + node_type=self.template.node_type, + transform_type=self.template.transform_type, + has_default_settings=self.has_default_settings, + ) diff --git a/flowfile_core/flowfile_core/routes/routes.py b/flowfile_core/flowfile_core/routes/routes.py index 575758b90..7d8f63e76 100644 --- a/flowfile_core/flowfile_core/routes/routes.py +++ b/flowfile_core/flowfile_core/routes/routes.py @@ -108,8 +108,20 @@ _MANAGED_FLOW_STEM_DISALLOWED_RE = re.compile(r"[^A-Za-z0-9_-]+") -def get_node_model(setting_name_ref: str): - """(Internal) Retrieves a node's Pydantic model from the input_schema module by its name.""" +def get_node_model(setting_name_ref: str, node_type: str | None = None): + """(Internal) Retrieves a node's Pydantic settings model. + + Resolves through the node registry when ``node_type`` is given; falls back + to the legacy reflective scan of the input_schema module by lowercased + class name (still needed for types outside the registry's settings map, + e.g. the legacy ``datasource``). + """ + if node_type is not None: + from flowfile_core.flowfile.node_registry import get_node_spec + + spec = get_node_spec(node_type) + if spec is not None and spec.settings_class is not None: + return spec.settings_class logger.info("Getting node model for: " + setting_name_ref) for ref_name, ref in inspect.getmodule(input_schema).__dict__.items(): if ref_name.lower() == setting_name_ref: @@ -573,7 +585,7 @@ def add_node( if check_if_has_default_setting(node_type): logger.info(f"Found standard settings for {node_type}, trying to upload them") setting_name_ref = "node" + node_type.replace("_", "") - node_model = get_node_model(setting_name_ref) + node_model = get_node_model(setting_name_ref, node_type=node_type) # Temporarily disable history tracking for initial settings original_track_history = flow.flow_settings.track_history @@ -1151,7 +1163,7 @@ def add_generic_settings( if add_func is None: raise HTTPException(404, "could not find the function") try: - ref = get_node_model(setting_name_ref) + ref = get_node_model(setting_name_ref, node_type=node_type) if ref: parsed_input = ref(**input_data) except ValidationError as e: diff --git a/flowfile_core/flowfile_core/schemas/schemas.py b/flowfile_core/flowfile_core/schemas/schemas.py index 5edc2325d..cd4f86343 100644 --- a/flowfile_core/flowfile_core/schemas/schemas.py +++ b/flowfile_core/flowfile_core/schemas/schemas.py @@ -16,52 +16,25 @@ LazinessLiteral = Literal["lazy", "eager", "conditional"] _custom_node_store_cache = None -NODE_TYPE_TO_SETTINGS_CLASS = { - "manual_input": input_schema.NodeManualInput, - "filter": input_schema.NodeFilter, - "formula": input_schema.NodeFormula, - "dynamic_rename": input_schema.NodeDynamicRename, - "select": input_schema.NodeSelect, - "sort": input_schema.NodeSort, - "record_id": input_schema.NodeRecordId, - "sample": input_schema.NodeSample, - "random_split": input_schema.NodeRandomSplit, - "unique": input_schema.NodeUnique, - "group_by": input_schema.NodeGroupBy, - "window_functions": input_schema.NodeWindowFunctions, - "pivot": input_schema.NodePivot, - "unpivot": input_schema.NodeUnpivot, - "text_to_rows": input_schema.NodeTextToRows, - "graph_solver": input_schema.NodeGraphSolver, - "python_script": input_schema.NodePythonScript, - "polars_code": input_schema.NodePolarsCode, - "sql_query": input_schema.NodeSqlQuery, - "join": input_schema.NodeJoin, - "cross_join": input_schema.NodeCrossJoin, - "fuzzy_match": input_schema.NodeFuzzyMatch, - "record_count": input_schema.NodeRecordCount, - "explore_data": input_schema.NodeExploreData, - "union": input_schema.NodeUnion, - "output": input_schema.NodeOutput, - "api_response": input_schema.NodeApiResponse, - "read": input_schema.NodeRead, - "database_reader": input_schema.NodeDatabaseReader, - "database_writer": input_schema.NodeDatabaseWriter, - "cloud_storage_reader": input_schema.NodeCloudStorageReader, - "cloud_storage_writer": input_schema.NodeCloudStorageWriter, - "catalog_reader": input_schema.NodeCatalogReader, - "catalog_writer": input_schema.NodeCatalogWriter, - "kafka_source": input_schema.NodeKafkaSource, - "google_analytics_reader": input_schema.NodeGoogleAnalyticsReader, - "rest_api_reader": input_schema.NodeRestApiReader, - "external_source": input_schema.NodeExternalSource, - "promise": input_schema.NodePromise, - "user_defined": input_schema.UserDefinedNode, - "train_model": input_schema.NodeTrainModel, - "apply_model": input_schema.NodeApplyModel, - "evaluate_model": input_schema.NodeEvaluateModel, - "wait_for": input_schema.NodeWaitFor, -} +# NODE_TYPE_TO_SETTINGS_CLASS is derived from the node registry (the single +# source of truth for node types) and exposed lazily via module __getattr__ +# below to avoid a schemas <-> node_registry import cycle. +_node_type_to_settings_class: dict[str, type] | None = None + + +def _get_node_type_to_settings_class() -> dict[str, type]: + global _node_type_to_settings_class + if _node_type_to_settings_class is None: + from flowfile_core.flowfile.node_registry import BUILTIN_REGISTRY + + _node_type_to_settings_class = BUILTIN_REGISTRY.settings_class_map() + return _node_type_to_settings_class + + +def __getattr__(name: str): + if name == "NODE_TYPE_TO_SETTINGS_CLASS": + return _get_node_type_to_settings_class() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") def get_global_execution_location() -> ExecutionLocationsLiteral: @@ -88,7 +61,7 @@ def _get_custom_node_store(): def get_settings_class_for_node_type(node_type: str): """Get the settings class for a node type, supporting both standard and user-defined nodes.""" - model_class = NODE_TYPE_TO_SETTINGS_CLASS.get(node_type) + model_class = _get_node_type_to_settings_class().get(node_type) if model_class is None: if node_type in _get_custom_node_store(): return input_schema.UserDefinedNode diff --git a/flowfile_core/tests/flowfile/execution/__init__.py b/flowfile_core/tests/flowfile/execution/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flowfile_core/tests/flowfile/execution/test_add_from_spec.py b/flowfile_core/tests/flowfile/execution/test_add_from_spec.py new file mode 100644 index 000000000..e45826cb0 --- /dev/null +++ b/flowfile_core/tests/flowfile/execution/test_add_from_spec.py @@ -0,0 +1,134 @@ +"""Behavior tests for the declarative _add_from_spec path (Phase 4 nodes).""" + +import polars as pl +import pytest + +from flowfile_core.flowfile.flow_graph import FlowGraph, add_connection +from flowfile_core.flowfile.handler import FlowfileHandler +from flowfile_core.schemas import input_schema, schemas, transform_schema + +pytestmark = pytest.mark.usefixtures("flowfile_worker") + +MIGRATED_TYPES = ["filter", "sort", "record_count", "sample", "union"] + + +def _graph(execution_location: str = "remote") -> FlowGraph: + handler = FlowfileHandler() + handler.register_flow( + schemas.FlowSettings(flow_id=1, name="spec_flow", path=".", execution_location=execution_location) + ) + graph = handler.get_flow(1) + graph.add_node_promise(input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input")) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, + node_id=1, + raw_data_format=input_schema.RawData.from_pylist( + [{"a": 1, "b": "x"}, {"a": 2, "b": "y"}, {"a": 3, "b": "z"}] + ), + ) + ) + return graph + + +def _add_node(graph: FlowGraph, node_id: int, node_type: str): + graph.add_node_promise(input_schema.NodePromise(flow_id=1, node_id=node_id, node_type=node_type)) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, node_id)) + if node_type == "filter": + graph.add_filter( + input_schema.NodeFilter( + flow_id=1, + node_id=node_id, + depending_on_id=1, + filter_input=transform_schema.FilterInput(mode="advanced", advanced_filter="[a] > 1"), + ) + ) + elif node_type == "sort": + graph.add_sort( + input_schema.NodeSort( + flow_id=1, + node_id=node_id, + depending_on_id=1, + sort_input=[transform_schema.SortByInput(column="a", how="desc")], + ) + ) + elif node_type == "record_count": + graph.add_record_count(input_schema.NodeRecordCount(flow_id=1, node_id=node_id, depending_on_id=1)) + elif node_type == "sample": + graph.add_sample(input_schema.NodeSample(flow_id=1, node_id=node_id, depending_on_id=1, sample_size=2)) + elif node_type == "union": + second_input_id = node_id + 100 + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=second_input_id, node_type="manual_input") + ) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, + node_id=second_input_id, + raw_data_format=input_schema.RawData.from_pylist([{"a": 4, "b": "w"}]), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(second_input_id, node_id)) + graph.add_union(input_schema.NodeUnion(flow_id=1, node_id=node_id, depending_on_ids=[1, second_input_id])) + + +@pytest.mark.parametrize("node_type", MIGRATED_TYPES) +def test_spec_node_preserves_identity_conventions(node_type): + graph = _graph() + _add_node(graph, 2, node_type) + node = graph.get_node(2) + # add_node_step conventions the skip-logic and saved flows rely on. + assert node.name == node_type + assert node._function.__name__ == "_func" + assert node.node_type == node_type + assert node.setting_input.node_id == 2 + + +@pytest.mark.parametrize("node_type", MIGRATED_TYPES) +def test_spec_node_runs_and_produces_expected_data(node_type): + graph = _graph() + _add_node(graph, 2, node_type) + run_info = graph.run_graph() + assert run_info.success, f"run failed: {[r.error for r in run_info.node_step_result if not r.success]}" + result = graph.get_node(2).get_resulting_data().data_frame.lazy().collect() + if node_type == "filter": + assert result["a"].to_list() == [2, 3] + elif node_type == "sort": + assert result["a"].to_list() == [3, 2, 1] + elif node_type == "record_count": + assert result["number_of_records"].to_list() == [3] + elif node_type == "sample": + assert result.height == 2 + elif node_type == "union": + assert result.height == 4 + + +def test_spec_node_serialization_roundtrip(tmp_path): + graph = _graph() + for i, node_type in enumerate(MIGRATED_TYPES, start=2): + _add_node(graph, i, node_type) + storage = graph.get_flowfile_data() + node_types = {n.id: n.type for n in storage.nodes} + for i, node_type in enumerate(MIGRATED_TYPES, start=2): + assert node_types[i] == node_type + + +def test_filter_split_mode_still_works(): + graph = _graph() + graph.add_node_promise(input_schema.NodePromise(flow_id=1, node_id=2, node_type="filter")) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + graph.add_filter( + input_schema.NodeFilter( + flow_id=1, + node_id=2, + depending_on_id=1, + filter_input=transform_schema.FilterInput(mode="advanced", advanced_filter="[a] > 1"), + split_mode=True, + ) + ) + run_info = graph.run_graph() + assert run_info.success + named = graph.get_node(2)._named_outputs + assert len(named) == 2 + heights = sorted(fde.data_frame.lazy().select(pl.len()).collect().item() for fde in named.values()) + assert heights == [1, 2] diff --git a/flowfile_core/tests/flowfile/execution/test_backend_parity.py b/flowfile_core/tests/flowfile/execution/test_backend_parity.py new file mode 100644 index 000000000..a81c30b1c --- /dev/null +++ b/flowfile_core/tests/flowfile/execution/test_backend_parity.py @@ -0,0 +1,93 @@ +"""Parity tests: nodes migrated to the ExecutionBackend seam behave the same +in local and remote execution locations.""" + +from pathlib import Path + +import polars as pl +import pytest + +from flowfile_core.flowfile.flow_graph import add_connection +from flowfile_core.flowfile.handler import FlowfileHandler +from flowfile_core.schemas import input_schema, schemas + +pytestmark = pytest.mark.usefixtures("flowfile_worker") + + +def _create_graph(execution_location: str, flow_id: int = 1): + handler = FlowfileHandler() + handler.register_flow( + schemas.FlowSettings( + flow_id=flow_id, + name="parity_flow", + path=".", + execution_mode="Development", + execution_location=execution_location, + ) + ) + return handler.get_flow(flow_id) + + +def _add_manual_input(graph, data, node_id: int = 1): + graph.add_node_promise(input_schema.NodePromise(flow_id=1, node_id=node_id, node_type="manual_input")) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=node_id, raw_data_format=input_schema.RawData.from_pylist(data) + ) + ) + + +def _connect(graph, from_id: int, to_id: int): + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(from_id, to_id)) + + +_DATA = [{"a": i, "b": f"row{i}"} for i in range(100)] + + +@pytest.mark.parametrize("execution_location", ["local", "remote"]) +def test_output_node_writes_file_in_both_locations(execution_location, tmp_path: Path): + graph = _create_graph(execution_location) + _add_manual_input(graph, _DATA, node_id=1) + graph.add_node_promise(input_schema.NodePromise(flow_id=1, node_id=2, node_type="output")) + _connect(graph, 1, 2) + + out_file = tmp_path / f"parity_{execution_location}.csv" + output_settings = input_schema.OutputSettings( + name=out_file.name, + directory=str(tmp_path), + file_type="csv", + table_settings=input_schema.OutputCsvTable(), + ) + graph.add_output(input_schema.NodeOutput(flow_id=1, node_id=2, output_settings=output_settings)) + + run_info = graph.run_graph() + assert run_info.success, f"run failed: {[r.error for r in run_info.node_step_result if not r.success]}" + assert out_file.exists() + written = pl.read_csv(out_file) + assert written.height == len(_DATA) + + +@pytest.mark.parametrize("execution_location", ["local", "remote"]) +def test_random_split_node_in_both_locations(execution_location): + graph = _create_graph(execution_location) + _add_manual_input(graph, _DATA, node_id=1) + graph.add_node_promise(input_schema.NodePromise(flow_id=1, node_id=2, node_type="random_split")) + _connect(graph, 1, 2) + graph.add_random_split( + input_schema.NodeRandomSplit( + flow_id=1, + node_id=2, + splits=[ + input_schema.RandomSplitGroup(name="train", percentage=70.0), + input_schema.RandomSplitGroup(name="test", percentage=30.0), + ], + seed=42, + ) + ) + + run_info = graph.run_graph() + assert run_info.success, f"run failed: {[r.error for r in run_info.node_step_result if not r.success]}" + node = graph.get_node(2) + named_outputs = node._named_outputs + assert set(named_outputs) == {"output-0", "output-1"} + total = sum(fde.data_frame.lazy().select(pl.len()).collect().item() for fde in named_outputs.values()) + assert total == len(_DATA) diff --git a/flowfile_core/tests/flowfile/execution/test_backends.py b/flowfile_core/tests/flowfile/execution/test_backends.py new file mode 100644 index 000000000..f7bed6789 --- /dev/null +++ b/flowfile_core/tests/flowfile/execution/test_backends.py @@ -0,0 +1,85 @@ +from uuid import uuid4 + +import polars as pl +import pytest + +from flowfile_core.flowfile.execution.backends import ( + LocalBackend, + RemoteWorkerBackend, + resolve_backend, +) +from flowfile_core.flowfile.execution.handles import LocalResultHandle, TaskHandle +from flowfile_core.utils.arrow_reader import read_top_n + + +@pytest.fixture(params=["local", "remote"]) +def backend(request): + return resolve_backend(request.param) + + +def _sample_lf() -> pl.LazyFrame: + return pl.LazyFrame({"a": [1, 2, 3, 4, 5], "b": ["x", "y", "z", "x", "y"]}) + + +def test_resolve_backend_mapping(): + assert isinstance(resolve_backend("local"), LocalBackend) + assert isinstance(resolve_backend("remote"), RemoteWorkerBackend) + + +def test_remote_location_never_resolves_local(): + # Guard for the core-never-collects contract: any non-local location must + # route full materialisation to the worker backend. + assert isinstance(resolve_backend("remote"), RemoteWorkerBackend) + + +def test_run_lazyframe_returns_equivalent_result(backend): + handle = backend.run_lazyframe( + _sample_lf(), flow_id=1, node_id=-1, file_ref=str(uuid4()), wait_on_completion=True + ) + assert isinstance(handle, TaskHandle) + result = handle.get_result() + assert isinstance(result, pl.LazyFrame) + assert result.collect().sort("a").equals(_sample_lf().collect().sort("a")) + + +def test_count_records_parity(backend): + assert backend.count_records(_sample_lf(), flow_id=1, node_id=-1) == 5 + + +def test_sample_writes_readable_arrow_file(backend): + handle = backend.sample( + _sample_lf(), file_ref=str(uuid4()), flow_id=1, node_id=-1, sample_size=3, wait_on_completion=True + ) + assert handle.status is not None + table = read_top_n(handle.status.file_ref, n=3) + assert table.num_rows == 3 + assert set(table.column_names) == {"a", "b"} + + +def test_local_backend_has_no_result_cache(): + backend = LocalBackend() + assert backend.results_exist("does-not-exist") is False + assert backend.clear_result("does-not-exist") is False + with pytest.raises(Exception): + backend.get_cached_lazyframe("does-not-exist") + + +def test_remote_backend_result_cache_roundtrip(): + backend = resolve_backend("remote") + file_ref = str(uuid4()) + handle = backend.run_lazyframe(_sample_lf(), flow_id=1, node_id=-1, file_ref=file_ref, wait_on_completion=True) + handle.get_result() + assert backend.results_exist(file_ref) is True + cached = backend.get_cached_lazyframe(file_ref) + assert cached.collect().sort("a").equals(_sample_lf().collect().sort("a")) + assert backend.clear_result(file_ref) is True + assert backend.results_exist(file_ref) is False + + +def test_local_result_handle_is_task_handle(): + handle = LocalResultHandle(result=42, file_ref="x") + assert isinstance(handle, TaskHandle) + assert handle.get_result() == 42 + assert handle.error_code == 0 + assert handle.error_description is None + handle.cancel() diff --git a/flowfile_core/tests/flowfile/execution/test_no_inline_location_branches.py b/flowfile_core/tests/flowfile/execution/test_no_inline_location_branches.py new file mode 100644 index 000000000..ced003260 --- /dev/null +++ b/flowfile_core/tests/flowfile/execution/test_no_inline_location_branches.py @@ -0,0 +1,42 @@ +"""Ratchet guard: inline execution_location branches in flow_graph.py may only shrink. + +Nodes migrated to the ExecutionBackend seam (output, database_reader, +random_split) must not regress to inline ``if execution_location == "local"`` +branches, and new nodes must not add any. Lower the ceiling as follow-up +migrations land; never raise it. +""" + +import re +from pathlib import Path + +import flowfile_core.flowfile.flow_graph as flow_graph_module + +# 13 sites existed before the backend seam; 3 have been migrated. +MAX_INLINE_LOCATION_COMPARISONS = 10 + +_MIGRATED_FUNCTIONS = ["def add_output", "def add_database_reader", "def add_random_split"] + + +def _flow_graph_source() -> str: + return Path(flow_graph_module.__file__).read_text() + + +def test_inline_location_branch_count_only_shrinks(): + source = _flow_graph_source() + comparisons = re.findall(r'execution_location [!=]= "local"', source) + assert len(comparisons) <= MAX_INLINE_LOCATION_COMPARISONS, ( + f"flow_graph.py has {len(comparisons)} inline execution_location comparisons " + f"(ceiling {MAX_INLINE_LOCATION_COMPARISONS}). Route new local/remote variance through " + f"ExecutionBackend (flowfile/execution/backends) instead of branching inline." + ) + + +def test_migrated_builders_stay_branch_free(): + source = _flow_graph_source() + for marker in _MIGRATED_FUNCTIONS: + start = source.index(marker) + # A method body ends where the next def at the same indentation starts. + next_def = source.find("\n def ", start + 1) + body = source[start : next_def if next_def != -1 else len(source)] + assert 'execution_location == "local"' not in body, f"{marker} regressed to an inline location branch" + assert 'execution_location != "local"' not in body, f"{marker} regressed to an inline location branch" diff --git a/flowfile_core/tests/flowfile/execution/test_node_registry.py b/flowfile_core/tests/flowfile/execution/test_node_registry.py new file mode 100644 index 000000000..df94d7063 --- /dev/null +++ b/flowfile_core/tests/flowfile/execution/test_node_registry.py @@ -0,0 +1,149 @@ +"""Contract tests: the node registry is the single source of truth and its +derived views match frozen snapshots of the pre-registry catalogs.""" + +from flowfile_core.flowfile.flow_graph import FlowGraph +from flowfile_core.flowfile.node_registry import BUILTIN_REGISTRY, InputArity, get_node_spec + +# Frozen snapshot of NODE_TYPE_TO_SETTINGS_CLASS before it became a derived +# view (node_type -> settings class name). Guards against silent drift. +SETTINGS_MAP_SNAPSHOT = { + "manual_input": "NodeManualInput", + "filter": "NodeFilter", + "formula": "NodeFormula", + "dynamic_rename": "NodeDynamicRename", + "select": "NodeSelect", + "sort": "NodeSort", + "record_id": "NodeRecordId", + "sample": "NodeSample", + "random_split": "NodeRandomSplit", + "unique": "NodeUnique", + "group_by": "NodeGroupBy", + "window_functions": "NodeWindowFunctions", + "pivot": "NodePivot", + "unpivot": "NodeUnpivot", + "text_to_rows": "NodeTextToRows", + "graph_solver": "NodeGraphSolver", + "python_script": "NodePythonScript", + "polars_code": "NodePolarsCode", + "sql_query": "NodeSqlQuery", + "join": "NodeJoin", + "cross_join": "NodeCrossJoin", + "fuzzy_match": "NodeFuzzyMatch", + "record_count": "NodeRecordCount", + "explore_data": "NodeExploreData", + "union": "NodeUnion", + "output": "NodeOutput", + "api_response": "NodeApiResponse", + "read": "NodeRead", + "database_reader": "NodeDatabaseReader", + "database_writer": "NodeDatabaseWriter", + "cloud_storage_reader": "NodeCloudStorageReader", + "cloud_storage_writer": "NodeCloudStorageWriter", + "catalog_reader": "NodeCatalogReader", + "catalog_writer": "NodeCatalogWriter", + "kafka_source": "NodeKafkaSource", + "google_analytics_reader": "NodeGoogleAnalyticsReader", + "rest_api_reader": "NodeRestApiReader", + "external_source": "NodeExternalSource", + "promise": "NodePromise", + "user_defined": "UserDefinedNode", + "train_model": "NodeTrainModel", + "apply_model": "NodeApplyModel", + "evaluate_model": "NodeEvaluateModel", + "wait_for": "NodeWaitFor", +} + +NODES_WITH_DEFAULTS_SNAPSHOT = {"sample", "sort", "union", "select", "record_count"} + + +def test_settings_class_map_matches_snapshot(): + derived = {k: v.__name__ for k, v in BUILTIN_REGISTRY.settings_class_map().items()} + assert derived == SETTINGS_MAP_SNAPSHOT + + +def test_schemas_module_exposes_derived_map(): + from flowfile_core.schemas import schemas + + assert {k: v.__name__ for k, v in schemas.NODE_TYPE_TO_SETTINGS_CLASS.items()} == SETTINGS_MAP_SNAPSHOT + from flowfile_core.schemas.schemas import NODE_TYPE_TO_SETTINGS_CLASS + + assert NODE_TYPE_TO_SETTINGS_CLASS is schemas.NODE_TYPE_TO_SETTINGS_CLASS + + +def test_nodes_with_defaults_derived(): + from flowfile_core.configs import node_store + + assert BUILTIN_REGISTRY.node_types_with_default_settings() == NODES_WITH_DEFAULTS_SNAPSHOT + assert node_store.nodes_with_defaults == NODES_WITH_DEFAULTS_SNAPSHOT + assert node_store.check_if_has_default_setting("sort") + assert not node_store.check_if_has_default_setting("filter") + + +def test_ai_classification_map_derived(): + from flowfile_core.ai.tools import classification + + derived = classification._NODE_CLASS_MAP + assert derived == BUILTIN_REGISTRY.ai_classification_map() + # Spot-check the buckets that drive executor routing. + assert classification.classify_node_type("manual_input") == "source" + assert classification.classify_node_type("pivot") == "dynamic" + assert classification.classify_node_type("filter") == "static" + assert classification.classify_node_type("promise") == "passthrough" + assert classification.classify_node_type("never_heard_of_it") == "dynamic" + + +def test_every_template_has_exactly_one_spec(): + from flowfile_core.configs.node_store.nodes import get_all_standard_nodes + + nodes_list, node_dict, node_defaults = get_all_standard_nodes() + for template in nodes_list: + spec = get_node_spec(template.item) + assert spec is not None, f"template {template.item} has no NodeSpec" + assert spec.template == template + assert "polars_lazy_frame" in node_dict + assert "polars_lazy_frame" not in node_defaults + assert {t.item for t in nodes_list} == set(node_defaults) + + +def test_every_drawer_spec_has_matching_add_method(): + for spec in BUILTIN_REGISTRY: + if spec.template is None or not spec.drawer_visible: + continue + if spec.node_type == "user_defined": + continue + assert hasattr(FlowGraph, f"add_{spec.node_type}"), ( + f"NodeSpec {spec.node_type} has no FlowGraph.add_{spec.node_type} method" + ) + + +def test_node_defaults_match_templates(): + from flowfile_core.configs.node_store.nodes import get_all_standard_nodes + + _, _, node_defaults = get_all_standard_nodes() + for item, default in node_defaults.items(): + spec = get_node_spec(item) + assert default.node_name == spec.template.name + assert default.node_type == spec.template.node_type + assert default.transform_type == spec.template.transform_type + assert bool(default.has_default_settings) == spec.has_default_settings + + +def test_input_arity_derivation(): + assert get_node_spec("read").input_arity is InputArity.SOURCE + assert get_node_spec("filter").input_arity is InputArity.SINGLE + assert get_node_spec("join").input_arity is InputArity.DOUBLE + assert get_node_spec("union").input_arity is InputArity.MULTI + + +def test_routes_node_model_resolution_matches_registry(): + from flowfile_core.routes.routes import get_node_model + + for spec in BUILTIN_REGISTRY: + if spec.settings_class is None: + continue + ref = get_node_model("node" + spec.node_type.replace("_", ""), node_type=spec.node_type) + assert ref is spec.settings_class, f"{spec.node_type} resolved to {ref}" + # Legacy reflective fallback still resolves types outside the settings map. + from flowfile_core.schemas import input_schema + + assert get_node_model("nodedatasource", node_type="datasource") is input_schema.NodeDatasource diff --git a/flowfile_core/tests/flowfile/execution/test_transport.py b/flowfile_core/tests/flowfile/execution/test_transport.py new file mode 100644 index 000000000..f3c33a3f8 --- /dev/null +++ b/flowfile_core/tests/flowfile/execution/test_transport.py @@ -0,0 +1,143 @@ +import pytest +import requests + +from flowfile_core.flowfile.execution.exceptions import ( + WorkerConnectionError, + WorkerError, + WorkerTaskError, +) +from flowfile_core.flowfile.execution.transport import WorkerTransport, get_default_transport + + +class FakeResponse: + def __init__(self, status_code=200, json_data=None, text="", ok=None): + self.status_code = status_code + self._json_data = json_data or {} + self.text = text + self.ok = ok if ok is not None else status_code < 400 + + def json(self): + return self._json_data + + +def _capture_requests(monkeypatch, response=None, exc=None): + calls = [] + + def fake_request(method, url, **kwargs): + calls.append((method, url, kwargs)) + if exc is not None: + raise exc + return response or FakeResponse() + + monkeypatch.setattr(requests, "request", fake_request) + return calls + + +def test_base_url_from_string(): + t = WorkerTransport(base_url="http://myworker:1234") + assert t.base_url == "http://myworker:1234" + + +def test_base_url_from_callable_resolved_per_call(): + urls = iter(["http://a:1", "http://b:2"]) + t = WorkerTransport(base_url=lambda: next(urls)) + assert t.base_url == "http://a:1" + assert t.base_url == "http://b:2" + + +def test_base_url_defaults_to_settings(): + from flowfile_core.configs.settings import WORKER_URL + + assert WorkerTransport().base_url == WORKER_URL + + +def test_ws_base_url_conversion(): + assert WorkerTransport(base_url="http://w:1").ws_base_url == "ws://w:1" + assert WorkerTransport(base_url="https://w:1").ws_base_url == "wss://w:1" + + +def test_post_builds_url_and_passes_kwargs(monkeypatch): + calls = _capture_requests(monkeypatch) + t = WorkerTransport(base_url="http://w:1") + t.post("/submit_query/", data=b"abc", headers={"X-Task-Id": "t1"}) + method, url, kwargs = calls[0] + assert method == "post" + assert url == "http://w:1/submit_query/" + assert kwargs["data"] == b"abc" + assert kwargs["headers"]["X-Task-Id"] == "t1" + + +def test_connection_error_raises_typed_error(monkeypatch): + _capture_requests(monkeypatch, exc=requests.exceptions.ConnectionError("Connection refused")) + t = WorkerTransport(base_url="http://w:1") + with pytest.raises(WorkerConnectionError) as excinfo: + t.get("/status/abc") + assert "http://w:1/status/abc" in str(excinfo.value) + + +def test_worker_connection_error_is_requests_exception(): + # Legacy call sites catch requests.RequestException; the typed error must stay catchable there. + err = WorkerConnectionError("nope") + assert isinstance(err, requests.exceptions.ConnectionError) + assert isinstance(err, requests.RequestException) + assert isinstance(err, WorkerError) + + +def test_get_status_returns_status(monkeypatch): + payload = { + "background_task_id": "t1", + "status": "Completed", + "file_ref": "t1", + "progress": 100, + "results": None, + } + _capture_requests(monkeypatch, response=FakeResponse(200, payload)) + status = WorkerTransport(base_url="http://w:1").get_status("t1") + assert status.status == "Completed" + assert status.background_task_id == "t1" + + +def test_get_status_raises_on_error(monkeypatch): + _capture_requests(monkeypatch, response=FakeResponse(404, text="not found")) + with pytest.raises(WorkerTaskError): + WorkerTransport(base_url="http://w:1").get_status("t1") + + +def test_results_exist_paths(monkeypatch): + t = WorkerTransport(base_url="http://w:1") + _capture_requests(monkeypatch, response=FakeResponse(200, {"status": "Completed"})) + assert t.results_exist("t1") is True + _capture_requests(monkeypatch, response=FakeResponse(200, {"status": "Processing"})) + assert t.results_exist("t1") is False + _capture_requests(monkeypatch, exc=requests.exceptions.ConnectionError("refused")) + assert t.results_exist("t1") is False + + +def test_cancel_and_clear_task(monkeypatch): + t = WorkerTransport(base_url="http://w:1") + calls = _capture_requests(monkeypatch, response=FakeResponse(200)) + assert t.cancel_task("t1") is True + assert t.clear_task("t1") is True + assert [(m, u) for m, u, _ in calls] == [ + ("post", "http://w:1/cancel_task/t1"), + ("delete", "http://w:1/clear_task/t1"), + ] + + +def test_default_transport_is_singleton(): + assert get_default_transport() is get_default_transport() + + +def test_single_file_mode_url_gets_worker_prefix(): + from flowfile_core.configs import settings + + original = bool(settings.SINGLE_FILE_MODE) + settings.SINGLE_FILE_MODE.set(True) + try: + url = settings.get_default_worker_url(worker_port=63578) + assert url.endswith("/worker") + t = WorkerTransport(base_url=url) + assert t.base_url.endswith("/worker") + assert t.ws_base_url.startswith("ws://") and t.ws_base_url.endswith("/worker") + finally: + settings.SINGLE_FILE_MODE.set(original)