diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fdb252aa..d9df8bdce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) - Add agentic search workflow template with flow agent ([#1349](https://github.com/opensearch-project/flow-framework/pull/1349)) - Add agentic search workflow template with conversational agent ([#1353](https://github.com/opensearch-project/flow-framework/pull/1353)) ### Enhancements +- Use ML Commons validation methods for name/description fields during workflow parsing ([#1368](https://github.com/opensearch-project/flow-framework/pull/1368)) - Support Jackson 3.x release line ([#1376](https://github.com/opensearch-project/flow-framework/pull/1376)) ### Bug Fixes - Handle ResourceAlreadyExistsException race condition in FlowFrameworkIndicesHandler ([#1378](https://github.com/opensearch-project/flow-framework/pull/1378)) diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 2f64a2683..ffd40a007 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; @@ -21,6 +22,8 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.ml.common.utils.FieldDescriptor; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.plugins.PluginInfo; import org.opensearch.plugins.PluginsService; import org.opensearch.threadpool.ThreadPool; @@ -39,6 +42,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS; @@ -64,6 +69,22 @@ public class WorkflowProcessSorter { DeleteSearchPipelineStep.NAME ); + /** ML Commons step types that have name/description field validation, mapped to a human-readable label */ + private static final Map ML_COMMONS_VALIDATED_STEPS = Map.of( + CreateConnectorStep.NAME, + "Model connector", + RegisterRemoteModelStep.NAME, + "Model", + RegisterLocalCustomModelStep.NAME, + "Model", + RegisterLocalSparseEncodingModelStep.NAME, + "Model", + RegisterLocalPretrainedModelStep.NAME, + "Model", + RegisterModelGroupStep.NAME, + "Model group" + ); + private WorkflowStepFactory workflowStepFactory; private ThreadPool threadPool; private Integer maxWorkflowSteps; @@ -437,6 +458,7 @@ public void validate(List processNodes, PluginsService pluginsServi .collect(Collectors.toList()); validatePluginsInstalled(processNodes, installedPlugins); validateGraph(processNodes); + validateFieldValues(processNodes); } /** @@ -510,6 +532,45 @@ public void validateGraph(List processNodes) throws Exception { } } + /** + * Validates field values in process nodes using ML Commons validation methods. + * Checks that name and description fields in ML Commons steps contain only safe characters. + * @param processNodes A list of process nodes + * @throws Exception on validation failure + */ + public void validateFieldValues(List processNodes) throws Exception { + Map fieldsToValidate = new HashMap<>(); + + for (ProcessNode processNode : processNodes) { + String nodeType = processNode.workflowStep().getName(); + if (!ML_COMMONS_VALIDATED_STEPS.containsKey(nodeType)) { + continue; + } + Map userInputs = processNode.input().getContent(); + String stepLabel = ML_COMMONS_VALIDATED_STEPS.get(nodeType); + + Object nameValue = userInputs.get(NAME_FIELD); + if (nameValue instanceof String) { + fieldsToValidate.put(stepLabel + " name [node " + processNode.id() + "]", new FieldDescriptor((String) nameValue, true)); + } + + Object descValue = userInputs.get(DESCRIPTION_FIELD); + if (descValue instanceof String) { + fieldsToValidate.put( + stepLabel + " description [node " + processNode.id() + "]", + new FieldDescriptor((String) descValue, false) + ); + } + } + + if (!fieldsToValidate.isEmpty()) { + ActionRequestValidationException exception = StringUtils.validateFields(fieldsToValidate); + if (exception != null) { + throw new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST); + } + } + } + /** * A method for parsing workflow timeout value. * The value could be parsed from node NODE_TIMEOUT_FIELD, the timeout field in workflow-step.json, diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index aa4fb9917..514e67156 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.Version; +import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; @@ -26,6 +27,8 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.plugins.PluginInfo; +import org.opensearch.plugins.PluginsService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -832,4 +835,121 @@ public void testCreateReprovisionSequenceWithUpdates() throws Exception { assertTrue(reprovisionWorkflowStepNames.contains(UpdateIndexStep.NAME)); } + public void testValidate() throws Exception { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Collections.emptyMap(), + Map.ofEntries( + Map.entry("name", "My Valid Connector"), + Map.entry("description", "A valid description."), + Map.entry("version", "1"), + Map.entry("protocol", "http"), + Map.entry("parameters", "{}"), + Map.entry("credential", "{}"), + Map.entry("actions", "[]") + ) + ); + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + RegisterRemoteModelStep.NAME, + Map.ofEntries(Map.entry("workflow_step_1", CONNECTOR_ID)), + Map.ofEntries(Map.entry("name", "Valid Model"), Map.entry("function_name", "remote"), Map.entry("description", "Valid desc")) + ); + WorkflowEdge edge = new WorkflowEdge(createConnector.id(), registerModel.id()); + Workflow workflow = new Workflow(Collections.emptyMap(), List.of(createConnector, registerModel), List.of(edge)); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap(), null); + + PluginsService pluginsService = mock(PluginsService.class); + PluginsAndModules pluginsAndModules = mock(PluginsAndModules.class); + when(pluginsService.info()).thenReturn(pluginsAndModules); + when(pluginsAndModules.getPluginInfos()).thenReturn( + List.of( + new PluginInfo("opensearch-flow-framework", "", "", Version.CURRENT, "", "", List.of(), false), + new PluginInfo("opensearch-ml", "", "", Version.CURRENT, "", "", List.of(), false) + ) + ); + workflowProcessSorter.validate(sortedProcessNodes, pluginsService); + } + + public void testSuccessfulFieldValueValidation() throws Exception { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Collections.emptyMap(), + Map.ofEntries( + Map.entry("name", "My Valid Connector"), + Map.entry("description", "A valid description."), + Map.entry("version", "1"), + Map.entry("protocol", "http"), + Map.entry("parameters", "{}"), + Map.entry("credential", "{}"), + Map.entry("actions", "[]") + ) + ); + Workflow workflow = new Workflow(Collections.emptyMap(), List.of(createConnector), Collections.emptyList()); + List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap(), null); + workflowProcessSorter.validateFieldValues(sortedProcessNodes); + } + + public void testFieldValueValidationInvalidName() throws Exception { + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + CreateConnectorStep.NAME, + Collections.emptyMap(), + Map.ofEntries( + Map.entry("name", "Bad