From fdabdbd89bf219c7580c0a51096cabce7234ea15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Fri, 10 Apr 2026 16:14:48 +0000 Subject: [PATCH] Fix: Allow toggling dynamic log links for Ray, Dask, Spark plugins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../go/tasks/plugins/k8s/dask/dask.go | 6 +++ .../go/tasks/plugins/k8s/dask/dask_test.go | 40 ++++++++++++++++ flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 10 +++- .../go/tasks/plugins/k8s/ray/ray_test.go | 48 ++++++++++++++++++- .../go/tasks/plugins/k8s/spark/spark.go | 11 ++++- .../go/tasks/plugins/k8s/spark/spark_test.go | 46 ++++++++++++++++-- 6 files changed, 151 insertions(+), 10 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go index 28be2b444f1..5fd58e8cb02 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask.go @@ -297,6 +297,11 @@ func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s OccurredAt: &occurredAt, } + taskTemplate, err := pluginContext.TaskReader().Read(ctx) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() o, err := logPlugin.GetTaskLogs( tasklog.Input{ @@ -304,6 +309,7 @@ func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s PodName: job.Status.JobRunnerPodName, LogName: "(Dask Runner Logs)", TaskExecutionID: taskExecID, + TaskTemplate: taskTemplate, }, ) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index bd347f38648..b5d34113408 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -18,12 +18,14 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils" ) @@ -866,3 +868,41 @@ func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { assert.NoError(t, err) assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1) } + +func TestGetTaskPhase_DynamicLogLinks(t *testing.T) { + daskResourceHandler := daskResourceHandler{} + ctx := context.TODO() + + dynamicLinks := map[string]tasklog.TemplateLogPlugin{ + "test-dynamic-link": { + TemplateURIs: []tasklog.TemplateURI{"https://some-service.com/{{.taskConfig.dynamicParam}}"}, + }, + } + + assert.NoError(t, SetConfig(&Config{ + Logs: logs.LogConfig{ + DynamicLogLinks: dynamicLinks, + }, + })) + + taskTemplate := dummyDaskTaskTemplate("", nil, "") + taskTemplate.Config = map[string]string{ + "link_type": "test-dynamic-link", + "dynamicParam": "dynamic-value", + } + taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{}) + + taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobRunning)) + assert.NoError(t, err) + assert.NotNil(t, taskPhase.Info()) + assert.NotNil(t, taskPhase.Info().Logs) + + var dynamicLog *core.TaskLog + for _, l := range taskPhase.Info().Logs { + if l.GetUri() == "https://some-service.com/dynamic-value" { + dynamicLog = l + break + } + } + assert.NotNil(t, dynamicLog, "expected dynamic log link in task logs") +} diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index f34f8ef6f6b..56eb7accde1 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -558,7 +558,7 @@ func (rayJobResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx }, nil } -func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1.RayJob) (*pluginsCore.TaskInfo, error) { +func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1.RayJob, taskTemplate *core.TaskTemplate) (*pluginsCore.TaskInfo, error) { logPlugin, err := logs.InitializeLogPlugins(&logConfig) if err != nil { return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err) @@ -584,6 +584,7 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon PodUnixStartTime: startTime, PodUnixFinishTime: finishTime, ExtraTemplateVars: []tasklog.TemplateVar{}, + TaskTemplate: taskTemplate, } if rayJob.Status.JobId != "" { input.ExtraTemplateVars = append( @@ -629,7 +630,12 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { rayJob := resource.(*rayv1.RayJob) - info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob) + taskTemplate, err := pluginContext.TaskReader().Read(ctx) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob, taskTemplate) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index e7433bad508..de4e7a5c6d1 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -1128,6 +1128,10 @@ func newPluginContext(pluginState k8s.PluginState) k8s.PluginContext { plg.EXPECT().PluginStateReader().Return(&pluginStateReaderMock) + taskReader := &mocks.TaskReader{} + taskReader.EXPECT().Read(mock.Anything).Return(&core.TaskTemplate{}, nil) + plg.EXPECT().TaskReader().Return(taskReader) + return plg } @@ -1309,10 +1313,12 @@ func TestGetEventInfo_LogTemplates(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) ti, err := getEventInfoForRayJob( logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, pluginCtx, &tc.rayJob, + taskTemplate, ) assert.NoError(t, err) assert.Equal(t, tc.expectedTaskLogs, ti.Logs) @@ -1408,10 +1414,13 @@ func TestGetEventInfo_LogTemplates_V1(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) + ti, err := getEventInfoForRayJob( logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}}, pluginCtx, &tc.rayJob, + taskTemplate, ) assert.NoError(t, err) assert.Equal(t, tc.expectedTaskLogs, ti.Logs) @@ -1463,8 +1472,9 @@ func TestGetEventInfo_DashboardURL(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) - ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) + ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob, taskTemplate) assert.NoError(t, err) assert.Equal(t, tc.expectedTaskLogs, ti.Logs) }) @@ -1515,14 +1525,48 @@ func TestGetEventInfo_DashboardURL_V1(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) + assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) - ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) + ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob, taskTemplate) assert.NoError(t, err) assert.Equal(t, tc.expectedTaskLogs, ti.Logs) }) } } +func TestGetEventInfo_DynamicLogLinks(t *testing.T) { + pluginCtx := newPluginContext(k8s.PluginState{}) + + dynamicLinks := map[string]tasklog.TemplateLogPlugin{ + "test-dynamic-link": { + TemplateURIs: []tasklog.TemplateURI{"https://some-service.com/{{.taskConfig.dynamicParam}}"}, + }, + } + + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) + taskTemplate.Config = map[string]string{ + "link_type": "test-dynamic-link", + "dynamicParam": "dynamic-value", + } + + rayJob := rayv1.RayJob{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "test-namespace", + }, + } + + ti, err := getEventInfoForRayJob( + logs.LogConfig{DynamicLogLinks: dynamicLinks}, + pluginCtx, + &rayJob, + taskTemplate, + ) + assert.NoError(t, err) + assert.Equal(t, 1, len(ti.Logs)) + assert.Equal(t, "https://some-service.com/dynamic-value", ti.Logs[0].GetUri()) +} + func TestGetPropertiesRay(t *testing.T) { rayJobResourceHandler := rayJobResourceHandler{} expected := k8s.PluginProperties{GeneratedNameMaxLength: ptr.To[int](47)} diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 8329c5b6efe..b6cb1e3708a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -385,7 +385,7 @@ func (sparkResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx p }, nil } -func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkApplication) (*pluginsCore.TaskInfo, error) { +func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkApplication, taskTemplate *core.TaskTemplate) (*pluginsCore.TaskInfo, error) { sparkConfig := GetSparkConfig() taskLogs := make([]*core.TaskLog, 0, 3) @@ -403,6 +403,7 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl Namespace: sj.Namespace, LogName: "(Driver Logs)", TaskExecutionID: taskExecID, + TaskTemplate: taskTemplate, }) if err != nil { @@ -520,7 +521,13 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl func (sparkResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { app := resource.(*sparkOp.SparkApplication) - info, err := getEventInfoForSpark(pluginContext, app) + + taskTemplate, err := pluginContext.TaskReader().Read(ctx) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + info, err := getEventInfoForSpark(pluginContext, app, taskTemplate) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 6137ed49ec0..c7bdbe83686 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -25,6 +25,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils" ) @@ -98,8 +99,10 @@ func TestGetEventInfo(t *testing.T) { }, }, })) - taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("blah-1", dummySparkConf), false, k8s.PluginState{}) - info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState)) + + taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf) + taskCtx := dummySparkTaskContext(taskTemplate, false, k8s.PluginState{}) + info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState), taskTemplate) assert.NoError(t, err) assert.Len(t, info.Logs, 6) assert.Equal(t, "https://spark-ui.flyte", info.CustomInfo.GetFields()[sparkDriverUI].GetStringValue()) @@ -119,7 +122,7 @@ func TestGetEventInfo(t *testing.T) { assert.Equal(t, expectedLinks, generatedLinks) - info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.SubmittedState)) + info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.SubmittedState), taskTemplate) assert.NoError(t, err) generatedLinks = make([]string, 0, len(info.Logs)) @@ -150,7 +153,7 @@ func TestGetEventInfo(t *testing.T) { }, })) - info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.FailedState)) + info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.FailedState), taskTemplate) assert.NoError(t, err) assert.Len(t, info.Logs, 5) assert.Equal(t, "spark-history.flyte/history/app-id", info.CustomInfo.GetFields()[sparkHistoryUI].GetStringValue()) @@ -1141,6 +1144,41 @@ func TestBuildResourceCustomK8SPod(t *testing.T) { assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) } +func TestGetEventInfo_DynamicLogLinks(t *testing.T) { + dynamicLinks := map[string]tasklog.TemplateLogPlugin{ + "test-dynamic-link": { + TemplateURIs: []tasklog.TemplateURI{"https://some-service.com/{{.taskConfig.dynamicParam}}"}, + }, + } + + assert.NoError(t, setSparkConfig(&Config{ + LogConfig: LogConfig{ + Mixed: logs.LogConfig{ + DynamicLogLinks: dynamicLinks, + }, + }, + })) + + taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf) + taskTemplate.Config = map[string]string{ + "link_type": "test-dynamic-link", + "dynamicParam": "dynamic-value", + } + + taskCtx := dummySparkTaskContext(taskTemplate, false, k8s.PluginState{}) + info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState), taskTemplate) + assert.NoError(t, err) + + var dynamicLog *core.TaskLog + for _, l := range info.Logs { + if l.GetUri() == "https://some-service.com/dynamic-value" { + dynamicLog = l + break + } + } + assert.NotNil(t, dynamicLog, "expected dynamic log link in task logs") +} + func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct { data, err := json.Marshal(obj) assert.Nil(t, err)