Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/dask/dask.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,19 @@ func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s
OccurredAt: &occurredAt,
}

taskTemplate, err := pluginContext.TaskReader().Read(ctx)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be an s3 hit if I recall.. do we want to do that every time?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay for now since we already do that for PythonFunctionTask. We can have a separate PR to do it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, here via this line 🤔

But this is a good point, I will check audit logs next week how many of those requests propeller makes!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good point...

if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID()
o, err := logPlugin.GetTaskLogs(
tasklog.Input{
Namespace: job.ObjectMeta.Namespace,
PodName: job.Status.JobRunnerPodName,
LogName: "(Dask Runner Logs)",
TaskExecutionID: taskExecID,
TaskTemplate: taskTemplate,
},
)
if err != nil {
Expand Down
40 changes: 40 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ import (

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils"
)
Expand Down Expand Up @@ -866,3 +868,41 @@ func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1)
}

func TestGetTaskPhase_DynamicLogLinks(t *testing.T) {
daskResourceHandler := daskResourceHandler{}
ctx := context.TODO()

dynamicLinks := map[string]tasklog.TemplateLogPlugin{
"test-dynamic-link": {
TemplateURIs: []tasklog.TemplateURI{"https://some-service.com/{{.taskConfig.dynamicParam}}"},
},
}

assert.NoError(t, SetConfig(&Config{
Logs: logs.LogConfig{
DynamicLogLinks: dynamicLinks,
},
}))

taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskTemplate.Config = map[string]string{
"link_type": "test-dynamic-link",
"dynamicParam": "dynamic-value",
}
taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{})

taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobRunning))
assert.NoError(t, err)
assert.NotNil(t, taskPhase.Info())
assert.NotNil(t, taskPhase.Info().Logs)

var dynamicLog *core.TaskLog
for _, l := range taskPhase.Info().Logs {
if l.GetUri() == "https://some-service.com/dynamic-value" {
dynamicLog = l
break
}
}
assert.NotNil(t, dynamicLog, "expected dynamic log link in task logs")
}
10 changes: 8 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ func (rayJobResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx
}, nil
}

func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1.RayJob) (*pluginsCore.TaskInfo, error) {
func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginContext, rayJob *rayv1.RayJob, taskTemplate *core.TaskTemplate) (*pluginsCore.TaskInfo, error) {
logPlugin, err := logs.InitializeLogPlugins(&logConfig)
if err != nil {
return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err)
Expand All @@ -584,6 +584,7 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon
PodUnixStartTime: startTime,
PodUnixFinishTime: finishTime,
ExtraTemplateVars: []tasklog.TemplateVar{},
TaskTemplate: taskTemplate,
}
if rayJob.Status.JobId != "" {
input.ExtraTemplateVars = append(
Expand Down Expand Up @@ -629,7 +630,12 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon

func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
rayJob := resource.(*rayv1.RayJob)
info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob)
taskTemplate, err := pluginContext.TaskReader().Read(ctx)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

info, err := getEventInfoForRayJob(GetConfig().Logs, pluginContext, rayJob, taskTemplate)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
Expand Down
48 changes: 46 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,10 @@ func newPluginContext(pluginState k8s.PluginState) k8s.PluginContext {

plg.EXPECT().PluginStateReader().Return(&pluginStateReaderMock)

taskReader := &mocks.TaskReader{}
taskReader.EXPECT().Read(mock.Anything).Return(&core.TaskTemplate{}, nil)
plg.EXPECT().TaskReader().Return(taskReader)

return plg
}

Expand Down Expand Up @@ -1362,10 +1366,12 @@ func TestGetEventInfo_LogTemplates(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj())
ti, err := getEventInfoForRayJob(
logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}},
pluginCtx,
&tc.rayJob,
taskTemplate,
)
assert.NoError(t, err)
assert.Equal(t, tc.expectedTaskLogs, ti.Logs)
Expand Down Expand Up @@ -1461,10 +1467,13 @@ func TestGetEventInfo_LogTemplates_V1(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj())

ti, err := getEventInfoForRayJob(
logs.LogConfig{Templates: []tasklog.TemplateLogPlugin{tc.logPlugin}},
pluginCtx,
&tc.rayJob,
taskTemplate,
)
assert.NoError(t, err)
assert.Equal(t, tc.expectedTaskLogs, ti.Logs)
Expand Down Expand Up @@ -1516,8 +1525,9 @@ func TestGetEventInfo_DashboardURL(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj())
assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate}))
ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob)
ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob, taskTemplate)
assert.NoError(t, err)
assert.Equal(t, tc.expectedTaskLogs, ti.Logs)
})
Expand Down Expand Up @@ -1568,14 +1578,48 @@ func TestGetEventInfo_DashboardURL_V1(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj())

assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate}))
ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob)
ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob, taskTemplate)
assert.NoError(t, err)
assert.Equal(t, tc.expectedTaskLogs, ti.Logs)
})
}
}

