From 2f2b32ebb5555b6f36a074bac77792a4ae3ecede Mon Sep 17 00:00:00 2001 From: Dan Widdis Date: Tue, 21 Apr 2026 17:24:11 +0000 Subject: [PATCH 1/3] Use ML Commons validation methods in workflow parsing Adds field value validation during workflow template validation using ML Commons StringUtils.validateFields(). When validation=all is set, name and description fields in ML Commons steps (create_connector, register_remote_model, register_local_custom_model, register_local_sparse_encoding_model, register_local_pretrained_model, register_model_group) are now checked for safe characters before provisioning. Resolves #1152 Signed-off-by: Dan Widdis --- .../workflow/WorkflowProcessSorter.java | 61 ++++++++++++++ .../workflow/WorkflowProcessSorterTests.java | 80 +++++++++++++++++++ 2 files changed, 141 insertions(+) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 2f64a2683..ffd40a007 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; @@ -21,6 +22,8 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.ml.common.utils.FieldDescriptor; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.plugins.PluginInfo; import org.opensearch.plugins.PluginsService; import org.opensearch.threadpool.ThreadPool; @@ -39,6 +42,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; @@ -64,6 +69,22 @@ public class WorkflowProcessSorter { DeleteSearchPipelineStep.NAME ); + /** ML Commons step types that have name/description field validation, mapped to a human-readable label */ + private static final Map ML_COMMONS_VALIDATED_STEPS = Map.of( + CreateConnectorStep.NAME, + "Model connector", + RegisterRemoteModelStep.NAME, + "Model", + RegisterLocalCustomModelStep.NAME, + "Model", + RegisterLocalSparseEncodingModelStep.NAME, + "Model", + RegisterLocalPretrainedModelStep.NAME, + "Model", + RegisterModelGroupStep.NAME, + "Model group" + ); + private WorkflowStepFactory workflowStepFactory; private ThreadPool threadPool; private Integer maxWorkflowSteps; @@ -437,6 +458,7 @@ public void validate(List processNodes, PluginsService pluginsServi .collect(Collectors.toList()); validatePluginsInstalled(processNodes, installedPlugins); validateGraph(processNodes); + validateFieldValues(processNodes); } /** @@ -510,6 +532,45 @@ public void validateGraph(List processNodes) throws Exception { } } + /** + * Validates field values in process nodes using ML Commons validation methods. + * Checks that name and description fields in ML Commons steps contain only safe characters. + * @param processNodes A list of process nodes + * @throws Exception on validation failure + */ + public void validateFieldValues(List processNodes) throws Exception { + Map fieldsToValidate = new HashMap<>(); + + for (ProcessNode processNode : processNodes) { + String nodeType = processNode.workflowStep().getName(); + if (!ML_COMMONS_VALIDATED_STEPS.containsKey(nodeType)) { + continue; + } + Map userInputs = processNode.input().getContent(); + String stepLabel = ML_COMMONS_VALIDATED_STEPS.get(nodeType); + + Object nameValue = userInputs.get(NAME_FIELD); + if (nameValue instanceof String) { + fieldsToValidate.put(stepLabel + " name [node " + processNode.id() + "]", new FieldDescriptor((String) nameValue, true)); + } + + Object descValue = userInputs.get(DESCRIPTION_FIELD); + if (descValue instanceof String) { + fieldsToValidate.put( + stepLabel + " description [node " + processNode.id() + "]", + new FieldDescriptor((String) descValue, false) + ); + } + } + + if (!fieldsToValidate.isEmpty()) { + ActionRequestValidationException exception = StringUtils.validateFields(fieldsToValidate); + if (exception != null) { + throw new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST); + } + } + } + /** * A method for parsing workflow timeout value. * The value could be parsed from node NODE_TIMEOUT_FIELD, the timeout field in workflow-step.json, diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index aa4fb9917..2174ad22e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -832,4 +832,84 @@ public void testCreateReprovisionSequenceWithUpdates() throws Exception { assertTrue(reprovisionWorkflowStepNames.contains(UpdateIndexStep.NAME)); } + public void testSuccessfulFieldValueValidation() throws Exception { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Collections.emptyMap(), + Map.ofEntries( + Map.entry("name", "My Valid Connector"), + Map.entry("description", "A valid description."), + Map.entry("version", "1"), + Map.entry("protocol", "http"), + Map.entry("parameters", "{}"), + Map.entry("credential", "{}"), + Map.entry("actions", "[]") + ) + ); + Workflow workflow = new Workflow(Collections.emptyMap(), List.of(createConnector), Collections.emptyList()); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap(), null); + workflowProcessSorter.validateFieldValues(sortedProcessNodes); + } + + public void testFieldValueValidationInvalidName() throws Exception { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Collections.emptyMap(), + Map.ofEntries( + Map.entry("name", "Bad