Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
- Add agentic search workflow template with flow agent ([#1349](https://github.com/opensearch-project/flow-framework/pull/1349))
- Add agentic search workflow template with conversational agent ([#1353](https://github.com/opensearch-project/flow-framework/pull/1353))
### Enhancements
- Use ML Commons validation methods for name/description fields during workflow parsing ([#1368](https://github.com/opensearch-project/flow-framework/pull/1368))
- Support Jackson 3.x release line ([#1376](https://github.com/opensearch-project/flow-framework/pull/1376))
### Bug Fixes
- Handle ResourceAlreadyExistsException race condition in FlowFrameworkIndicesHandler ([#1378](https://github.com/opensearch-project/flow-framework/pull/1378))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
Expand All @@ -21,6 +22,8 @@
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.common.utils.FieldDescriptor;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.plugins.PluginInfo;
import org.opensearch.plugins.PluginsService;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -39,6 +42,8 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.MAX_WORKFLOW_STEPS;
Expand All @@ -64,6 +69,22 @@ public class WorkflowProcessSorter {
DeleteSearchPipelineStep.NAME
);

/** ML Commons step types that have name/description field validation, mapped to a human-readable label */
private static final Map<String, String> ML_COMMONS_VALIDATED_STEPS = Map.of(
CreateConnectorStep.NAME,
"Model connector",
RegisterRemoteModelStep.NAME,
"Model",
RegisterLocalCustomModelStep.NAME,
"Model",
RegisterLocalSparseEncodingModelStep.NAME,
"Model",
RegisterLocalPretrainedModelStep.NAME,
"Model",
RegisterModelGroupStep.NAME,
"Model group"
);

private WorkflowStepFactory workflowStepFactory;
private ThreadPool threadPool;
private Integer maxWorkflowSteps;
Expand Down Expand Up @@ -437,6 +458,7 @@ public void validate(List<ProcessNode> processNodes, PluginsService pluginsServi
.collect(Collectors.toList());
validatePluginsInstalled(processNodes, installedPlugins);
validateGraph(processNodes);
validateFieldValues(processNodes);
}

/**
Expand Down Expand Up @@ -510,6 +532,45 @@ public void validateGraph(List<ProcessNode> processNodes) throws Exception {
}
}

/**
* Validates field values in process nodes using ML Commons validation methods.
* Checks that name and description fields in ML Commons steps contain only safe characters.
* @param processNodes A list of process nodes
* @throws Exception on validation failure
*/
public void validateFieldValues(List<ProcessNode> processNodes) throws Exception {
Map<String, FieldDescriptor> fieldsToValidate = new HashMap<>();

for (ProcessNode processNode : processNodes) {
String nodeType = processNode.workflowStep().getName();
if (!ML_COMMONS_VALIDATED_STEPS.containsKey(nodeType)) {
continue;
}
Map<String, Object> userInputs = processNode.input().getContent();
String stepLabel = ML_COMMONS_VALIDATED_STEPS.get(nodeType);

Object nameValue = userInputs.get(NAME_FIELD);
if (nameValue instanceof String) {
fieldsToValidate.put(stepLabel + " name [node " + processNode.id() + "]", new FieldDescriptor((String) nameValue, true));
}

Object descValue = userInputs.get(DESCRIPTION_FIELD);
if (descValue instanceof String) {
fieldsToValidate.put(
stepLabel + " description [node " + processNode.id() + "]",
new FieldDescriptor((String) descValue, false)
);
}
}

if (!fieldsToValidate.isEmpty()) {
ActionRequestValidationException exception = StringUtils.validateFields(fieldsToValidate);
if (exception != null) {
throw new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST);
}
}
}

/**
* A method for parsing workflow timeout value.
* The value could be parsed from node NODE_TIMEOUT_FIELD, the timeout field in workflow-step.json,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.flowframework.workflow;

import org.opensearch.Version;
import org.opensearch.action.admin.cluster.node.info.PluginsAndModules;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
Expand All @@ -26,6 +27,8 @@
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.plugins.PluginInfo;
import org.opensearch.plugins.PluginsService;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ScalingExecutorBuilder;
import org.opensearch.threadpool.TestThreadPool;
Expand Down Expand Up @@ -832,4 +835,121 @@ public void testCreateReprovisionSequenceWithUpdates() throws Exception {
assertTrue(reprovisionWorkflowStepNames.contains(UpdateIndexStep.NAME));
}

public void testValidate() throws Exception {
WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
CreateConnectorStep.NAME,
Collections.emptyMap(),
Map.ofEntries(
Map.entry("name", "My Valid Connector"),
Map.entry("description", "A valid description."),
Map.entry("version", "1"),
Map.entry("protocol", "http"),
Map.entry("parameters", "{}"),
Map.entry("credential", "{}"),
Map.entry("actions", "[]")
)
);
WorkflowNode registerModel = new WorkflowNode(
"workflow_step_2",
RegisterRemoteModelStep.NAME,
Map.ofEntries(Map.entry("workflow_step_1", CONNECTOR_ID)),
Map.ofEntries(Map.entry("name", "Valid Model"), Map.entry("function_name", "remote"), Map.entry("description", "Valid desc"))
);
WorkflowEdge edge = new WorkflowEdge(createConnector.id(), registerModel.id());
Workflow workflow = new Workflow(Collections.emptyMap(), List.of(createConnector, registerModel), List.of(edge));
List<ProcessNode> sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap(), null);

PluginsService pluginsService = mock(PluginsService.class);
PluginsAndModules pluginsAndModules = mock(PluginsAndModules.class);
when(pluginsService.info()).thenReturn(pluginsAndModules);
when(pluginsAndModules.getPluginInfos()).thenReturn(
List.of(
new PluginInfo("opensearch-flow-framework", "", "", Version.CURRENT, "", "", List.of(), false),
new PluginInfo("opensearch-ml", "", "", Version.CURRENT, "", "", List.of(), false)
)
);
workflowProcessSorter.validate(sortedProcessNodes, pluginsService);
}

public void testSuccessfulFieldValueValidation() throws Exception {
WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
CreateConnectorStep.NAME,
Collections.emptyMap(),
Map.ofEntries(
Map.entry("name", "My Valid Connector"),
Map.entry("description", "A valid description."),
Map.entry("version", "1"),
Map.entry("protocol", "http"),
Map.entry("parameters", "{}"),
Map.entry("credential", "{}"),
Map.entry("actions", "[]")
)
);
Workflow workflow = new Workflow(Collections.emptyMap(), List.of(createConnector), Collections.emptyList());
List<ProcessNode> sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap(), null);
workflowProcessSorter.validateFieldValues(sortedProcessNodes);
}

public void testFieldValueValidationInvalidName() throws Exception {
WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
CreateConnectorStep.NAME,
Collections.emptyMap(),
Map.ofEntries(
Map.entry("name", "Bad<script>Name"),
Map.entry("description", "A valid description."),
Map.entry("version", "1"),
Map.entry("protocol", "http"),
Map.entry("parameters", "{}"),
Map.entry("credential", "{}"),
Map.entry("actions", "[]")
)
);
Workflow workflow = new Workflow(Collections.emptyMap(), List.of(createConnector), Collections.emptyList());
List<ProcessNode> sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap(), null);
FlowFrameworkException ex = expectThrows(
FlowFrameworkException.class,
() -> workflowProcessSorter.validateFieldValues(sortedProcessNodes)
);
assertTrue(ex.getMessage().contains("Model connector name"));
assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus());
}

public void testFieldValueValidationInvalidDescription() throws Exception {
WorkflowNode registerModel = new WorkflowNode(
"workflow_step_1",
RegisterRemoteModelStep.NAME,
Collections.emptyMap(),
Map.ofEntries(
Map.entry("name", "Valid Model Name"),
Map.entry("function_name", "remote"),
Map.entry("description", "Bad<script>Desc"),
Map.entry("connector_id", "abc123")
)
);
Workflow workflow = new Workflow(Collections.emptyMap(), List.of(registerModel), Collections.emptyList());
List<ProcessNode> sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap(), null);
FlowFrameworkException ex = expectThrows(
FlowFrameworkException.class,
() -> workflowProcessSorter.validateFieldValues(sortedProcessNodes)
);
assertTrue(ex.getMessage().contains("Model description"));
assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus());
}

public void testFieldValueValidationSkipsNonMlSteps() throws Exception {
WorkflowNode createIndex = new WorkflowNode(
"workflow_step_1",
CreateIndexStep.NAME,
Collections.emptyMap(),
Map.ofEntries(Map.entry("index_name", "my-index"), Map.entry("configurations", "{}"))
);
Workflow workflow = new Workflow(Collections.emptyMap(), List.of(createIndex), Collections.emptyList());
List<ProcessNode> sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123", Collections.emptyMap(), null);
// Should not throw - non-ML steps are skipped
workflowProcessSorter.validateFieldValues(sortedProcessNodes);
}

}
Loading