func TestGetEventInfo_DynamicLogLinks(t *testing.T) {
pluginCtx := newPluginContext(k8s.PluginState{})

dynamicLinks := map[string]tasklog.TemplateLogPlugin{
"test-dynamic-link": {
TemplateURIs: []tasklog.TemplateURI{"https://some-service.com/{{.taskConfig.dynamicParam}}"},
},
}

taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj())
taskTemplate.Config = map[string]string{
"link_type": "test-dynamic-link",
"dynamicParam": "dynamic-value",
}

rayJob := rayv1.RayJob{
ObjectMeta: metav1.ObjectMeta{
Namespace: "test-namespace",
},
}

ti, err := getEventInfoForRayJob(
logs.LogConfig{DynamicLogLinks: dynamicLinks},
pluginCtx,
&rayJob,
taskTemplate,
)
assert.NoError(t, err)
assert.Equal(t, 1, len(ti.Logs))
assert.Equal(t, "https://some-service.com/dynamic-value", ti.Logs[0].GetUri())
}

func TestGetPropertiesRay(t *testing.T) {
rayJobResourceHandler := rayJobResourceHandler{}
expected := k8s.PluginProperties{GeneratedNameMaxLength: ptr.To[int](47)}
Expand Down
11 changes: 9 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ func (sparkResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx p
}, nil
}

func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkApplication) (*pluginsCore.TaskInfo, error) {
func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkApplication, taskTemplate *core.TaskTemplate) (*pluginsCore.TaskInfo, error) {

sparkConfig := GetSparkConfig()
taskLogs := make([]*core.TaskLog, 0, 3)
Expand All @@ -403,6 +403,7 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl
Namespace: sj.Namespace,
LogName: "(Driver Logs)",
TaskExecutionID: taskExecID,
TaskTemplate: taskTemplate,
})

if err != nil {
Expand Down Expand Up @@ -520,7 +521,13 @@ func getEventInfoForSpark(pluginContext k8s.PluginContext, sj *sparkOp.SparkAppl
func (sparkResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {

app := resource.(*sparkOp.SparkApplication)
info, err := getEventInfoForSpark(pluginContext, app)

taskTemplate, err := pluginContext.TaskReader().Read(ctx)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

info, err := getEventInfoForSpark(pluginContext, app, taskTemplate)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}
Expand Down
46 changes: 42 additions & 4 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"
stdlibUtils "github.com/flyteorg/flyte/flytestdlib/utils"
)
Expand Down Expand Up @@ -98,8 +99,10 @@ func TestGetEventInfo(t *testing.T) {
},
},
}))
taskCtx := dummySparkTaskContext(dummySparkTaskTemplateContainer("blah-1", dummySparkConf), false, k8s.PluginState{})
info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState))

taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf)
taskCtx := dummySparkTaskContext(taskTemplate, false, k8s.PluginState{})
info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState), taskTemplate)
assert.NoError(t, err)
assert.Len(t, info.Logs, 6)
assert.Equal(t, "https://spark-ui.flyte", info.CustomInfo.GetFields()[sparkDriverUI].GetStringValue())
Expand All @@ -119,7 +122,7 @@ func TestGetEventInfo(t *testing.T) {

assert.Equal(t, expectedLinks, generatedLinks)

info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.SubmittedState))
info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.SubmittedState), taskTemplate)
assert.NoError(t, err)

generatedLinks = make([]string, 0, len(info.Logs))
Expand Down Expand Up @@ -150,7 +153,7 @@ func TestGetEventInfo(t *testing.T) {
},
}))

info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.FailedState))
info, err = getEventInfoForSpark(taskCtx, dummySparkApplication(sj.FailedState), taskTemplate)
assert.NoError(t, err)
assert.Len(t, info.Logs, 5)
assert.Equal(t, "spark-history.flyte/history/app-id", info.CustomInfo.GetFields()[sparkHistoryUI].GetStringValue())
Expand Down Expand Up @@ -1141,6 +1144,41 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
}

func TestGetEventInfo_DynamicLogLinks(t *testing.T) {
dynamicLinks := map[string]tasklog.TemplateLogPlugin{
"test-dynamic-link": {
TemplateURIs: []tasklog.TemplateURI{"https://some-service.com/{{.taskConfig.dynamicParam}}"},
},
}

assert.NoError(t, setSparkConfig(&Config{
LogConfig: LogConfig{
Mixed: logs.LogConfig{
DynamicLogLinks: dynamicLinks,
},
},
}))

taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf)
taskTemplate.Config = map[string]string{
"link_type": "test-dynamic-link",
"dynamicParam": "dynamic-value",
}

taskCtx := dummySparkTaskContext(taskTemplate, false, k8s.PluginState{})
info, err := getEventInfoForSpark(taskCtx, dummySparkApplication(sj.RunningState), taskTemplate)
assert.NoError(t, err)

var dynamicLog *core.TaskLog
for _, l := range info.Logs {
if l.GetUri() == "https://some-service.com/dynamic-value" {
dynamicLog = l
break
}
}
assert.NotNil(t, dynamicLog, "expected dynamic log link in task logs")
}

func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct {
data, err := json.Marshal(obj)
assert.Nil(t, err)
Expand Down
Loading