Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ packages:
all: true
dir: gen/go/flyteidl2/service/mocks
include-auto-generated: true
github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect:
config:
all: true
dir: gen/go/flyteidl2/project/projectconnect/mocks
include-auto-generated: true
github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect:
config:
all: true
Expand Down
27 changes: 26 additions & 1 deletion dataproxy/service/dataproxy_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"github.com/flyteorg/flyte/v2/dataproxy/logs"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger"
Expand All @@ -42,17 +44,19 @@ type Service struct {
taskClient taskconnect.TaskServiceClient
triggerClient triggerconnect.TriggerServiceClient
runClient workflowconnect.RunServiceClient
projectClient projectconnect.ProjectServiceClient
logStreamer logs.LogStreamer
}

// NewService creates a new DataProxyService instance.
func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore, taskClient taskconnect.TaskServiceClient, triggerClient triggerconnect.TriggerServiceClient, runClient workflowconnect.RunServiceClient, logStreamer logs.LogStreamer) *Service {
func NewService(cfg config.DataProxyConfig, dataStore *storage.DataStore, taskClient taskconnect.TaskServiceClient, triggerClient triggerconnect.TriggerServiceClient, runClient workflowconnect.RunServiceClient, projectClient projectconnect.ProjectServiceClient, logStreamer logs.LogStreamer) *Service {
return &Service{
cfg: cfg,
dataStore: dataStore,
taskClient: taskClient,
triggerClient: triggerClient,
runClient: runClient,
projectClient: projectClient,
logStreamer: logStreamer,
}
}
Expand All @@ -74,6 +78,9 @@ func (s *Service) CreateUploadLocation(
logger.Errorf(ctx, "Request validation failed: %v", err)
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
if err := s.validateProjectExists(ctx, req.Msg.GetProject()); err != nil {
return nil, err
}

// Build the storage path
storagePath, err := s.constructStoragePath(ctx, req.Msg)
Expand Down Expand Up @@ -123,6 +130,20 @@ func (s *Service) CreateUploadLocation(
return connect.NewResponse(resp), nil
}

// validateProjectExists checks that the given project ID exists by calling the ProjectService.
func (s *Service) validateProjectExists(ctx context.Context, projectID string) error {
if _, err := s.projectClient.GetProject(ctx, connect.NewRequest(&project.GetProjectRequest{
Id: projectID,
})); err != nil {
if connect.CodeOf(err) == connect.CodeNotFound {
return connect.NewError(connect.CodeNotFound, fmt.Errorf("project %q not found", projectID))
}
logger.Errorf(ctx, "Failed to validate project %q: %v", projectID, err)
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to validate project: %w", err))
}
return nil
}

// checkFileExists validates whether a file upload is safe by checking existing files.
// Returns an error if:
// - File exists without content_md5 provided (cannot verify safe overwrite)
Expand Down Expand Up @@ -226,6 +247,10 @@ func (s *Service) UploadInputs(
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("id is required"))
}

if err := s.validateProjectExists(ctx, project); err != nil {
return nil, err
}

// Resolve the task template to get cache_ignore_input_vars.
taskTemplate, err := s.resolveTaskTemplate(ctx, req.Msg)
if err != nil {
Expand Down
32 changes: 22 additions & 10 deletions dataproxy/service/dataproxy_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ import (
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow"
projectMocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect/mocks"
workflowMocks "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect/mocks"
)

Expand Down Expand Up @@ -104,7 +106,12 @@ func TestCreateUploadLocation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStore(t)
service := NewService(cfg, mockStore, nil, nil, nil, nil)
mockProjectClient := projectMocks.NewProjectServiceClient(t)
if !tt.wantErr {
mockProjectClient.On("GetProject", mock.Anything, mock.Anything).Return(
connect.NewResponse(&project.GetProjectResponse{}), nil)
}
service := NewService(cfg, mockStore, nil, nil, nil, mockProjectClient, nil)

req := &connect.Request[dataproxy.CreateUploadLocationRequest]{
Msg: tt.req,
Expand Down Expand Up @@ -225,7 +232,7 @@ func TestCheckFileExists(t *testing.T) {
mockStore = setupMockDataStoreWithExistingFile(t, tt.existingFileMD5)
}

service := NewService(cfg, mockStore, nil, nil, nil, nil)
service := NewService(cfg, mockStore, nil, nil, nil, nil, nil)
storagePath := storage.DataReference("s3://test-bucket/uploads/test-project/test-domain/test-root/test-file.txt")

err := service.checkFileExists(ctx, storagePath, tt.req)
Expand Down Expand Up @@ -303,7 +310,7 @@ func TestConstructStoragePath(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStore(t)
service := NewService(cfg, mockStore, nil, nil, nil, nil)
service := NewService(cfg, mockStore, nil, nil, nil, nil, nil)

path, err := service.constructStoragePath(ctx, tt.req)

Expand Down Expand Up @@ -460,7 +467,12 @@ func TestUploadInputs(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore := setupMockDataStoreWithWriteProtobuf(t)
svc := NewService(cfg, mockStore, nil, nil, nil, nil)
mockProjectClient := projectMocks.NewProjectServiceClient(t)
if !tt.wantErr {
mockProjectClient.On("GetProject", mock.Anything, mock.Anything).Return(
connect.NewResponse(&project.GetProjectResponse{}), nil)
}
svc := NewService(cfg, mockStore, nil, nil, nil, mockProjectClient, nil)

req := &connect.Request[dataproxy.UploadInputsRequest]{
Msg: tt.req,
Expand Down Expand Up @@ -624,7 +636,7 @@ func TestGetActionData(t *testing.T) {
ComposedProtobufStore: mockComposedStore,
ReferenceConstructor: &simpleRefConstructor{},
}
svc := NewService(cfg, ds, nil, nil, runClient, nil)
svc := NewService(cfg, ds, nil, nil, runClient, nil, nil)

resp, err := svc.GetActionData(ctx, connect.NewRequest(&dataproxy.GetActionDataRequest{
ActionId: actionID,
Expand Down Expand Up @@ -719,7 +731,7 @@ func TestTailLogs(t *testing.T) {
_ = stream.Send(&dataproxy.TailLogsResponse{})
}).Return(nil)

svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, streamer)
svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, nil, streamer)
client := newTailLogsTestClient(t, svc)

stream, err := client.TailLogs(context.Background(), connect.NewRequest(&dataproxy.TailLogsRequest{
Expand All @@ -742,7 +754,7 @@ func TestTailLogs(t *testing.T) {
nil, connect.NewError(connect.CodeNotFound, assertErr("action missing")))

streamer := &mockLogStreamer{}
svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, streamer)
svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, nil, streamer)
client := newTailLogsTestClient(t, svc)

stream, err := client.TailLogs(context.Background(), connect.NewRequest(&dataproxy.TailLogsRequest{
Expand All @@ -765,7 +777,7 @@ func TestTailLogs(t *testing.T) {
}), nil)

streamer := &mockLogStreamer{}
svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, streamer)
svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, nil, streamer)
client := newTailLogsTestClient(t, svc)

stream, err := client.TailLogs(context.Background(), connect.NewRequest(&dataproxy.TailLogsRequest{
Expand All @@ -789,7 +801,7 @@ func TestTailLogs(t *testing.T) {
streamer.On("TailLogs", mock.Anything, logContext, mock.Anything).Return(
connect.NewError(connect.CodeInternal, assertErr("streamer boom")))

svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, streamer)
svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, nil, streamer)
client := newTailLogsTestClient(t, svc)

stream, err := client.TailLogs(context.Background(), connect.NewRequest(&dataproxy.TailLogsRequest{
Expand All @@ -813,7 +825,7 @@ func TestTailLogs(t *testing.T) {
streamer := &mockLogStreamer{}
streamer.On("TailLogs", mock.Anything, logContext, mock.Anything).Return(nil)

svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, streamer)
svc := NewService(config.DataProxyConfig{}, nil, nil, nil, runClient, nil, streamer)
client := newTailLogsTestClient(t, svc)

stream, err := client.TailLogs(context.Background(), connect.NewRequest(&dataproxy.TailLogsRequest{
Expand Down
4 changes: 3 additions & 1 deletion dataproxy/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/flyteorg/flyte/v2/flytestdlib/logger"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/cluster/clusterconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/project/projectconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/task/taskconnect"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/trigger/triggerconnect"
)
Expand All @@ -27,6 +28,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error {
taskClient := taskconnect.NewTaskServiceClient(http.DefaultClient, baseURL)
triggerClient := triggerconnect.NewTriggerServiceClient(http.DefaultClient, baseURL)
runClient := workflowconnect.NewRunServiceClient(http.DefaultClient, baseURL)
projectClient := projectconnect.NewProjectServiceClient(http.DefaultClient, baseURL)

var logStreamer logs.LogStreamer
if sc.K8sConfig != nil {
Expand All @@ -37,7 +39,7 @@ func Setup(ctx context.Context, sc *app.SetupContext) error {
}
}

svc := service.NewService(*cfg, sc.DataStore, taskClient, triggerClient, runClient, logStreamer)
svc := service.NewService(*cfg, sc.DataStore, taskClient, triggerClient, runClient, projectClient, logStreamer)

path, handler := dataproxyconnect.NewDataProxyServiceHandler(svc)
sc.Mux.Handle(path, handler)
Expand Down
Loading
Loading