From 70e6a9b205cdd76c329bb01d8807b3b64976fd83 Mon Sep 17 00:00:00 2001 From: Deep1998 Date: Thu, 1 Feb 2024 22:07:34 +0530 Subject: [PATCH 01/15] [feat] RR Create API 2: Add storage and spanner accessor (#756) * Add accessors for storage and spanner. * Add Unmarshall method * Rename storageacc and spanner acc to storageaccessor and spanneraccessor * Add empty test files * Increade version retention period Add log statements to storage accessor functions * Add storage accessor interface and impl * Add storage client unit tests * Add spanner admin client unit tests * Add spanner instance admin client unit tests * Add spanner client unit tests * Add interface and implementor for Spanner Accessor * Add unit test for storage utils * Add unit test for dataflow utils:UnmarshalDataflowConfig * Move mock methods inside mock struct * Add wrapper for storage client * Resolved storage comments * Add spanner mocks and move out clients out of accessors * add admin client mocks and few tests * add instance accessor unit tests and mocks * Rearrange files into mock.go and interface.go * add unit tests and mock for storage * add dataflow accessor mock * Add comments * Move storage parameters into a struct --- accessors/clients/dataflow/dataflow_client.go | 8 +- accessors/clients/dataflow/interface.go | 44 ++ accessors/clients/dataflow/mocks.go | 31 ++ .../clients/spanner/admin/admin_client.go | 43 ++ .../spanner/admin/admin_client_test.go | 116 +++++ accessors/clients/spanner/admin/interface.go | 90 ++++ accessors/clients/spanner/admin/mocks.go | 61 +++ .../clients/spanner/client/spanner_client.go | 43 ++ .../spanner/client/spanner_client_test.go | 116 +++++ .../spanner/instanceadmin/interface.go | 49 +++ .../clients/spanner/instanceadmin/mocks.go | 36 ++ .../instanceadmin/spanner_instance_admin.go | 43 ++ .../spanner_instance_admin_test.go | 116 +++++ accessors/clients/storage/interface.go | 86 ++++ accessors/clients/storage/mocks.go | 96 +++++ accessors/clients/storage/storage_client.go | 43 ++ .../clients/storage/storage_client_test.go | 117 ++++++ accessors/dataflow/dataflow_accessor.go | 4 +- accessors/dataflow/dataflow_accessor_test.go | 17 +- accessors/dataflow/mocks.go | 30 ++ .../helpers/dataflow/dataflow_helpers.go | 43 ++ .../helpers/dataflow/dataflow_helpers_test.go | 141 +++++++ accessors/spanner/mocks.go | 61 +++ accessors/spanner/spanner_accessor.go | 203 +++++++++ accessors/spanner/spanner_accessor_test.go | 336 +++++++++++++++ accessors/storage/mocks.go | 55 +++ accessors/storage/storage_accessor.go | 193 +++++++++ accessors/storage/storage_accessor_test.go | 395 ++++++++++++++++++ accessors/storage/types.go | 24 ++ cmd/data.go | 9 +- common/constants/constants.go | 3 +- common/utils/storage_utils.go | 43 ++ common/utils/storage_utils_test.go | 107 +++++ common/utils/utils.go | 87 +--- conversion/conversion.go | 130 +++--- streaming/streaming.go | 53 +-- .../spanner/spanner_accessor_test.go | 135 ++++++ testing/conversion/conversion_test.go | 23 - webv2/helpers/helpers.go | 9 +- webv2/profile/profile.go | 16 +- webv2/session/session_service.go | 16 +- webv2/web.go | 26 +- 42 files changed, 3049 insertions(+), 248 deletions(-) create mode 100644 accessors/clients/dataflow/interface.go create mode 100644 accessors/clients/dataflow/mocks.go create mode 100644 accessors/clients/spanner/admin/admin_client.go create mode 100644 accessors/clients/spanner/admin/admin_client_test.go create mode 100644 accessors/clients/spanner/admin/interface.go create mode 100644 accessors/clients/spanner/admin/mocks.go create mode 100644 accessors/clients/spanner/client/spanner_client.go create mode 100644 accessors/clients/spanner/client/spanner_client_test.go create mode 100644 accessors/clients/spanner/instanceadmin/interface.go create mode 100644 accessors/clients/spanner/instanceadmin/mocks.go create mode 100644 accessors/clients/spanner/instanceadmin/spanner_instance_admin.go create mode 100644 accessors/clients/spanner/instanceadmin/spanner_instance_admin_test.go create mode 100644 accessors/clients/storage/interface.go create mode 100644 accessors/clients/storage/mocks.go create mode 100644 accessors/clients/storage/storage_client.go create mode 100644 accessors/clients/storage/storage_client_test.go create mode 100644 accessors/dataflow/mocks.go create mode 100644 accessors/helpers/dataflow/dataflow_helpers.go create mode 100644 accessors/helpers/dataflow/dataflow_helpers_test.go create mode 100644 accessors/spanner/mocks.go create mode 100644 accessors/spanner/spanner_accessor.go create mode 100644 accessors/spanner/spanner_accessor_test.go create mode 100644 accessors/storage/mocks.go create mode 100644 accessors/storage/storage_accessor.go create mode 100644 accessors/storage/storage_accessor_test.go create mode 100644 accessors/storage/types.go create mode 100644 common/utils/storage_utils.go create mode 100644 common/utils/storage_utils_test.go create mode 100644 testing/accessors/spanner/spanner_accessor_test.go diff --git a/accessors/clients/dataflow/dataflow_client.go b/accessors/clients/dataflow/dataflow_client.go index cc76c34fda..5f70698c8d 100644 --- a/accessors/clients/dataflow/dataflow_client.go +++ b/accessors/clients/dataflow/dataflow_client.go @@ -19,19 +19,13 @@ import ( "sync" dataflow "cloud.google.com/go/dataflow/apiv1beta3" - "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb" - "github.com/googleapis/gax-go/v2" ) -type DataflowClient interface { - LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) -} - var once sync.Once var dfClient *dataflow.FlexTemplatesClient // This function is declared as a global variable to make it testable. The unit -// tests edit this function, acting like a double. +// tests update this function, acting like a double. var newFlexTemplatesClient = dataflow.NewFlexTemplatesClient func GetOrCreateClient(ctx context.Context) (*dataflow.FlexTemplatesClient, error) { diff --git a/accessors/clients/dataflow/interface.go b/accessors/clients/dataflow/interface.go new file mode 100644 index 0000000000..9918b48a02 --- /dev/null +++ b/accessors/clients/dataflow/interface.go @@ -0,0 +1,44 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dataflowclient + +import ( + "context" + + dataflow "cloud.google.com/go/dataflow/apiv1beta3" + "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb" + "github.com/googleapis/gax-go/v2" +) + +// Use this interface instead of dataflow.FlexTemplatesClient to support mocking. +type DataflowClient interface { + LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) +} + +// This implements the DataflowClient interface. This is the primary implementation that should be used in all places other than tests. +type DataflowClientImpl struct { + client *dataflow.FlexTemplatesClient +} + +func NewDataflowClientImpl(ctx context.Context) (*DataflowClientImpl, error) { + c, err := GetOrCreateClient(ctx) + if err != nil { + return nil, err + } + return &DataflowClientImpl{client: c}, nil +} + +func (c *DataflowClientImpl) LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { + return c.client.LaunchFlexTemplate(ctx, req, opts...) +} diff --git a/accessors/clients/dataflow/mocks.go b/accessors/clients/dataflow/mocks.go new file mode 100644 index 0000000000..ac899a7345 --- /dev/null +++ b/accessors/clients/dataflow/mocks.go @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dataflowclient + +import ( + "context" + + "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb" + "github.com/googleapis/gax-go/v2" +) + +// Mock that implements the DataflowClient interface. +// Pass in unit tests where DataflowClient is an input parameter. +type DataflowClientMock struct { + LaunchFlexTemplateMock func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) +} + +func (dcm *DataflowClientMock) LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { + return dcm.LaunchFlexTemplateMock(ctx, req, opts...) +} diff --git a/accessors/clients/spanner/admin/admin_client.go b/accessors/clients/spanner/admin/admin_client.go new file mode 100644 index 0000000000..6d60d8f194 --- /dev/null +++ b/accessors/clients/spanner/admin/admin_client.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneradmin + +import ( + "context" + "fmt" + "sync" + + database "cloud.google.com/go/spanner/admin/database/apiv1" +) + +var once sync.Once +var spannerAdminClient *database.DatabaseAdminClient + +// This function is declared as a global variable to make it testable. The unit +// tests update this function, acting like a double. +var newDatabaseAdminClient = database.NewDatabaseAdminClient + +func GetOrCreateClient(ctx context.Context) (*database.DatabaseAdminClient, error) { + var err error + if spannerAdminClient == nil { + once.Do(func() { + spannerAdminClient, err = newDatabaseAdminClient(ctx) + }) + if err != nil { + return nil, fmt.Errorf("failed to create spanner admin client: %v", err) + } + return spannerAdminClient, nil + } + return spannerAdminClient, nil +} diff --git a/accessors/clients/spanner/admin/admin_client_test.go b/accessors/clients/spanner/admin/admin_client_test.go new file mode 100644 index 0000000000..7c911ee096 --- /dev/null +++ b/accessors/clients/spanner/admin/admin_client_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneradmin + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + database "cloud.google.com/go/spanner/admin/database/apiv1" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + spannerAdminClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newDatabaseAdminClient + defer func() { newDatabaseAdminClient = oldFunc }() + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return &database.DatabaseAdminClient{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newDatabaseAdminClient + defer func() { newDatabaseAdminClient = oldFunc }() + + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return &database.DatabaseAdminClient{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + spannerAdminClient = nil + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newDatabaseAdminClient + defer func() { newDatabaseAdminClient = oldFunc }() + + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return &database.DatabaseAdminClient{}, nil + } + oldC, err := GetOrCreateClient(ctx) + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx) + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newDatabaseAdminClient + defer func() { newDatabaseAdminClient = oldFunc }() + + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/accessors/clients/spanner/admin/interface.go b/accessors/clients/spanner/admin/interface.go new file mode 100644 index 0000000000..d0a2ab41b4 --- /dev/null +++ b/accessors/clients/spanner/admin/interface.go @@ -0,0 +1,90 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneradmin + +import ( + "context" + + database "cloud.google.com/go/spanner/admin/database/apiv1" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "github.com/googleapis/gax-go/v2" +) + +// Use this interface instead of database.DatabaseAdminClient to support mocking. +type AdminClient interface { + GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) + CreateDatabase(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) + UpdateDatabaseDdl(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) +} + +// Use this interface instead of database.CreateDatabaseOperation to support mocking. +type CreateDatabaseOperation interface { + Wait(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) +} + +// Use this interface instead of database.UpdateDatabaseDdlOperation to support mocking. +type UpdateDatabaseDdlOperation interface { + Wait(ctx context.Context, opts ...gax.CallOption) error +} + +// This implements the AdminClient interface. This is the primary implementation that should be used in all places other than tests. +type AdminClientImpl struct { + adminClient *database.DatabaseAdminClient +} + +func NewAdminClientImpl(ctx context.Context) (*AdminClientImpl, error) { + c, err := GetOrCreateClient(ctx) + if err != nil { + return nil, err + } + return &AdminClientImpl{adminClient: c}, nil +} + +func (c *AdminClientImpl) GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return c.adminClient.GetDatabase(ctx, req, opts...) +} + +func (c *AdminClientImpl) CreateDatabase(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) { + op, err := c.adminClient.CreateDatabase(ctx, req, opts...) + if err != nil { + return nil, err + } + return &CreateDatabaseOperationImpl{dbo: op}, nil +} + +func (c *AdminClientImpl) UpdateDatabaseDdl(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) { + op, err := c.adminClient.UpdateDatabaseDdl(ctx, req, opts...) + if err != nil { + return nil, err + } + return &UpdateDatabaseDdlImpl{dbo: op}, nil +} + +// This implements the CreateDatabaseOperation interface. This is the primary implementation that should be used in all places other than tests. +type CreateDatabaseOperationImpl struct { + dbo *database.CreateDatabaseOperation +} + +func (c *CreateDatabaseOperationImpl) Wait(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { + return c.dbo.Wait(ctx, opts...) +} + +// This implements the UpdateDatabaseDdl interface. This is the primary implementation that should be used in all places other than tests. +type UpdateDatabaseDdlImpl struct { + dbo *database.UpdateDatabaseDdlOperation +} + +func (c *UpdateDatabaseDdlImpl) Wait(ctx context.Context, opts ...gax.CallOption) error { + return c.dbo.Wait(ctx, opts...) +} diff --git a/accessors/clients/spanner/admin/mocks.go b/accessors/clients/spanner/admin/mocks.go new file mode 100644 index 0000000000..5d84ddeadd --- /dev/null +++ b/accessors/clients/spanner/admin/mocks.go @@ -0,0 +1,61 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneradmin + +import ( + "context" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "github.com/googleapis/gax-go/v2" +) + +// Mock that implements the AdminClient interface. +// Pass in unit tests where AdminClient is an input parameter. +type AdminClientMock struct { + GetDatabaseMock func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) + CreateDatabaseMock func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) + UpdateDatabaseDdlMock func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) +} + +func (acm *AdminClientMock) GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return acm.GetDatabaseMock(ctx, req, opts...) +} + +func (acm *AdminClientMock) CreateDatabase(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) { + return acm.CreateDatabaseMock(ctx, req, opts...) +} + +func (acm *AdminClientMock) UpdateDatabaseDdl(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) { + return acm.UpdateDatabaseDdlMock(ctx, req, opts...) +} + +// Mock that implements the CreateDatabaseOperation interface. +// Pass in unit tests where CreateDatabaseOperation is an input parameter. +type CreateDatabaseOperationMock struct { + WaitMock func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) +} + +func (dbo *CreateDatabaseOperationMock) Wait(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { + return dbo.WaitMock(ctx, opts...) +} + +// Mock that implements the UpdateDatabaseDdlOperation interface. +// Pass in unit tests where UpdateDatabaseDdlOperation is an input parameter. +type UpdateDatabaseDdlOperationMock struct { + WaitMock func(ctx context.Context, opts ...gax.CallOption) error +} + +func (dbo *UpdateDatabaseDdlOperationMock) Wait(ctx context.Context, opts ...gax.CallOption) error { + return dbo.WaitMock(ctx, opts...) +} diff --git a/accessors/clients/spanner/client/spanner_client.go b/accessors/clients/spanner/client/spanner_client.go new file mode 100644 index 0000000000..d6edd81ece --- /dev/null +++ b/accessors/clients/spanner/client/spanner_client.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spannerclient + +import ( + "context" + "fmt" + "sync" + + sp "cloud.google.com/go/spanner" +) + +var once sync.Once +var spannerClient *sp.Client + +// This function is declared as a global variable to make it testable. The unit +// tests edit this function, acting like a double. +var newClient = sp.NewClient + +func GetOrCreateClient(ctx context.Context, dbURI string) (*sp.Client, error) { + var err error + if spannerClient == nil { + once.Do(func() { + spannerClient, err = newClient(ctx, dbURI) + }) + if err != nil { + return nil, fmt.Errorf("failed to create spanner database client: %v", err) + } + return spannerClient, nil + } + return spannerClient, nil +} diff --git a/accessors/clients/spanner/client/spanner_client_test.go b/accessors/clients/spanner/client/spanner_client_test.go new file mode 100644 index 0000000000..66f5059591 --- /dev/null +++ b/accessors/clients/spanner/client/spanner_client_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spannerclient + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + spannerClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + spannerClient = nil + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx, "testURI") + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + oldC, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx, "testURI") + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/accessors/clients/spanner/instanceadmin/interface.go b/accessors/clients/spanner/instanceadmin/interface.go new file mode 100644 index 0000000000..5e195ebf1c --- /dev/null +++ b/accessors/clients/spanner/instanceadmin/interface.go @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spinstanceadmin + +import ( + "context" + + instance "cloud.google.com/go/spanner/admin/instance/apiv1" + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + "github.com/googleapis/gax-go/v2" +) + +// Use this interface instead of instance.InstanceAdminClient to support mocking. +type InstanceAdminClient interface { + GetInstance(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) + GetInstanceConfig(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) +} + +// This implements the InstanceAdminClient interface. This is the primary implementation that should be used in all places other than tests. +type InstanceAdminClientImpl struct { + client *instance.InstanceAdminClient +} + +func NewInstanceAdminClientImpl(ctx context.Context) (*InstanceAdminClientImpl, error) { + c, err := GetOrCreateClient(ctx) + if err != nil { + return nil, err + } + return &InstanceAdminClientImpl{client: c}, nil +} + +func (c *InstanceAdminClientImpl) GetInstance(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return c.client.GetInstance(ctx, req, opts...) +} + +func (c *InstanceAdminClientImpl) GetInstanceConfig(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return c.client.GetInstanceConfig(ctx, req, opts...) +} diff --git a/accessors/clients/spanner/instanceadmin/mocks.go b/accessors/clients/spanner/instanceadmin/mocks.go new file mode 100644 index 0000000000..bb88cca00a --- /dev/null +++ b/accessors/clients/spanner/instanceadmin/mocks.go @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spinstanceadmin + +import ( + "context" + + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + "github.com/googleapis/gax-go/v2" +) + +// Mock that implements the InstanceAdminClient interface. +// Pass in unit tests where InstanceAdminClient is an input parameter. +type InstanceAdminClientMock struct { + GetInstanceMock func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) + GetInstanceConfigMock func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) +} + +func (iac *InstanceAdminClientMock) GetInstance(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return iac.GetInstanceMock(ctx, req, opts...) +} + +func (iac *InstanceAdminClientMock) GetInstanceConfig(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return iac.GetInstanceConfigMock(ctx, req, opts...) +} diff --git a/accessors/clients/spanner/instanceadmin/spanner_instance_admin.go b/accessors/clients/spanner/instanceadmin/spanner_instance_admin.go new file mode 100644 index 0000000000..a3d597e173 --- /dev/null +++ b/accessors/clients/spanner/instanceadmin/spanner_instance_admin.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spinstanceadmin + +import ( + "context" + "fmt" + "sync" + + instance "cloud.google.com/go/spanner/admin/instance/apiv1" +) + +var once sync.Once +var instanceAdminClient *instance.InstanceAdminClient + +// This function is declared as a global variable to make it testable. The unit +// tests update this function, acting like a double. +var newInstanceAdminClient = instance.NewInstanceAdminClient + +func GetOrCreateClient(ctx context.Context) (*instance.InstanceAdminClient, error) { + var err error + if instanceAdminClient == nil { + once.Do(func() { + instanceAdminClient, err = newInstanceAdminClient(ctx) + }) + if err != nil { + return nil, fmt.Errorf("failed to create spanner instance admin client: %v", err) + } + return instanceAdminClient, nil + } + return instanceAdminClient, nil +} diff --git a/accessors/clients/spanner/instanceadmin/spanner_instance_admin_test.go b/accessors/clients/spanner/instanceadmin/spanner_instance_admin_test.go new file mode 100644 index 0000000000..2792d03e82 --- /dev/null +++ b/accessors/clients/spanner/instanceadmin/spanner_instance_admin_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spinstanceadmin + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + instance "cloud.google.com/go/spanner/admin/instance/apiv1" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + instanceAdminClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newInstanceAdminClient + defer func() { newInstanceAdminClient = oldFunc }() + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return &instance.InstanceAdminClient{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newInstanceAdminClient + defer func() { newInstanceAdminClient = oldFunc }() + + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return &instance.InstanceAdminClient{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + instanceAdminClient = nil + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newInstanceAdminClient + defer func() { newInstanceAdminClient = oldFunc }() + + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return &instance.InstanceAdminClient{}, nil + } + oldC, err := GetOrCreateClient(ctx) + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx) + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newInstanceAdminClient + defer func() { newInstanceAdminClient = oldFunc }() + + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/accessors/clients/storage/interface.go b/accessors/clients/storage/interface.go new file mode 100644 index 0000000000..978acd841c --- /dev/null +++ b/accessors/clients/storage/interface.go @@ -0,0 +1,86 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageclient + +import ( + "context" + "io" + + "cloud.google.com/go/storage" +) + +// Use this interface instead of storage.Client to support mocking. +type StorageClient interface { + Bucket(name string) BucketHandle +} + +// Use this interface instead of storage.BucketHandle to support mocking. +type BucketHandle interface { + Create(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) + Update(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) + Object(name string) ObjectHandle +} + +// Use this interface instead of storage.ObjectHandle to support mocking. +type ObjectHandle interface { + NewWriter(ctx context.Context) io.WriteCloser + NewReader(ctx context.Context) (io.ReadCloser, error) +} + +// This implements the StorageClient interface. This is the primary implementation that should be used in all places other than tests. +type StorageClientImpl struct { + client *storage.Client +} + +func NewStorageClientImpl(ctx context.Context) (*StorageClientImpl, error) { + c, err := GetOrCreateClient(ctx) + if err != nil { + return nil, err + } + return &StorageClientImpl{client: c}, nil +} + +func (c *StorageClientImpl) Bucket(name string) BucketHandle { + return &BucketHandleImpl{bucketHandle: c.client.Bucket(name)} +} + +// This implements the BucketHandle interface. This is the primary implementation that should be used in all places other than tests. +type BucketHandleImpl struct { + bucketHandle *storage.BucketHandle +} + +func (b *BucketHandleImpl) Create(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return b.bucketHandle.Create(ctx, projectID, attrs) +} + +func (b *BucketHandleImpl) Update(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + return b.bucketHandle.Update(ctx, uattrs) +} + +func (b *BucketHandleImpl) Object(name string) ObjectHandle { + return &ObjectHandleImpl{objectHandle: b.bucketHandle.Object(name)} +} + +// This implements the ObjectHandle interface. This is the primary implementation that should be used in all places other than tests. +type ObjectHandleImpl struct { + objectHandle *storage.ObjectHandle +} + +func (o *ObjectHandleImpl) NewWriter(ctx context.Context) io.WriteCloser { + return o.objectHandle.NewWriter(ctx) +} + +func (o *ObjectHandleImpl) NewReader(ctx context.Context) (io.ReadCloser, error) { + return o.objectHandle.NewReader(ctx) +} diff --git a/accessors/clients/storage/mocks.go b/accessors/clients/storage/mocks.go new file mode 100644 index 0000000000..f6effe2414 --- /dev/null +++ b/accessors/clients/storage/mocks.go @@ -0,0 +1,96 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageclient + +import ( + "context" + "io" + + "cloud.google.com/go/storage" +) + +// Mock that implements the StorageClient interface. +// Pass in unit tests where StorageClient is an input parameter. +type StorageClientMock struct { + BucketMock func(name string) BucketHandle +} + +func (scm *StorageClientMock) Bucket(name string) BucketHandle { + return scm.BucketMock(name) +} + +// Mock that implements the BucketHandle interface. +// Pass in unit tests where BucketHandle is an input parameter. +type BucketHandleMock struct { + CreateMock func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) + UpdateMock func(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) + ObjectMock func(name string) ObjectHandle +} + +func (b *BucketHandleMock) Create(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return b.CreateMock(ctx, projectID, attrs) +} + +func (b *BucketHandleMock) Update(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + return b.UpdateMock(ctx, uattrs) +} + +func (b *BucketHandleMock) Object(name string) ObjectHandle { + return b.ObjectMock(name) +} + +// Mock that implements the ObjectHandle interface. +// Pass in unit tests where ObjectHandle is an input parameter. +type ObjectHandleMock struct { + NewWriterMock func(ctx context.Context) io.WriteCloser + NewReaderMock func(ctx context.Context) (io.ReadCloser, error) +} + +func (o *ObjectHandleMock) NewWriter(ctx context.Context) io.WriteCloser { + return o.NewWriterMock(ctx) +} + +func (o *ObjectHandleMock) NewReader(ctx context.Context) (io.ReadCloser, error) { + return o.NewReaderMock(ctx) +} + +// Mock that implements the io.WriteCloser interface. +// Pass in unit tests where io.WriteCloser is an input parameter. +type WriterMock struct { + WriteMock func(p []byte) (n int, err error) + CloseMock func() error +} + +func (w *WriterMock) Write(p []byte) (n int, err error) { + return w.WriteMock(p) +} + +func (w *WriterMock) Close() error { + return w.CloseMock() +} + +// Mock that implements the io.ReadCloser interface. +// Pass in unit tests where io.ReadCloser is an input parameter. +type ReaderMock struct { + ReadMock func(p []byte) (n int, err error) + CloseMock func() error +} + +func (r *ReaderMock) Read(p []byte) (n int, err error) { + return r.ReadMock(p) +} + +func (r *ReaderMock) Close() error { + return r.CloseMock() +} diff --git a/accessors/clients/storage/storage_client.go b/accessors/clients/storage/storage_client.go new file mode 100644 index 0000000000..09a66c401a --- /dev/null +++ b/accessors/clients/storage/storage_client.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageclient + +import ( + "context" + "fmt" + "sync" + + "cloud.google.com/go/storage" +) + +var once sync.Once +var gcsClient *storage.Client + +// This function is declared as a global variable to make it testable. The unit +// tests update this function, acting like a double. +var newClient = storage.NewClient + +func GetOrCreateClient(ctx context.Context) (*storage.Client, error) { + var err error + if gcsClient == nil { + once.Do(func() { + gcsClient, err = newClient(ctx) + }) + if err != nil { + return nil, fmt.Errorf("failed to create storage client: %v", err) + } + return gcsClient, nil + } + return gcsClient, nil +} diff --git a/accessors/clients/storage/storage_client_test.go b/accessors/clients/storage/storage_client_test.go new file mode 100644 index 0000000000..73bd6b873a --- /dev/null +++ b/accessors/clients/storage/storage_client_test.go @@ -0,0 +1,117 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageclient + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + "cloud.google.com/go/storage" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + gcsClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return &storage.Client{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return &storage.Client{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + gcsClient = nil + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return &storage.Client{}, nil + } + oldC, err := GetOrCreateClient(ctx) + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx) + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/accessors/dataflow/dataflow_accessor.go b/accessors/dataflow/dataflow_accessor.go index aa7f1bea49..3fed9b3cc9 100644 --- a/accessors/dataflow/dataflow_accessor.go +++ b/accessors/dataflow/dataflow_accessor.go @@ -25,12 +25,14 @@ import ( "golang.org/x/exp/maps" ) +// The DataflowAccessor provides methods that internally use the dataflow client. Methods should only contain generic logic here that can be used by multiple workflows. type DataflowAccessor interface { // This function takes the template parameters (@parameters) and runtime environment config (@cfg) as input, and returns // the generated jobId, equivalentGcloudCommand and error if any. - LaunchFlexTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) + LaunchDataflowTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) } +// This implements the DataflowAccessor interface. This is the primary implementation that should be used in all places other than tests. type DataflowAccessorImpl struct{} func (dfA *DataflowAccessorImpl) LaunchDataflowTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) { diff --git a/accessors/dataflow/dataflow_accessor_test.go b/accessors/dataflow/dataflow_accessor_test.go index ffb8b6537b..76cec60007 100644 --- a/accessors/dataflow/dataflow_accessor_test.go +++ b/accessors/dataflow/dataflow_accessor_test.go @@ -20,6 +20,7 @@ import ( "testing" "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb" + dataflowclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/dataflow" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/google/go-cmp/cmp" "github.com/googleapis/gax-go/v2" @@ -157,14 +158,6 @@ func getExpectedGcloudCmd2() string { "transformationContextFilePath=gs://transformationContext.json" } -type DataflowClientMock struct { - LaunchFlexTemplateMock func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) -} - -func (dcm *DataflowClientMock) LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { - return dcm.LaunchFlexTemplateMock(ctx, req, opts...) -} - func TestLaunchDataflowTemplate(t *testing.T) { ctx := context.Background() da := DataflowAccessorImpl{} @@ -172,7 +165,7 @@ func TestLaunchDataflowTemplate(t *testing.T) { name string params map[string]string cfg DataflowTuningConfig - dcm DataflowClientMock + dcm dataflowclient.DataflowClientMock expectError bool expectedJobId string expectedGcloudCmd string @@ -181,7 +174,7 @@ func TestLaunchDataflowTemplate(t *testing.T) { name: "Basic Correct", params: getParameters(), cfg: getTuningConfig(), - dcm: DataflowClientMock{ + dcm: dataflowclient.DataflowClientMock{ LaunchFlexTemplateMock: func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { return &dataflowpb.LaunchFlexTemplateResponse{Job: &dataflowpb.Job{Id: "1234"}}, nil }, @@ -194,7 +187,7 @@ func TestLaunchDataflowTemplate(t *testing.T) { name: "Request builder error", params: getParameters(), cfg: DataflowTuningConfig{Subnetwork: "test"}, - dcm: DataflowClientMock{ + dcm: dataflowclient.DataflowClientMock{ LaunchFlexTemplateMock: func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { return &dataflowpb.LaunchFlexTemplateResponse{Job: &dataflowpb.Job{Id: "1234"}}, nil }, @@ -207,7 +200,7 @@ func TestLaunchDataflowTemplate(t *testing.T) { name: "Launch flex template throws error", params: getParameters(), cfg: getTuningConfig(), - dcm: DataflowClientMock{ + dcm: dataflowclient.DataflowClientMock{ LaunchFlexTemplateMock: func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { return nil, fmt.Errorf("test error") }, diff --git a/accessors/dataflow/mocks.go b/accessors/dataflow/mocks.go new file mode 100644 index 0000000000..9a65d9d512 --- /dev/null +++ b/accessors/dataflow/mocks.go @@ -0,0 +1,30 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dataflowaccessor + +import ( + "context" + + dataflowclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/dataflow" +) + +// Mock that implements the DataflowAccessor interface. +// Pass in unit tests where DataflowAccessor is an input parameter. +type DataflowAccessorMock struct { + LaunchFlexTemplateMock func(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) +} + +func (d *DataflowAccessorMock) LaunchDataflowTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) { + return d.LaunchFlexTemplateMock(ctx, c, parameters, cfg) +} diff --git a/accessors/helpers/dataflow/dataflow_helpers.go b/accessors/helpers/dataflow/dataflow_helpers.go new file mode 100644 index 0000000000..ee117ccced --- /dev/null +++ b/accessors/helpers/dataflow/dataflow_helpers.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This is a package is kept with accessors because some functions import other accessors. +// The common/utils package should not import any SMT dependency. +package dataflowhelpers + +import ( + "context" + "encoding/json" + + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + dataflowaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/dataflow" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" +) + +// This package contains common helper methods using the accessor package. This will be used by multiple flows. +// Do not move to util since util only expects methods that do no import any other internal dependency. + +// Reads any local or gcs file and unmarshals the data into a DataflowTuningConfig struct. +func UnmarshalDataflowTuningConfig(ctx context.Context, sc storageclient.StorageClient, sa storageaccessor.StorageAccessor, filePath string) (dataflowaccessor.DataflowTuningConfig, error) { + jsonStr, err := sa.ReadAnyFile(ctx, sc, filePath) + if err != nil { + return dataflowaccessor.DataflowTuningConfig{}, err + } + tuningCfg := dataflowaccessor.DataflowTuningConfig{} + err = json.Unmarshal([]byte(jsonStr), &tuningCfg) + if err != nil { + return dataflowaccessor.DataflowTuningConfig{}, err + } + return tuningCfg, nil +} diff --git a/accessors/helpers/dataflow/dataflow_helpers_test.go b/accessors/helpers/dataflow/dataflow_helpers_test.go new file mode 100644 index 0000000000..1418dea448 --- /dev/null +++ b/accessors/helpers/dataflow/dataflow_helpers_test.go @@ -0,0 +1,141 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dataflowhelpers + +import ( + "context" + "fmt" + "os" + "testing" + + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + dataflowaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/dataflow" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func TestUnmarshalDataflowTuningConfig(t *testing.T) { + testCases := []struct { + name string + sam storageaccessor.StorageAccessorMock + expectError bool + want dataflowaccessor.DataflowTuningConfig + }{ + { + name: "Basic", + sam: storageaccessor.StorageAccessorMock{ + ReadAnyFileMock: func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return `{ + "projectId": "test-project", + "jobName": "test-job-name", + "location": "us-central1", + "network": "test-network", + "subnetwork": "test-subnetwork", + "hostProjectId": "test-host-project", + "maxWorkers": 3, + "numWorkers": 2, + "serviceAccountEmail": "abc@xyz.com", + "machineType": "n1-standard-8", + "additionalUserLabels": {"my": "label"}, + "kmsKeyName": "test-key", + "gcsTemplatePath": "gs://path", + "additionalExperiments": ["xyz","123"], + "enableStreamingEngine": true + }`, nil + }, + }, + expectError: false, + want: dataflowaccessor.DataflowTuningConfig{ + ProjectId: "test-project", + JobName: "test-job-name", + Location: "us-central1", + Network: "test-network", + Subnetwork: "test-subnetwork", + VpcHostProjectId: "test-host-project", + MaxWorkers: 3, + NumWorkers: 2, + ServiceAccountEmail: "abc@xyz.com", + MachineType: "n1-standard-8", + AdditionalUserLabels: map[string]string{"my": "label"}, + KmsKeyName: "test-key", + GcsTemplatePath: "gs://path", + AdditionalExperiments: []string{"xyz", "123"}, + EnableStreamingEngine: true, + }, + }, + { + name: "Defaults", + sam: storageaccessor.StorageAccessorMock{ + ReadAnyFileMock: func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return `{}`, nil + }, + }, + expectError: false, + want: dataflowaccessor.DataflowTuningConfig{ + ProjectId: "", + JobName: "", + Location: "", + Network: "", + Subnetwork: "", + VpcHostProjectId: "", + MaxWorkers: 0, + NumWorkers: 0, + ServiceAccountEmail: "", + MachineType: "", + AdditionalUserLabels: nil, + KmsKeyName: "", + GcsTemplatePath: "", + AdditionalExperiments: nil, + EnableStreamingEngine: false, + }, + }, + { + name: "ReadAnyFile throws error", + sam: storageaccessor.StorageAccessorMock{ + ReadAnyFileMock: func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return "", fmt.Errorf("test error") + }, + }, + expectError: true, + want: dataflowaccessor.DataflowTuningConfig{}, + }, + { + name: "Json unmarshall throws error", + sam: storageaccessor.StorageAccessorMock{ + ReadAnyFileMock: func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return "{\"abc\"", nil + }, + }, + expectError: true, + want: dataflowaccessor.DataflowTuningConfig{}, + }, + } + ctx := context.Background() + for _, tc := range testCases { + got, err := UnmarshalDataflowTuningConfig(ctx, nil, &tc.sam, "unused/path/due/to/mock") + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} diff --git a/accessors/spanner/mocks.go b/accessors/spanner/mocks.go new file mode 100644 index 0000000000..828d2e17ce --- /dev/null +++ b/accessors/spanner/mocks.go @@ -0,0 +1,61 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneraccessor + +import ( + "context" + + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spinstanceadmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/instanceadmin" +) + +// Mock that implements the SpannerAccessor interface. +// Pass in unit tests where SpannerAccessor is an input parameter. +type SpannerAccessorMock struct { + GetDatabaseDialectMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (string, error) + CheckExistingDbMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (bool, error) + CreateEmptyDatabaseMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error + GetSpannerLeaderLocationMock func(ctx context.Context, instanceClient spinstanceadmin.InstanceAdminClient, instanceURI string) (string, error) + CheckIfChangeStreamExistsMock func(ctx context.Context, changeStreamName, dbURI string) (bool, error) + ValidateChangeStreamOptionsMock func(ctx context.Context, changeStreamName, dbURI string) error + CreateChangeStreamMock func(ctx context.Context, adminClient spanneradmin.AdminClient, changeStreamName, dbURI string) error +} + +func (sam *SpannerAccessorMock) GetDatabaseDialect(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (string, error) { + return sam.GetDatabaseDialectMock(ctx, adminClient, dbURI) +} + +func (sam *SpannerAccessorMock) CheckExistingDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (bool, error) { + return sam.CheckExistingDbMock(ctx, adminClient, dbURI) +} + +func (sam *SpannerAccessorMock) CreateEmptyDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error { + return sam.CreateEmptyDatabaseMock(ctx, adminClient, dbURI) +} + +func (sam *SpannerAccessorMock) GetSpannerLeaderLocation(ctx context.Context, instanceClient spinstanceadmin.InstanceAdminClient, instanceURI string) (string, error) { + return sam.GetSpannerLeaderLocationMock(ctx, instanceClient, instanceURI) +} + +func (sam *SpannerAccessorMock) CheckIfChangeStreamExists(ctx context.Context, changeStreamName, dbURI string) (bool, error) { + return sam.CheckIfChangeStreamExistsMock(ctx, changeStreamName, dbURI) +} + +func (sam *SpannerAccessorMock) ValidateChangeStreamOptions(ctx context.Context, changeStreamName, dbURI string) error { + return sam.ValidateChangeStreamOptionsMock(ctx, changeStreamName, dbURI) +} + +func (sam *SpannerAccessorMock) CreateChangeStream(ctx context.Context, adminClient spanneradmin.AdminClient, changeStreamName, dbURI string) error { + return sam.CreateChangeStreamMock(ctx, adminClient, changeStreamName, dbURI) +} diff --git a/accessors/spanner/spanner_accessor.go b/accessors/spanner/spanner_accessor.go new file mode 100644 index 0000000000..fd5c2cae22 --- /dev/null +++ b/accessors/spanner/spanner_accessor.go @@ -0,0 +1,203 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneraccessor + +import ( + "context" + "fmt" + "strings" + "time" + + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spannerclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/client" + spinstanceadmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/instanceadmin" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "google.golang.org/api/iterator" +) + +// The SpannerAccessor provides methods that internally use a spanner client (can be adminClient/databaseclient/instanceclient etc). +// Methods should only contain generic logic here that can be used by multiple workflows. +type SpannerAccessor interface { + // Fetch the dialect of the spanner database. + GetDatabaseDialect(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (string, error) + // CheckExistingDb checks whether the database with dbURI exists or not. + // If API call doesn't respond then user is informed after every 5 minutes on command line. + CheckExistingDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (bool, error) + // Create a database with no schema. + CreateEmptyDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error + // Fetch the leader of the Spanner instance. + GetSpannerLeaderLocation(ctx context.Context, instanceClient spinstanceadmin.InstanceAdminClient, instanceURI string) (string, error) + // Check if a change stream already exists. + CheckIfChangeStreamExists(ctx context.Context, changeStreamName, dbURI string) (bool, error) + // Validate that change stream option 'VALUE_CAPTURE_TYPE' is 'NEW_ROW'. + ValidateChangeStreamOptions(ctx context.Context, changeStreamName, dbURI string) error + // Create a change stream with default options. + CreateChangeStream(ctx context.Context, adminClient spanneradmin.AdminClient, changeStreamName, dbURI string) error +} + +// This implements the SpannerAccessor interface. This is the primary implementation that should be used in all places other than tests. +type SpannerAccessorImpl struct{} + +func (sp *SpannerAccessorImpl) GetDatabaseDialect(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (string, error) { + result, err := adminClient.GetDatabase(ctx, &databasepb.GetDatabaseRequest{Name: dbURI}) + if err != nil { + return "", fmt.Errorf("cannot connect to database: %v", err) + } + return strings.ToLower(result.DatabaseDialect.String()), nil +} + +func (sp *SpannerAccessorImpl) CheckExistingDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (bool, error) { + gotResponse := make(chan bool) + var err error + go func() { + _, err = adminClient.GetDatabase(ctx, &databasepb.GetDatabaseRequest{Name: dbURI}) + gotResponse <- true + }() + for { + select { + case <-time.After(5 * time.Minute): + fmt.Println("WARNING! API call not responding: make sure that spanner api endpoint is configured properly") + case <-gotResponse: + if err != nil { + if utils.ContainsAny(strings.ToLower(err.Error()), []string{"database not found"}) { + return false, nil + } + return false, fmt.Errorf("can't get database info: %s", err) + } + return true, nil + } + } +} + +func (sp *SpannerAccessorImpl) CreateEmptyDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error { + project, instance, dbName := utils.ParseDbURI(dbURI) + req := &databasepb.CreateDatabaseRequest{ + Parent: fmt.Sprintf("projects/%s/instances/%s", project, instance), + CreateStatement: "CREATE DATABASE `" + dbName + "`", + } + op, err := adminClient.CreateDatabase(ctx, req) + if err != nil { + return fmt.Errorf("can't build CreateDatabaseRequest: %w", utils.AnalyzeError(err, dbURI)) + } + if _, err := op.Wait(ctx); err != nil { + return fmt.Errorf("createDatabase call failed: %w", utils.AnalyzeError(err, dbURI)) + } + return nil +} + +func (sp *SpannerAccessorImpl) GetSpannerLeaderLocation(ctx context.Context, instanceClient spinstanceadmin.InstanceAdminClient, instanceURI string) (string, error) { + instanceInfo, err := instanceClient.GetInstance(ctx, &instancepb.GetInstanceRequest{Name: instanceURI}) + if err != nil { + return "", err + } + instanceConfig, err := instanceClient.GetInstanceConfig(ctx, &instancepb.GetInstanceConfigRequest{Name: instanceInfo.Config}) + if err != nil { + return "", err + + } + for _, replica := range instanceConfig.Replicas { + if replica.DefaultLeaderLocation { + return replica.Location, nil + } + } + return "", fmt.Errorf("no leader found for spanner instance %s while trying fetch location", instanceURI) +} + +// Consider using a CreateChangestream operation and check for alreadyExists error. That uses adminClient which can be unit tested. +func (sp *SpannerAccessorImpl) CheckIfChangeStreamExists(ctx context.Context, changeStreamName, dbURI string) (bool, error) { + spClient, err := spannerclient.GetOrCreateClient(ctx, dbURI) + if err != nil { + return false, err + } + stmt := spanner.Statement{ + SQL: `SELECT CHANGE_STREAM_NAME FROM information_schema.change_streams`, + } + iter := spClient.Single().Query(ctx, stmt) + defer iter.Stop() + var cs_name string + csExists := false + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return false, fmt.Errorf("couldn't read row from change_streams table: %w", err) + } + err = row.Columns(&cs_name) + if err != nil { + return false, fmt.Errorf("can't scan row from change_streams table: %v", err) + } + if cs_name == changeStreamName { + csExists = true + break + } + } + return csExists, nil +} + +func (sp *SpannerAccessorImpl) ValidateChangeStreamOptions(ctx context.Context, changeStreamName, dbURI string) error { + spClient, err := spannerclient.GetOrCreateClient(ctx, dbURI) + if err != nil { + return err + } + // Validate if change stream options are set correctly. + stmt := spanner.Statement{ + SQL: `SELECT option_value FROM information_schema.change_stream_options + WHERE change_stream_name = @p1 AND option_name = 'value_capture_type'`, + Params: map[string]interface{}{ + "p1": changeStreamName, + }, + } + iter := spClient.Single().Query(ctx, stmt) + defer iter.Stop() + var option_value string + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return fmt.Errorf("couldn't read row from change_stream_options table: %w", err) + } + err = row.Columns(&option_value) + if err != nil { + return fmt.Errorf("can't scan row from change_stream_options table: %v", err) + } + if option_value != "NEW_ROW" { + return fmt.Errorf("VALUE_CAPTURE_TYPE for changestream %s is not NEW_ROW. Please update the changestream option or create a new one", changeStreamName) + } + } + return nil +} + +func (sp *SpannerAccessorImpl) CreateChangeStream(ctx context.Context, adminClient spanneradmin.AdminClient, changeStreamName, dbURI string) error { + op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ + Database: dbURI, + // TODO: create change stream for only the tables present in Spanner. + Statements: []string{fmt.Sprintf("CREATE CHANGE STREAM %s FOR ALL OPTIONS (value_capture_type = 'NEW_ROW', retention_period = '7d')", changeStreamName)}, + }) + if err != nil { + return fmt.Errorf("cannot submit request create change stream request: %v", err) + } + if err := op.Wait(ctx); err != nil { + return fmt.Errorf("could not update database ddl: %v", err) + } else { + fmt.Println("Successfully created changestream", changeStreamName) + } + return nil +} diff --git a/accessors/spanner/spanner_accessor_test.go b/accessors/spanner/spanner_accessor_test.go new file mode 100644 index 0000000000..9f93050496 --- /dev/null +++ b/accessors/spanner/spanner_accessor_test.go @@ -0,0 +1,336 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneraccessor + +import ( + "context" + "fmt" + "os" + "testing" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spinstanceadmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/instanceadmin" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/googleapis/gax-go/v2" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func TestSpannerAccessorImpl_GetDatabaseDialect(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + want string + }{ + { + name: "Basic", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + }, + expectError: false, + want: "google_standard_sql", + }, + { + name: "Pg Dialect", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_POSTGRESQL}, nil + }, + }, + expectError: false, + want: "postgresql", + }, + { + name: "Unspecified Dialect", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_DATABASE_DIALECT_UNSPECIFIED}, nil + }, + }, + expectError: false, + want: "database_dialect_unspecified", + }, + { + name: "Error case", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("test-error") + }, + }, + expectError: true, + want: "", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + got, err := spA.GetDatabaseDialect(ctx, &tc.acm, "testUri") + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} + +func TestSpannerAccessorImpl_CheckExistingDb(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + want bool + }{ + { + name: "Basic", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, nil + }, + }, + expectError: false, + want: true, + }, + { + name: "Database not found error", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("database not found") + }, + }, + expectError: false, + want: false, + }, + { + name: "Could not get db info", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("failed to connect") + }, + }, + expectError: true, + want: false, + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + got, err := spA.CheckExistingDb(ctx, &tc.acm, "testUri") + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} + +func TestSpannerAccessorImpl_CreateEmptyDatabase(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + want string + }{ + { + name: "Basic", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + }, + expectError: false, + }, + { + name: "Create database returns error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return nil, fmt.Errorf("test error") + }, + }, + expectError: true, + }, + { + name: "Wait returns error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("test error") + }, + }, nil + }, + }, + expectError: true, + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + err := spA.CreateEmptyDatabase(ctx, &tc.acm, "projects/test-project/instances/test-instance/databases/mydb") + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestSpannerAccessorImpl_CreateChangeStream(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + want string + }{ + { + name: "Basic", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + expectError: false, + }, + { + name: "Update database ddl returns error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("test error") + }, + }, + expectError: true, + }, + { + name: "Wait returns error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { + return fmt.Errorf("test error") + }, + }, nil + }, + }, + expectError: true, + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + err := spA.CreateChangeStream(ctx, &tc.acm, "my-changestream", "projects/test-project/instances/test-instance/databases/mydb") + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestSpannerAccessorImpl_GetSpannerLeaderLocation(t *testing.T) { + testCases := []struct { + name string + iac spinstanceadmin.InstanceAdminClientMock + expectError bool + want string + }{ + { + name: "Basic", + iac: spinstanceadmin.InstanceAdminClientMock{ + GetInstanceMock: func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return &instancepb.Instance{Config: "projects/test-project/instanceConfigs/test-config"}, nil + }, + GetInstanceConfigMock: func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return &instancepb.InstanceConfig{Replicas: []*instancepb.ReplicaInfo{ + &instancepb.ReplicaInfo{ + Location: "us-east1", + DefaultLeaderLocation: false, + }, + &instancepb.ReplicaInfo{ + Location: "india1", + DefaultLeaderLocation: true, + }, + &instancepb.ReplicaInfo{ + Location: "europe2", + DefaultLeaderLocation: false, + }, + }}, nil + }, + }, + expectError: false, + want: "india1", + }, + { + name: "GetInstanceMock returns error", + iac: spinstanceadmin.InstanceAdminClientMock{ + GetInstanceMock: func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return nil, fmt.Errorf("test-error") + }, + GetInstanceConfigMock: func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return nil, nil + }, + }, + expectError: true, + want: "", + }, + { + name: "GetInstanceConfigMock returns error", + iac: spinstanceadmin.InstanceAdminClientMock{ + GetInstanceMock: func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return &instancepb.Instance{Config: "projects/test-project/instanceConfigs/test-config"}, nil + }, + GetInstanceConfigMock: func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return nil, fmt.Errorf("test-error") + }, + }, + expectError: true, + want: "", + }, + { + name: "No leader found returns error", + iac: spinstanceadmin.InstanceAdminClientMock{ + GetInstanceMock: func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return &instancepb.Instance{Config: "projects/test-project/instanceConfigs/test-config"}, nil + }, + GetInstanceConfigMock: func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return &instancepb.InstanceConfig{Replicas: []*instancepb.ReplicaInfo{ + &instancepb.ReplicaInfo{ + Location: "us-east1", + DefaultLeaderLocation: false, + }, + &instancepb.ReplicaInfo{ + Location: "india1", + DefaultLeaderLocation: false, + }, + &instancepb.ReplicaInfo{ + Location: "europe2", + DefaultLeaderLocation: false, + }, + }}, nil + }, + }, + expectError: true, + want: "", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + got, err := spA.GetSpannerLeaderLocation(ctx, &tc.iac, "projects/test-project/instances/test-instance") + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} diff --git a/accessors/storage/mocks.go b/accessors/storage/mocks.go new file mode 100644 index 0000000000..21673be905 --- /dev/null +++ b/accessors/storage/mocks.go @@ -0,0 +1,55 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageaccessor + +import ( + "context" + + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" +) + +// Mock that implements the StorageAccessor interface. +// Pass in unit tests where StorageAccessor is an input parameter. +type StorageAccessorMock struct { + CreateGCSBucketMock func(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error + ApplyBucketLifecycleDeleteRuleMock func(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error + UploadLocalFileToGCSMock func(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, localFilePath string) error + WriteDataToGCSMock func(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, data string) error + ReadGcsFileMock func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) + ReadAnyFileMock func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) +} + +func (sam *StorageAccessorMock) CreateGCSBucket(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + return sam.CreateGCSBucketMock(ctx, sc, req) +} + +func (sam *StorageAccessorMock) ApplyBucketLifecycleDeleteRule(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + return sam.ApplyBucketLifecycleDeleteRuleMock(ctx, sc, req) +} + +func (sam *StorageAccessorMock) UploadLocalFileToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, localFilePath string) error { + return sam.UploadLocalFileToGCSMock(ctx, sc, filePath, fileName, localFilePath) +} + +func (sam *StorageAccessorMock) WriteDataToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, data string) error { + return sam.WriteDataToGCSMock(ctx, sc, filePath, fileName, data) +} + +func (sam *StorageAccessorMock) ReadGcsFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return sam.ReadGcsFileMock(ctx, sc, filePath) +} + +func (sam *StorageAccessorMock) ReadAnyFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return sam.ReadAnyFileMock(ctx, sc, filePath) +} diff --git a/accessors/storage/storage_accessor.go b/accessors/storage/storage_accessor.go new file mode 100644 index 0000000000..42399648ce --- /dev/null +++ b/accessors/storage/storage_accessor.go @@ -0,0 +1,193 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageaccessor + +import ( + "context" + "fmt" + "io" + "os" + "strings" + + "cloud.google.com/go/storage" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "google.golang.org/api/googleapi" +) + +// The StorageAccessor provides methods that internally use a storage client. +// Methods should only contain generic logic here that can be used by multiple workflows. +type StorageAccessor interface { + // Create a GCS bucket with the given name in the input projectId and location. If ttl is > 0, + // also apply a delete lifecycle rule with the input ttl and prefixes. Set @ttl to 0 to skip creating lifecycle rules. + CreateGCSBucket(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error + // Applies the bucket lifecycle with delete rule. Only accepts the Age and prefix rule conditions as it is only used for the Datastream destination + // bucket currently. + ApplyBucketLifecycleDeleteRule(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error + // UploadLocalFileToGCS uploads a local file at @localFilePath to a gcs file path @filePath with name @fileName. + UploadLocalFileToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, localFilePath string) error + // Uploads a gcs object to gs://@filePath/@fileName with @data as content. + WriteDataToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, data string) error + // Read a Gcs file path and returns the contents as a string. + ReadGcsFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) + // Read a local or gcs file path. Files starting with a 'gs://' are treated as GCS files. + ReadAnyFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) +} + +// This implements the StorageAccessor interface. This is the primary implementation that should be used in all places other than tests. +type StorageAccessorImpl struct{} + +func (sa *StorageAccessorImpl) CreateGCSBucket(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + bucket := sc.Bucket(req.BucketName) + attrs := storage.BucketAttrs{ + Location: req.Location, + } + if req.Ttl > 0 { + attrs.Lifecycle = storage.Lifecycle{ + Rules: []storage.LifecycleRule{ + { + Action: storage.LifecycleAction{Type: "Delete"}, + Condition: storage.LifecycleCondition{ + AgeInDays: req.Ttl, + // The prefixes should not contain the bucket names and starting slash. + // For object gs://my_bucket/pictures/paris_2022.jpg, + // you would use a condition such as "matchesPrefix":["pictures/paris_"]. + MatchesPrefix: req.MatchesPrefix, + }, + }, + }, + } + } + + if err := bucket.Create(ctx, req.ProjectID, &attrs); err != nil { + if e, ok := err.(*googleapi.Error); ok { + // Ignoring the bucket already exists error. + if e.Code != 409 { + return fmt.Errorf("failed to create bucket: %v", err) + } else { + fmt.Printf("Using the existing bucket: %v \n", req.BucketName) + } + } else { + return fmt.Errorf("failed to create bucket: %v", err) + } + + } else { + logger.Log.Info(fmt.Sprintf("Created new GCS bucket: %v\n", req.BucketName)) + } + return nil +} + +func (sa *StorageAccessorImpl) ApplyBucketLifecycleDeleteRule(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + for i, str := range req.MatchesPrefix { + req.MatchesPrefix[i] = strings.TrimPrefix(str, "/") + } + bucket := sc.Bucket(req.BucketName) + bucketAttrsToUpdate := storage.BucketAttrsToUpdate{ + Lifecycle: &storage.Lifecycle{ + Rules: []storage.LifecycleRule{ + { + Action: storage.LifecycleAction{Type: "Delete"}, + Condition: storage.LifecycleCondition{ + AgeInDays: req.Ttl, + // The prefixes should not contain the bucket names and starting slash. + // For object gs://my_bucket/pictures/paris_2022.jpg, + // you would use a condition such as "matchesPrefix":["pictures/paris_"]. + MatchesPrefix: req.MatchesPrefix, + }, + }, + }, + }, + } + + attrs, err := bucket.Update(ctx, bucketAttrsToUpdate) + if err != nil { + return fmt.Errorf("could not bucket with lifecycle: %w", err) + } + logger.Log.Info(fmt.Sprintf("Added lifecycle rule to bucket %v\n. Rule Action: %v\t Rule Condition: %v\n", + req.BucketName, attrs.Lifecycle.Rules[0].Action, attrs.Lifecycle.Rules[0].Condition)) + return nil +} + +func (sa *StorageAccessorImpl) UploadLocalFileToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, localFilePath string) error { + data, err := os.ReadFile(localFilePath) + if err != nil { + return fmt.Errorf("could not read file %s: %w", localFilePath, err) + } + return sa.WriteDataToGCS(ctx, sc, filePath, fileName, string(data)) +} + +func (sa *StorageAccessorImpl) WriteDataToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, data string) error { + u, err := utils.ParseGCSFilePath(filePath) + if err != nil { + return fmt.Errorf("parseFilePath: unable to parse file path: %v", err) + } + bucketName := u.Host + bucket := sc.Bucket(bucketName) + fullFilePath := u.Path + fileName + if strings.HasPrefix(fullFilePath, "/") { + fullFilePath = u.Path[1:] + fileName + } + obj := bucket.Object(fullFilePath) + + w := obj.NewWriter(ctx) + logger.Log.Info(fmt.Sprintf("Writing data to %s", filePath)) + n, err := fmt.Fprint(w, data) + if err != nil { + fmt.Printf("Failed to write to Cloud Storage: %s\n", filePath) + return err + } + logger.Log.Info(fmt.Sprintf("Wrote %d bytes to GCS", n)) + + if err := w.Close(); err != nil { + fmt.Printf("Failed to close GCS file: %s\n", filePath) + return err + } + return nil +} + +func (sa *StorageAccessorImpl) ReadGcsFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + u, err := utils.ParseGCSFilePath(filePath) + if err != nil { + return "", fmt.Errorf("unable to parse file path: %v", err) + } + bucketName := u.Host + bucket := sc.Bucket(bucketName) + obj := bucket.Object(u.Path[1:]) + rc, err := obj.NewReader(ctx) + if err != nil { + return "", err + } + defer rc.Close() + buf := new(strings.Builder) + logger.Log.Info(fmt.Sprintf("Reading from %s", filePath)) + n, err := io.Copy(buf, rc) + if err != nil { + return "", err + } + logger.Log.Info(fmt.Sprintf("Read %d bytes", n)) + return buf.String(), nil +} + +func (sa *StorageAccessorImpl) ReadAnyFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + if strings.HasPrefix(filePath, constants.GCS_FILE_PREFIX) { + return sa.ReadGcsFile(ctx, sc, filePath) + } + buf, err := os.ReadFile(filePath) + if err != nil { + return "", err + } + return string(buf), nil +} diff --git a/accessors/storage/storage_accessor_test.go b/accessors/storage/storage_accessor_test.go new file mode 100644 index 0000000000..4380a6a8f1 --- /dev/null +++ b/accessors/storage/storage_accessor_test.go @@ -0,0 +1,395 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a +package storageaccessor + +import ( + "context" + "fmt" + "io" + "os" + "strings" + "testing" + + "cloud.google.com/go/storage" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/googleapi" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func TestStorageAccessorImpl_CreateGCSBucket(t *testing.T) { + testCases := []struct { + name string + scm storageclient.StorageClientMock + expectError bool + }{ + { + name: "Basic", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + CreateMock: func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return nil + }, + } + }, + }, + expectError: false, + }, + { + name: "random error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + CreateMock: func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return fmt.Errorf("random error") + }, + } + }, + }, + expectError: true, + }, + { + name: "Bucket already exists", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + CreateMock: func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return &googleapi.Error{Code: 409} + }, + } + }, + }, + expectError: false, + }, + { + name: "Other google api error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + CreateMock: func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return &googleapi.Error{Code: 100} + }, + } + }, + }, + expectError: true, + }, + } + ctx := context.Background() + sa := StorageAccessorImpl{} + for _, tc := range testCases { + err := sa.CreateGCSBucket(ctx, &tc.scm, StorageBucketMetadata{ + BucketName: "test-bucket", + ProjectID: "test-project", + Location: "india2", + Ttl: 1, + MatchesPrefix: nil, + }) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestStorageAccessorImpl_ApplyBucketLifecycleDeleteRule(t *testing.T) { + testCases := []struct { + name string + scm storageclient.StorageClientMock + ttl int64 + matchesPrefix []string + expectError bool + }{ + { + name: "Basic", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + UpdateMock: func(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + if strings.HasPrefix(uattrs.Lifecycle.Rules[0].Condition.MatchesPrefix[0], "/") { + return nil, fmt.Errorf("test error") + } + return &storage.BucketAttrs{Lifecycle: storage.Lifecycle{Rules: []storage.LifecycleRule{ + { + Action: storage.LifecycleAction{Type: "Delete"}, + Condition: storage.LifecycleCondition{ + AgeInDays: 5, + MatchesPrefix: []string{}, + }, + }, + }}}, nil + }, + } + }, + }, + ttl: 5, + matchesPrefix: []string{"test"}, + expectError: false, + }, + { + name: "Update error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + UpdateMock: func(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + return nil, fmt.Errorf("test error") + }, + } + }, + }, + ttl: 5, + matchesPrefix: []string{"test"}, + expectError: true, + }, + { + name: "Prefix '/' gets removed", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + UpdateMock: func(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + if strings.HasPrefix(uattrs.Lifecycle.Rules[0].Condition.MatchesPrefix[0], "/") { + return nil, fmt.Errorf("test error") + } + return &storage.BucketAttrs{Lifecycle: storage.Lifecycle{Rules: []storage.LifecycleRule{ + { + Action: storage.LifecycleAction{Type: "Delete"}, + Condition: storage.LifecycleCondition{ + AgeInDays: 5, + MatchesPrefix: []string{}, + }, + }, + }}}, nil + }, + } + }, + }, + ttl: 5, + matchesPrefix: []string{"/test"}, + expectError: false, + }, + } + ctx := context.Background() + sa := StorageAccessorImpl{} + for _, tc := range testCases { + err := sa.ApplyBucketLifecycleDeleteRule(ctx, &tc.scm, StorageBucketMetadata{ + BucketName: "test-bucket", + Ttl: tc.ttl, + MatchesPrefix: tc.matchesPrefix, + }) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestStorageAccessorImpl_WriteDataToGCS(t *testing.T) { + testCases := []struct { + name string + scm storageclient.StorageClientMock + filePath string + fileName string + data string + expectError bool + }{ + { + name: "Basic", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewWriterMock: func(ctx context.Context) io.WriteCloser { + return &storageclient.WriterMock{ + WriteMock: func(p []byte) (n int, err error) { return len(p), nil }, + CloseMock: func() error { return nil }, + } + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + fileName: "test-file", + data: "abcd", + expectError: false, + }, + { + name: "File parsing error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewWriterMock: func(ctx context.Context) io.WriteCloser { + return &storageclient.WriterMock{ + WriteMock: func(p []byte) (n int, err error) { return len(p), nil }, + CloseMock: func() error { return nil }, + } + }, + } + }, + } + }, + }, + filePath: "://bucket/path", + fileName: "test-file", + data: "abcd", + expectError: true, + }, + { + name: "Write error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewWriterMock: func(ctx context.Context) io.WriteCloser { + return &storageclient.WriterMock{ + WriteMock: func(p []byte) (n int, err error) { return 0, fmt.Errorf("test-error") }, + CloseMock: func() error { return nil }, + } + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + fileName: "test-file", + data: "abcd", + expectError: true, + }, + { + name: "Close error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewWriterMock: func(ctx context.Context) io.WriteCloser { + return &storageclient.WriterMock{ + WriteMock: func(p []byte) (n int, err error) { return len(p), nil }, + CloseMock: func() error { return fmt.Errorf("test error") }, + } + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + fileName: "test-file", + data: "abcd", + expectError: true, + }, + } + ctx := context.Background() + sa := StorageAccessorImpl{} + for _, tc := range testCases { + err := sa.WriteDataToGCS(ctx, &tc.scm, tc.filePath, tc.fileName, tc.data) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestStorageAccessorImpl_ReadGcsFile(t *testing.T) { + testCases := []struct { + name string + scm storageclient.StorageClientMock + filePath string + expectError bool + want string + }{ + { + name: "Basic", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewReaderMock: func(ctx context.Context) (io.ReadCloser, error) { + return &storageclient.ReaderMock{ + ReadMock: func(p []byte) (n int, err error) { + copy(p, "hello") + return 5, io.EOF + }, + CloseMock: func() error { return nil }, + }, nil + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + expectError: false, + want: "hello", + }, + { + name: "Parse error", + scm: storageclient.StorageClientMock{}, + filePath: "://bucket/path", + expectError: true, + want: "", + }, + { + name: "New reader error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewReaderMock: func(ctx context.Context) (io.ReadCloser, error) { + return nil, fmt.Errorf("test error") + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + expectError: true, + want: "", + }, + { + name: "Read error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewReaderMock: func(ctx context.Context) (io.ReadCloser, error) { + return &storageclient.ReaderMock{ + ReadMock: func(p []byte) (n int, err error) { + return 0, fmt.Errorf("test error") + }, + CloseMock: func() error { return nil }, + }, nil + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + expectError: true, + want: "", + }, + } + ctx := context.Background() + sa := StorageAccessorImpl{} + for _, tc := range testCases { + got, err := sa.ReadGcsFile(ctx, &tc.scm, tc.filePath) + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} diff --git a/accessors/storage/types.go b/accessors/storage/types.go new file mode 100644 index 0000000000..5a20e1a002 --- /dev/null +++ b/accessors/storage/types.go @@ -0,0 +1,24 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageaccessor + +type StorageBucketMetadata struct { + BucketName string + // Not required for Updates. + ProjectID string + // Not required for Updates. + Location string + Ttl int64 + MatchesPrefix []string +} diff --git a/cmd/data.go b/cmd/data.go index c908f8ad94..9be772042e 100644 --- a/cmd/data.go +++ b/cmd/data.go @@ -26,6 +26,8 @@ import ( sp "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" @@ -178,7 +180,12 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface // validateExistingDb validates that the existing spanner schema is in accordance with the one specified in the session file. func validateExistingDb(ctx context.Context, spDialect, dbURI string, adminClient *database.DatabaseAdminClient, client *sp.Client, conv *internal.Conv) error { - dbExists, err := conversion.CheckExistingDb(ctx, adminClient, dbURI) + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + return err + } + spA := spanneraccessor.SpannerAccessorImpl{} + dbExists, err := spA.CheckExistingDb(ctx, adminClientImpl, dbURI) if err != nil { err = fmt.Errorf("can't verify target database: %v", err) return err diff --git a/common/constants/constants.go b/common/constants/constants.go index e98855a81d..c0eaa87685 100644 --- a/common/constants/constants.go +++ b/common/constants/constants.go @@ -64,7 +64,8 @@ const ( MigrationMetadataKey string = "cloud-spanner-migration-metadata" // Scheme used for GCS paths - GCS_SCHEME string = "gs" + GCS_SCHEME string = "gs" + GCS_FILE_PREFIX string = "gs://" // File upload prefix for dump and session load. UPLOAD_FILE_DIR string = "upload-file" diff --git a/common/utils/storage_utils.go b/common/utils/storage_utils.go new file mode 100644 index 0000000000..3968b1731e --- /dev/null +++ b/common/utils/storage_utils.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +Package utils contains common helper functions used across multiple other packages. +Utils should not import any Spanner migration tool packages. +*/ +package utils + +import ( + "fmt" + "net/url" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" +) + +func ParseGCSFilePath(filePath string) (*url.URL, error) { + if len(filePath) == 0 { + return nil, fmt.Errorf("found empty GCS path") + } + if filePath[len(filePath)-1] != '/' { + filePath = filePath + "/" + } + u, err := url.Parse(filePath) + if err != nil { + return nil, fmt.Errorf("parseFilePath: unable to parse file path %s", filePath) + } + if u.Scheme != constants.GCS_SCHEME { + return nil, fmt.Errorf("not a valid GCS path: %s, should start with 'gs'", filePath) + } + return u, nil +} diff --git a/common/utils/storage_utils_test.go b/common/utils/storage_utils_test.go new file mode 100644 index 0000000000..fa00f0c302 --- /dev/null +++ b/common/utils/storage_utils_test.go @@ -0,0 +1,107 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package utils + +import ( + "net/url" + "os" + "testing" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func TestParseGCSFilePath(t *testing.T) { + testCases := []struct { + name string + filePath string + expectError bool + want *url.URL + }{ + { + name: "Basic", + filePath: "gs://test-bucket/path/to/folder/", + expectError: false, + want: &url.URL{ + Scheme: "gs", + Host: "test-bucket", + Path: "/path/to/folder/", + }, + }, + { + name: "Append Slash", + filePath: "gs://test-bucket/path/to/folder", + expectError: false, + want: &url.URL{ + Scheme: "gs", + Host: "test-bucket", + Path: "/path/to/folder/", + }, + }, + { + name: "Empty path", + filePath: "gs://test-bucket", + expectError: false, + want: &url.URL{ + Scheme: "gs", + Host: "test-bucket", + Path: "/", + }, + }, + { + name: "Empty path with leading slash", + filePath: "gs://test-bucket/", + expectError: false, + want: &url.URL{ + Scheme: "gs", + Host: "test-bucket", + Path: "/", + }, + }, + { + name: "Empty File path", + filePath: "", + expectError: true, + want: nil, + }, + { + name: "Wrong Scheme", + filePath: "ab://testpath", + expectError: true, + want: nil, + }, + { + name: "Malformed Path", + filePath: "://path", + expectError: true, + want: nil, + }, + } + + for _, tc := range testCases { + got, err := ParseGCSFilePath(tc.filePath) + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} diff --git a/common/utils/utils.go b/common/utils/utils.go index 70ec53f882..932da34f11 100644 --- a/common/utils/utils.go +++ b/common/utils/utils.go @@ -19,11 +19,11 @@ package utils import ( "bufio" "context" + "crypto/rand" "fmt" "io" "io/ioutil" "log" - "math/rand" "net/url" "os" "os/exec" @@ -37,6 +37,7 @@ import ( sp "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" instance "cloud.google.com/go/spanner/admin/instance/apiv1" + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" "cloud.google.com/go/storage" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" @@ -44,10 +45,8 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/spanner" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "golang.org/x/crypto/ssh/terminal" - "google.golang.org/api/googleapi" "google.golang.org/api/iterator" "google.golang.org/api/option" - instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" ) // IOStreams is a struct that contains the file descriptor for dumpFile. @@ -173,82 +172,6 @@ func PreloadGCSFiles(tables []ManifestTable) ([]ManifestTable, error) { return tables, nil } -func ParseGCSFilePath(filePath string) (*url.URL, error) { - if len(filePath) == 0 { - return nil, fmt.Errorf("found empty GCS path") - } - if filePath[len(filePath)-1] != '/' { - filePath = filePath + "/" - } - u, err := url.Parse(filePath) - if err != nil { - return nil, fmt.Errorf("parseFilePath: unable to parse file path %s", filePath) - } - if u.Scheme != constants.GCS_SCHEME { - return nil, fmt.Errorf("not a valid GCS path: %s, should start with 'gs'", filePath) - } - return u, nil -} - -func WriteToGCS(filePath, fileName, data string) error { - ctx := context.Background() - - client, err := storage.NewClient(ctx) - if err != nil { - fmt.Printf("Failed to create GCS client") - return err - } - defer client.Close() - u, err := ParseGCSFilePath(filePath) - if err != nil { - return fmt.Errorf("parseFilePath: unable to parse file path: %v", err) - } - bucketName := u.Host - bucket := client.Bucket(bucketName) - obj := bucket.Object(u.Path[1:] + fileName) - - w := obj.NewWriter(ctx) - if _, err := fmt.Fprint(w, data); err != nil { - fmt.Printf("Failed to write to Cloud Storage: %s", filePath) - return err - } - if err := w.Close(); err != nil { - fmt.Printf("Failed to close GCS file: %s", filePath) - return err - } - return nil -} - -func CreateGCSBucket(bucketName, projectID, location string) error { - ctx := context.Background() - - client, err := storage.NewClient(ctx) - if err != nil { - return fmt.Errorf("failed to create GCS client: %v", err) - } - defer client.Close() - bucket := client.Bucket(bucketName) - attrs := storage.BucketAttrs{ - Location: location, - } - if err := bucket.Create(ctx, projectID, &attrs); err != nil { - if e, ok := err.(*googleapi.Error); ok { - // Ignoring the bucket already exists error. - if e.Code != 409 { - return fmt.Errorf("failed to create bucket: %v", err) - } else { - fmt.Printf("Using the existing bucket: %v \n", bucketName) - } - } else { - return fmt.Errorf("failed to create bucket: %v", err) - } - - } else { - fmt.Printf("Created new GCS bucket: %v\n", bucketName) - } - return nil -} - // GetProject returns the cloud project we should use for accessing Spanner. // Use environment variable GCLOUD_PROJECT if it is set. // Otherwise, use the default project returned from gcloud. @@ -347,6 +270,12 @@ func GenerateName(prefix string) (string, error) { return fmt.Sprintf("%s_%x-%x", prefix, b[0:2], b[2:4]), nil } +func GenerateHashStr() string { + b := make([]byte, 4) + rand.Read(b) + return fmt.Sprintf("%x-%x", b[0:2], b[2:4]) +} + // parseURI parses an unknown URI string that could be a database, instance or project URI. func parseURI(URI string) (project, instance, dbName string) { project, instance, dbName = "", "", "" diff --git a/conversion/conversion.go b/conversion/conversion.go index 20216dfea4..f1cf091dc4 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -44,6 +44,10 @@ import ( datastream "cloud.google.com/go/datastream/apiv1" sp "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" @@ -67,12 +71,12 @@ import ( dydb "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodbstreams" mysqldriver "github.com/go-sql-driver/mysql" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" "go.uber.org/zap" adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/stdlib" ) var ( @@ -354,8 +358,17 @@ func dataFromDatabase(ctx context.Context, sourceProfile profiles.SourceProfile, // Try to apply lifecycle rule to Datastream destination bucket. gcsConfig := streamingCfg.GcsCfg + sc, err := storageclient.NewStorageClientImpl(ctx) + if err != nil { + return nil, err + } + sa := storageaccessor.StorageAccessorImpl{} if gcsConfig.TtlInDaysSet { - err = streaming.EnableBucketLifecycleDeleteRule(ctx, gcsBucket, []string{gcsDestPrefix}, gcsConfig.TtlInDays) + err = sa.ApplyBucketLifecycleDeleteRule(ctx, sc, storageaccessor.StorageBucketMetadata{ + BucketName: gcsBucket, + Ttl: gcsConfig.TtlInDays, + MatchesPrefix: []string{gcsDestPrefix}, + }) if err != nil { logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") @@ -475,8 +488,17 @@ func dataFromDatabaseForDataflowMigration(targetProfile profiles.TargetProfile, // Try to apply lifecycle rule to Datastream destination bucket. gcsConfig := streamingCfg.GcsCfg + sc, err := storageclient.NewStorageClientImpl(ctx) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + sa := storageaccessor.StorageAccessorImpl{} if gcsConfig.TtlInDaysSet { - err = streaming.EnableBucketLifecycleDeleteRule(ctx, gcsBucket, []string{gcsDestPrefix}, gcsConfig.TtlInDays) + err = sa.ApplyBucketLifecycleDeleteRule(ctx, sc, storageaccessor.StorageBucketMetadata{ + BucketName: gcsBucket, + Ttl: gcsConfig.TtlInDays, + MatchesPrefix: []string{gcsDestPrefix}, + }) if err != nil { logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") @@ -520,11 +542,11 @@ func dataFromDatabaseForDataflowMigration(targetProfile profiles.TargetProfile, // create monitoring aggregated dashboard for sharded migration aggMonitoringResources := metrics.MonitoringMetricsResources{ - ProjectId: targetProfile.Conn.Sp.Project, - SpannerInstanceId: targetProfile.Conn.Sp.Instance, - SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, - ShardToShardResourcesMap: conv.Audit.StreamingStats.ShardToShardResourcesMap, - MigrationRequestId: conv.Audit.MigrationRequestId, + ProjectId: targetProfile.Conn.Sp.Project, + SpannerInstanceId: targetProfile.Conn.Sp.Instance, + SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, + ShardToShardResourcesMap: conv.Audit.StreamingStats.ShardToShardResourcesMap, + MigrationRequestId: conv.Audit.MigrationRequestId, } aggRespDash, dashboardErr := aggMonitoringResources.CreateDataflowAggMonitoringDashboard(ctx) if dashboardErr != nil { @@ -798,7 +820,12 @@ func getSeekable(f *os.File) (*os.File, int64, error) { // VerifyDb checks whether the db exists and if it does, verifies if the schema is what we currently support. func VerifyDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (dbExists bool, err error) { - dbExists, err = CheckExistingDb(ctx, adminClient, dbURI) + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + return dbExists, err + } + spA := spanneraccessor.SpannerAccessorImpl{} + dbExists, err = spA.CheckExistingDb(ctx, adminClientImpl, dbURI) if err != nil { return dbExists, err } @@ -808,31 +835,6 @@ func VerifyDb(ctx context.Context, adminClient *database.DatabaseAdminClient, db return dbExists, err } -// CheckExistingDb checks whether the database with dbURI exists or not. -// If API call doesn't respond then user is informed after every 5 minutes on command line. -func CheckExistingDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (bool, error) { - gotResponse := make(chan bool) - var err error - go func() { - _, err = adminClient.GetDatabase(ctx, &adminpb.GetDatabaseRequest{Name: dbURI}) - gotResponse <- true - }() - for { - select { - case <-time.After(5 * time.Minute): - fmt.Println("WARNING! API call not responding: make sure that spanner api endpoint is configured properly") - case <-gotResponse: - if err != nil { - if utils.ContainsAny(strings.ToLower(err.Error()), []string{"database not found"}) { - return false, nil - } - return false, fmt.Errorf("can't get database info: %s", err) - } - return true, nil - } - } -} - // ValidateTables validates that all the tables in the database are empty. // It returns the name of the first non-empty table if found, and an empty string otherwise. func ValidateTables(ctx context.Context, client *sp.Client, spDialect string) (string, error) { @@ -1305,20 +1307,20 @@ func GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfi switch driver { case constants.MYSQL: d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) - if err != nil { - return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) - } - var opts []cloudsqlconn.DialOption + if err != nil { + return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) + } + var opts []cloudsqlconn.DialOption instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Mysql.Project, sourceProfile.ConnCloudSQL.Mysql.Region, sourceProfile.ConnCloudSQL.Mysql.InstanceName) - mysqldriver.RegisterDialContext("cloudsqlconn", - func(ctx context.Context, addr string) (net.Conn, error) { - return d.Dial(ctx, instanceName, opts...) - }) + mysqldriver.RegisterDialContext("cloudsqlconn", + func(ctx context.Context, addr string) (net.Conn, error) { + return d.Dial(ctx, instanceName, opts...) + }) - dbURI := fmt.Sprintf("%s:empty@cloudsqlconn(localhost:3306)/%s?parseTime=true", - sourceProfile.ConnCloudSQL.Mysql.User, sourceProfile.ConnCloudSQL.Mysql.Db) + dbURI := fmt.Sprintf("%s:empty@cloudsqlconn(localhost:3306)/%s?parseTime=true", + sourceProfile.ConnCloudSQL.Mysql.User, sourceProfile.ConnCloudSQL.Mysql.Db) - db, err := sql.Open("mysql", dbURI) + db, err := sql.Open("mysql", dbURI) if err != nil { return nil, fmt.Errorf("sql.Open: %w", err) } @@ -1330,25 +1332,25 @@ func GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfi }, nil case constants.POSTGRES: d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) - if err != nil { - return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) - } - var opts []cloudsqlconn.DialOption - - dsn := fmt.Sprintf("user=%s database=%s", sourceProfile.ConnCloudSQL.Pg.User, sourceProfile.ConnCloudSQL.Pg.Db) - config, err := pgx.ParseConfig(dsn) - if err != nil { - return nil, err - } + if err != nil { + return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) + } + var opts []cloudsqlconn.DialOption + + dsn := fmt.Sprintf("user=%s database=%s", sourceProfile.ConnCloudSQL.Pg.User, sourceProfile.ConnCloudSQL.Pg.Db) + config, err := pgx.ParseConfig(dsn) + if err != nil { + return nil, err + } instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Pg.Project, sourceProfile.ConnCloudSQL.Pg.Region, sourceProfile.ConnCloudSQL.Pg.InstanceName) - config.DialFunc = func(ctx context.Context, network, instance string) (net.Conn, error) { - return d.Dial(ctx, instanceName, opts...) - } - dbURI := stdlib.RegisterConnConfig(config) - db, err := sql.Open("pgx", dbURI) - if err != nil { - return nil, fmt.Errorf("sql.Open: %w", err) - } + config.DialFunc = func(ctx context.Context, network, instance string) (net.Conn, error) { + return d.Dial(ctx, instanceName, opts...) + } + dbURI := stdlib.RegisterConnConfig(config) + db, err := sql.Open("pgx", dbURI) + if err != nil { + return nil, fmt.Errorf("sql.Open: %w", err) + } temp := false return postgres.InfoSchemaImpl{ Db: db, diff --git a/streaming/streaming.go b/streaming/streaming.go index 1d95521993..550ffffc31 100644 --- a/streaming/streaming.go +++ b/streaming/streaming.go @@ -33,11 +33,14 @@ import ( resourcemanager "cloud.google.com/go/resourcemanager/apiv3" resourcemanagerpb "cloud.google.com/go/resourcemanager/apiv3/resourcemanagerpb" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" dataflowaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/dataflow" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/google/uuid" "github.com/googleapis/gax-go/v2" @@ -859,12 +862,16 @@ func StartDatastream(ctx context.Context, streamingCfg StreamingCfg, sourceProfi } func StartDataflow(ctx context.Context, targetProfile profiles.TargetProfile, streamingCfg StreamingCfg, conv *internal.Conv) (internal.DataflowOutput, error) { - + sc, err := storageclient.NewStorageClientImpl(ctx) + if err != nil { + return internal.DataflowOutput{}, err + } + sa := storageaccessor.StorageAccessorImpl{} convJSON, err := json.MarshalIndent(conv, "", " ") if err != nil { return internal.DataflowOutput{}, fmt.Errorf("can't encode session state to JSON: %v", err) } - err = utils.WriteToGCS(streamingCfg.TmpDir, "session.json", string(convJSON)) + err = sa.WriteDataToGCS(ctx, sc, streamingCfg.TmpDir, "session.json", string(convJSON)) if err != nil { return internal.DataflowOutput{}, fmt.Errorf("error while writing to GCS: %v", err) } @@ -875,7 +882,7 @@ func StartDataflow(ctx context.Context, targetProfile profiles.TargetProfile, st if err != nil { return internal.DataflowOutput{}, fmt.Errorf("failed to compute transformation context: %s", err.Error()) } - err = utils.WriteToGCS(streamingCfg.TmpDir, "transformationContext.json", string(transformationContext)) + err = sa.WriteDataToGCS(ctx, sc, streamingCfg.TmpDir, "transformationContext.json", string(transformationContext)) if err != nil { return internal.DataflowOutput{}, fmt.Errorf("error while writing to GCS: %v", err) } @@ -885,43 +892,3 @@ func StartDataflow(ctx context.Context, targetProfile profiles.TargetProfile, st } return dfOutput, nil } - -// Applies the bucket lifecycle with delete rule. Only accepts the Age and -// prefix rule conditions as it is only used for the Datastream destination -// bucket currently. -func EnableBucketLifecycleDeleteRule(ctx context.Context, bucketName string, matchesPrefix []string, ttl int64) error { - client, err := storage.NewClient(ctx) - if err != nil { - return fmt.Errorf("could not create client while enabling lifecycle: %w", err) - } - defer client.Close() - - for i, str := range matchesPrefix { - matchesPrefix[i] = strings.TrimPrefix(str, "/") - } - bucket := client.Bucket(bucketName) - bucketAttrsToUpdate := storage.BucketAttrsToUpdate{ - Lifecycle: &storage.Lifecycle{ - Rules: []storage.LifecycleRule{ - { - Action: storage.LifecycleAction{Type: "Delete"}, - Condition: storage.LifecycleCondition{ - AgeInDays: ttl, - // The prefixes should not contain the bucket names and starting slash. - // For object gs://my_bucket/pictures/paris_2022.jpg, - // you would use a condition such as "matchesPrefix":["pictures/paris_"]. - MatchesPrefix: matchesPrefix, - }, - }, - }, - }, - } - - attrs, err := bucket.Update(ctx, bucketAttrsToUpdate) - if err != nil { - return fmt.Errorf("could not bucket with lifecycle: %w", err) - } - logger.Log.Info(fmt.Sprintf("Added lifecycle rule to bucket %v\n. Rule Action: %v\t Rule Condition: %v\n", - bucketName, attrs.Lifecycle.Rules[0].Action, attrs.Lifecycle.Rules[0].Condition)) - return nil -} diff --git a/testing/accessors/spanner/spanner_accessor_test.go b/testing/accessors/spanner/spanner_accessor_test.go new file mode 100644 index 0000000000..311d2de0eb --- /dev/null +++ b/testing/accessors/spanner/spanner_accessor_test.go @@ -0,0 +1,135 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO: Refactor this file and other integration tests by moving all common code +// to remove redundancy. + +package spanneraccessor_test + +import ( + "context" + "flag" + "fmt" + "log" + "os" + "testing" + "time" + + database "cloud.google.com/go/spanner/admin/database/apiv1" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +var ( + projectID string + instanceID string + + ctx context.Context + databaseAdmin *database.DatabaseAdminClient +) + +// This test should move as a mock unit test inside accessors itself. +func TestMain(m *testing.M) { + cleanup := initTests() + res := m.Run() + cleanup() + os.Exit(res) +} + +func init() { + logger.Log = zap.NewNop() +} + +func initTests() (cleanup func()) { + projectID = os.Getenv("SPANNER_MIGRATION_TOOL_TESTS_GCLOUD_PROJECT_ID") + instanceID = os.Getenv("SPANNER_MIGRATION_TOOL_TESTS_GCLOUD_INSTANCE_ID") + + ctx = context.Background() + flag.Parse() // Needed for testing.Short(). + noop := func() {} + + if testing.Short() { + log.Println("Unit test for UpdateDDLForeignKeys skipped in -short mode.") + return noop + } + + if projectID == "" { + log.Println("Unit test for UpdateDDLForeignKeys skipped: SPANNER_MIGRATION_TOOL_TESTS_GCLOUD_PROJECT_ID is missing") + return noop + } + + if instanceID == "" { + log.Println("Unit test for UpdateDDLForeignKeys skipped: SPANNER_MIGRATION_TOOL_TESTS_GCLOUD_INSTANCE_ID is missing") + return noop + } + + var err error + databaseAdmin, err = database.NewDatabaseAdminClient(ctx) + if err != nil { + log.Fatalf("cannot create databaseAdmin client: %v", err) + } + + return func() { + databaseAdmin.Close() + } +} + +func dropDatabase(t *testing.T, dbPath string) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + // Drop the testing database. + if err := databaseAdmin.DropDatabase(ctx, &databasepb.DropDatabaseRequest{Database: dbPath}); err != nil { + t.Fatalf("failed to drop testing database %v: %v", dbPath, err) + } +} + +func TestCheckExistingDb(t *testing.T) { + onlyRunForEmulatorTest(t) + dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, "check-db-exists") + err := conversion.CreateDatabase(ctx, databaseAdmin, dbURI, internal.MakeConv(), os.Stdout, "", constants.BULK_MIGRATION) + if err != nil { + t.Fatal(err) + } + defer dropDatabase(t, dbURI) + testCases := []struct { + dbName string + dbExists bool + }{ + {"check-db-exists", true}, + {"check-db-does-not-exist", false}, + } + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + t.Fatal(err) + } + spA := spanneraccessor.SpannerAccessorImpl{} + for _, tc := range testCases { + dbExists, err := spA.CheckExistingDb(ctx, adminClientImpl, fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, tc.dbName)) + assert.Nil(t, err) + assert.Equal(t, tc.dbExists, dbExists) + } +} + +func onlyRunForEmulatorTest(t *testing.T) { + if os.Getenv("SPANNER_EMULATOR_HOST") == "" { + t.Skip("Skipping tests only running against the emulator.") + } +} diff --git a/testing/conversion/conversion_test.go b/testing/conversion/conversion_test.go index ff9a8e49ef..ba08519a71 100644 --- a/testing/conversion/conversion_test.go +++ b/testing/conversion/conversion_test.go @@ -248,29 +248,6 @@ func TestVerifyDb(t *testing.T) { } } -func TestCheckExistingDb(t *testing.T) { - onlyRunForEmulatorTest(t) - dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, "check-db-exists") - err := conversion.CreateDatabase(ctx, databaseAdmin, dbURI, internal.MakeConv(), os.Stdout, "", constants.BULK_MIGRATION) - if err != nil { - t.Fatal(err) - } - defer dropDatabase(t, dbURI) - testCases := []struct { - dbName string - dbExists bool - }{ - {"check-db-exists", true}, - {"check-db-does-not-exist", false}, - } - - for _, tc := range testCases { - dbExists, err := conversion.CheckExistingDb(ctx, databaseAdmin, fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, tc.dbName)) - assert.Nil(t, err) - assert.Equal(t, tc.dbExists, dbExists) - } -} - func TestValidateDDL(t *testing.T) { onlyRunForEmulatorTest(t) diff --git a/webv2/helpers/helpers.go b/webv2/helpers/helpers.go index d0dbea31f9..980bc7eac2 100644 --- a/webv2/helpers/helpers.go +++ b/webv2/helpers/helpers.go @@ -21,8 +21,9 @@ import ( "strings" database "cloud.google.com/go/spanner/admin/database/apiv1" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" - "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" ) @@ -153,14 +154,14 @@ func CheckOrCreateMetadataDb(projectId string, instanceId string) bool { } ctx := context.Background() - adminClient, err := database.NewDatabaseAdminClient(ctx) + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) if err != nil { fmt.Println(err) return false } - defer adminClient.Close() - dbExists, err := conversion.CheckExistingDb(ctx, adminClient, uri) + spA := spanneraccessor.SpannerAccessorImpl{} + dbExists, err := spA.CheckExistingDb(ctx, adminClientImpl, uri) if err != nil { fmt.Println(err) return false diff --git a/webv2/profile/profile.go b/webv2/profile/profile.go index 9808a594f6..f3935f5a29 100644 --- a/webv2/profile/profile.go +++ b/webv2/profile/profile.go @@ -11,6 +11,8 @@ import ( "strings" datastream "cloud.google.com/go/datastream/apiv1" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" @@ -153,6 +155,12 @@ func CreateConnectionProfile(w http.ResponseWriter, r *http.Request) { ValidateOnly: details.ValidateOnly, } var bucketName string + sc, err := storageclient.NewStorageClientImpl(ctx) + if err != nil { + http.Error(w, fmt.Sprintf("Error while StorageClientImpl: %v", err), http.StatusBadRequest) + return + } + sa := storageaccessor.StorageAccessorImpl{} if !details.IsSource { if sessionState.IsSharded { @@ -160,7 +168,13 @@ func CreateConnectionProfile(w http.ResponseWriter, r *http.Request) { } else { bucketName = strings.ToLower(sessionState.Conv.Audit.MigrationRequestId) } - err = utils.CreateGCSBucket(bucketName, sessionState.GCPProjectID, sessionState.Region) + err = sa.CreateGCSBucket(ctx, sc, storageaccessor.StorageBucketMetadata{ + BucketName: bucketName, + ProjectID: sessionState.GCPProjectID, + Location: sessionState.Region, + Ttl: 0, + MatchesPrefix: nil, + }) if err != nil { http.Error(w, fmt.Sprintf("Error while creating bucket: %v", err), http.StatusBadRequest) return diff --git a/webv2/session/session_service.go b/webv2/session/session_service.go index df190eac4c..c30e9d8a62 100644 --- a/webv2/session/session_service.go +++ b/webv2/session/session_service.go @@ -7,7 +7,8 @@ import ( "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" - "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" helpers "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/helpers" ) @@ -80,15 +81,15 @@ func getOldMetadataDbUri(projectId string, instanceId string) string { func migrateMetadataDb(projectId, instanceId string) { ctx := context.Background() - adminClient, err := database.NewDatabaseAdminClient(ctx) + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) if err != nil { fmt.Println(err) return } - defer adminClient.Close() + spA := spanneraccessor.SpannerAccessorImpl{} oldMetadataDbUri := getOldMetadataDbUri(projectId, instanceId) - oldMetadataDBExists, err := conversion.CheckExistingDb(ctx, adminClient, oldMetadataDbUri) + oldMetadataDBExists, err := spA.CheckExistingDb(ctx, adminClientImpl, oldMetadataDbUri) if err != nil { fmt.Printf("could not check if oldMetadataDB exists. error=%v\n", err) return @@ -157,7 +158,12 @@ func migrateMetadataDb(projectId, instanceId string) { } fmt.Println("Successfully wrote data to new metadata DB.") - + adminClient, err := database.NewDatabaseAdminClient(ctx) + if err != nil { + fmt.Println(err) + return + } + defer adminClient.Close() err = adminClient.DropDatabase(ctx, &databasepb.DropDatabaseRequest{ Database: oldMetadataDbUri, }) diff --git a/webv2/web.go b/webv2/web.go index 7b8c84b8cc..22ed1a40f1 100644 --- a/webv2/web.go +++ b/webv2/web.go @@ -36,6 +36,8 @@ import ( "time" instance "cloud.google.com/go/spanner/admin/instance/apiv1" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" "github.com/GoogleCloudPlatform/spanner-migration-tool/cmd" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" @@ -2263,7 +2265,7 @@ func migrate(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Can't get source and target profiles: %v", err), http.StatusBadRequest) return } - err = writeSessionFile(sessionState) + err = writeSessionFile(ctx, sessionState) if err != nil { log.Println("can't write session file") http.Error(w, fmt.Sprintf("Can't write session file to GCS: %v", err), http.StatusBadRequest) @@ -2484,9 +2486,19 @@ func createConfigFileForShardedBulkMigration(sessionState *session.SessionState, return nil } -func writeSessionFile(sessionState *session.SessionState) error { - - err := utils.CreateGCSBucket(sessionState.Bucket, sessionState.GCPProjectID, sessionState.Region) +func writeSessionFile(ctx context.Context, sessionState *session.SessionState) error { + sc, err := storageclient.NewStorageClientImpl(ctx) + if err != nil { + return err + } + sa := storageaccessor.StorageAccessorImpl{} + err = sa.CreateGCSBucket(ctx, sc, storageaccessor.StorageBucketMetadata{ + BucketName: sessionState.Bucket, + ProjectID: sessionState.GCPProjectID, + Location: sessionState.Region, + Ttl: 0, + MatchesPrefix: nil, + }) if err != nil { return fmt.Errorf("error while creating bucket: %v", err) } @@ -2495,7 +2507,7 @@ func writeSessionFile(sessionState *session.SessionState) error { if err != nil { return fmt.Errorf("can't encode session state to JSON: %v", err) } - err = utils.WriteToGCS("gs://"+sessionState.Bucket+sessionState.RootPath, "session.json", string(convJSON)) + err = sa.WriteDataToGCS(ctx, sc, "gs://"+sessionState.Bucket+sessionState.RootPath, "session.json", string(convJSON)) if err != nil { return fmt.Errorf("error while writing to GCS: %v", err) } @@ -3031,7 +3043,7 @@ type ResourceDetails struct { ResourceType string `json:"ResourceType"` ResourceName string `json:"ResourceName"` ResourceUrl string `json:"ResourceUrl"` - GcloudCmd string `json:"GcloudCmd"` + GcloudCmd string `json:"GcloudCmd"` } type GeneratedResources struct { MigrationJobId string `json:"MigrationJobId"` @@ -3054,7 +3066,7 @@ type GeneratedResources struct { AggMonitoringDashboardName string `json:"AggMonitoringDashboardName"` AggMonitoringDashboardUrl string `json:"AggMonitoringDashboardUrl"` //Used for sharded migration flow - ShardToShardResourcesMap map[string][]ResourceDetails `json:"ShardToShardResourcesMap"` + ShardToShardResourcesMap map[string][]ResourceDetails `json:"ShardToShardResourcesMap"` } func addTypeToList(convertedType string, spType string, issues []internal.SchemaIssue, l []typeIssue) []typeIssue { From 339b840af241d03a23d9ea0c839b4885b3701ed4 Mon Sep 17 00:00:00 2001 From: Astha Mohta <35952883+asthamohta@users.noreply.github.com> Date: Fri, 2 Feb 2024 11:06:28 +0530 Subject: [PATCH 02/15] tests: Unit Tests for SMT (#757) * tests test * [feat] RR Create API 1: Add dataflow accessor (#745) * Add dataflow accessor * Add enable streaming engine struct tag Mofe Unmarshall Method to acc2 due ot storage dependency * Moved dataflow utils to accessor and creates types.go * Create dataflowutils package * Renamed testing package for dataflow util * Added unit tests * Added empty test files for clients * Move test to same package * Add tests for dataflow client * Update fake for client test * Make dataflow accessor interface and struct to make it testable * Remove interface from accessor package * Add dataflow accessor interface * Add comments to dataflow client and comments on unit tests * Move all dataflow dependencies to accessors and remove dataflow utils * Create dataflow client interface for accessor method to make it unit testable * tests tests * common testing * change * changes on comments --------- Co-authored-by: Deep1998 --- cmd/utils.go | 8 +- common/utils/utils.go | 18 +- conversion/conversion.go | 4 +- go.mod | 1 + go.sum | 1 + profiles/common.go | 6 +- profiles/common_test.go | 212 +++++ profiles/source_profile.go | 83 +- profiles/source_profile_test.go | 803 +++++++++++++++++- profiles/target_profile.go | 12 +- profiles/target_profile_test.go | 16 + streaming/cleanup.go | 5 +- streaming/store.go | 6 +- streaming/streaming.go | 2 +- testing/dynamodb/snapshot/integration_test.go | 3 +- .../dynamodb/streaming/integration_test.go | 3 +- testing/postgres/integration_test.go | 12 +- webv2/utilities/utilities.go | 3 +- webv2/web.go | 3 +- 19 files changed, 1129 insertions(+), 72 deletions(-) create mode 100644 profiles/common_test.go create mode 100644 profiles/target_profile_test.go diff --git a/cmd/utils.go b/cmd/utils.go index a1cf40523a..1f25d967d9 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -45,7 +45,7 @@ func CreateDatabaseClient(ctx context.Context, targetProfile profiles.TargetProf if targetProfile.Conn.Sp.Dbname == "" { targetProfile.Conn.Sp.Dbname = dbName } - project, instance, dbName, err := targetProfile.GetResourceIds(ctx, time.Now(), driver, ioHelper.Out) + project, instance, dbName, err := targetProfile.GetResourceIds(ctx, time.Now(), driver, ioHelper.Out, &utils.GetUtilInfoImpl{}) if err != nil { return nil, nil, "", err } @@ -74,7 +74,8 @@ func PrepareMigrationPrerequisites(sourceProfileString, targetProfileString, sou return profiles.SourceProfile{}, profiles.TargetProfile{}, utils.IOStreams{}, "", err } - sourceProfile, err := profiles.NewSourceProfile(sourceProfileString, source) + n := profiles.NewSourceProfileImpl{} + sourceProfile, err := profiles.NewSourceProfile(sourceProfileString, source, &n) if err != nil { return profiles.SourceProfile{}, targetProfile, utils.IOStreams{}, "", err } @@ -92,7 +93,8 @@ func PrepareMigrationPrerequisites(sourceProfileString, targetProfileString, sou defer ioHelper.In.Close() } - dbName, err := utils.GetDatabaseName(sourceProfile.Driver, time.Now()) + getInfo := utils.GetUtilInfoImpl{} + dbName, err := getInfo.GetDatabaseName(sourceProfile.Driver, time.Now()) if err != nil { err = fmt.Errorf("can't generate database name for prefix: %v", err) return sourceProfile, targetProfile, ioHelper, "", err diff --git a/common/utils/utils.go b/common/utils/utils.go index 932da34f11..226afeacd5 100644 --- a/common/utils/utils.go +++ b/common/utils/utils.go @@ -61,6 +61,16 @@ type ManifestTable struct { File_patterns []string `json:"file_patterns"` } +// Interface to fetch spanner details +type GetUtilInfoInterface interface { + GetProject() (string, error) + GetInstance(ctx context.Context, project string, out *os.File) (string, error) + GetPassword() string + GetDatabaseName(driver string, now time.Time) (string, error) +} + +type GetUtilInfoImpl struct{} + // NewIOStreams returns a new IOStreams struct such that input stream is set // to open file descriptor for dumpFile if driver is PGDUMP or MYSQLDUMP. // Input stream defaults to stdin. Output stream is always set to stdout. @@ -175,7 +185,7 @@ func PreloadGCSFiles(tables []ManifestTable) ([]ManifestTable, error) { // GetProject returns the cloud project we should use for accessing Spanner. // Use environment variable GCLOUD_PROJECT if it is set. // Otherwise, use the default project returned from gcloud. -func GetProject() (string, error) { +func (gui *GetUtilInfoImpl) GetProject() (string, error) { project := os.Getenv("GCLOUD_PROJECT") if project != "" { return project, nil @@ -192,7 +202,7 @@ func GetProject() (string, error) { // GetInstance returns the Spanner instance we should use for creating DBs. // If the user specified instance (via flag 'instance') then use that. // Otherwise try to deduce the instance using gcloud. -func GetInstance(ctx context.Context, project string, out *os.File) (string, error) { +func (gui *GetUtilInfoImpl) GetInstance(ctx context.Context, project string, out *os.File) (string, error) { l, err := getInstances(ctx, project) if err != nil { return "", err @@ -238,7 +248,7 @@ func getInstances(ctx context.Context, project string) ([]string, error) { return l, nil } -func GetPassword() string { +func (gui *GetUtilInfoImpl) GetPassword() string { calledFromGCloud := os.Getenv("GCLOUD_HB_PLUGIN") if strings.EqualFold(calledFromGCloud, "true") { fmt.Println("\n Please specify password in enviroment variables (recommended) or --source-profile " + @@ -256,7 +266,7 @@ func GetPassword() string { } // GetDatabaseName generates database name with driver_date prefix. -func GetDatabaseName(driver string, now time.Time) (string, error) { +func (gui *GetUtilInfoImpl) GetDatabaseName(driver string, now time.Time) (string, error) { return GenerateName(fmt.Sprintf("%s_%s", driver, now.Format("2006-01-02"))) } diff --git a/conversion/conversion.go b/conversion/conversion.go index f1cf091dc4..908b1de53a 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -199,7 +199,9 @@ func getInfoSchemaForShard(shardConnInfo profiles.DirectConnectionConfig, driver //while adding other sources, a switch-case will be added here on the basis of the driver input param passed. //pased on the driver name, profiles.NewSourceProfileConnection will need to be called to create //the source profile information. - sourceProfileConnectionMySQL, err := profiles.NewSourceProfileConnectionMySQL(params) + sourceProfileDialect := &profiles.SourceProfileDialectImpl{} + getInfo := utils.GetUtilInfoImpl{} + sourceProfileConnectionMySQL, err := sourceProfileDialect.NewSourceProfileConnectionMySQL(params, &getInfo) if err != nil { return nil, fmt.Errorf("cannot parse connection configuration for the primary shard") } diff --git a/go.mod b/go.mod index 273334f922..37d79c1a05 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/stretchr/objx v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect ) diff --git a/go.sum b/go.sum index a02f8938cd..f1d5ecaf7a 100644 --- a/go.sum +++ b/go.sum @@ -549,6 +549,7 @@ github.com/stathat/consistent v1.0.0 h1:ZFJ1QTRn8npNBKW065raSZ8xfOqhpb8vLOkfp4Cc github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/profiles/common.go b/profiles/common.go index 0f59fc6e2f..e9f0646f7f 100644 --- a/profiles/common.go +++ b/profiles/common.go @@ -115,7 +115,8 @@ func GeneratePGSQLConnectionStr() (string, error) { } password := os.Getenv("PGPASSWORD") if password == "" { - password = utils.GetPassword() + getInfo := utils.GetUtilInfoImpl{} + password = getInfo.GetPassword() } return getPGSQLConnectionStr(server, port, user, password, dbName), nil } @@ -135,7 +136,8 @@ func GenerateMYSQLConnectionStr() (string, error) { } password := os.Getenv("MYSQLPWD") if password == "" { - password = utils.GetPassword() + getInfo := utils.GetUtilInfoImpl{} + password = getInfo.GetPassword() } return getMYSQLConnectionStr(server, port, user, password, dbName), nil } diff --git a/profiles/common_test.go b/profiles/common_test.go new file mode 100644 index 0000000000..41c44dc18a --- /dev/null +++ b/profiles/common_test.go @@ -0,0 +1,212 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package profiles + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// code for testing parse map +func TestParseMap(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + inputString string + expectedParams map[string]string + errorExpected bool + }{ + { + name: "empty params", + inputString: "", + expectedParams: map[string]string{}, + errorExpected: false, + }, + { + name: "valid params=", + inputString: "instance=instance", + expectedParams: map[string]string{"instance": "instance"}, + errorExpected: false, + }, + { + name: "invalid params incorrect format", + inputString: "uuwy", + expectedParams: map[string]string{}, + errorExpected: true, + }, + { + name: "invalid params new line char", + inputString: "uuwy\n hjgse", + expectedParams: map[string]string{}, + errorExpected: true, + }, + { + name: "invalid params duplicates", + inputString: "instance=instance, instance=instance", + expectedParams: map[string]string{"instance": "instance"}, + errorExpected: true, + }, + } + + for _, tc := range testCases { + res, err := ParseMap(tc.inputString) + assert.Equal(t, tc.expectedParams, res, tc.name) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + } +} + + +// code for testing parse list +func TestParseList(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + inputString string + expectedParams []string + errorExpected bool + }{ + { + name: "empty input string", + inputString: "", + expectedParams: nil, + errorExpected: false, + }, + { + name: "valid input string", + inputString: "hello, world", + expectedParams: []string{"hello","world"}, + errorExpected: false, + }, + { + name: "invalid input string new line char", + inputString: "hello, world\n, !", + expectedParams: nil, + errorExpected: true, + }, + } + + for _, tc := range testCases { + res, err := ParseList(tc.inputString) + assert.Equal(t, tc.expectedParams, res, tc.name) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + } +} + +// code for testing sql connection string +func TestGetSQLConnectionStr(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + host := "0.0.0.0" + port := "3306" + user := "user" + pwd := "password" + db := "database" + testCases := []struct { + name string + inputSourceProfileConn SourceProfileConnection + expectedOutput string + }{ + { + name: "source profile connection type mysql", + inputSourceProfileConn: SourceProfileConnection{Ty: SourceProfileConnectionTypeMySQL, Mysql: SourceProfileConnectionMySQL{Host: host, Port: port, User: user, Pwd: pwd, Db: db}}, + expectedOutput: "user:password@tcp(0.0.0.0:3306)/database", + }, + { + name: "source profile connection type postgres", + inputSourceProfileConn: SourceProfileConnection{Ty: SourceProfileConnectionTypePostgreSQL, Pg: SourceProfileConnectionPostgreSQL{Host: host, Port: port, User: user, Pwd: pwd, Db: db}}, + expectedOutput: "host=0.0.0.0 port=3306 user=user password=password dbname=database sslmode=disable", + }, + { + name: "source profile connection type dynamodb", + inputSourceProfileConn: SourceProfileConnection{Ty: SourceProfileConnectionTypeDynamoDB}, + expectedOutput: "", + }, + { + name: "source profile connection type sql server", + inputSourceProfileConn: SourceProfileConnection{Ty: SourceProfileConnectionTypeSqlServer, SqlServer: SourceProfileConnectionSqlServer{Host: host, Port: port, User: user, Pwd: pwd, Db: db}}, + expectedOutput: "sqlserver://user:password@0.0.0.0:3306?database=database", + }, + { + name: "source profile connection type oracle", + inputSourceProfileConn: SourceProfileConnection{Ty: SourceProfileConnectionTypeOracle, Oracle: SourceProfileConnectionOracle{Host: host, Port: port, User: user, Pwd: pwd, Db: db}}, + expectedOutput: "oracle://user:password@0.0.0.0:3306/database", + }, + } + + for _, tc := range testCases { + res:= GetSQLConnectionStr(SourceProfile{Ty: SourceProfileType(SourceProfileTypeConnection), Conn: tc.inputSourceProfileConn}) + assert.Equal(t, tc.expectedOutput, res, tc.name) + } +} + +func TestGenerateConnectionStr(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + before := func(){ + setEnvVariables() + } + + after := func(){ + unsetEnvVariables() + } + testCases := []struct { + name string + expectedOutputPg string + expectedOutputMysql string + errorExpected bool + }{ + { + name: "valid get mysql and postgres conn string", + expectedOutputPg: "host=0.0.0.0 port=3306 user=user password=password dbname=db sslmode=disable", + expectedOutputMysql: "user:password@tcp(0.0.0.0:3306)/db", + errorExpected: false, + }, + } + + for _, tc := range testCases { + before() + res, err:= GeneratePGSQLConnectionStr() + assert.Equal(t, tc.expectedOutputPg, res, tc.name) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + res, err= GenerateMYSQLConnectionStr() + assert.Equal(t, tc.expectedOutputMysql, res, tc.name) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + after() + } +} + +func TestGetSchemaSampleSize(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + inputSourceProfile SourceProfile + expectedOutput int64 + }{ + { + name: "mysql source profile type", + inputSourceProfile: SourceProfile{Ty: SourceProfileType(SourceProfileTypeConnection), Conn: SourceProfileConnection{Ty: SourceProfileConnectionTypeMySQL}}, + expectedOutput: int64(100000), + }, + { + name: "dynamo db source profile type", + inputSourceProfile: SourceProfile{Ty: SourceProfileType(SourceProfileTypeConnection), Conn: SourceProfileConnection{Ty: SourceProfileConnectionTypeDynamoDB, Dydb: SourceProfileConnectionDynamoDB{SchemaSampleSize: int64(5000)}}}, + expectedOutput: int64(5000), + }, + } + + for _, tc := range testCases { + res := GetSchemaSampleSize(tc.inputSourceProfile) + assert.Equal(t, tc.expectedOutput, res, tc.name) + } +} \ No newline at end of file diff --git a/profiles/source_profile.go b/profiles/source_profile.go index a02b02d61f..754d9b6aa1 100644 --- a/profiles/source_profile.go +++ b/profiles/source_profile.go @@ -28,6 +28,7 @@ import ( type SourceProfileType int + const ( SourceProfileTypeUnset = iota SourceProfileTypeFile @@ -42,7 +43,30 @@ type SourceProfileFile struct { Format string } -func NewSourceProfileFile(params map[string]string) SourceProfileFile { +// Interface to create source profiles for different database dialects +type SourceProfileDialectInterface interface { + NewSourceProfileConnectionCloudSQLMySQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionCloudSQLMySQL, error) + NewSourceProfileConnectionMySQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionMySQL, error) + NewSourceProfileConnectionCloudSQLPostgreSQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionCloudSQLPostgreSQL, error) + NewSourceProfileConnectionPostgreSQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionPostgreSQL, error) + NewSourceProfileConnectionSqlServer(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionSqlServer, error) + NewSourceProfileConnectionDynamoDB(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionDynamoDB, error) + NewSourceProfileConnectionOracle(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionOracle, error) +} + +type SourceProfileDialectImpl struct {} + +// Interface to create new source profiles for different input types +type NewSourceProfileInterface interface { + NewSourceProfileFile(params map[string]string) SourceProfileFile + NewSourceProfileConfig(source string, path string) (SourceProfileConfig, error) + NewSourceProfileConnectionCloudSQL(source string, params map[string]string, s SourceProfileDialectInterface) (SourceProfileConnectionCloudSQL, error) + NewSourceProfileConnection(source string, params map[string]string, s SourceProfileDialectInterface) (SourceProfileConnection, error) +} + +type NewSourceProfileImpl struct{} + +func (nsp *NewSourceProfileImpl) NewSourceProfileFile(params map[string]string) SourceProfileFile { profile := SourceProfileFile{} if !filePipedToStdin() { profile.Path = params["file"] @@ -84,7 +108,7 @@ type SourceProfileConnectionCloudSQLMySQL struct { Region string } -func NewSourceProfileConnectionCloudSQLMySQL(params map[string]string) (SourceProfileConnectionCloudSQLMySQL, error) { +func (spd *SourceProfileDialectImpl) NewSourceProfileConnectionCloudSQLMySQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionCloudSQLMySQL, error) { mysql := SourceProfileConnectionCloudSQLMySQL{} user, userOk := params["user"] db, dbOk := params["dbName"] @@ -92,7 +116,7 @@ func NewSourceProfileConnectionCloudSQLMySQL(params map[string]string) (SourcePr project, projectOk := params["project"] var err error if !projectOk { - project, err = utils.GetProject() + project, err = g.GetProject() if err != nil { return mysql, fmt.Errorf("project for cloudsql instance not specified in source-profile, and unable to fetch from gcloud. Please specify project in the source-profile or configure in gcloud") } @@ -118,7 +142,7 @@ type SourceProfileConnectionMySQL struct { StreamingConfig string } -func NewSourceProfileConnectionMySQL(params map[string]string) (SourceProfileConnectionMySQL, error) { +func (spd *SourceProfileDialectImpl) NewSourceProfileConnectionMySQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionMySQL, error) { mysql := SourceProfileConnectionMySQL{} host, hostOk := params["host"] @@ -171,7 +195,7 @@ func NewSourceProfileConnectionMySQL(params map[string]string) (SourceProfileCon mysql.Port = "3306" } if mysql.Pwd == "" { - mysql.Pwd = utils.GetPassword() + mysql.Pwd = g.GetPassword() } return mysql, nil @@ -185,7 +209,7 @@ type SourceProfileConnectionCloudSQLPostgreSQL struct { Region string } -func NewSourceProfileConnectionCloudSQLPostgreSQL(params map[string]string) (SourceProfileConnectionCloudSQLPostgreSQL, error) { +func (spd *SourceProfileDialectImpl) NewSourceProfileConnectionCloudSQLPostgreSQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionCloudSQLPostgreSQL, error) { postgres := SourceProfileConnectionCloudSQLPostgreSQL{} user, userOk := params["user"] db, dbOk := params["dbName"] @@ -193,7 +217,7 @@ func NewSourceProfileConnectionCloudSQLPostgreSQL(params map[string]string) (Sou project, projectOk := params["project"] var err error if !projectOk { - project, err = utils.GetProject() + project, err = g.GetProject() if err != nil { return postgres, fmt.Errorf("project for cloudsql instance not specified in source-profile, and unable to fetch from gcloud. Please specify project in the source-profile or configure in gcloud") } @@ -219,7 +243,7 @@ type SourceProfileConnectionPostgreSQL struct { StreamingConfig string } -func NewSourceProfileConnectionPostgreSQL(params map[string]string) (SourceProfileConnectionPostgreSQL, error) { +func (spd *SourceProfileDialectImpl) NewSourceProfileConnectionPostgreSQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionPostgreSQL, error) { pg := SourceProfileConnectionPostgreSQL{} host, hostOk := params["host"] user, userOk := params["user"] @@ -265,7 +289,7 @@ func NewSourceProfileConnectionPostgreSQL(params map[string]string) (SourceProfi pg.Port = "5432" } if pg.Pwd == "" { - pg.Pwd = utils.GetPassword() + pg.Pwd = g.GetPassword() } return pg, nil @@ -279,7 +303,7 @@ type SourceProfileConnectionSqlServer struct { Pwd string } -func NewSourceProfileConnectionSqlServer(params map[string]string) (SourceProfileConnectionSqlServer, error) { +func (spd *SourceProfileDialectImpl) NewSourceProfileConnectionSqlServer(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionSqlServer, error) { ss := SourceProfileConnectionSqlServer{} host, hostOk := params["host"] user, userOk := params["user"] @@ -331,7 +355,7 @@ func NewSourceProfileConnectionSqlServer(params map[string]string) (SourceProfil // If source profile and env do not have password then get password via prompt. if ss.Pwd == "" { - ss.Pwd = utils.GetPassword() + ss.Pwd = g.GetPassword() } return ss, nil @@ -349,7 +373,7 @@ type SourceProfileConnectionDynamoDB struct { enableStreaming string // Used for confirming streaming migration (valid options: `yes`,`no`,`true`,`false`) } -func NewSourceProfileConnectionDynamoDB(params map[string]string) (SourceProfileConnectionDynamoDB, error) { +func (spd *SourceProfileDialectImpl) NewSourceProfileConnectionDynamoDB(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionDynamoDB, error) { dydb := SourceProfileConnectionDynamoDB{} if schemaSampleSize, ok := params["schema-sample-size"]; ok { schemaSampleSizeInt, err := strconv.Atoi(schemaSampleSize) @@ -396,7 +420,7 @@ type SourceProfileConnectionOracle struct { StreamingConfig string } -func NewSourceProfileConnectionOracle(params map[string]string) (SourceProfileConnectionOracle, error) { +func (spd *SourceProfileDialectImpl) NewSourceProfileConnectionOracle(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionOracle, error) { ss := SourceProfileConnectionOracle{} host, hostOk := params["host"] user, userOk := params["user"] @@ -427,7 +451,7 @@ func NewSourceProfileConnectionOracle(params map[string]string) (SourceProfileCo ss.Port = "1521" } if ss.Pwd == "" { - ss.Pwd = utils.GetPassword() + ss.Pwd = g.GetPassword() } return ss, nil @@ -449,14 +473,14 @@ type SourceProfileConnectionCloudSQL struct { Pg SourceProfileConnectionCloudSQLPostgreSQL } -func NewSourceProfileConnection(source string, params map[string]string) (SourceProfileConnection, error) { +func (nsp *NewSourceProfileImpl) NewSourceProfileConnection(source string, params map[string]string, s SourceProfileDialectInterface) (SourceProfileConnection, error) { conn := SourceProfileConnection{} var err error switch strings.ToLower(source) { case "mysql": { conn.Ty = SourceProfileConnectionTypeMySQL - conn.Mysql, err = NewSourceProfileConnectionMySQL(params) + conn.Mysql, err = s.NewSourceProfileConnectionMySQL(params, &utils.GetUtilInfoImpl{}) if err != nil { return conn, err } @@ -467,7 +491,7 @@ func NewSourceProfileConnection(source string, params map[string]string) (Source case "postgresql", "postgres", "pg": { conn.Ty = SourceProfileConnectionTypePostgreSQL - conn.Pg, err = NewSourceProfileConnectionPostgreSQL(params) + conn.Pg, err = s.NewSourceProfileConnectionPostgreSQL(params, &utils.GetUtilInfoImpl{}) if err != nil { return conn, err } @@ -478,7 +502,7 @@ func NewSourceProfileConnection(source string, params map[string]string) (Source case "dynamodb": { conn.Ty = SourceProfileConnectionTypeDynamoDB - conn.Dydb, err = NewSourceProfileConnectionDynamoDB(params) + conn.Dydb, err = s.NewSourceProfileConnectionDynamoDB(params, &utils.GetUtilInfoImpl{}) if err != nil { return conn, err } @@ -490,7 +514,7 @@ func NewSourceProfileConnection(source string, params map[string]string) (Source case "sqlserver", "mssql": { conn.Ty = SourceProfileConnectionTypeSqlServer - conn.SqlServer, err = NewSourceProfileConnectionSqlServer(params) + conn.SqlServer, err = s.NewSourceProfileConnectionSqlServer(params, &utils.GetUtilInfoImpl{}) if err != nil { return conn, err } @@ -498,7 +522,7 @@ func NewSourceProfileConnection(source string, params map[string]string) (Source case "oracle": { conn.Ty = SourceProfileConnectionTypeOracle - conn.Oracle, err = NewSourceProfileConnectionOracle(params) + conn.Oracle, err = s.NewSourceProfileConnectionOracle(params, &utils.GetUtilInfoImpl{}) if err != nil { return conn, err } @@ -512,14 +536,14 @@ func NewSourceProfileConnection(source string, params map[string]string) (Source return conn, nil } -func NewSourceProfileConnectionCloudSQL(source string, params map[string]string) (SourceProfileConnectionCloudSQL, error) { +func (nsp *NewSourceProfileImpl) NewSourceProfileConnectionCloudSQL(source string, params map[string]string, s SourceProfileDialectInterface) (SourceProfileConnectionCloudSQL, error) { conn := SourceProfileConnectionCloudSQL{} var err error switch strings.ToLower(source) { case "mysql": { conn.Ty = SourceProfileConnectionTypeCloudSQLMySQL - conn.Mysql, err = NewSourceProfileConnectionCloudSQLMySQL(params) + conn.Mysql, err = s.NewSourceProfileConnectionCloudSQLMySQL(params, &utils.GetUtilInfoImpl{}) if err != nil { return conn, err } @@ -527,7 +551,7 @@ func NewSourceProfileConnectionCloudSQL(source string, params map[string]string) case "postgresql", "postgres", "pg": { conn.Ty = SourceProfileConnectionTypeCloudSQLPostgreSQL - conn.Pg, err = NewSourceProfileConnectionCloudSQLPostgreSQL(params) + conn.Pg, err = s.NewSourceProfileConnectionCloudSQLPostgreSQL(params, &utils.GetUtilInfoImpl{}) if err != nil { return conn, err } @@ -618,7 +642,7 @@ type SourceProfileConfig struct { ShardConfigurationDMS ShardConfigurationDMS `json:"shardConfigurationDMS"` } -func NewSourceProfileConfig(source string, path string) (SourceProfileConfig, error) { +func (nsp *NewSourceProfileImpl) NewSourceProfileConfig(source string, path string) (SourceProfileConfig, error) { //given the source, the fact that this 'config=', determine the appropiate object to marshal into switch source { case constants.MYSQL: @@ -752,7 +776,7 @@ func (src SourceProfile) ToLegacyDriver(source string) (string, error) { // from envrironment variables. // // Format 3. Specify a config file that specifies source connection profile. -func NewSourceProfile(s string, source string) (SourceProfile, error) { +func NewSourceProfile(s string, source string, n NewSourceProfileInterface) (SourceProfile, error) { if source == "" { return SourceProfile{}, fmt.Errorf("cannot leave -source flag empty, please specify source databases e.g., -source=postgres etc") } @@ -765,22 +789,23 @@ func NewSourceProfile(s string, source string) (SourceProfile, error) { } if _, ok := params["file"]; ok || filePipedToStdin() { - profile := NewSourceProfileFile(params) + profile := n.NewSourceProfileFile(params) return SourceProfile{Ty: SourceProfileTypeFile, File: profile}, nil } else if format, ok := params["format"]; ok { // File is not passed in from stdin or specified using "file" flag. return SourceProfile{Ty: SourceProfileTypeFile}, fmt.Errorf("file not specified, but format set to %v", format) } else if file, ok := params["config"]; ok { - config, err := NewSourceProfileConfig(strings.ToLower(source), file) + config, err := n.NewSourceProfileConfig(strings.ToLower(source), file) return SourceProfile{Ty: SourceProfileTypeConfig, Config: config}, err } else if _, ok := params["instance"]; ok { - conn, err := NewSourceProfileConnectionCloudSQL(source, params) + conn, err := n.NewSourceProfileConnectionCloudSQL(source, params, &SourceProfileDialectImpl{}) return SourceProfile{Ty: SourceProfileTypeCloudSQL, ConnCloudSQL: conn}, err } else { // Assume connection profile type connection by default, since // connection parameters could be specified as part of environment // variables. - conn, err := NewSourceProfileConnection(source, params) + + conn, err := n.NewSourceProfileConnection(source, params, &SourceProfileDialectImpl{}) return SourceProfile{Ty: SourceProfileTypeConnection, Conn: conn}, err } } diff --git a/profiles/source_profile_test.go b/profiles/source_profile_test.go index 8a88181216..de8d68c6db 100644 --- a/profiles/source_profile_test.go +++ b/profiles/source_profile_test.go @@ -15,12 +15,159 @@ package profiles import ( + "fmt" + "os" "path/filepath" "testing" + "time" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/net/context" ) +type MockSourceProfileDialect struct { + mock.Mock +} + +func (m *MockSourceProfileDialect) NewSourceProfileConnectionCloudSQLMySQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionCloudSQLMySQL, error) { + args := m.Called(params, g) + return args.Get(0).(SourceProfileConnectionCloudSQLMySQL), args.Error(1) +} + +func (m *MockSourceProfileDialect) NewSourceProfileConnectionMySQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionMySQL, error) { + args := m.Called(params, g) + return args.Get(0).(SourceProfileConnectionMySQL), args.Error(1) +} + +func (m *MockSourceProfileDialect) NewSourceProfileConnectionCloudSQLPostgreSQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionCloudSQLPostgreSQL, error) { + args := m.Called(params, g) + return args.Get(0).(SourceProfileConnectionCloudSQLPostgreSQL), args.Error(1) +} + +func (m *MockSourceProfileDialect) NewSourceProfileConnectionPostgreSQL(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionPostgreSQL, error) { + args := m.Called(params, g) + return args.Get(0).(SourceProfileConnectionPostgreSQL), args.Error(1) +} + +func (m *MockSourceProfileDialect) NewSourceProfileConnectionSqlServer(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionSqlServer, error) { + args := m.Called(params, g) + return args.Get(0).(SourceProfileConnectionSqlServer), args.Error(1) +} + +func (m *MockSourceProfileDialect) NewSourceProfileConnectionDynamoDB(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionDynamoDB, error) { + args := m.Called(params, g) + return args.Get(0).(SourceProfileConnectionDynamoDB), args.Error(1) +} + +func (m *MockSourceProfileDialect) NewSourceProfileConnectionOracle(params map[string]string, g utils.GetUtilInfoInterface) (SourceProfileConnectionOracle, error) { + args := m.Called(params, g) + return args.Get(0).(SourceProfileConnectionOracle), args.Error(1) +} + +func setEnvVariables() { + // My Sql variables + os.Setenv("MYSQLHOST", "0.0.0.0") + os.Setenv("MYSQLUSER", "user") + os.Setenv("MYSQLDATABASE", "db") + os.Setenv("MYSQLPORT", "3306") + os.Setenv("MYSQLPWD", "password") + + //PG Variables + os.Setenv("PGHOST", "0.0.0.0") + os.Setenv("PGUSER", "user") + os.Setenv("PGDATABASE", "db") + os.Setenv("PGPORT", "3306") + os.Setenv("PGPASSWORD", "password") + + // My Sql Server Connection + os.Setenv("MSSQL_IP_ADDRESS", "0.0.0.0") + os.Setenv("MSSQL_SA_USER", "user") + os.Setenv("MSSQL_DATABASE", "db") + os.Setenv("MSSQL_TCP_PORT", "3306") + os.Setenv("MSSQL_SA_PASSWORD", "password") +} + +func unsetEnvVariables() { + // My Sql Server Connection + os.Setenv("MSSQL_IP_ADDRESS", "") + os.Setenv("MSSQL_SA_USER", "") + os.Setenv("MSSQL_DATABASE", "") + os.Setenv("MSSQL_TCP_PORT", "") + os.Setenv("MSSQL_SA_PASSWORD", "") + + //PG Variables + os.Setenv("PGHOST", "") + os.Setenv("PGUSER", "") + os.Setenv("PGDATABASE", "") + os.Setenv("PGPORT", "") + os.Setenv("PGPASSWORD", "") + + // My Sql Server Connection + os.Setenv("MSSQL_IP_ADDRESS", "") + os.Setenv("MSSQL_SA_USER", "") + os.Setenv("MSSQL_DATABASE", "") + os.Setenv("MSSQL_TCP_PORT", "") + os.Setenv("MSSQL_SA_PASSWORD", "") +} + +type GetUtilInfoMock struct { + mock.Mock +} + +func (gui *GetUtilInfoMock) GetProject() (string, error) { + args := gui.Called() + return args.Get(0).(string), args.Error(1) + +} + +func (gui *GetUtilInfoMock) GetInstance(ctx context.Context, project string, out *os.File) (string, error) { + args := gui.Called() + return args.Get(0).(string), args.Error(1) +} + +func (gui *GetUtilInfoMock) GetPassword() string { + args := gui.Called() + return args.Get(0).(string) +} + +func (gui *GetUtilInfoMock) GetDatabaseName(driver string, now time.Time) (string, error) { + args := gui.Called() + return args.Get(0).(string), args.Error(1) +} + +func setGetInfoMockValues(g *GetUtilInfoMock) { + g.On("GetDatabaseName", mock.AnythingOfType("string"), mock.AnythingOfType("time.Time")).Return("database-id", nil) + g.On("GetInstance", mock.AnythingOfType("*context.Context"), mock.AnythingOfType("string"), mock.AnythingOfType("*os.File")).Return("instance-id", nil) + g.On("GetPassword").Return("password") +} + +type MockNewSourceProfile struct { + mock.Mock +} + +func (nspm *MockNewSourceProfile) NewSourceProfileFile(params map[string]string) SourceProfileFile { + args := nspm.Called() + return args.Get(0).(SourceProfileFile) +} + +func (nspm *MockNewSourceProfile) NewSourceProfileConfig(source string, path string) (SourceProfileConfig, error) { + args := nspm.Called() + return args.Get(0).(SourceProfileConfig), args.Error(1) +} + +func (nspm *MockNewSourceProfile) NewSourceProfileConnectionCloudSQL(source string, params map[string]string, s SourceProfileDialectInterface) (SourceProfileConnectionCloudSQL, error) { + args := nspm.Called() + return args.Get(0).(SourceProfileConnectionCloudSQL), args.Error(1) +} + +func (nspm *MockNewSourceProfile) NewSourceProfileConnection(source string, params map[string]string, s SourceProfileDialectInterface) (SourceProfileConnection, error) { + args := nspm.Called() + return args.Get(0).(SourceProfileConnection), args.Error(1) +} + func TestNewSourceProfileFile(t *testing.T) { testCases := []struct { name string @@ -64,7 +211,8 @@ func TestNewSourceProfileFile(t *testing.T) { // Override filePipedToStdin with the test value. filePipedToStdin = func() bool { return tc.pipedToStdin } - profile := NewSourceProfileFile(tc.params) + n := NewSourceProfileImpl{} + profile := n.NewSourceProfileFile(tc.params) assert.Equal(t, profile, tc.want, tc.name) } } @@ -130,8 +278,9 @@ func TestNewSourceProfileConfigFile(t *testing.T) { }, } for _, tc := range testCases { - sourceProfileConfig, err := NewSourceProfileConfig(tc.source, tc.path) - assert.Equal(t, tc.errorExpected, err != nil) + n := NewSourceProfileImpl{} + sourceProfileConfig, err := n.NewSourceProfileConfig(tc.source, tc.path) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) tc.validationFn(sourceProfileConfig) } } @@ -178,13 +327,41 @@ func TestNewSourceProfileConnectionSQL(t *testing.T) { params: map[string]string{"host": "a", "user": "b", "dbName": "c", "password": "e"}, errorExpected: false, }, + { + name: "mandatory params provided", + params: map[string]string{"host": "a", "user": "b", "dbName": "c", "password": "e", "streamingCfg": ""}, + errorExpected: true, + }, + { + name: "mandatory params provided", + params: map[string]string{}, + errorExpected: false, + }, + { + name: "empty password", + params: map[string]string{"host": "a", "user": "b", "dbName": "c"}, + errorExpected: false, + }, + } + + before := func() { + setEnvVariables() + } + + after := func() { + unsetEnvVariables() } for _, tc := range testCases { - _, pgErr := NewSourceProfileConnectionPostgreSQL(tc.params) - _, mysqlErr := NewSourceProfileConnectionMySQL(tc.params) - assert.Equal(t, tc.errorExpected, pgErr != nil) - assert.Equal(t, tc.errorExpected, mysqlErr != nil) + before() + sourceProfileDialect := SourceProfileDialectImpl{} + g := GetUtilInfoMock{} + setGetInfoMockValues(&g) + _, pgErr := sourceProfileDialect.NewSourceProfileConnectionPostgreSQL(tc.params, &g) + _, mysqlErr := sourceProfileDialect.NewSourceProfileConnectionMySQL(tc.params, &g) + assert.Equal(t, tc.errorExpected, pgErr != nil, tc.name) + assert.Equal(t, tc.errorExpected, mysqlErr != nil, tc.name) + after() } } @@ -210,11 +387,42 @@ func TestNewSourceProfileConnectionDynamoDB(t *testing.T) { params: map[string]string{"schema-sample-size": "a"}, errorExpected: true, }, + { + name: "valid aws access key id ", + params: map[string]string{"aws-access-key-id": "hdsjg"}, + errorExpected: false, + }, + { + name: "valid aws region", + params: map[string]string{"aws-region": "us-central"}, + errorExpected: false, + }, + { + name: "valid dydb endpoint", + params: map[string]string{"dydb-endpoint": "0.0.0.0"}, + errorExpected: false, + }, + { + name: "enable streaming true", + params: map[string]string{"enableStreaming": "true"}, + errorExpected: false, + }, + { + name: "enable streaming false", + params: map[string]string{"enableStreaming": "false"}, + errorExpected: false, + }, + { + name: "invalid enable streaming", + params: map[string]string{"enableStreaming": "ujeh"}, + errorExpected: true, + }, } for _, tc := range testCases { - _, err := NewSourceProfileConnectionDynamoDB(tc.params) - assert.Equal(t, tc.errorExpected, err != nil) + sourceProfileDialect := SourceProfileDialectImpl{} + _, err := sourceProfileDialect.NewSourceProfileConnectionDynamoDB(tc.params, &GetUtilInfoMock{}) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) } } @@ -278,13 +486,26 @@ func TestNewSourceProfileConnectionSqlServer(t *testing.T) { { name: "No param provided", params: map[string]string{}, - errorExpected: true, + errorExpected: false, }, } + before := func() { + setEnvVariables() + } + + after := func() { + unsetEnvVariables() + } + for _, tc := range testCases { - _, sqlServer := NewSourceProfileConnectionSqlServer(tc.params) - assert.Equal(t, tc.errorExpected, sqlServer != nil) + before() + sourceProfileDialect := SourceProfileDialectImpl{} + g := GetUtilInfoMock{} + setGetInfoMockValues(&g) + _, sqlServer := sourceProfileDialect.NewSourceProfileConnectionSqlServer(tc.params, &g) + assert.Equal(t, tc.errorExpected, sqlServer != nil, tc.name) + after() } } @@ -344,7 +565,561 @@ func TestNewSourceProfileConnectionOracle(t *testing.T) { } for _, tc := range testCases { - _, oracleErr := NewSourceProfileConnectionOracle(tc.params) - assert.Equal(t, tc.errorExpected, oracleErr != nil) + sourceProfileDialect := SourceProfileDialectImpl{} + g := GetUtilInfoMock{} + setGetInfoMockValues(&g) + _, oracleErr := sourceProfileDialect.NewSourceProfileConnectionOracle(tc.params, &g) + assert.Equal(t, tc.errorExpected, oracleErr != nil, tc.name) + } +} + +// code for testing cloud sql mysql connection +func TestNewSourceProfileConnectionCloudSQLMySQL(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + params map[string]string + errorExpected bool + }{ + { + name: "user is blank", + params: map[string]string{"dbName": "b", "instance": "c", "region": "d", "project": "e"}, + errorExpected: true, + }, + { + name: "dbname is blank", + params: map[string]string{"user": "a", "instance": "c", "region": "d", "project": "e"}, + errorExpected: true, + }, + { + name: "instance is blank", + params: map[string]string{"user": "a", "dbName": "b", "region": "d", "project": "e"}, + errorExpected: true, + }, + { + name: "region is blank", + params: map[string]string{"user": "a", "dbName": "b", "instance": "c", "project": "e"}, + errorExpected: true, + }, + { + name: "project is blank and util getProject () returns project successfully", + params: map[string]string{"user": "a", "dbName": "b", "instance": "c", "region": "d"}, + errorExpected: false, + }, + { + name: "project is blank and util getProject () fails", + params: map[string]string{"user": "a", "dbName": "b", "instance": "c", "region": "d"}, + errorExpected: true, + }, + { + name: "test runs successfully", + params: map[string]string{"user": "a", "dbName": "b", "instance": "c", "region": "d", "project": "e"}, + errorExpected: false, + }, + } + + for _, tc := range testCases { + sourceProfileDialect := SourceProfileDialectImpl{} + g := GetUtilInfoMock{} + setGetInfoMockValues(&g) + if tc.name == "project is blank and util getProject () fails" { + g.On("GetProject").Return("", fmt.Errorf("error")) + } else { + g.On("GetProject").Return("project-id", nil) + } + _, mysqlErr := sourceProfileDialect.NewSourceProfileConnectionCloudSQLMySQL(tc.params, &g) + assert.Equal(t, tc.errorExpected, mysqlErr != nil, tc.name) + } +} + +// code for testing postgres sql source connection profile +func TestNewSourceProfileConnectionCloudSQLPostgreSQL(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + params map[string]string + errorExpected bool + }{ + { + name: "user is blank", + params: map[string]string{"dbName": "b", "instance": "c", "region": "d", "project": "e"}, + errorExpected: true, + }, + { + name: "dbname is blank", + params: map[string]string{"user": "a", "instance": "c", "region": "d", "project": "e"}, + errorExpected: true, + }, + { + name: "instance is blank", + params: map[string]string{"user": "a", "dbName": "b", "region": "d", "project": "e"}, + errorExpected: true, + }, + { + name: "region is blank", + params: map[string]string{"user": "a", "dbName": "b", "instance": "c", "project": "e"}, + errorExpected: true, + }, + { + name: "project is blank and util getProject () returns project successfully", + params: map[string]string{"user": "a", "dbName": "b", "instance": "c", "region": "d"}, + errorExpected: false, + }, + { + name: "project is blank and util getProject () fails", + params: map[string]string{"user": "a", "dbName": "b", "instance": "c", "region": "d"}, + errorExpected: true, + }, + { + name: "test runs successfully", + params: map[string]string{"user": "a", "dbName": "b", "instance": "c", "region": "d", "project": "e"}, + errorExpected: false, + }, + } + + for _, tc := range testCases { + sourceProfileDialect := SourceProfileDialectImpl{} + g := GetUtilInfoMock{} + setGetInfoMockValues(&g) + if tc.name == "project is blank and util getProject () fails" { + g.On("GetProject").Return("", fmt.Errorf("error")) + } else { + g.On("GetProject").Return("project-id", nil) + } + _, mysqlErr := sourceProfileDialect.NewSourceProfileConnectionCloudSQLPostgreSQL(tc.params, &g) + assert.Equal(t, tc.errorExpected, mysqlErr != nil, tc.name) + } +} + +// code for testing new source connection profile +func TestNewSourceProfileConnection(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + source string + params map[string]string + function string + returnConnProfile interface{} + errorExpected bool + }{ + { + name: "source mysql", + source: "mysql", + params: map[string]string{}, + function: "NewSourceProfileConnectionMySQL", + returnConnProfile: SourceProfileConnectionMySQL{}, + errorExpected: false, + }, + { + name: "source postgresql", + source: "postgresql", + params: map[string]string{}, + function: "NewSourceProfileConnectionPostgreSQL", + returnConnProfile: SourceProfileConnectionPostgreSQL{}, + errorExpected: false, + }, + { + name: "source dynamodb", + source: "dynamodb", + params: map[string]string{}, + function: "NewSourceProfileConnectionDynamoDB", + returnConnProfile: SourceProfileConnectionDynamoDB{}, + errorExpected: false, + }, + { + name: "source sqlserver", + source: "sqlserver", + params: map[string]string{}, + function: "NewSourceProfileConnectionSqlServer", + returnConnProfile: SourceProfileConnectionSqlServer{}, + errorExpected: false, + }, + { + name: "source oracle", + source: "oracle", + params: map[string]string{}, + function: "NewSourceProfileConnectionOracle", + returnConnProfile: SourceProfileConnectionOracle{}, + errorExpected: false, + }, + { + name: "invalid source", + source: "invalid", + params: map[string]string{}, + function: "", + returnConnProfile: nil, + errorExpected: true, + }, + } + + for _, tc := range testCases { + m := MockSourceProfileDialect{} + m.On(tc.function, mock.Anything, mock.Anything).Return(tc.returnConnProfile, nil) + n := NewSourceProfileImpl{} + _, err := n.NewSourceProfileConnection(tc.source, tc.params, &m) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + if err == nil { + m.AssertExpectations(t) + } + } +} + +// code for testing cloud sql source connection profile +func TestNewSourceProfileConnectionCloudSQL(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + source string + params map[string]string + function string + returnConnProfile interface{} + returnError error + errorExpected bool + }{ + { + name: "source mysql", + source: "mysql", + params: map[string]string{}, + function: "NewSourceProfileConnectionCloudSQLMySQL", + returnConnProfile: SourceProfileConnectionCloudSQLMySQL{}, + returnError: nil, + errorExpected: false, + }, + { + name: "source mysql", + source: "mysql", + params: map[string]string{}, + function: "NewSourceProfileConnectionCloudSQLMySQL", + returnConnProfile: SourceProfileConnectionCloudSQLMySQL{}, + returnError: fmt.Errorf("error"), + errorExpected: true, + }, + { + name: "source postgresql error", + source: "postgresql", + params: map[string]string{}, + function: "NewSourceProfileConnectionCloudSQLPostgreSQL", + returnConnProfile: SourceProfileConnectionCloudSQLPostgreSQL{}, + returnError: nil, + errorExpected: false, + }, + { + name: "source postgres error", + source: "postgresql", + params: map[string]string{}, + function: "NewSourceProfileConnectionCloudSQLPostgreSQL", + returnConnProfile: SourceProfileConnectionCloudSQLPostgreSQL{}, + returnError: fmt.Errorf("error"), + errorExpected: true, + }, + } + + for _, tc := range testCases { + m := MockSourceProfileDialect{} + m.On(tc.function, mock.Anything, mock.Anything).Return(tc.returnConnProfile, tc.returnError) + n := NewSourceProfileImpl{} + _, err := n.NewSourceProfileConnectionCloudSQL(tc.source, tc.params, &m) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + m.AssertExpectations(t) + } +} + +// code for testing csv source profile +func TestNewSourceProfileCsv(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + params map[string]string + returnCsvProfile SourceProfileCsv + }{ + { + name: "default params", + params: map[string]string{"manifest": "manifest.txt"}, + returnCsvProfile: SourceProfileCsv{Manifest: "manifest.txt", Delimiter: ",", NullStr: ""}, + }, + { + name: "override delimiter", + params: map[string]string{"manifest": "manifest.txt", "delimiter": "/"}, + returnCsvProfile: SourceProfileCsv{Manifest: "manifest.txt", Delimiter: "/", NullStr: ""}, + }, + { + name: "override nulltr", + params: map[string]string{"manifest": "manifest.txt", "nullStr": "/n"}, + returnCsvProfile: SourceProfileCsv{Manifest: "manifest.txt", Delimiter: ",", NullStr: "/n"}, + }, + } + + for _, tc := range testCases { + res := NewSourceProfileCsv(tc.params) + assert.Equal(t, res, tc.returnCsvProfile, tc.name) + } +} + +// code to test boolean UseTargetSchema functionality +func TestUseTargetSchema(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + srcDriver string + returnBoolean bool + }{ + { + name: "csv as source driver", + srcDriver: "csv", + returnBoolean: true, + }, + { + name: "not csv as source driver", + srcDriver: "cfg", + returnBoolean: false, + }, + } + for _, tc := range testCases { + src := SourceProfile{ + Driver: tc.srcDriver, + } + res := src.UseTargetSchema() + assert.Equal(t, res, tc.returnBoolean, tc.name) + } +} + +// code to test legacy driver +func TestToLegacyDriver(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + const ( + SourceProfileTypeUnset = iota + SourceProfileTypeFile + SourceProfileTypeConnection + SourceProfileTypeConfig + SourceProfileTypeCsv + SourceProfileTypeCloudSQL + InvalidType + ) + testCases := []struct { + name string + srcDriver SourceProfile + source string + returnConstant string + errorExpected bool + }{ + { + name: "source profile type FILE and source mysql", + srcDriver: SourceProfile{Ty: SourceProfileTypeFile}, + source: "mysql", + returnConstant: constants.MYSQLDUMP, + errorExpected: false, + }, + { + name: "source profile type FILE and source postgresql", + srcDriver: SourceProfile{Ty: SourceProfileTypeFile}, + source: "postgresql", + returnConstant: constants.PGDUMP, + errorExpected: false, + }, + { + name: "source profile type FILE and source dynamodb", + srcDriver: SourceProfile{Ty: SourceProfileTypeFile}, + source: "dynamodb", + returnConstant: "", + errorExpected: true, + }, + { + name: "source profile type FILE and source invalid", + srcDriver: SourceProfile{Ty: SourceProfileTypeFile}, + source: "invalid", + returnConstant: "", + errorExpected: true, + }, + { + name: "source profile type CONNECTION and source mysql", + srcDriver: SourceProfile{Ty: SourceProfileTypeConnection}, + source: "mysql", + returnConstant: constants.MYSQL, + errorExpected: false, + }, + { + name: "source profile type CONNECTION and source postgresql", + srcDriver: SourceProfile{Ty: SourceProfileTypeConnection}, + source: "postgresql", + returnConstant: constants.POSTGRES, + errorExpected: false, + }, + { + name: "source profile type CONNECTION and source dynamodb", + srcDriver: SourceProfile{Ty: SourceProfileTypeConnection}, + source: "dynamodb", + returnConstant: constants.DYNAMODB, + errorExpected: false, + }, + { + name: "source profile type CONNECTION and source mssql", + srcDriver: SourceProfile{Ty: SourceProfileTypeConnection}, + source: "mssql", + returnConstant: constants.SQLSERVER, + errorExpected: false, + }, + { + name: "source profile type CONNECTION and source oracle", + srcDriver: SourceProfile{Ty: SourceProfileTypeConnection}, + source: "oracle", + returnConstant: constants.ORACLE, + errorExpected: false, + }, + { + name: "source profile type CONNECTION and source invalid", + srcDriver: SourceProfile{Ty: SourceProfileTypeConnection}, + source: "invalid", + returnConstant: "", + errorExpected: true, + }, + { + name: "source profile type CLOUD SQL and source mysql", + srcDriver: SourceProfile{Ty: SourceProfileTypeCloudSQL}, + source: "mysql", + returnConstant: constants.MYSQL, + errorExpected: false, + }, + { + name: "source profile type CLOUD SQL and source postgresql", + srcDriver: SourceProfile{Ty: SourceProfileTypeCloudSQL}, + source: "postgresql", + returnConstant: constants.POSTGRES, + errorExpected: false, + }, + { + name: "source profile type CLOUD SQL and source invalid", + srcDriver: SourceProfile{Ty: SourceProfileTypeCloudSQL}, + source: "invalid", + returnConstant: "", + errorExpected: true, + }, + { + name: "source profile type CONFIG and source mysql", + srcDriver: SourceProfile{Ty: SourceProfileTypeConfig}, + source: "mysql", + returnConstant: constants.MYSQL, + errorExpected: false, + }, + { + name: "source profile type CONFIG and source invalid", + srcDriver: SourceProfile{Ty: SourceProfileTypeConfig}, + source: "invalid", + returnConstant: "", + errorExpected: true, + }, + { + name: "source profile type CSV and source mysql", + srcDriver: SourceProfile{Ty: SourceProfileTypeCsv}, + source: "", + returnConstant: constants.CSV, + errorExpected: false, + }, + { + name: "source profile type CONFIG and source invalid", + srcDriver: SourceProfile{Ty: InvalidType}, + source: "", + returnConstant: "", + errorExpected: true, + }, + } + for _, tc := range testCases { + src := tc.srcDriver + res, err := src.ToLegacyDriver(tc.source) + assert.Equal(t, res, tc.returnConstant, tc.name) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + } +} + +// code for testing new source profile +func TestNewSourceProfile(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + const ( + SourceProfileTypeUnset = iota + SourceProfileTypeFile + SourceProfileTypeConnection + SourceProfileTypeConfig + SourceProfileTypeCsv + SourceProfileTypeCloudSQL + ) + testCases := []struct { + name string + params string + source string + function string + mockReturn interface{} + returnTy int + errorExpected bool + }{ + { + name: "source profile for file", + params: "file='file.txt'", + source: "file", + function: "NewSourceProfileFile", + mockReturn: SourceProfileFile{}, + returnTy: SourceProfileTypeFile, + errorExpected: false, + }, + { + name: "invalid source profile for file", + params: "format='some-format'", + source: "file", + function: "", + mockReturn: SourceProfileFile{}, + returnTy: SourceProfileTypeFile, + errorExpected: true, + }, + { + name: "source profile for config", + params: "config='file.cfg'", + source: "cfg", + function: "NewSourceProfileConfig", + mockReturn: SourceProfileConfig{}, + returnTy: SourceProfileTypeConfig, + errorExpected: false, + }, + { + name: "source profile for cloud sql instance", + params: "instance='instance'", + source: "instance", + function: "NewSourceProfileConnectionCloudSQL", + mockReturn: SourceProfileConnectionCloudSQL{}, + returnTy: SourceProfileTypeCloudSQL, + errorExpected: false, + }, + { + name: "source profile for csv", + params: "", + source: "csv", + function: "", + mockReturn: SourceProfile{}, + returnTy: SourceProfileTypeCsv, + errorExpected: false, + }, + { + name: "unset source profile params", + params: "", + source: "source", + function: "NewSourceProfileConnection", + mockReturn: SourceProfileConnection{}, + returnTy: SourceProfileTypeConnection, + errorExpected: false, + }, + { + name: "unset source", + params: "", + source: "", + function: "", + mockReturn: SourceProfile{}, + returnTy: SourceProfileTypeUnset, + errorExpected: true, + }, + } + + for _, tc := range testCases { + n := MockNewSourceProfile{} + n.On(tc.function, mock.Anything, mock.Anything, mock.Anything).Return(tc.mockReturn, nil) + res, err := NewSourceProfile(tc.params, tc.source, &n) + assert.Equal(t, SourceProfileType(tc.returnTy), res.Ty, tc.name) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) } } diff --git a/profiles/target_profile.go b/profiles/target_profile.go index 97c1419d70..394359ab36 100644 --- a/profiles/target_profile.go +++ b/profiles/target_profile.go @@ -65,7 +65,7 @@ func (trg TargetProfile) FetchTargetDialect(ctx context.Context) (string, error) // Ideally we should use the client we create at the beginning, but we can fix that with the refactoring. adminClient, _ := utils.NewDatabaseAdminClient(ctx) // The parameters are irrelevant because the results are already cached when called the first time. - project, instance, dbName, _ := trg.GetResourceIds(ctx, time.Now(), "", nil) + project, instance, dbName, _ := trg.GetResourceIds(ctx, time.Now(), "", nil, &utils.GetUtilInfoImpl{}) result, err := adminClient.GetDatabase(ctx, &adminpb.GetDatabaseRequest{Name: fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, dbName)}) if err != nil { return "", fmt.Errorf("cannot connect to target: %v", err) @@ -73,11 +73,11 @@ func (trg TargetProfile) FetchTargetDialect(ctx context.Context) (string, error) return strings.ToLower(result.DatabaseDialect.String()), nil } -func (targetProfile *TargetProfile) GetResourceIds(ctx context.Context, now time.Time, driverName string, out *os.File) (string, string, string, error) { +func (targetProfile *TargetProfile) GetResourceIds(ctx context.Context, now time.Time, driverName string, out *os.File, g utils.GetUtilInfoInterface) (string, string, string, error) { var err error project := targetProfile.Conn.Sp.Project if project == "" { - project, err = utils.GetProject() + project, err = g.GetProject() if err != nil { return "", "", "", fmt.Errorf("can't get project: %v", err) } @@ -86,7 +86,8 @@ func (targetProfile *TargetProfile) GetResourceIds(ctx context.Context, now time instance := targetProfile.Conn.Sp.Instance if instance == "" { - instance, err = utils.GetInstance(ctx, project, out) + g := utils.GetUtilInfoImpl{} + instance, err = g.GetInstance(ctx, project, out) if err != nil { return "", "", "", fmt.Errorf("can't get instance: %v", err) } @@ -95,7 +96,8 @@ func (targetProfile *TargetProfile) GetResourceIds(ctx context.Context, now time dbName := targetProfile.Conn.Sp.Dbname if dbName == "" { - dbName, err = utils.GetDatabaseName(driverName, now) + g := utils.GetUtilInfoImpl{} + dbName, err = g.GetDatabaseName(driverName, now) if err != nil { return "", "", "", fmt.Errorf("can't get database name: %v", err) } diff --git a/profiles/target_profile_test.go b/profiles/target_profile_test.go new file mode 100644 index 0000000000..33a6692eed --- /dev/null +++ b/profiles/target_profile_test.go @@ -0,0 +1,16 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package profiles + diff --git a/streaming/cleanup.go b/streaming/cleanup.go index 62ef84bf3b..4f782f4b5f 100644 --- a/streaming/cleanup.go +++ b/streaming/cleanup.go @@ -198,8 +198,9 @@ func FetchResources(ctx context.Context, migrationJobId string, resourceType str func GetInstanceDetails(ctx context.Context, targetProfile profiles.TargetProfile) (string, string, error) { var err error project := targetProfile.Conn.Sp.Project + g := utils.GetUtilInfoImpl{} if project == "" { - project, err = utils.GetProject() + project, err = g.GetProject() if err != nil { return "", "", fmt.Errorf("can't get project: %v", err) } @@ -207,7 +208,7 @@ func GetInstanceDetails(ctx context.Context, targetProfile profiles.TargetProfil instance := targetProfile.Conn.Sp.Instance if instance == "" { - instance, err = utils.GetInstance(ctx, project, os.Stdout) + instance, err = g.GetInstance(ctx, project, os.Stdout) if err != nil { return "", "", fmt.Errorf("can't get instance: %v", err) } diff --git a/streaming/store.go b/streaming/store.go index 61ac82291d..f9780f4b19 100644 --- a/streaming/store.go +++ b/streaming/store.go @@ -32,7 +32,7 @@ import ( // PersistJobDetails stores all the metadata associated with a job orchestration for a minimal downtime migration in the metadata db. An example of this metadata is job level data such as the spanner database name. func PersistJobDetails(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile, conv *internal.Conv, migrationJobId string, isSharded bool) (err error) { - project, instance, dbName, err := targetProfile.GetResourceIds(ctx, time.Now(), sourceProfile.Driver, nil) + project, instance, dbName, err := targetProfile.GetResourceIds(ctx, time.Now(), sourceProfile.Driver, nil, &utils.GetUtilInfoImpl{}) if err != nil { err = fmt.Errorf("can't get resource ids: %v", err) return err @@ -59,7 +59,7 @@ func PersistJobDetails(ctx context.Context, targetProfile profiles.TargetProfile func PersistAggregateMonitoringResources(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile, conv *internal.Conv, migrationJobId string) error { logger.Log.Debug(fmt.Sprintf("Storing aggregate monitoring dashboard for migration jobId: %s\n", migrationJobId)) - project, instance, _, err := targetProfile.GetResourceIds(ctx, time.Now(), sourceProfile.Driver, nil) + project, instance, _, err := targetProfile.GetResourceIds(ctx, time.Now(), sourceProfile.Driver, nil, &utils.GetUtilInfoImpl{}) if err != nil { err = fmt.Errorf("can't get resource ids: %v", err) return err @@ -96,7 +96,7 @@ func PersistAggregateMonitoringResources(ctx context.Context, targetProfile prof // PersistResources stores all the metadata associated with a shard orchestration for a minimal downtime migration in the metadata db. An example of this metadata is generated resources. func PersistResources(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile, conv *internal.Conv, migrationJobId string, dataShardId string) (err error) { - project, instance, _, err := targetProfile.GetResourceIds(ctx, time.Now(), sourceProfile.Driver, nil) + project, instance, _, err := targetProfile.GetResourceIds(ctx, time.Now(), sourceProfile.Driver, nil, &utils.GetUtilInfoImpl{}) if err != nil { err = fmt.Errorf("can't get resource ids: %v", err) return err diff --git a/streaming/streaming.go b/streaming/streaming.go index 550ffffc31..475b132d72 100644 --- a/streaming/streaming.go +++ b/streaming/streaming.go @@ -589,7 +589,7 @@ func LaunchStream(ctx context.Context, sourceProfile profiles.SourceProfile, dbL // LaunchDataflowJob populates the parameters from the streaming config and triggers a Dataflow job. func LaunchDataflowJob(ctx context.Context, targetProfile profiles.TargetProfile, streamingCfg StreamingCfg, conv *internal.Conv) (internal.DataflowOutput, error) { - project, instance, dbName, _ := targetProfile.GetResourceIds(ctx, time.Now(), "", nil) + project, instance, dbName, _ := targetProfile.GetResourceIds(ctx, time.Now(), "", nil, &utils.GetUtilInfoImpl{}) dataflowCfg := streamingCfg.DataflowCfg datastreamCfg := streamingCfg.DatastreamCfg diff --git a/testing/dynamodb/snapshot/integration_test.go b/testing/dynamodb/snapshot/integration_test.go index c616edd254..936a78678a 100644 --- a/testing/dynamodb/snapshot/integration_test.go +++ b/testing/dynamodb/snapshot/integration_test.go @@ -215,7 +215,8 @@ func TestIntegration_DYNAMODB_Command(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := utils.GetDatabaseName(constants.DYNAMODB, now) + g := utils.GetUtilInfoImpl{} + dbName, _ := g.GetDatabaseName(constants.DYNAMODB, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName) diff --git a/testing/dynamodb/streaming/integration_test.go b/testing/dynamodb/streaming/integration_test.go index ae169e80bb..9c44ade5f0 100644 --- a/testing/dynamodb/streaming/integration_test.go +++ b/testing/dynamodb/streaming/integration_test.go @@ -323,7 +323,8 @@ func TestIntegration_DYNAMODB_Streaming_Command(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := utils.GetDatabaseName(constants.DYNAMODB, now) + g := utils.GetUtilInfoImpl{} + dbName, _ := g.GetDatabaseName(constants.DYNAMODB, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName) diff --git a/testing/postgres/integration_test.go b/testing/postgres/integration_test.go index a543b85e81..59e73885af 100644 --- a/testing/postgres/integration_test.go +++ b/testing/postgres/integration_test.go @@ -113,7 +113,8 @@ func TestIntegration_PGDUMP_SchemaAndDataSubcommand(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := utils.GetDatabaseName(constants.PGDUMP, now) + g := utils.GetUtilInfoImpl{} + dbName, _ := g.GetDatabaseName(constants.PGDUMP, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) dataFilepath := "../../test_data/pg_dump.test.out" @@ -137,7 +138,8 @@ func TestIntegration_PGDUMP_SchemaSubcommand(t *testing.T) { tmpdir := prepareIntegrationTest(t) defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := utils.GetDatabaseName(constants.PGDUMP, now) + g := utils.GetUtilInfoImpl{} + dbName, _ := g.GetDatabaseName(constants.PGDUMP, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) dataFilepath := "../../test_data/pg_dump.test.out" @@ -159,7 +161,8 @@ func TestIntegration_POSTGRES_SchemaAndDataSubcommand(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := utils.GetDatabaseName(constants.POSTGRES, now) + g := utils.GetUtilInfoImpl{} + dbName, _ := g.GetDatabaseName(constants.POSTGRES, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName) @@ -182,7 +185,8 @@ func TestIntegration_POSTGRES_SchemaSubcommand(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := utils.GetDatabaseName(constants.POSTGRES, now) + g := utils.GetUtilInfoImpl{} + dbName, _ := g.GetDatabaseName(constants.POSTGRES, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName) diff --git a/webv2/utilities/utilities.go b/webv2/utilities/utilities.go index d122741445..b05ac6f80f 100644 --- a/webv2/utilities/utilities.go +++ b/webv2/utilities/utilities.go @@ -308,7 +308,8 @@ func GetFilePrefix(now time.Time) (string, error) { dbName := sessionState.DbName var err error if dbName == "" { - dbName, err = utils.GetDatabaseName(sessionState.Driver, now) + g := utils.GetUtilInfoImpl{} + dbName, err = g.GetDatabaseName(sessionState.Driver, now) if err != nil { return "", fmt.Errorf("Can not create database name : %v", err) } diff --git a/webv2/web.go b/webv2/web.go index 22ed1a40f1..70f987a0f5 100644 --- a/webv2/web.go +++ b/webv2/web.go @@ -595,7 +595,8 @@ func convertSchemaDump(w http.ResponseWriter, r *http.Request) { return } // We don't support Dynamodb in web hence no need to pass schema sample size here. - sourceProfile, _ := profiles.NewSourceProfile("", dc.Config.Driver) + n := profiles.NewSourceProfileImpl{} + sourceProfile, _ := profiles.NewSourceProfile("", dc.Config.Driver, &n) sourceProfile.Driver = dc.Config.Driver conv, err := conversion.SchemaFromDump(sourceProfile.Driver, dc.SpannerDetails.Dialect, &utils.IOStreams{In: f, Out: os.Stdout}) if err != nil { From fec076ceda6a4319957585e262a2624d252b698e Mon Sep 17 00:00:00 2001 From: Vardhan Vinay Thigle <39047439+VardhanThigle@users.noreply.github.com> Date: Mon, 5 Feb 2024 16:23:50 +0530 Subject: [PATCH 03/15] Add note on retries for client API calls in streaming.go (#762) --- streaming/streaming.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/streaming/streaming.go b/streaming/streaming.go index 475b132d72..89ca10c891 100644 --- a/streaming/streaming.go +++ b/streaming/streaming.go @@ -48,6 +48,17 @@ import ( "google.golang.org/grpc/codes" ) +// Uber Note on API retries: +// This file makes a lot of API calls. +// Many of the Google Cloud API calls have out of box retries for a vetted list of errorcodes (typically `UNAVAILABLE`) +// If that's the case for the given call, please add a comment. In case we face an issue in integration test and it is needed +// to add additional retry error codes for that call, please use the gax retry options for the call. +// There are cases where the retry is not out of box. For such calls: +// 1. Read documentation for how to make the call idempotent, for example in some cases, special fields (like RequestId) need to be set. +// 2. Add retry for `UNAVIALBE` followed by any other retriable error you might have seen in testing only if you are sure your call is idempotent. +// It might be good to run the PR at some scale (depnds on each case) to avoid surprises. +// If any Rretry related exploration is not immediately feasible, please do add a TODO comment in the code. + var ( // Default value for max concurrent backfill tasks in Datastream. Datastream resorts to its default value for 0. maxCdcTasks int32 = 5 @@ -243,6 +254,8 @@ func VerifyAndUpdateCfg(streamingCfg *StreamingCfg, dbName string, schemaDetails return fmt.Errorf("failed to create GCS client") } defer client.Close() + // The Get calls for Google Cloud Storage API have out of box retries. + // Reference - https://cloud.google.com/storage/docs/retry-strategy#idempotency-operations bucket := client.Bucket(bucketName) _, err = bucket.Attrs(ctx) if err != nil { @@ -457,11 +470,15 @@ func createPubsubTopicAndSubscription(ctx context.Context, pubsubClient *pubsub. pubsubCfg.TopicId = topicId // Create Topic and Subscription + // CreateTopic has out of box retires + // Ref - https://github.com/googleapis/googleapis/blob/master/google/pubsub/v1/pubsub_grpc_service_config.json topicObj, err := pubsubClient.CreateTopic(ctx, pubsubCfg.TopicId) if err != nil { return pubsubCfg, fmt.Errorf("pubsub topic could not be created: %v", err) } + // CreateSubscription has out of box retires + // Ref - https://github.com/googleapis/googleapis/blob/master/google/pubsub/v1/pubsub_grpc_service_config.json _, err = pubsubClient.CreateSubscription(ctx, pubsubCfg.SubscriptionId, pubsub.SubscriptionConfig{ Topic: topicObj, AckDeadline: time.Minute * 10, @@ -479,11 +496,14 @@ func FetchTargetBucketAndPath(ctx context.Context, datastreamClient *datastream. return "", "", fmt.Errorf("datastream client could not be created") } dstProf := fmt.Sprintf("projects/%s/locations/%s/connectionProfiles/%s", projectID, datastreamDestinationConnCfg.Location, datastreamDestinationConnCfg.Name) + // `GetConnectionProfile` has out of box retries. Ref - https://github.com/googleapis/googleapis/blob/master/google/cloud/datastream/v1/datastream_grpc_service_config.json res, err := datastreamClient.GetConnectionProfile(ctx, &datastreampb.GetConnectionProfileRequest{Name: dstProf}) if err != nil { return "", "", fmt.Errorf("could not get connection profiles: %v", err) } // Fetch the GCS path from the target connection profile. + // The Get calls for Google Cloud Storage API have out of box retries. + // Reference - https://cloud.google.com/storage/docs/retry-strategy#idempotency-operations gcsProfile := res.Profile.(*datastreampb.ConnectionProfile_GcsProfile).GcsProfile bucketName := gcsProfile.Bucket prefix := gcsProfile.RootPath + datastreamDestinationConnCfg.Prefix @@ -499,6 +519,9 @@ func createNotificationOnBucket(ctx context.Context, storageClient *storage.Clie ObjectNamePrefix: prefix, } + // TODO: Explore if there's a way to make this idempotent or retriable. + // The classification for this call is never idempotent + // Ref - https://cloud.google.com/storage/docs/retry-strategy createdNotification, err := storageClient.Bucket(bucketName).AddNotification(ctx, ¬ification) if err != nil { return "", fmt.Errorf("GCS Notification could not be created: %v", err) @@ -721,6 +744,10 @@ func LaunchDataflowJob(ctx context.Context, targetProfile profiles.TargetProfile } fmt.Println("Created flex template request body...") + // LaunchFlexTemplate does not have out of box retries or any direct documentation on how + // to make the call idempotent. + // Ref - https://github.com/googleapis/googleapis/blob/master/google/dataflow/v1beta3/dataflow_grpc_service_config.json + // TODO explore retries. respDf, err := c.LaunchFlexTemplate(ctx, req) if err != nil { fmt.Printf("flexTemplateRequest: %+v\n", req) @@ -830,6 +857,8 @@ func GetProjectNumberResource(ctx context.Context, projectID string) string { return projectID } defer rmClient.Close() + // `GetProjectRequest` has out of box retries. + // Ref - https://github.com/googleapis/googleapis/blob/master/google/cloud/resourcemanager/v3/cloudresourcemanager_v3_grpc_service_config.json req := resourcemanagerpb.GetProjectRequest{Name: projectID} project, err := rmClient.GetProject(ctx, &req) if err != nil { From 60e96b1a0a29aa052e9beaa5ecdbdaf65a8a7bc4 Mon Sep 17 00:00:00 2001 From: Astha Mohta <35952883+asthamohta@users.noreply.github.com> Date: Wed, 7 Feb 2024 13:43:52 +0530 Subject: [PATCH 04/15] test: Conversion refactoring (#759) * tests test * [feat] RR Create API 1: Add dataflow accessor (#745) * Add dataflow accessor * Add enable streaming engine struct tag Mofe Unmarshall Method to acc2 due ot storage dependency * Moved dataflow utils to accessor and creates types.go * Create dataflowutils package * Renamed testing package for dataflow util * Added unit tests * Added empty test files for clients * Move test to same package * Add tests for dataflow client * Update fake for client test * Make dataflow accessor interface and struct to make it testable * Remove interface from accessor package * Add dataflow accessor interface * Add comments to dataflow client and comments on unit tests * Move all dataflow dependencies to accessors and remove dataflow utils * Create dataflow client interface for accessor method to make it unit testable * tests tests * common testing * change * change * changes on comments * change * accessor mysql conn * changes * change * change --------- Co-authored-by: Deep1998 --- cmd/data.go | 3 +- cmd/schema.go | 3 +- cmd/schema_and_data.go | 5 +- cmd/utils.go | 6 +- common/utils/utils.go | 5 +- conversion/conversion.go | 1098 +-------------------- conversion/conversion_from_source.go | 339 +++++++ conversion/conversion_from_source_test.go | 175 ++++ conversion/conversion_helper.go | 212 ++++ conversion/conversion_test.go | 203 ++++ conversion/data_from_database.go | 224 +++++ conversion/get_info.go | 210 ++++ conversion/mocks.go | 83 ++ conversion/snapshot_migration.go | 66 ++ conversion/store_files.go | 274 +++++ conversion/validations.go | 108 ++ sources/common/dbdump.go | 8 +- sources/common/infoschema.go | 37 +- sources/common/mocks.go | 71 ++ sources/common/toddl.go | 16 +- sources/common/utils.go | 10 +- sources/csv/data.go | 14 +- sources/csv/data_test.go | 6 +- sources/dynamodb/schema_test.go | 12 +- sources/dynamodb/toddl_test.go | 6 +- sources/mysql/infoschema.go | 3 +- sources/mysql/infoschema_test.go | 20 +- sources/mysql/toddl_test.go | 9 +- sources/oracle/infoschema.go | 4 +- sources/oracle/infoschema_test.go | 3 +- sources/oracle/toddl_test.go | 6 +- sources/postgres/infoschema.go | 3 +- sources/postgres/infoschema_test.go | 15 +- sources/postgres/toddl_test.go | 6 +- sources/sqlserver/infoschema_test.go | 3 +- sources/sqlserver/toddl_test.go | 6 +- webv2/web.go | 12 +- 37 files changed, 2129 insertions(+), 1155 deletions(-) create mode 100644 conversion/conversion_from_source.go create mode 100644 conversion/conversion_from_source_test.go create mode 100644 conversion/conversion_helper.go create mode 100644 conversion/conversion_test.go create mode 100644 conversion/data_from_database.go create mode 100644 conversion/get_info.go create mode 100644 conversion/mocks.go create mode 100644 conversion/snapshot_migration.go create mode 100644 conversion/store_files.go create mode 100644 conversion/validations.go create mode 100644 sources/common/mocks.go diff --git a/cmd/data.go b/cmd/data.go index 9be772042e..d779efeb5d 100644 --- a/cmd/data.go +++ b/cmd/data.go @@ -156,7 +156,8 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface banner = utils.GetBanner(now, dbURI) } else { conv.Audit.DryRun = true - bw, err = conversion.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, nil, conv, true, cmd.WriteLimit) + convImpl := &conversion.ConvImpl{} + bw, err = convImpl.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, nil, conv, true, cmd.WriteLimit, &conversion.DataFromSourceImpl{}) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbName, err) return subcommands.ExitFailure diff --git a/cmd/schema.go b/cmd/schema.go index ec408bddb1..c89c2923a8 100644 --- a/cmd/schema.go +++ b/cmd/schema.go @@ -112,7 +112,8 @@ func (cmd *SchemaCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfa schemaConversionStartTime := time.Now() var conv *internal.Conv - conv, err = conversion.SchemaConv(sourceProfile, targetProfile, &ioHelper) + convImpl := &conversion.ConvImpl{} + conv, err = convImpl.SchemaConv(sourceProfile, targetProfile, &ioHelper, &conversion.SchemaFromSourceImpl{}) if err != nil { return subcommands.ExitFailure } diff --git a/cmd/schema_and_data.go b/cmd/schema_and_data.go index 1296f6fc55..51e7f30dac 100644 --- a/cmd/schema_and_data.go +++ b/cmd/schema_and_data.go @@ -121,7 +121,8 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... banner string dbURI string ) - conv, err = conversion.SchemaConv(sourceProfile, targetProfile, &ioHelper) + convImpl := &conversion.ConvImpl{} + conv, err = convImpl.SchemaConv(sourceProfile, targetProfile, &ioHelper, &conversion.SchemaFromSourceImpl{}) if err != nil { panic(err) } @@ -152,7 +153,7 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... conv.Audit.DryRun = true schemaCoversionEndTime := time.Now() conv.Audit.SchemaConversionDuration = schemaCoversionEndTime.Sub(schemaConversionStartTime) - bw, err = conversion.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, nil, conv, true, cmd.WriteLimit) + bw, err = convImpl.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, nil, conv, true, cmd.WriteLimit, &conversion.DataFromSourceImpl{}) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbName, err) return subcommands.ExitFailure diff --git a/cmd/utils.go b/cmd/utils.go index 1f25d967d9..d8e4604ae5 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -162,7 +162,8 @@ func migrateData(ctx context.Context, targetProfile profiles.TargetProfile, sour } fmt.Printf("Schema validated successfully for data migration for db %s\n", dbURI) } - bw, err = conversion.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, true, cmd.WriteLimit) + c := &conversion.ConvImpl{} + bw, err = c.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, true, cmd.WriteLimit, &conversion.DataFromSourceImpl{}) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) return nil, err @@ -185,7 +186,8 @@ func migrateSchemaAndData(ctx context.Context, targetProfile profiles.TargetProf return nil, err } conv.Audit.Progress.UpdateProgress("Schema migration complete.", completionPercentage, internal.SchemaMigrationComplete) - bw, err := conversion.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, true, cmd.WriteLimit) + convImpl := &conversion.ConvImpl{} + bw, err := convImpl.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, true, cmd.WriteLimit, &conversion.DataFromSourceImpl{}) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) return nil, err diff --git a/common/utils/utils.go b/common/utils/utils.go index 226afeacd5..956bdc65c4 100644 --- a/common/utils/utils.go +++ b/common/utils/utils.go @@ -504,7 +504,8 @@ func GetLegacyModeSupportedDrivers() []string { // ReadSpannerSchema fills conv by querying Spanner infoschema treating Spanner as both the source and dest. func ReadSpannerSchema(ctx context.Context, conv *internal.Conv, client *sp.Client) error { infoSchema := spanner.InfoSchemaImpl{Client: client, Ctx: ctx, SpDialect: conv.SpDialect} - err := common.ProcessSchema(conv, infoSchema, common.DefaultWorkers, internal.AdditionalSchemaAttributes{IsSharded: false}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, infoSchema, common.DefaultWorkers, internal.AdditionalSchemaAttributes{IsSharded: false}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) if err != nil { return fmt.Errorf("error trying to read and convert spanner schema: %v", err) } @@ -590,7 +591,7 @@ func CompareSchema(sessionFileConv, actualSpannerConv *internal.Conv) error { } else { if sessionColDef.Name != spannerColDef.Name || sessionColDef.T.IsArray != spannerColDef.T.IsArray || sessionColDef.T.Len != spannerColDef.T.Len || sessionColDef.T.Name != spannerColDef.T.Name || sessionColDef.NotNull != spannerColDef.NotNull { - return fmt.Errorf("column detail for table %v don't match: session column: %v, spanner column: %v", sessionTable.Name, sessionColDef, spannerColDef) + return fmt.Errorf("column detail for table %v don't match: session column: %v, spanner column: %v", sessionTable.Name, sessionColDef, spannerColDef) } } } diff --git a/conversion/conversion.go b/conversion/conversion.go index 908b1de53a..b9ec31f5e3 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -26,28 +26,17 @@ package conversion import ( "bufio" "context" - "database/sql" "encoding/base64" "encoding/json" "fmt" - "io" - "io/ioutil" - "net" "os" "strings" "sync" - "sync/atomic" - "syscall" "time" - "cloud.google.com/go/cloudsqlconn" datastream "cloud.google.com/go/datastream/apiv1" sp "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" - spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" - storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" - spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" - storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" @@ -57,22 +46,8 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/csv" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/dynamodb" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/oracle" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/postgres" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/spanner" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/sqlserver" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" - "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - dydb "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodbstreams" - mysqldriver "github.com/go-sql-driver/mysql" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/stdlib" "go.uber.org/zap" adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" "google.golang.org/grpc/metadata" @@ -99,14 +74,20 @@ func getDatastreamClient(ctx context.Context) *datastream.Client { return datastreamClient } +type ConvInterface interface { + SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, schemaFromSource SchemaFromSourceInterface) (*internal.Conv, error) + DataConv(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, writeLimit int64, dataFromSource DataFromSourceInterface) (*writer.BatchWriter, error) +} +type ConvImpl struct {} + // SchemaConv performs the schema conversion // The SourceProfile param provides the connection details to use the go SQL library. -func SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams) (*internal.Conv, error) { +func (ci *ConvImpl) SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, schemaFromSource SchemaFromSourceInterface) (*internal.Conv, error) { switch sourceProfile.Driver { case constants.POSTGRES, constants.MYSQL, constants.DYNAMODB, constants.SQLSERVER, constants.ORACLE: - return schemaFromDatabase(sourceProfile, targetProfile) + return schemaFromSource.schemaFromDatabase(sourceProfile, targetProfile, &GetInfoImpl{}, &common.ProcessSchemaImpl{}) case constants.PGDUMP, constants.MYSQLDUMP: - return SchemaFromDump(sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper) + return schemaFromSource.SchemaFromDump(sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper, &ProcessDumpByDialectImpl{}) default: return nil, fmt.Errorf("schema conversion for driver %s not supported", sourceProfile.Driver) } @@ -114,7 +95,7 @@ func SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.Tar // DataConv performs the data conversion // The SourceProfile param provides the connection details to use the go SQL library. -func DataConv(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, writeLimit int64) (*writer.BatchWriter, error) { +func (ci *ConvImpl) DataConv(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, writeLimit int64, dataFromSource DataFromSourceInterface) (*writer.BatchWriter, error) { config := writer.BatchWriterConfig{ BytesLimit: 100 * 1000 * 1000, WriteLimit: writeLimit, @@ -123,619 +104,19 @@ func DataConv(ctx context.Context, sourceProfile profiles.SourceProfile, targetP } switch sourceProfile.Driver { case constants.POSTGRES, constants.MYSQL, constants.DYNAMODB, constants.SQLSERVER, constants.ORACLE: - return dataFromDatabase(ctx, sourceProfile, targetProfile, config, conv, client) + return dataFromSource.dataFromDatabase(ctx, sourceProfile, targetProfile, config, conv, client, &GetInfoImpl{}, &DataFromDatabaseImpl{}, &SnapshotMigrationImpl{}) case constants.PGDUMP, constants.MYSQLDUMP: if conv.SpSchema.CheckInterleaved() { return nil, fmt.Errorf("spanner migration tool does not currently support data conversion from dump files\nif the schema contains interleaved tables. Suggest using direct access to source database\ni.e. using drivers postgres and mysql") } - return dataFromDump(sourceProfile.Driver, config, ioHelper, client, conv, dataOnly) + return dataFromSource.dataFromDump(sourceProfile.Driver, config, ioHelper, client, conv, dataOnly, &ProcessDumpByDialectImpl{}, &PopulateDataConvImpl{}) case constants.CSV: - return dataFromCSV(ctx, sourceProfile, targetProfile, config, conv, client) + return dataFromSource.dataFromCSV(ctx, sourceProfile, targetProfile, config, conv, client, &PopulateDataConvImpl{}, &csv.CsvImpl{}) default: return nil, fmt.Errorf("data conversion for driver %s not supported", sourceProfile.Driver) } } -func connectionConfig(sourceProfile profiles.SourceProfile) (interface{}, error) { - switch sourceProfile.Driver { - // For PG and MYSQL, When called as part of the subcommand flow, host/user/db etc will - // never be empty as we error out right during source profile creation. If any of them - // are empty, that means this was called through the legacy cmd flow and we create the - // string using env vars. - case constants.POSTGRES: - pgConn := sourceProfile.Conn.Pg - if !(pgConn.Host != "" && pgConn.User != "" && pgConn.Db != "") { - return profiles.GeneratePGSQLConnectionStr() - } else { - return profiles.GetSQLConnectionStr(sourceProfile), nil - } - case constants.MYSQL: - // If empty, this is called as part of the legacy mode witih global CLI flags. - // When using source-profile mode is used, the sqlConnectionStr is already populated. - mysqlConn := sourceProfile.Conn.Mysql - if !(mysqlConn.Host != "" && mysqlConn.User != "" && mysqlConn.Db != "") { - return profiles.GenerateMYSQLConnectionStr() - } else { - return profiles.GetSQLConnectionStr(sourceProfile), nil - } - // For Dynamodb, both legacy and new flows use env vars. - case constants.DYNAMODB: - return getDynamoDBClientConfig() - case constants.SQLSERVER: - return profiles.GetSQLConnectionStr(sourceProfile), nil - case constants.ORACLE: - return profiles.GetSQLConnectionStr(sourceProfile), nil - default: - return "", fmt.Errorf("driver %s not supported", sourceProfile.Driver) - } -} - -func getDbNameFromSQLConnectionStr(driver, sqlConnectionStr string) string { - switch driver { - case constants.POSTGRES: - dbParam := strings.Split(sqlConnectionStr, " ")[4] - return strings.Split(dbParam, "=")[1] - case constants.MYSQL: - return strings.Split(sqlConnectionStr, ")/")[1] - case constants.SQLSERVER: - splts := strings.Split(sqlConnectionStr, "?database=") - return splts[len(splts)-1] - case constants.ORACLE: - // connection string formate : "oracle://user:password@104.108.154.85:1521/XE" - substr := sqlConnectionStr[9:] - dbName := strings.Split(substr, ":")[0] - return dbName - } - return "" -} - -func getInfoSchemaForShard(shardConnInfo profiles.DirectConnectionConfig, driver string, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { - params := make(map[string]string) - params["host"] = shardConnInfo.Host - params["user"] = shardConnInfo.User - params["dbName"] = shardConnInfo.DbName - params["port"] = shardConnInfo.Port - params["password"] = shardConnInfo.Password - //while adding other sources, a switch-case will be added here on the basis of the driver input param passed. - //pased on the driver name, profiles.NewSourceProfileConnection will need to be called to create - //the source profile information. - sourceProfileDialect := &profiles.SourceProfileDialectImpl{} - getInfo := utils.GetUtilInfoImpl{} - sourceProfileConnectionMySQL, err := sourceProfileDialect.NewSourceProfileConnectionMySQL(params, &getInfo) - if err != nil { - return nil, fmt.Errorf("cannot parse connection configuration for the primary shard") - } - sourceProfileConnection := profiles.SourceProfileConnection{Mysql: sourceProfileConnectionMySQL, Ty: profiles.SourceProfileConnectionTypeMySQL} - //create a source profile which contains the sourceProfileConnection object for the primary shard - //this is done because GetSQLConnectionStr() should not be aware of sharding - newSourceProfile := profiles.SourceProfile{Conn: sourceProfileConnection, Ty: profiles.SourceProfileTypeConnection} - newSourceProfile.Driver = driver - infoSchema, err := GetInfoSchema(newSourceProfile, targetProfile) - if err != nil { - return nil, err - } - return infoSchema, nil -} - -func schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (*internal.Conv, error) { - conv := internal.MakeConv() - conv.SpDialect = targetProfile.Conn.Sp.Dialect - //handle fetching schema differently for sharded migrations, we only connect to the primary shard to - //fetch the schema. We reuse the SourceProfileConnection object for this purpose. - var infoSchema common.InfoSchema - var err error - isSharded := false - switch sourceProfile.Ty { - case profiles.SourceProfileTypeConfig: - isSharded = true - //Find Primary Shard Name - if sourceProfile.Config.ConfigType == constants.BULK_MIGRATION { - schemaSource := sourceProfile.Config.ShardConfigurationBulk.SchemaSource - infoSchema, err = getInfoSchemaForShard(schemaSource, sourceProfile.Driver, targetProfile) - if err != nil { - return conv, err - } - } else if sourceProfile.Config.ConfigType == constants.DATAFLOW_MIGRATION { - schemaSource := sourceProfile.Config.ShardConfigurationDataflow.SchemaSource - infoSchema, err = getInfoSchemaForShard(schemaSource, sourceProfile.Driver, targetProfile) - if err != nil { - return conv, err - } - } else if sourceProfile.Config.ConfigType == constants.DMS_MIGRATION { - // TODO: Define the schema processing logic for DMS migrations here. - return conv, fmt.Errorf("dms based migrations are not implemented yet") - } else { - return conv, fmt.Errorf("unknown type of migration, please select one of bulk, dataflow or dms") - } - case profiles.SourceProfileTypeCloudSQL: - infoSchema, err = GetInfoSchemaFromCloudSQL(sourceProfile, targetProfile) - if err != nil { - return conv, err - } - - default: - infoSchema, err = GetInfoSchema(sourceProfile, targetProfile) - if err != nil { - return conv, err - } - } - additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ - IsSharded: isSharded, - } - return conv, common.ProcessSchema(conv, infoSchema, common.DefaultWorkers, additionalSchemaAttributes) -} - -func performSnapshotMigration(config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema, additionalAttributes internal.AdditionalDataAttributes) *writer.BatchWriter { - common.SetRowStats(conv, infoSchema) - totalRows := conv.Rows() - if !conv.Audit.DryRun { - conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) - } - batchWriter := populateDataConv(conv, config, client) - common.ProcessData(conv, infoSchema, additionalAttributes) - batchWriter.Flush() - return batchWriter -} - -func snapshotMigrationHandler(sourceProfile profiles.SourceProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema) (*writer.BatchWriter, error) { - switch sourceProfile.Driver { - // Skip snapshot migration via Spanner migration tool for mysql and oracle since dataflow job will job will handle this from backfilled data. - case constants.MYSQL, constants.ORACLE, constants.POSTGRES: - return &writer.BatchWriter{}, nil - case constants.DYNAMODB: - return performSnapshotMigration(config, conv, client, infoSchema, internal.AdditionalDataAttributes{ShardId: ""}), nil - default: - return &writer.BatchWriter{}, fmt.Errorf("streaming migration not supported for driver %s", sourceProfile.Driver) - } -} - -func updateShardsWithTuningConfigs(shardedTuningConfig profiles.ShardConfigurationDataflow) { - for _, dataShard := range shardedTuningConfig.DataShards { - dataShard.DatastreamConfig = shardedTuningConfig.DatastreamConfig - dataShard.GcsConfig = shardedTuningConfig.GcsConfig - dataShard.DataflowConfig = shardedTuningConfig.DataflowConfig - } -} - -func dataFromDatabase(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client) (*writer.BatchWriter, error) { - //handle migrating data for sharded migrations differently - //sharded migrations are identified via the config= flag, if that flag is not present - //carry on with the existing code path in the else block - switch sourceProfile.Ty { - case profiles.SourceProfileTypeConfig: - ////There are three cases to cover here, bulk migrations and sharded migrations (and later DMS) - //We provide an if-else based handling for each within the sharded code branch - //This will be determined via the configType, which can be "bulk", "dataflow" or "dms" - if sourceProfile.Config.ConfigType == constants.BULK_MIGRATION { - return dataFromDatabaseForBulkMigration(sourceProfile, targetProfile, config, conv, client) - } else if sourceProfile.Config.ConfigType == constants.DATAFLOW_MIGRATION { - return dataFromDatabaseForDataflowMigration(targetProfile, ctx, sourceProfile, conv) - } else if sourceProfile.Config.ConfigType == constants.DMS_MIGRATION { - return dataFromDatabaseForDMSMigration() - } else { - return nil, fmt.Errorf("configType should be one of 'bulk', 'dataflow' or 'dms'") - } - default: - var infoSchema common.InfoSchema - var err error - if sourceProfile.Ty == profiles.SourceProfileTypeCloudSQL { - infoSchema, err = GetInfoSchemaFromCloudSQL(sourceProfile, targetProfile) - if err != nil { - return nil, err - } - } else { - infoSchema, err = GetInfoSchema(sourceProfile, targetProfile) - if err != nil { - return nil, err - } - } - var streamInfo map[string]interface{} - // minimal downtime migration for a single shard - if sourceProfile.Conn.Streaming { - //Generate a job Id - migrationJobId := conv.Audit.MigrationRequestId - logger.Log.Info(fmt.Sprintf("Creating a migration job with id: %v. This jobId can be used in future commmands (such as cleanup) to refer to this job.\n", migrationJobId)) - streamInfo, err = infoSchema.StartChangeDataCapture(ctx, conv) - if err != nil { - return nil, err - } - bw, err := snapshotMigrationHandler(sourceProfile, config, conv, client, infoSchema) - if err != nil { - return nil, err - } - dfOutput, err := infoSchema.StartStreamingMigration(ctx, client, conv, streamInfo) - if err != nil { - return nil, err - } - dfJobId := dfOutput.JobID - gcloudCmd := dfOutput.GCloudCmd - streamingCfg, _ := streamInfo["streamingCfg"].(streaming.StreamingCfg) - // Fetch and store the GCS bucket associated with the datastream - dsClient := getDatastreamClient(ctx) - gcsBucket, gcsDestPrefix, fetchGcsErr := streaming.FetchTargetBucketAndPath(ctx, dsClient, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig) - if fetchGcsErr != nil { - logger.Log.Info("Could not fetch GCS Bucket, hence Monitoring Dashboard will not contain Metrics for the gcs bucket\n") - logger.Log.Debug("Error", zap.Error(fetchGcsErr)) - } - - // Try to apply lifecycle rule to Datastream destination bucket. - gcsConfig := streamingCfg.GcsCfg - sc, err := storageclient.NewStorageClientImpl(ctx) - if err != nil { - return nil, err - } - sa := storageaccessor.StorageAccessorImpl{} - if gcsConfig.TtlInDaysSet { - err = sa.ApplyBucketLifecycleDeleteRule(ctx, sc, storageaccessor.StorageBucketMetadata{ - BucketName: gcsBucket, - Ttl: gcsConfig.TtlInDays, - MatchesPrefix: []string{gcsDestPrefix}, - }) - if err != nil { - logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) - logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") - } - } - - monitoringResources := metrics.MonitoringMetricsResources{ - ProjectId: targetProfile.Conn.Sp.Project, - DataflowJobId: dfOutput.JobID, - DatastreamId: streamingCfg.DatastreamCfg.StreamId, - JobMetadataGcsBucket: gcsBucket, - PubsubSubscriptionId: streamingCfg.PubsubCfg.SubscriptionId, - SpannerInstanceId: targetProfile.Conn.Sp.Instance, - SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, - ShardId: "", - MigrationRequestId: conv.Audit.MigrationRequestId, - } - respDash, dashboardErr := monitoringResources.CreateDataflowShardMonitoringDashboard(ctx) - var dashboardName string - if dashboardErr != nil { - dashboardName = "" - logger.Log.Info("Creation of the monitoring dashboard failed, please create the dashboard manually") - logger.Log.Debug("Error", zap.Error(dashboardErr)) - } else { - dashboardName = strings.Split(respDash.Name, "/")[3] - fmt.Printf("Monitoring Dashboard: %+v\n", dashboardName) - } - // store the generated resources locally in conv, this is used as source of truth for persistence and the UI (should change to persisted values) - streaming.StoreGeneratedResources(conv, streamingCfg, dfJobId, gcloudCmd, targetProfile.Conn.Sp.Project, "", internal.GcsResources{BucketName: gcsBucket}, dashboardName) - //persist job and shard level data in the metadata db - err = streaming.PersistJobDetails(ctx, targetProfile, sourceProfile, conv, migrationJobId, false) - if err != nil { - logger.Log.Info(fmt.Sprintf("Error storing job details in SMT metadata store...the migration job will still continue as intended. %v", err)) - } else { - //only attempt persisting shard level data if the job level data is persisted - err = streaming.PersistResources(ctx, targetProfile, sourceProfile, conv, migrationJobId, constants.DEFAULT_SHARD_ID) - if err != nil { - logger.Log.Info(fmt.Sprintf("Error storing details for migration job: %s, data shard: %s in SMT metadata store...the migration job will still continue as intended. err = %v\n", migrationJobId, constants.DEFAULT_SHARD_ID, err)) - } - } - return bw, nil - } - //bulk migration for a single shard - return performSnapshotMigration(config, conv, client, infoSchema, internal.AdditionalDataAttributes{ShardId: ""}), nil - } -} - -// TODO: Define the data processing logic for DMS migrations here. -func dataFromDatabaseForDMSMigration() (*writer.BatchWriter, error) { - return nil, fmt.Errorf("dms configType is not implemented yet, please use one of 'bulk' or 'dataflow'") -} - -// 1. Create batch for each physical shard -// 2. Create streaming cfg from the config source type. -// 3. Verify the CFG and update it with SMT defaults -// 4. Launch the stream for the physical shard -// 5. Perform streaming migration via dataflow -func dataFromDatabaseForDataflowMigration(targetProfile profiles.TargetProfile, ctx context.Context, sourceProfile profiles.SourceProfile, conv *internal.Conv) (*writer.BatchWriter, error) { - updateShardsWithTuningConfigs(sourceProfile.Config.ShardConfigurationDataflow) - //Generate a job Id - migrationJobId := conv.Audit.MigrationRequestId - fmt.Printf("Creating a migration job with id: %v. This jobId can be used in future commmands (such as cleanup) to refer to this job.\n", migrationJobId) - conv.Audit.StreamingStats.ShardToShardResourcesMap = make(map[string]internal.ShardResources) - schemaDetails, err := common.GetIncludedSrcTablesFromConv(conv) - if err != nil { - fmt.Printf("unable to determine tableList from schema, falling back to full database") - schemaDetails = map[string]internal.SchemaDetails{} - } - err = streaming.PersistJobDetails(ctx, targetProfile, sourceProfile, conv, migrationJobId, true) - if err != nil { - logger.Log.Info(fmt.Sprintf("Error storing job details in SMT metadata store...the migration job will still continue as intended. %v", err)) - } - asyncProcessShards := func(p *profiles.DataShard, mutex *sync.Mutex) common.TaskResult[*profiles.DataShard] { - dbNameToShardIdMap := make(map[string]string) - for _, l := range p.LogicalShards { - dbNameToShardIdMap[l.DbName] = l.LogicalShardId - } - if p.DataShardId == "" { - dataShardId, err := utils.GenerateName("smt-datashard") - dataShardId = strings.Replace(dataShardId, "_", "-", -1) - if err != nil { - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - p.DataShardId = dataShardId - fmt.Printf("Data shard id generated: %v\n", p.DataShardId) - } - streamingCfg := streaming.CreateStreamingConfig(*p) - err := streaming.VerifyAndUpdateCfg(&streamingCfg, targetProfile.Conn.Sp.Dbname, schemaDetails) - if err != nil { - err = fmt.Errorf("failed to process shard: %s, there seems to be an error in the sharding configuration, error: %v", p.DataShardId, err) - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - fmt.Printf("Initiating migration for shard: %v\n", p.DataShardId) - pubsubCfg, err := streaming.CreatePubsubResources(ctx, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig, targetProfile.Conn.Sp.Dbname) - if err != nil { - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - streamingCfg.PubsubCfg = *pubsubCfg - err = streaming.LaunchStream(ctx, sourceProfile, p.LogicalShards, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg) - if err != nil { - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - streamingCfg.DataflowCfg.DbNameToShardIdMap = dbNameToShardIdMap - dfOutput, err := streaming.StartDataflow(ctx, targetProfile, streamingCfg, conv) - if err != nil { - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - // store the generated resources locally in conv, this is used as source of truth for persistence and the UI (should change to persisted values) - - // Fetch and store the GCS bucket associated with the datastream - dsClient := getDatastreamClient(ctx) - gcsBucket, gcsDestPrefix, fetchGcsErr := streaming.FetchTargetBucketAndPath(ctx, dsClient, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig) - if fetchGcsErr != nil { - logger.Log.Info(fmt.Sprintf("Could not fetch GCS Bucket for Shard %s hence Monitoring Dashboard will not contain Metrics for the gcs bucket\n", p.DataShardId)) - logger.Log.Debug("Error", zap.Error(fetchGcsErr)) - } - - // Try to apply lifecycle rule to Datastream destination bucket. - gcsConfig := streamingCfg.GcsCfg - sc, err := storageclient.NewStorageClientImpl(ctx) - if err != nil { - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - sa := storageaccessor.StorageAccessorImpl{} - if gcsConfig.TtlInDaysSet { - err = sa.ApplyBucketLifecycleDeleteRule(ctx, sc, storageaccessor.StorageBucketMetadata{ - BucketName: gcsBucket, - Ttl: gcsConfig.TtlInDays, - MatchesPrefix: []string{gcsDestPrefix}, - }) - if err != nil { - logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) - logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") - } - } - - // create monitoring dashboard for a single shard - monitoringResources := metrics.MonitoringMetricsResources{ - ProjectId: targetProfile.Conn.Sp.Project, - DataflowJobId: dfOutput.JobID, - DatastreamId: streamingCfg.DatastreamCfg.StreamId, - JobMetadataGcsBucket: gcsBucket, - PubsubSubscriptionId: streamingCfg.PubsubCfg.SubscriptionId, - SpannerInstanceId: targetProfile.Conn.Sp.Instance, - SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, - ShardId: p.DataShardId, - MigrationRequestId: conv.Audit.MigrationRequestId, - } - respDash, dashboardErr := monitoringResources.CreateDataflowShardMonitoringDashboard(ctx) - var dashboardName string - if dashboardErr != nil { - dashboardName = "" - logger.Log.Info(fmt.Sprintf("Creation of the monitoring dashboard for shard %s failed, please create the dashboard manually\n", p.DataShardId)) - logger.Log.Debug("Error", zap.Error(dashboardErr)) - } else { - dashboardName = strings.Split(respDash.Name, "/")[3] - fmt.Printf("Monitoring Dashboard for shard %v: %+v\n", p.DataShardId, dashboardName) - } - streaming.StoreGeneratedResources(conv, streamingCfg, dfOutput.JobID, dfOutput.GCloudCmd, targetProfile.Conn.Sp.Project, p.DataShardId, internal.GcsResources{BucketName: gcsBucket}, dashboardName) - //persist the generated resources in a metadata db - err = streaming.PersistResources(ctx, targetProfile, sourceProfile, conv, migrationJobId, p.DataShardId) - if err != nil { - fmt.Printf("Error storing generated resources in SMT metadata store for dataShardId: %s...the migration job will still continue as intended, error: %v\n", p.DataShardId, err) - } - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - _, err = common.RunParallelTasks(sourceProfile.Config.ShardConfigurationDataflow.DataShards, 20, asyncProcessShards, true) - if err != nil { - return nil, fmt.Errorf("unable to start minimal downtime migrations: %v", err) - } - - // create monitoring aggregated dashboard for sharded migration - aggMonitoringResources := metrics.MonitoringMetricsResources{ - ProjectId: targetProfile.Conn.Sp.Project, - SpannerInstanceId: targetProfile.Conn.Sp.Instance, - SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, - ShardToShardResourcesMap: conv.Audit.StreamingStats.ShardToShardResourcesMap, - MigrationRequestId: conv.Audit.MigrationRequestId, - } - aggRespDash, dashboardErr := aggMonitoringResources.CreateDataflowAggMonitoringDashboard(ctx) - if dashboardErr != nil { - logger.Log.Error(fmt.Sprintf("Creation of the aggregated monitoring dashboard failed, please create the dashboard manually\n error=%v\n", dashboardErr)) - } else { - fmt.Printf("Aggregated Monitoring Dashboard: %+v\n", strings.Split(aggRespDash.Name, "/")[3]) - conv.Audit.StreamingStats.AggMonitoringResources = internal.MonitoringResources{DashboardName: strings.Split(aggRespDash.Name, "/")[3]} - } - err = streaming.PersistAggregateMonitoringResources(ctx, targetProfile, sourceProfile, conv, migrationJobId) - if err != nil { - logger.Log.Info(fmt.Sprintf("Unable to store aggregated monitoring dashboard in metadata database\n error=%v\n", err)) - } else { - logger.Log.Debug("Aggregate monitoring resources stored successfully.\n") - } - return &writer.BatchWriter{}, nil -} - -// 1. Migrate the data from the data shards, the schema shard needs to be specified here again. -// 2. Create a connection profile object for it -// 3. Perform a snapshot migration for the shard -// 4. Once all shard migrations are complete, return the batch writer object -func dataFromDatabaseForBulkMigration(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client) (*writer.BatchWriter, error) { - var bw *writer.BatchWriter - for _, dataShard := range sourceProfile.Config.ShardConfigurationBulk.DataShards { - - fmt.Printf("Initiating migration for shard: %v\n", dataShard.DbName) - infoSchema, err := getInfoSchemaForShard(dataShard, sourceProfile.Driver, targetProfile) - if err != nil { - return nil, err - } - additionalDataAttributes := internal.AdditionalDataAttributes{ - ShardId: dataShard.DataShardId, - } - bw = performSnapshotMigration(config, conv, client, infoSchema, additionalDataAttributes) - } - - return bw, nil -} - -func getDynamoDBClientConfig() (*aws.Config, error) { - cfg := aws.Config{} - endpointOverride := os.Getenv("DYNAMODB_ENDPOINT_OVERRIDE") - if endpointOverride != "" { - cfg.Endpoint = aws.String(endpointOverride) - } - return &cfg, nil -} - -func SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams) (*internal.Conv, error) { - f, n, err := getSeekable(ioHelper.In) - if err != nil { - utils.PrintSeekError(driver, err, ioHelper.Out) - return nil, fmt.Errorf("can't get seekable input file") - } - ioHelper.SeekableIn = f - ioHelper.BytesRead = n - conv := internal.MakeConv() - conv.SpDialect = spDialect - p := internal.NewProgress(n, "Generating schema", internal.Verbose(), false, int(internal.SchemaCreationInProgress)) - r := internal.NewReader(bufio.NewReader(f), p) - conv.SetSchemaMode() // Build schema and ignore data in dump. - conv.SetDataSink(nil) - err = ProcessDump(driver, conv, r) - if err != nil { - fmt.Fprintf(ioHelper.Out, "Failed to parse the data file: %v", err) - return nil, fmt.Errorf("failed to parse the data file") - } - p.Done() - return conv, nil -} - -func dataFromDump(driver string, config writer.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool) (*writer.BatchWriter, error) { - // TODO: refactor of the way we handle getSeekable - // to avoid the code duplication here - if !dataOnly { - _, err := ioHelper.SeekableIn.Seek(0, 0) - if err != nil { - fmt.Printf("\nCan't seek to start of file (preparation for second pass): %v\n", err) - return nil, fmt.Errorf("can't seek to start of file") - } - } else { - // Note: input file is kept seekable to plan for future - // changes in showing progress for data migration. - f, n, err := getSeekable(ioHelper.In) - if err != nil { - utils.PrintSeekError(driver, err, ioHelper.Out) - return nil, fmt.Errorf("can't get seekable input file") - } - ioHelper.SeekableIn = f - ioHelper.BytesRead = n - } - totalRows := conv.Rows() - - conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) - r := internal.NewReader(bufio.NewReader(ioHelper.SeekableIn), nil) - batchWriter := populateDataConv(conv, config, client) - ProcessDump(driver, conv, r) - batchWriter.Flush() - conv.Audit.Progress.Done() - - return batchWriter, nil -} - -func dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client) (*writer.BatchWriter, error) { - if targetProfile.Conn.Sp.Dbname == "" { - return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source") - } - conv.SpDialect = targetProfile.Conn.Sp.Dialect - dialect, err := targetProfile.FetchTargetDialect(ctx) - if err != nil { - return nil, fmt.Errorf("could not fetch dialect: %v", err) - } - if strings.ToLower(dialect) != constants.DIALECT_POSTGRESQL { - dialect = constants.DIALECT_GOOGLESQL - } - - if dialect != conv.SpDialect { - return nil, fmt.Errorf("dialect specified in target profile does not match spanner dialect") - } - - delimiterStr := sourceProfile.Csv.Delimiter - if len(delimiterStr) != 1 { - return nil, fmt.Errorf("delimiter should only be a single character long, found '%s'", delimiterStr) - } - - delimiter := rune(delimiterStr[0]) - - err = utils.ReadSpannerSchema(ctx, conv, client) - if err != nil { - return nil, fmt.Errorf("error trying to read and convert spanner schema: %v", err) - } - - tables, err := csv.GetCSVFiles(conv, sourceProfile) - if err != nil { - return nil, fmt.Errorf("error finding csv files: %v", err) - } - - // Find the number of rows in each csv file for generating stats. - err = csv.SetRowStats(conv, tables, delimiter) - if err != nil { - return nil, err - } - - totalRows := conv.Rows() - conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) - batchWriter := populateDataConv(conv, config, client) - err = csv.ProcessCSV(conv, tables, sourceProfile.Csv.NullStr, delimiter) - if err != nil { - return nil, fmt.Errorf("can't process csv: %v", err) - } - batchWriter.Flush() - conv.Audit.Progress.Done() - return batchWriter, nil -} - -func populateDataConv(conv *internal.Conv, config writer.BatchWriterConfig, client *sp.Client) *writer.BatchWriter { - rows := int64(0) - config.Write = func(m []*sp.Mutation) error { - ctx := context.Background() - if !conv.Audit.SkipMetricsPopulation { - migrationData := metrics.GetMigrationData(conv, "", constants.DataConv) - serializedMigrationData, _ := proto.Marshal(migrationData) - migrationMetadataValue := base64.StdEncoding.EncodeToString(serializedMigrationData) - ctx = metadata.AppendToOutgoingContext(context.Background(), constants.MigrationMetadataKey, migrationMetadataValue) - } - _, err := client.Apply(ctx, m) - if err != nil { - return err - } - atomic.AddInt64(&rows, int64(len(m))) - conv.Audit.Progress.MaybeReport(atomic.LoadInt64(&rows)) - return nil - } - batchWriter := writer.NewBatchWriter(config) - conv.SetDataMode() - if !conv.Audit.DryRun { - conv.SetDataSink( - func(table string, cols []string, vals []interface{}) { - batchWriter.AddRow(table, cols, vals) - }) - conv.DataFlush = func() { - batchWriter.Flush() - } - } - - return batchWriter -} // Report generates a report of schema and data conversion. func Report(driver string, badWrites map[string]int64, BytesRead int64, banner string, conv *internal.Conv, reportFileName string, dbName string, out *os.File) { @@ -788,88 +169,6 @@ func Report(driver string, badWrites map[string]int64, BytesRead int64, banner s } } -// getSeekable returns a seekable file (with same content as f) and the size of the content (in bytes). -func getSeekable(f *os.File) (*os.File, int64, error) { - _, err := f.Seek(0, 0) - if err == nil { // Stdin is seekable, let's just use that. This happens when you run 'cmd < file'. - n, err := utils.GetFileSize(f) - return f, n, err - } - internal.VerbosePrintln("Creating a tmp file with a copy of stdin because stdin is not seekable.") - logger.Log.Debug("Creating a tmp file with a copy of stdin because stdin is not seekable.") - - // Create file in os.TempDir. Its not clear this is a good idea e.g. if the - // pg_dump/mysqldump output is large (tens of GBs) and os.TempDir points to a directory - // (such as /tmp) that's configured with a small amount of disk space. - // To workaround such limits on Unix, set $TMPDIR to a directory with lots - // of disk space. - fcopy, err := ioutil.TempFile("", "spanner-migration-tool.data") - if err != nil { - return nil, 0, err - } - syscall.Unlink(fcopy.Name()) // File will be deleted when this process exits. - _, err = io.Copy(fcopy, f) - if err != nil { - return nil, 0, fmt.Errorf("can't write stdin to tmp file: %w", err) - } - _, err = fcopy.Seek(0, 0) - if err != nil { - return nil, 0, fmt.Errorf("can't reset file offset: %w", err) - } - n, _ := utils.GetFileSize(fcopy) - return fcopy, n, nil -} - -// VerifyDb checks whether the db exists and if it does, verifies if the schema is what we currently support. -func VerifyDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (dbExists bool, err error) { - adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) - if err != nil { - return dbExists, err - } - spA := spanneraccessor.SpannerAccessorImpl{} - dbExists, err = spA.CheckExistingDb(ctx, adminClientImpl, dbURI) - if err != nil { - return dbExists, err - } - if dbExists { - err = ValidateDDL(ctx, adminClient, dbURI) - } - return dbExists, err -} - -// ValidateTables validates that all the tables in the database are empty. -// It returns the name of the first non-empty table if found, and an empty string otherwise. -func ValidateTables(ctx context.Context, client *sp.Client, spDialect string) (string, error) { - infoSchema := spanner.InfoSchemaImpl{Client: client, Ctx: ctx, SpDialect: spDialect} - tables, err := infoSchema.GetTables() - if err != nil { - return "", err - } - for _, table := range tables { - count, err := infoSchema.GetRowCount(table) - if err != nil { - return "", err - } - if count != 0 { - return table.Name, nil - } - } - return "", nil -} - -// ValidateDDL verifies if an existing DB's ddl follows what is supported by Spanner migration tool. Currently, -// we only support empty schema when db already exists. -func ValidateDDL(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) error { - dbDdl, err := adminClient.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{Database: dbURI}) - if err != nil { - return fmt.Errorf("can't fetch database ddl: %v", err) - } - if len(dbDdl.Statements) != 0 { - return fmt.Errorf("spanner migration tool supports writing to existing databases only if they have an empty schema") - } - return nil -} - // CreatesOrUpdatesDatabase updates an existing Spanner database or creates a new one if one does not exist. func CreateOrUpdateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI, driver string, conv *internal.Conv, out *os.File, migrationType string) error { dbExists, err := VerifyDb(ctx, adminClient, dbURI) @@ -1056,374 +355,3 @@ admin quota limit by spreading the FK creation requests over time.`) conv.Audit.Progress.Done() return nil } - -// WriteSchemaFile writes DDL statements in a file. It includes CREATE TABLE -// statements and ALTER TABLE statements to add foreign keys. -// The parameter name should end with a .txt. -func WriteSchemaFile(conv *internal.Conv, now time.Time, name string, out *os.File, driver string) { - f, err := os.Create(name) - if err != nil { - fmt.Fprintf(out, "Can't create schema file %s: %v\n", name, err) - return - } - - // The schema file we write out below is optimized for reading. It includes comments, foreign keys - // and doesn't add backticks around table and column names. This file is - // intended for explanatory and documentation purposes, and is not strictly - // legal Cloud Spanner DDL (Cloud Spanner doesn't currently support comments). - spDDL := conv.SpSchema.GetDDL(ddl.Config{Comments: true, ProtectIds: false, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) - if len(spDDL) == 0 { - spDDL = []string{"\n-- Schema is empty -- no tables found\n"} - } - l := []string{ - fmt.Sprintf("-- Schema generated %s\n", now.Format("2006-01-02 15:04:05")), - strings.Join(spDDL, ";\n\n"), - "\n", - } - if _, err := f.WriteString(strings.Join(l, "")); err != nil { - fmt.Fprintf(out, "Can't write out schema file: %v\n", err) - return - } - fmt.Fprintf(out, "Wrote schema to file '%s'.\n", name) - - // Convert . to .ddl.. - nameSplit := strings.Split(name, ".") - nameSplit = append(nameSplit[:len(nameSplit)-1], "ddl", nameSplit[len(nameSplit)-1]) - name = strings.Join(nameSplit, ".") - f, err = os.Create(name) - if err != nil { - fmt.Fprintf(out, "Can't create legal schema ddl file %s: %v\n", name, err) - return - } - - // We change 'Comments' to false and 'ProtectIds' to true below to write out a - // schema file that is a legal Cloud Spanner DDL. - spDDL = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) - if len(spDDL) == 0 { - spDDL = []string{"\n-- Schema is empty -- no tables found\n"} - } - l = []string{ - strings.Join(spDDL, ";\n\n"), - "\n", - } - if _, err = f.WriteString(strings.Join(l, "")); err != nil { - fmt.Fprintf(out, "Can't write out legal schema ddl file: %v\n", err) - return - } - fmt.Fprintf(out, "Wrote legal schema ddl to file '%s'.\n", name) -} - -// WriteSessionFile writes conv struct to a file in JSON format. -func WriteSessionFile(conv *internal.Conv, name string, out *os.File) { - f, err := os.Create(name) - if err != nil { - fmt.Fprintf(out, "Can't create session file %s: %v\n", name, err) - return - } - // Session file will basically contain 'conv' struct in JSON format. - // It contains all the information for schema and data conversion state. - convJSON, err := json.MarshalIndent(conv, "", " ") - if err != nil { - fmt.Fprintf(out, "Can't encode session state to JSON: %v\n", err) - return - } - if _, err := f.Write(convJSON); err != nil { - fmt.Fprintf(out, "Can't write out session file: %v\n", err) - return - } - fmt.Fprintf(out, "Wrote session to file '%s'.\n", name) -} - -// WriteConvGeneratedFiles creates a directory labeled downloads with the current timestamp -// where it writes the sessionfile, report summary and DDLs then returns the directory where it writes. -func WriteConvGeneratedFiles(conv *internal.Conv, dbName string, driver string, BytesRead int64, out *os.File) (string, error) { - now := time.Now() - dirPath := "spanner_migration_tool_output/" + dbName + "/" - err := os.MkdirAll(dirPath, os.ModePerm) - if err != nil { - fmt.Fprintf(out, "Can't create directory %s: %v\n", dirPath, err) - return "", err - } - schemaFileName := dirPath + dbName + "_schema.txt" - WriteSchemaFile(conv, now, schemaFileName, out, driver) - reportFileName := dirPath + dbName - Report(driver, nil, BytesRead, "", conv, reportFileName, dbName, out) - sessionFileName := dirPath + dbName + ".session.json" - WriteSessionFile(conv, sessionFileName, out) - return dirPath, nil -} - -// ReadSessionFile reads a session JSON file and -// unmarshal it's content into *internal.Conv. -func ReadSessionFile(conv *internal.Conv, sessionJSON string) error { - s, err := ioutil.ReadFile(sessionJSON) - if err != nil { - return err - } - err = json.Unmarshal(s, &conv) - if err != nil { - return err - } - return nil -} - -// WriteBadData prints summary stats about bad rows and writes detailed info -// to file 'name'. -func WriteBadData(bw *writer.BatchWriter, conv *internal.Conv, banner, name string, out *os.File) { - badConversions := conv.BadRows() - badWrites := utils.SumMapValues(bw.DroppedRowsByTable()) - - badDataStreaming := int64(0) - if conv.Audit.StreamingStats.Streaming { - badDataStreaming = getBadStreamingDataCount(conv) - } - - if badConversions == 0 && badWrites == 0 && badDataStreaming == 0 { - os.Remove(name) // Cleanup bad-data file from previous run. - return - } - f, err := os.Create(name) - if err != nil { - fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) - return - } - f.WriteString(banner) - maxRows := 100 - if badConversions > 0 { - l := conv.SampleBadRows(maxRows) - if int64(len(l)) < badConversions { - f.WriteString("A sample of rows that generated conversion errors:\n") - } else { - f.WriteString("Rows that generated conversion errors:\n") - } - for _, r := range l { - _, err := f.WriteString(" " + r + "\n") - if err != nil { - fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) - return - } - } - } - if badWrites > 0 { - l := bw.SampleBadRows(maxRows) - if int64(len(l)) < badWrites { - f.WriteString("A sample of rows that successfully converted but couldn't be written to Spanner:\n") - } else { - f.WriteString("Rows that successfully converted but couldn't be written to Spanner:\n") - } - for _, r := range l { - _, err := f.WriteString(" " + r + "\n") - if err != nil { - fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) - return - } - } - } - if badDataStreaming > 0 { - err = writeBadStreamingData(conv, f) - if err != nil { - fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) - return - } - } - - fmt.Fprintf(out, "See file '%s' for details of bad rows\n", name) -} - -// getBadStreamingDataCount returns the total sum of bad and dropped records during -// streaming migration process. -func getBadStreamingDataCount(conv *internal.Conv) int64 { - badDataCount := int64(0) - - for _, x := range conv.Audit.StreamingStats.BadRecords { - badDataCount += utils.SumMapValues(x) - } - for _, x := range conv.Audit.StreamingStats.DroppedRecords { - badDataCount += utils.SumMapValues(x) - } - return badDataCount -} - -// writeBadStreamingData writes sample of bad records and dropped records during streaming -// migration process to bad data file. -func writeBadStreamingData(conv *internal.Conv, f *os.File) error { - f.WriteString("\nBad data encountered during streaming migration:\n\n") - - stats := (conv.Audit.StreamingStats) - - badRecords := int64(0) - for _, x := range stats.BadRecords { - badRecords += utils.SumMapValues(x) - } - droppedRecords := int64(0) - for _, x := range stats.DroppedRecords { - droppedRecords += utils.SumMapValues(x) - } - - if badRecords > 0 { - l := stats.SampleBadRecords - if int64(len(l)) < badRecords { - f.WriteString("A sample of records that generated conversion errors:\n") - } else { - f.WriteString("Records that generated conversion errors:\n") - } - for _, r := range l { - _, err := f.WriteString(" " + r + "\n") - if err != nil { - return err - } - } - f.WriteString("\n") - } - if droppedRecords > 0 { - l := stats.SampleBadWrites - if int64(len(l)) < droppedRecords { - f.WriteString("A sample of records that successfully converted but couldn't be written to Spanner:\n") - } else { - f.WriteString("Records that successfully converted but couldn't be written to Spanner:\n") - } - for _, r := range l { - _, err := f.WriteString(" " + r + "\n") - if err != nil { - return err - } - } - } - return nil -} - -// ProcessDump invokes process dump function from a sql package based on driver selected. -func ProcessDump(driver string, conv *internal.Conv, r *internal.Reader) error { - switch driver { - case constants.MYSQLDUMP: - return common.ProcessDbDump(conv, r, mysql.DbDumpImpl{}) - case constants.PGDUMP: - return common.ProcessDbDump(conv, r, postgres.DbDumpImpl{}) - default: - return fmt.Errorf("process dump for driver %s not supported", driver) - } -} - -func GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { - driver := sourceProfile.Driver - switch driver { - case constants.MYSQL: - d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) - if err != nil { - return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) - } - var opts []cloudsqlconn.DialOption - instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Mysql.Project, sourceProfile.ConnCloudSQL.Mysql.Region, sourceProfile.ConnCloudSQL.Mysql.InstanceName) - mysqldriver.RegisterDialContext("cloudsqlconn", - func(ctx context.Context, addr string) (net.Conn, error) { - return d.Dial(ctx, instanceName, opts...) - }) - - dbURI := fmt.Sprintf("%s:empty@cloudsqlconn(localhost:3306)/%s?parseTime=true", - sourceProfile.ConnCloudSQL.Mysql.User, sourceProfile.ConnCloudSQL.Mysql.Db) - - db, err := sql.Open("mysql", dbURI) - if err != nil { - return nil, fmt.Errorf("sql.Open: %w", err) - } - return mysql.InfoSchemaImpl{ - DbName: sourceProfile.ConnCloudSQL.Mysql.Db, - Db: db, - SourceProfile: sourceProfile, - TargetProfile: targetProfile, - }, nil - case constants.POSTGRES: - d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) - if err != nil { - return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) - } - var opts []cloudsqlconn.DialOption - - dsn := fmt.Sprintf("user=%s database=%s", sourceProfile.ConnCloudSQL.Pg.User, sourceProfile.ConnCloudSQL.Pg.Db) - config, err := pgx.ParseConfig(dsn) - if err != nil { - return nil, err - } - instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Pg.Project, sourceProfile.ConnCloudSQL.Pg.Region, sourceProfile.ConnCloudSQL.Pg.InstanceName) - config.DialFunc = func(ctx context.Context, network, instance string) (net.Conn, error) { - return d.Dial(ctx, instanceName, opts...) - } - dbURI := stdlib.RegisterConnConfig(config) - db, err := sql.Open("pgx", dbURI) - if err != nil { - return nil, fmt.Errorf("sql.Open: %w", err) - } - temp := false - return postgres.InfoSchemaImpl{ - Db: db, - SourceProfile: sourceProfile, - TargetProfile: targetProfile, - IsSchemaUnique: &temp, //this is a workaround to set a bool pointer - }, nil - default: - return nil, fmt.Errorf("driver %s not supported", driver) - } -} - -func GetInfoSchema(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { - connectionConfig, err := connectionConfig(sourceProfile) - if err != nil { - return nil, err - } - driver := sourceProfile.Driver - switch driver { - case constants.MYSQL: - db, err := sql.Open(driver, connectionConfig.(string)) - dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) - if err != nil { - return nil, err - } - return mysql.InfoSchemaImpl{ - DbName: dbName, - Db: db, - SourceProfile: sourceProfile, - TargetProfile: targetProfile, - }, nil - case constants.POSTGRES: - db, err := sql.Open(driver, connectionConfig.(string)) - if err != nil { - return nil, err - } - temp := false - return postgres.InfoSchemaImpl{ - Db: db, - SourceProfile: sourceProfile, - TargetProfile: targetProfile, - IsSchemaUnique: &temp, //this is a workaround to set a bool pointer - }, nil - case constants.DYNAMODB: - mySession := session.Must(session.NewSession()) - dydbClient := dydb.New(mySession, connectionConfig.(*aws.Config)) - var dydbStreamsClient *dynamodbstreams.DynamoDBStreams - if sourceProfile.Conn.Streaming { - newSession := session.Must(session.NewSession()) - dydbStreamsClient = dynamodbstreams.New(newSession, connectionConfig.(*aws.Config)) - } - return dynamodb.InfoSchemaImpl{ - DynamoClient: dydbClient, - SampleSize: profiles.GetSchemaSampleSize(sourceProfile), - DynamoStreamsClient: dydbStreamsClient, - }, nil - case constants.SQLSERVER: - db, err := sql.Open(driver, connectionConfig.(string)) - dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) - if err != nil { - return nil, err - } - return sqlserver.InfoSchemaImpl{DbName: dbName, Db: db}, nil - case constants.ORACLE: - db, err := sql.Open(driver, connectionConfig.(string)) - dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) - if err != nil { - return nil, err - } - return oracle.InfoSchemaImpl{DbName: strings.ToUpper(dbName), Db: db, SourceProfile: sourceProfile, TargetProfile: targetProfile}, nil - default: - return nil, fmt.Errorf("driver %s not supported", driver) - } -} diff --git a/conversion/conversion_from_source.go b/conversion/conversion_from_source.go new file mode 100644 index 0000000000..8a1b73c87f --- /dev/null +++ b/conversion/conversion_from_source.go @@ -0,0 +1,339 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "bufio" + "context" + "fmt" + "strings" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/csv" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" + "go.uber.org/zap" +) + +type SchemaFromSourceInterface interface { + schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) + SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) + } + +type SchemaFromSourceImpl struct{} + +type DataFromSourceInterface interface { + dataFromDatabase(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, getInfo GetInfoInterface, dataFromDb DataFromDatabaseInterface, snapshotMigration SnapshotMigrationInterface) (*writer.BatchWriter, error) + dataFromDump(driver string, config writer.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, processDump ProcessDumpByDialectInterface, populateDataConv PopulateDataConvInterface) (*writer.BatchWriter, error) + dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, populateDataConv PopulateDataConvInterface, csv csv.CsvInterface) (*writer.BatchWriter, error) +} + +type DataFromSourceImpl struct{} + +func (sads *SchemaFromSourceImpl) schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) { + conv := internal.MakeConv() + conv.SpDialect = targetProfile.Conn.Sp.Dialect + //handle fetching schema differently for sharded migrations, we only connect to the primary shard to + //fetch the schema. We reuse the SourceProfileConnection object for this purpose. + var infoSchema common.InfoSchema + var err error + isSharded := false + switch sourceProfile.Ty { + case profiles.SourceProfileTypeConfig: + isSharded = true + //Find Primary Shard Name + if sourceProfile.Config.ConfigType == constants.BULK_MIGRATION { + schemaSource := sourceProfile.Config.ShardConfigurationBulk.SchemaSource + infoSchema, err = getInfo.getInfoSchemaForShard(schemaSource, sourceProfile.Driver, targetProfile, &profiles.SourceProfileDialectImpl{}, &GetInfoImpl{}) + if err != nil { + return conv, err + } + } else if sourceProfile.Config.ConfigType == constants.DATAFLOW_MIGRATION { + schemaSource := sourceProfile.Config.ShardConfigurationDataflow.SchemaSource + infoSchema, err = getInfo.getInfoSchemaForShard(schemaSource, sourceProfile.Driver, targetProfile, &profiles.SourceProfileDialectImpl{}, &GetInfoImpl{}) + if err != nil { + return conv, err + } + } else if sourceProfile.Config.ConfigType == constants.DMS_MIGRATION { + // TODO: Define the schema processing logic for DMS migrations here. + return conv, fmt.Errorf("dms based migrations are not implemented yet") + } else { + return conv, fmt.Errorf("unknown type of migration, please select one of bulk, dataflow or dms") + } + case profiles.SourceProfileTypeCloudSQL: + infoSchema, err = getInfo.GetInfoSchemaFromCloudSQL(sourceProfile, targetProfile) + if err != nil { + return conv, err + } + + default: + infoSchema, err = getInfo.GetInfoSchema(sourceProfile, targetProfile) + if err != nil { + return conv, err + } + } + additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ + IsSharded: isSharded, + } + return conv, processSchema.ProcessSchema(conv, infoSchema, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) +} + +func (sads *SchemaFromSourceImpl) SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) { + f, n, err := getSeekable(ioHelper.In) + if err != nil { + utils.PrintSeekError(driver, err, ioHelper.Out) + return nil, fmt.Errorf("can't get seekable input file") + } + ioHelper.SeekableIn = f + ioHelper.BytesRead = n + conv := internal.MakeConv() + conv.SpDialect = spDialect + p := internal.NewProgress(n, "Generating schema", internal.Verbose(), false, int(internal.SchemaCreationInProgress)) + r := internal.NewReader(bufio.NewReader(f), p) + conv.SetSchemaMode() // Build schema and ignore data in dump. + conv.SetDataSink(nil) + err = processDump.ProcessDump(driver, conv, r) + if err != nil { + fmt.Fprintf(ioHelper.Out, "Failed to parse the data file: %v", err) + return nil, fmt.Errorf("failed to parse the data file") + } + p.Done() + return conv, nil +} + + +func (sads *DataFromSourceImpl) dataFromDump(driver string, config writer.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, processDump ProcessDumpByDialectInterface, populateDataConv PopulateDataConvInterface) (*writer.BatchWriter, error) { + // TODO: refactor of the way we handle getSeekable + // to avoid the code duplication here + if !dataOnly { + _, err := ioHelper.SeekableIn.Seek(0, 0) + if err != nil { + fmt.Printf("\nCan't seek to start of file (preparation for second pass): %v\n", err) + return nil, fmt.Errorf("can't seek to start of file") + } + } else { + // Note: input file is kept seekable to plan for future + // changes in showing progress for data migration. + f, n, err := getSeekable(ioHelper.In) + if err != nil { + utils.PrintSeekError(driver, err, ioHelper.Out) + return nil, fmt.Errorf("can't get seekable input file") + } + ioHelper.SeekableIn = f + ioHelper.BytesRead = n + } + totalRows := conv.Rows() + + conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) + r := internal.NewReader(bufio.NewReader(ioHelper.SeekableIn), nil) + batchWriter := populateDataConv.populateDataConv(conv, config, client) + processDump.ProcessDump(driver, conv, r) + batchWriter.Flush() + conv.Audit.Progress.Done() + + return batchWriter, nil +} + +func (sads *DataFromSourceImpl) dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, populateDataConv PopulateDataConvInterface, csv csv.CsvInterface) (*writer.BatchWriter, error) { + if targetProfile.Conn.Sp.Dbname == "" { + return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source") + } + conv.SpDialect = targetProfile.Conn.Sp.Dialect + dialect, err := targetProfile.FetchTargetDialect(ctx) + if err != nil { + return nil, fmt.Errorf("could not fetch dialect: %v", err) + } + if strings.ToLower(dialect) != constants.DIALECT_POSTGRESQL { + dialect = constants.DIALECT_GOOGLESQL + } + + if dialect != conv.SpDialect { + return nil, fmt.Errorf("dialect specified in target profile does not match spanner dialect") + } + + delimiterStr := sourceProfile.Csv.Delimiter + if len(delimiterStr) != 1 { + return nil, fmt.Errorf("delimiter should only be a single character long, found '%s'", delimiterStr) + } + + delimiter := rune(delimiterStr[0]) + + err = utils.ReadSpannerSchema(ctx, conv, client) + if err != nil { + return nil, fmt.Errorf("error trying to read and convert spanner schema: %v", err) + } + + tables, err := csv.GetCSVFiles(conv, sourceProfile) + if err != nil { + return nil, fmt.Errorf("error finding csv files: %v", err) + } + + // Find the number of rows in each csv file for generating stats. + err = csv.SetRowStats(conv, tables, delimiter) + if err != nil { + return nil, err + } + + totalRows := conv.Rows() + conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) + batchWriter := populateDataConv.populateDataConv(conv, config, client) + err = csv.ProcessCSV(conv, tables, sourceProfile.Csv.NullStr, delimiter) + if err != nil { + return nil, fmt.Errorf("can't process csv: %v", err) + } + batchWriter.Flush() + conv.Audit.Progress.Done() + return batchWriter, nil +} + + +func (sads *DataFromSourceImpl) dataFromDatabase(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, getInfo GetInfoInterface, dataFromDb DataFromDatabaseInterface, snapshotMigration SnapshotMigrationInterface) (*writer.BatchWriter, error) { + //handle migrating data for sharded migrations differently + //sharded migrations are identified via the config= flag, if that flag is not present + //carry on with the existing code path in the else block + switch sourceProfile.Ty { + case profiles.SourceProfileTypeConfig: + ////There are three cases to cover here, bulk migrations and sharded migrations (and later DMS) + //We provide an if-else based handling for each within the sharded code branch + //This will be determined via the configType, which can be "bulk", "dataflow" or "dms" + if sourceProfile.Config.ConfigType == constants.BULK_MIGRATION { + return dataFromDb.dataFromDatabaseForBulkMigration(sourceProfile, targetProfile, config, conv, client, getInfo, snapshotMigration) + } else if sourceProfile.Config.ConfigType == constants.DATAFLOW_MIGRATION { + return dataFromDb.dataFromDatabaseForDataflowMigration(targetProfile, ctx, sourceProfile, conv, &common.InfoSchemaImpl{}) + } else if sourceProfile.Config.ConfigType == constants.DMS_MIGRATION { + return dataFromDb.dataFromDatabaseForDMSMigration() + } else { + return nil, fmt.Errorf("configType should be one of 'bulk', 'dataflow' or 'dms'") + } + default: + var infoSchema common.InfoSchema + var err error + if sourceProfile.Ty == profiles.SourceProfileTypeCloudSQL { + infoSchema, err = getInfo.GetInfoSchemaFromCloudSQL(sourceProfile, targetProfile) + if err != nil { + return nil, err + } + } else { + infoSchema, err = getInfo.GetInfoSchema(sourceProfile, targetProfile) + if err != nil { + return nil, err + } + } + var streamInfo map[string]interface{} + // minimal downtime migration for a single shard + if sourceProfile.Conn.Streaming { + //Generate a job Id + migrationJobId := conv.Audit.MigrationRequestId + logger.Log.Info(fmt.Sprintf("Creating a migration job with id: %v. This jobId can be used in future commmands (such as cleanup) to refer to this job.\n", migrationJobId)) + streamInfo, err = infoSchema.StartChangeDataCapture(ctx, conv) + if err != nil { + return nil, err + } + bw, err := snapshotMigration.snapshotMigrationHandler(sourceProfile, config, conv, client, infoSchema) + if err != nil { + return nil, err + } + dfOutput, err := infoSchema.StartStreamingMigration(ctx, client, conv, streamInfo) + if err != nil { + return nil, err + } + dfJobId := dfOutput.JobID + gcloudCmd := dfOutput.GCloudCmd + streamingCfg, _ := streamInfo["streamingCfg"].(streaming.StreamingCfg) + // Fetch and store the GCS bucket associated with the datastream + dsClient := getDatastreamClient(ctx) + gcsBucket, gcsDestPrefix, fetchGcsErr := streaming.FetchTargetBucketAndPath(ctx, dsClient, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig) + if fetchGcsErr != nil { + logger.Log.Info("Could not fetch GCS Bucket, hence Monitoring Dashboard will not contain Metrics for the gcs bucket\n") + logger.Log.Debug("Error", zap.Error(fetchGcsErr)) + } + + // Try to apply lifecycle rule to Datastream destination bucket. + gcsConfig := streamingCfg.GcsCfg + sc, err := storageclient.NewStorageClientImpl(ctx) + if err != nil { + return nil, err + } + sa := storageaccessor.StorageAccessorImpl{} + if gcsConfig.TtlInDaysSet { + err = sa.ApplyBucketLifecycleDeleteRule(ctx, sc, storageaccessor.StorageBucketMetadata{ + BucketName: gcsBucket, + Ttl: gcsConfig.TtlInDays, + MatchesPrefix: []string{gcsDestPrefix}, + }) + if err != nil { + logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) + logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") + } + } + + monitoringResources := metrics.MonitoringMetricsResources{ + ProjectId: targetProfile.Conn.Sp.Project, + DataflowJobId: dfOutput.JobID, + DatastreamId: streamingCfg.DatastreamCfg.StreamId, + JobMetadataGcsBucket: gcsBucket, + PubsubSubscriptionId: streamingCfg.PubsubCfg.SubscriptionId, + SpannerInstanceId: targetProfile.Conn.Sp.Instance, + SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, + ShardId: "", + MigrationRequestId: conv.Audit.MigrationRequestId, + } + respDash, dashboardErr := monitoringResources.CreateDataflowShardMonitoringDashboard(ctx) + var dashboardName string + if dashboardErr != nil { + dashboardName = "" + logger.Log.Info("Creation of the monitoring dashboard failed, please create the dashboard manually") + logger.Log.Debug("Error", zap.Error(dashboardErr)) + } else { + dashboardName = strings.Split(respDash.Name, "/")[3] + fmt.Printf("Monitoring Dashboard: %+v\n", dashboardName) + } + // store the generated resources locally in conv, this is used as source of truth for persistence and the UI (should change to persisted values) + streaming.StoreGeneratedResources(conv, streamingCfg, dfJobId, gcloudCmd, targetProfile.Conn.Sp.Project, "", internal.GcsResources{BucketName: gcsBucket}, dashboardName) + //persist job and shard level data in the metadata db + err = streaming.PersistJobDetails(ctx, targetProfile, sourceProfile, conv, migrationJobId, false) + if err != nil { + logger.Log.Info(fmt.Sprintf("Error storing job details in SMT metadata store...the migration job will still continue as intended. %v", err)) + } else { + //only attempt persisting shard level data if the job level data is persisted + err = streaming.PersistResources(ctx, targetProfile, sourceProfile, conv, migrationJobId, constants.DEFAULT_SHARD_ID) + if err != nil { + logger.Log.Info(fmt.Sprintf("Error storing details for migration job: %s, data shard: %s in SMT metadata store...the migration job will still continue as intended. err = %v\n", migrationJobId, constants.DEFAULT_SHARD_ID, err)) + } + } + return bw, nil + } + //bulk migration for a single shard + return snapshotMigration.performSnapshotMigration(config, conv, client, infoSchema, internal.AdditionalDataAttributes{ShardId: ""}, &common.InfoSchemaImpl{}, &PopulateDataConvImpl{}), nil + } +} diff --git a/conversion/conversion_from_source_test.go b/conversion/conversion_from_source_test.go new file mode 100644 index 0000000000..f236807553 --- /dev/null +++ b/conversion/conversion_from_source_test.go @@ -0,0 +1,175 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "fmt" + "testing" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + + +func TestSchemaFromDatabase(t *testing.T) { + targetProfile := profiles.TargetProfile{ + Conn: profiles.TargetProfileConnection{ + Sp: profiles.TargetProfileConnectionSpanner{ + Dialect: "google_standard_sql", + }, + }, + } + + sourceProfileConfigBulk := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(3), + Config: profiles.SourceProfileConfig{ + ConfigType: "bulk", + }, + } + sourceProfileConfigDataflow := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(3), + Config: profiles.SourceProfileConfig{ + ConfigType: "dataflow", + }, + } + sourceProfileConfigDms := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(3), + Config: profiles.SourceProfileConfig{ + ConfigType: "dms", + }, + } + sourceProfileConfigInvalid := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(3), + Config: profiles.SourceProfileConfig{ + ConfigType: "invalid", + }, + } + sourceProfileCloudSql := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(5), + } + sourceProfileCloudDefault := profiles.SourceProfile{} + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + sourceProfile profiles.SourceProfile + getInfoError error + processSchemaError error + errorExpected bool + }{ + { + name: "successful source profile config for bulk migration", + sourceProfile: sourceProfileConfigBulk, + getInfoError: nil, + processSchemaError: nil, + errorExpected: false, + }, + { + name: "source profile config for bulk migration: get info error", + sourceProfile: sourceProfileConfigBulk, + getInfoError: fmt.Errorf("error"), + processSchemaError: nil, + errorExpected: true, + }, + { + name: "source profile config for bulk migration: process schema error", + sourceProfile: sourceProfileConfigBulk, + getInfoError: nil, + processSchemaError: fmt.Errorf("error"), + errorExpected: true, + }, + { + name: "successful source profile config for dataflow migration", + sourceProfile: sourceProfileConfigDataflow, + getInfoError: nil, + processSchemaError: nil, + errorExpected: false, + }, + { + name: "source profile config for dataflow migration: get info error", + sourceProfile: sourceProfileConfigDataflow, + getInfoError: fmt.Errorf("error"), + processSchemaError: nil, + errorExpected: true, + }, + { + name: "source profile config for dms migration", + sourceProfile: sourceProfileConfigDms, + getInfoError: nil, + processSchemaError: nil, + errorExpected: true, + }, + { + name: "invalid source profile config", + sourceProfile: sourceProfileConfigInvalid, + getInfoError: nil, + processSchemaError: nil, + errorExpected: true, + }, + { + name: "successful source profile cloud sql", + sourceProfile: sourceProfileCloudSql, + getInfoError: nil, + processSchemaError: nil, + errorExpected: false, + }, + { + name: "source profile cloud sql: get info error", + sourceProfile: sourceProfileCloudSql, + getInfoError: fmt.Errorf("error"), + processSchemaError: nil, + errorExpected: true, + }, + { + name: "successful source profile default", + sourceProfile: sourceProfileCloudDefault, + getInfoError: nil, + processSchemaError: nil, + errorExpected: false, + }, + { + name: "source profile default: get info error", + sourceProfile: sourceProfileCloudDefault, + getInfoError: fmt.Errorf("error"), + processSchemaError: nil, + errorExpected: true, + }, + } + + for _, tc := range testCases { + gim := MockGetInfo{} + ps := common.MockProcessSchema{} + + gim.On("getInfoSchemaForShard", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(mysql.InfoSchemaImpl{}, tc.getInfoError) + gim.On("GetInfoSchemaFromCloudSQL", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(mysql.InfoSchemaImpl{}, tc.getInfoError) + gim.On("GetInfoSchema", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(mysql.InfoSchemaImpl{}, tc.getInfoError) + ps.On("ProcessSchema", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.processSchemaError) + + s := SchemaFromSourceImpl{} + _, err := s.schemaFromDatabase(tc.sourceProfile, targetProfile, &gim, &ps) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + } +} \ No newline at end of file diff --git a/conversion/conversion_helper.go b/conversion/conversion_helper.go new file mode 100644 index 0000000000..8aa9d30b8e --- /dev/null +++ b/conversion/conversion_helper.go @@ -0,0 +1,212 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) + +package conversion + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "os" + "strings" + "sync/atomic" + "syscall" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/postgres" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/aws/aws-sdk-go/aws" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" +) + +type ProcessDumpByDialectInterface interface{ + ProcessDump(driver string, conv *internal.Conv, r *internal.Reader) error +} + +type ProcessDumpByDialectImpl struct{} + +type PopulateDataConvInterface interface{ + populateDataConv(conv *internal.Conv, config writer.BatchWriterConfig, client *sp.Client) *writer.BatchWriter +} + +type PopulateDataConvImpl struct{} +// getSeekable returns a seekable file (with same content as f) and the size of the content (in bytes). +func getSeekable(f *os.File) (*os.File, int64, error) { + _, err := f.Seek(0, 0) + if err == nil { // Stdin is seekable, let's just use that. This happens when you run 'cmd < file'. + n, err := utils.GetFileSize(f) + return f, n, err + } + internal.VerbosePrintln("Creating a tmp file with a copy of stdin because stdin is not seekable.") + logger.Log.Debug("Creating a tmp file with a copy of stdin because stdin is not seekable.") + + // Create file in os.TempDir. Its not clear this is a good idea e.g. if the + // pg_dump/mysqldump output is large (tens of GBs) and os.TempDir points to a directory + // (such as /tmp) that's configured with a small amount of disk space. + // To workaround such limits on Unix, set $TMPDIR to a directory with lots + // of disk space. + fcopy, err := ioutil.TempFile("", "spanner-migration-tool.data") + if err != nil { + return nil, 0, err + } + syscall.Unlink(fcopy.Name()) // File will be deleted when this process exits. + _, err = io.Copy(fcopy, f) + if err != nil { + return nil, 0, fmt.Errorf("can't write stdin to tmp file: %w", err) + } + _, err = fcopy.Seek(0, 0) + if err != nil { + return nil, 0, fmt.Errorf("can't reset file offset: %w", err) + } + n, _ := utils.GetFileSize(fcopy) + return fcopy, n, nil +} + +// ProcessDump invokes process dump function from a sql package based on driver selected. +func (pdd *ProcessDumpByDialectImpl) ProcessDump(driver string, conv *internal.Conv, r *internal.Reader) error { + switch driver { + case constants.MYSQLDUMP: + return common.ProcessDbDump(conv, r, mysql.DbDumpImpl{}) + case constants.PGDUMP: + return common.ProcessDbDump(conv, r, postgres.DbDumpImpl{}) + default: + return fmt.Errorf("process dump for driver %s not supported", driver) + } +} + + +func (pdc *PopulateDataConvImpl) populateDataConv(conv *internal.Conv, config writer.BatchWriterConfig, client *sp.Client) *writer.BatchWriter { + rows := int64(0) + config.Write = func(m []*sp.Mutation) error { + ctx := context.Background() + if !conv.Audit.SkipMetricsPopulation { + migrationData := metrics.GetMigrationData(conv, "", constants.DataConv) + serializedMigrationData, _ := proto.Marshal(migrationData) + migrationMetadataValue := base64.StdEncoding.EncodeToString(serializedMigrationData) + ctx = metadata.AppendToOutgoingContext(context.Background(), constants.MigrationMetadataKey, migrationMetadataValue) + } + _, err := client.Apply(ctx, m) + if err != nil { + return err + } + atomic.AddInt64(&rows, int64(len(m))) + conv.Audit.Progress.MaybeReport(atomic.LoadInt64(&rows)) + return nil + } + batchWriter := writer.NewBatchWriter(config) + conv.SetDataMode() + if !conv.Audit.DryRun { + conv.SetDataSink( + func(table string, cols []string, vals []interface{}) { + batchWriter.AddRow(table, cols, vals) + }) + conv.DataFlush = func() { + batchWriter.Flush() + } + } + + return batchWriter +} + + +func connectionConfig(sourceProfile profiles.SourceProfile) (interface{}, error) { + switch sourceProfile.Driver { + // For PG and MYSQL, When called as part of the subcommand flow, host/user/db etc will + // never be empty as we error out right during source profile creation. If any of them + // are empty, that means this was called through the legacy cmd flow and we create the + // string using env vars. + case constants.POSTGRES: + pgConn := sourceProfile.Conn.Pg + if !(pgConn.Host != "" && pgConn.User != "" && pgConn.Db != "") { + return profiles.GeneratePGSQLConnectionStr() + } else { + return profiles.GetSQLConnectionStr(sourceProfile), nil + } + case constants.MYSQL: + // If empty, this is called as part of the legacy mode witih global CLI flags. + // When using source-profile mode is used, the sqlConnectionStr is already populated. + mysqlConn := sourceProfile.Conn.Mysql + if !(mysqlConn.Host != "" && mysqlConn.User != "" && mysqlConn.Db != "") { + return profiles.GenerateMYSQLConnectionStr() + } else { + return profiles.GetSQLConnectionStr(sourceProfile), nil + } + // For Dynamodb, both legacy and new flows use env vars. + case constants.DYNAMODB: + return getDynamoDBClientConfig() + case constants.SQLSERVER: + return profiles.GetSQLConnectionStr(sourceProfile), nil + case constants.ORACLE: + return profiles.GetSQLConnectionStr(sourceProfile), nil + default: + return "", fmt.Errorf("driver %s not supported", sourceProfile.Driver) + } +} + +func getDbNameFromSQLConnectionStr(driver, sqlConnectionStr string) string { + switch driver { + case constants.POSTGRES: + dbParam := strings.Split(sqlConnectionStr, " ")[4] + return strings.Split(dbParam, "=")[1] + case constants.MYSQL: + return strings.Split(sqlConnectionStr, ")/")[1] + case constants.SQLSERVER: + splts := strings.Split(sqlConnectionStr, "?database=") + return splts[len(splts)-1] + case constants.ORACLE: + // connection string formate : "oracle://user:password@104.108.154.85:1521/XE" + substr := sqlConnectionStr[9:] + dbName := strings.Split(substr, ":")[0] + return dbName + } + return "" +} + +func updateShardsWithTuningConfigs(shardedTuningConfig profiles.ShardConfigurationDataflow) { + for _, dataShard := range shardedTuningConfig.DataShards { + dataShard.DatastreamConfig = shardedTuningConfig.DatastreamConfig + dataShard.GcsConfig = shardedTuningConfig.GcsConfig + dataShard.DataflowConfig = shardedTuningConfig.DataflowConfig + } +} + +func getDynamoDBClientConfig() (*aws.Config, error) { + cfg := aws.Config{} + endpointOverride := os.Getenv("DYNAMODB_ENDPOINT_OVERRIDE") + if endpointOverride != "" { + cfg.Endpoint = aws.String(endpointOverride) + } + return &cfg, nil +} \ No newline at end of file diff --git a/conversion/conversion_test.go b/conversion/conversion_test.go new file mode 100644 index 0000000000..2eb4ed3808 --- /dev/null +++ b/conversion/conversion_test.go @@ -0,0 +1,203 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "context" + "testing" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestSchemaConv(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + sourceProfileDriver string + output interface{} + function string + errorExpected bool + }{ + { + name: "postgres driver", + sourceProfileDriver: "postgres", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "mysql driver", + sourceProfileDriver: "mysql", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "dynamodb driver", + sourceProfileDriver: "dynamodb", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "sqlserver driver", + sourceProfileDriver: "sqlserver", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "oracle driver", + sourceProfileDriver: "oracle", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "pg dump driver", + sourceProfileDriver: "pg_dump", + output: &internal.Conv{}, + function: "SchemaFromDump", + errorExpected: false, + }, + { + name: "mysql dump driver", + sourceProfileDriver: "mysqldump", + output: &internal.Conv{}, + function: "SchemaFromDump", + errorExpected: false, + }, + { + name: "invalid driver", + sourceProfileDriver: "invalid", + output: nil, + function: "", + errorExpected: true, + }, + } + + for _, tc := range testCases { + m := MockSchemaFromSource{} + m.On(tc.function, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.output, nil) + c := ConvImpl{} + _, err := c.SchemaConv(profiles.SourceProfile{Driver: tc.sourceProfileDriver}, profiles.TargetProfile{}, &utils.IOStreams{}, &m) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + if err == nil { + m.AssertExpectations(t) + } + } +} + +func TestDataConv(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + sourceProfileDriver string + output interface{} + function string + errorExpected bool + }{ + { + name: "postgres driver", + sourceProfileDriver: "postgres", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "mysql driver", + sourceProfileDriver: "mysql", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "dynamodb driver", + sourceProfileDriver: "dynamodb", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "sqlserver driver", + sourceProfileDriver: "sqlserver", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "oracle driver", + sourceProfileDriver: "oracle", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "pg dump driver", + sourceProfileDriver: "pg_dump", + output: &writer.BatchWriter{}, + function: "dataFromDump", + errorExpected: false, + }, + { + name: "mysql dump driver", + sourceProfileDriver: "mysqldump", + output: &writer.BatchWriter{}, + function: "dataFromDump", + errorExpected: false, + }, + { + name: "crv driver", + sourceProfileDriver: "csv", + output: &writer.BatchWriter{}, + function: "dataFromCSV", + errorExpected: false, + }, + { + name: "invalid driver", + sourceProfileDriver: "invalid", + output: nil, + function: "", + errorExpected: true, + }, + } + + ctx:= context.Background() + for _, tc := range testCases { + m := MockDataFromSource{} + m.On(tc.function, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.output, nil) + c := ConvImpl{} + _, err := c.DataConv(ctx, profiles.SourceProfile{Driver: tc.sourceProfileDriver}, profiles.TargetProfile{}, &utils.IOStreams{}, &sp.Client{}, &internal.Conv{}, true, int64(5), &m) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + if err == nil { + m.AssertExpectations(t) + } + } +} \ No newline at end of file diff --git a/conversion/data_from_database.go b/conversion/data_from_database.go new file mode 100644 index 0000000000..706b61dd90 --- /dev/null +++ b/conversion/data_from_database.go @@ -0,0 +1,224 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "context" + "fmt" + "strings" + "sync" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + "go.uber.org/zap" +) + +type DataFromDatabaseInterface interface{ + dataFromDatabaseForDMSMigration() (*writer.BatchWriter, error) + dataFromDatabaseForDataflowMigration(targetProfile profiles.TargetProfile, ctx context.Context, sourceProfile profiles.SourceProfile, conv *internal.Conv, is common.InfoSchemaInterface) (*writer.BatchWriter, error) + dataFromDatabaseForBulkMigration(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, gi GetInfoInterface, sm SnapshotMigrationInterface) (*writer.BatchWriter, error) + +} + +type DataFromDatabaseImpl struct{} + + +// TODO: Define the data processing logic for DMS migrations here. +func (dd *DataFromDatabaseImpl) dataFromDatabaseForDMSMigration() (*writer.BatchWriter, error) { + return nil, fmt.Errorf("dms configType is not implemented yet, please use one of 'bulk' or 'dataflow'") +} + +// 1. Create batch for each physical shard +// 2. Create streaming cfg from the config source type. +// 3. Verify the CFG and update it with SMT defaults +// 4. Launch the stream for the physical shard +// 5. Perform streaming migration via dataflow +func (dd *DataFromDatabaseImpl) dataFromDatabaseForDataflowMigration(targetProfile profiles.TargetProfile, ctx context.Context, sourceProfile profiles.SourceProfile, conv *internal.Conv, is common.InfoSchemaInterface) (*writer.BatchWriter, error) { + updateShardsWithTuningConfigs(sourceProfile.Config.ShardConfigurationDataflow) + //Generate a job Id + migrationJobId := conv.Audit.MigrationRequestId + fmt.Printf("Creating a migration job with id: %v. This jobId can be used in future commmands (such as cleanup) to refer to this job.\n", migrationJobId) + conv.Audit.StreamingStats.ShardToShardResourcesMap = make(map[string]internal.ShardResources) + schemaDetails, err := is.GetIncludedSrcTablesFromConv(conv) + if err != nil { + fmt.Printf("unable to determine tableList from schema, falling back to full database") + schemaDetails = map[string]internal.SchemaDetails{} + } + err = streaming.PersistJobDetails(ctx, targetProfile, sourceProfile, conv, migrationJobId, true) + if err != nil { + logger.Log.Info(fmt.Sprintf("Error storing job details in SMT metadata store...the migration job will still continue as intended. %v", err)) + } + asyncProcessShards := func(p *profiles.DataShard, mutex *sync.Mutex) common.TaskResult[*profiles.DataShard] { + dbNameToShardIdMap := make(map[string]string) + for _, l := range p.LogicalShards { + dbNameToShardIdMap[l.DbName] = l.LogicalShardId + } + if p.DataShardId == "" { + dataShardId, err := utils.GenerateName("smt-datashard") + dataShardId = strings.Replace(dataShardId, "_", "-", -1) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + p.DataShardId = dataShardId + fmt.Printf("Data shard id generated: %v\n", p.DataShardId) + } + streamingCfg := streaming.CreateStreamingConfig(*p) + err := streaming.VerifyAndUpdateCfg(&streamingCfg, targetProfile.Conn.Sp.Dbname, schemaDetails) + if err != nil { + err = fmt.Errorf("failed to process shard: %s, there seems to be an error in the sharding configuration, error: %v", p.DataShardId, err) + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + fmt.Printf("Initiating migration for shard: %v\n", p.DataShardId) + pubsubCfg, err := streaming.CreatePubsubResources(ctx, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig, targetProfile.Conn.Sp.Dbname) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + streamingCfg.PubsubCfg = *pubsubCfg + err = streaming.LaunchStream(ctx, sourceProfile, p.LogicalShards, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + streamingCfg.DataflowCfg.DbNameToShardIdMap = dbNameToShardIdMap + dfOutput, err := streaming.StartDataflow(ctx, targetProfile, streamingCfg, conv) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + // store the generated resources locally in conv, this is used as source of truth for persistence and the UI (should change to persisted values) + + // Fetch and store the GCS bucket associated with the datastream + dsClient := getDatastreamClient(ctx) + gcsBucket, gcsDestPrefix, fetchGcsErr := streaming.FetchTargetBucketAndPath(ctx, dsClient, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig) + if fetchGcsErr != nil { + logger.Log.Info(fmt.Sprintf("Could not fetch GCS Bucket for Shard %s hence Monitoring Dashboard will not contain Metrics for the gcs bucket\n", p.DataShardId)) + logger.Log.Debug("Error", zap.Error(fetchGcsErr)) + } + + // Try to apply lifecycle rule to Datastream destination bucket. + gcsConfig := streamingCfg.GcsCfg + sc, err := storageclient.NewStorageClientImpl(ctx) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + sa := storageaccessor.StorageAccessorImpl{} + if gcsConfig.TtlInDaysSet { + err = sa.ApplyBucketLifecycleDeleteRule(ctx, sc, storageaccessor.StorageBucketMetadata{ + BucketName: gcsBucket, + Ttl: gcsConfig.TtlInDays, + MatchesPrefix: []string{gcsDestPrefix}, + }) + if err != nil { + logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) + logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") + } + } + + // create monitoring dashboard for a single shard + monitoringResources := metrics.MonitoringMetricsResources{ + ProjectId: targetProfile.Conn.Sp.Project, + DataflowJobId: dfOutput.JobID, + DatastreamId: streamingCfg.DatastreamCfg.StreamId, + JobMetadataGcsBucket: gcsBucket, + PubsubSubscriptionId: streamingCfg.PubsubCfg.SubscriptionId, + SpannerInstanceId: targetProfile.Conn.Sp.Instance, + SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, + ShardId: p.DataShardId, + MigrationRequestId: conv.Audit.MigrationRequestId, + } + respDash, dashboardErr := monitoringResources.CreateDataflowShardMonitoringDashboard(ctx) + var dashboardName string + if dashboardErr != nil { + dashboardName = "" + logger.Log.Info(fmt.Sprintf("Creation of the monitoring dashboard for shard %s failed, please create the dashboard manually\n", p.DataShardId)) + logger.Log.Debug("Error", zap.Error(dashboardErr)) + } else { + dashboardName = strings.Split(respDash.Name, "/")[3] + fmt.Printf("Monitoring Dashboard for shard %v: %+v\n", p.DataShardId, dashboardName) + } + streaming.StoreGeneratedResources(conv, streamingCfg, dfOutput.JobID, dfOutput.GCloudCmd, targetProfile.Conn.Sp.Project, p.DataShardId, internal.GcsResources{BucketName: gcsBucket}, dashboardName) + //persist the generated resources in a metadata db + err = streaming.PersistResources(ctx, targetProfile, sourceProfile, conv, migrationJobId, p.DataShardId) + if err != nil { + fmt.Printf("Error storing generated resources in SMT metadata store for dataShardId: %s...the migration job will still continue as intended, error: %v\n", p.DataShardId, err) + } + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + _, err = common.RunParallelTasks(sourceProfile.Config.ShardConfigurationDataflow.DataShards, 20, asyncProcessShards, true) + if err != nil { + return nil, fmt.Errorf("unable to start minimal downtime migrations: %v", err) + } + + // create monitoring aggregated dashboard for sharded migration + aggMonitoringResources := metrics.MonitoringMetricsResources{ + ProjectId: targetProfile.Conn.Sp.Project, + SpannerInstanceId: targetProfile.Conn.Sp.Instance, + SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, + ShardToShardResourcesMap: conv.Audit.StreamingStats.ShardToShardResourcesMap, + MigrationRequestId: conv.Audit.MigrationRequestId, + } + aggRespDash, dashboardErr := aggMonitoringResources.CreateDataflowAggMonitoringDashboard(ctx) + if dashboardErr != nil { + logger.Log.Error(fmt.Sprintf("Creation of the aggregated monitoring dashboard failed, please create the dashboard manually\n error=%v\n", dashboardErr)) + } else { + fmt.Printf("Aggregated Monitoring Dashboard: %+v\n", strings.Split(aggRespDash.Name, "/")[3]) + conv.Audit.StreamingStats.AggMonitoringResources = internal.MonitoringResources{DashboardName: strings.Split(aggRespDash.Name, "/")[3]} + } + err = streaming.PersistAggregateMonitoringResources(ctx, targetProfile, sourceProfile, conv, migrationJobId) + if err != nil { + logger.Log.Info(fmt.Sprintf("Unable to store aggregated monitoring dashboard in metadata database\n error=%v\n", err)) + } else { + logger.Log.Debug("Aggregate monitoring resources stored successfully.\n") + } + return &writer.BatchWriter{}, nil +} + +// 1. Migrate the data from the data shards, the schema shard needs to be specified here again. +// 2. Create a connection profile object for it +// 3. Perform a snapshot migration for the shard +// 4. Once all shard migrations are complete, return the batch writer object +func (dd *DataFromDatabaseImpl) dataFromDatabaseForBulkMigration(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, gi GetInfoInterface, sm SnapshotMigrationInterface) (*writer.BatchWriter, error) { + var bw *writer.BatchWriter + for _, dataShard := range sourceProfile.Config.ShardConfigurationBulk.DataShards { + + fmt.Printf("Initiating migration for shard: %v\n", dataShard.DbName) + infoSchema, err := gi.getInfoSchemaForShard(dataShard, sourceProfile.Driver, targetProfile, &profiles.SourceProfileDialectImpl{}, &GetInfoImpl{}) + if err != nil { + return nil, err + } + additionalDataAttributes := internal.AdditionalDataAttributes{ + ShardId: dataShard.DataShardId, + } + bw = sm.performSnapshotMigration(config, conv, client, infoSchema, additionalDataAttributes, &common.InfoSchemaImpl{}, &PopulateDataConvImpl{}) + } + + return bw, nil +} \ No newline at end of file diff --git a/conversion/get_info.go b/conversion/get_info.go new file mode 100644 index 0000000000..a43fb8a0f0 --- /dev/null +++ b/conversion/get_info.go @@ -0,0 +1,210 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "context" + "database/sql" + "fmt" + "net" + "strings" + + "cloud.google.com/go/cloudsqlconn" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/dynamodb" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/oracle" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/postgres" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/sqlserver" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + dydb "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodbstreams" + mysqldriver "github.com/go-sql-driver/mysql" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" +) + +type GetInfoInterface interface{ + getInfoSchemaForShard(shardConnInfo profiles.DirectConnectionConfig, driver string, targetProfile profiles.TargetProfile, sourceProfileDialect profiles.SourceProfileDialectInterface, getInfo GetInfoInterface) (common.InfoSchema, error) + GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) + GetInfoSchema(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) +} +type GetInfoImpl struct{} + +func (gi *GetInfoImpl) getInfoSchemaForShard(shardConnInfo profiles.DirectConnectionConfig, driver string, targetProfile profiles.TargetProfile, sourceProfileDialect profiles.SourceProfileDialectInterface, getInfo GetInfoInterface) (common.InfoSchema, error) { + params := make(map[string]string) + params["host"] = shardConnInfo.Host + params["user"] = shardConnInfo.User + params["dbName"] = shardConnInfo.DbName + params["port"] = shardConnInfo.Port + params["password"] = shardConnInfo.Password + //while adding other sources, a switch-case will be added here on the basis of the driver input param passed. + //pased on the driver name, profiles.NewSourceProfileConnection will need to be called to create + //the source profile information. + getUtilsInfo := utils.GetUtilInfoImpl{} + sourceProfileConnectionMySQL, err := sourceProfileDialect.NewSourceProfileConnectionMySQL(params, &getUtilsInfo) + if err != nil { + return nil, fmt.Errorf("cannot parse connection configuration for the primary shard") + } + sourceProfileConnection := profiles.SourceProfileConnection{Mysql: sourceProfileConnectionMySQL, Ty: profiles.SourceProfileConnectionTypeMySQL} + //create a source profile which contains the sourceProfileConnection object for the primary shard + //this is done because GetSQLConnectionStr() should not be aware of sharding + newSourceProfile := profiles.SourceProfile{Conn: sourceProfileConnection, Ty: profiles.SourceProfileTypeConnection} + newSourceProfile.Driver = driver + infoSchema, err := getInfo.GetInfoSchema(newSourceProfile, targetProfile) + if err != nil { + return nil, err + } + return infoSchema, nil +} + + +func (gi *GetInfoImpl) GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { + driver := sourceProfile.Driver + switch driver { + case constants.MYSQL: + d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) + if err != nil { + return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) + } + var opts []cloudsqlconn.DialOption + instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Mysql.Project, sourceProfile.ConnCloudSQL.Mysql.Region, sourceProfile.ConnCloudSQL.Mysql.InstanceName) + mysqldriver.RegisterDialContext("cloudsqlconn", + func(ctx context.Context, addr string) (net.Conn, error) { + return d.Dial(ctx, instanceName, opts...) + }) + + dbURI := fmt.Sprintf("%s:empty@cloudsqlconn(localhost:3306)/%s?parseTime=true", + sourceProfile.ConnCloudSQL.Mysql.User, sourceProfile.ConnCloudSQL.Mysql.Db) + + db, err := sql.Open("mysql", dbURI) + if err != nil { + return nil, fmt.Errorf("sql.Open: %w", err) + } + return mysql.InfoSchemaImpl{ + DbName: sourceProfile.ConnCloudSQL.Mysql.Db, + Db: db, + SourceProfile: sourceProfile, + TargetProfile: targetProfile, + }, nil + case constants.POSTGRES: + d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) + if err != nil { + return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) + } + var opts []cloudsqlconn.DialOption + + dsn := fmt.Sprintf("user=%s database=%s", sourceProfile.ConnCloudSQL.Pg.User, sourceProfile.ConnCloudSQL.Pg.Db) + config, err := pgx.ParseConfig(dsn) + if err != nil { + return nil, err + } + instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Pg.Project, sourceProfile.ConnCloudSQL.Pg.Region, sourceProfile.ConnCloudSQL.Pg.InstanceName) + config.DialFunc = func(ctx context.Context, network, instance string) (net.Conn, error) { + return d.Dial(ctx, instanceName, opts...) + } + dbURI := stdlib.RegisterConnConfig(config) + db, err := sql.Open("pgx", dbURI) + if err != nil { + return nil, fmt.Errorf("sql.Open: %w", err) + } + temp := false + return postgres.InfoSchemaImpl{ + Db: db, + SourceProfile: sourceProfile, + TargetProfile: targetProfile, + IsSchemaUnique: &temp, //this is a workaround to set a bool pointer + }, nil + default: + return nil, fmt.Errorf("driver %s not supported", driver) + } +} + + +func (gi *GetInfoImpl) GetInfoSchema(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { + connectionConfig, err := connectionConfig(sourceProfile) + if err != nil { + return nil, err + } + driver := sourceProfile.Driver + switch driver { + case constants.MYSQL: + db, err := sql.Open(driver, connectionConfig.(string)) + dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) + if err != nil { + return nil, err + } + return mysql.InfoSchemaImpl{ + DbName: dbName, + Db: db, + SourceProfile: sourceProfile, + TargetProfile: targetProfile, + }, nil + case constants.POSTGRES: + db, err := sql.Open(driver, connectionConfig.(string)) + if err != nil { + return nil, err + } + temp := false + return postgres.InfoSchemaImpl{ + Db: db, + SourceProfile: sourceProfile, + TargetProfile: targetProfile, + IsSchemaUnique: &temp, //this is a workaround to set a bool pointer + }, nil + case constants.DYNAMODB: + mySession := session.Must(session.NewSession()) + dydbClient := dydb.New(mySession, connectionConfig.(*aws.Config)) + var dydbStreamsClient *dynamodbstreams.DynamoDBStreams + if sourceProfile.Conn.Streaming { + newSession := session.Must(session.NewSession()) + dydbStreamsClient = dynamodbstreams.New(newSession, connectionConfig.(*aws.Config)) + } + return dynamodb.InfoSchemaImpl{ + DynamoClient: dydbClient, + SampleSize: profiles.GetSchemaSampleSize(sourceProfile), + DynamoStreamsClient: dydbStreamsClient, + }, nil + case constants.SQLSERVER: + db, err := sql.Open(driver, connectionConfig.(string)) + dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) + if err != nil { + return nil, err + } + return sqlserver.InfoSchemaImpl{DbName: dbName, Db: db}, nil + case constants.ORACLE: + db, err := sql.Open(driver, connectionConfig.(string)) + dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) + if err != nil { + return nil, err + } + return oracle.InfoSchemaImpl{DbName: strings.ToUpper(dbName), Db: db, SourceProfile: sourceProfile, TargetProfile: targetProfile}, nil + default: + return nil, fmt.Errorf("driver %s not supported", driver) + } +} \ No newline at end of file diff --git a/conversion/mocks.go b/conversion/mocks.go new file mode 100644 index 0000000000..65f53f0181 --- /dev/null +++ b/conversion/mocks.go @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "context" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/csv" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/stretchr/testify/mock" +) + + +type MockGetInfo struct { + mock.Mock +} + +func (mgi *MockGetInfo) getInfoSchemaForShard(shardConnInfo profiles.DirectConnectionConfig, driver string, targetProfile profiles.TargetProfile, s profiles.SourceProfileDialectInterface, g GetInfoInterface) (common.InfoSchema, error) { + args := mgi.Called(shardConnInfo, driver, targetProfile, s, g) + return args.Get(0).(common.InfoSchema), args.Error(1) +} +func (mgi *MockGetInfo) GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { + args := mgi.Called(sourceProfile, targetProfile) + return args.Get(0).(common.InfoSchema), args.Error(1) +} +func (mgi *MockGetInfo) GetInfoSchema(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { + args := mgi.Called(sourceProfile, targetProfile) + return args.Get(0).(common.InfoSchema), args.Error(1) +} + +type MockSchemaFromSource struct { + mock.Mock +} +func (msads *MockSchemaFromSource) schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) { + args := msads.Called(sourceProfile, targetProfile, getInfo, processSchema) + return args.Get(0).(*internal.Conv), args.Error(1) +} +func (msads *MockSchemaFromSource) SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) { + args := msads.Called(driver, spDialect, ioHelper, processDump) + return args.Get(0).(*internal.Conv), args.Error(1) +} + +type MockDataFromSource struct { + mock.Mock +} +func (msads *MockDataFromSource) dataFromDatabase(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, getInfo GetInfoInterface, dataFromDb DataFromDatabaseInterface, snapshotMigration SnapshotMigrationInterface) (*writer.BatchWriter, error) { + args := msads.Called(ctx, sourceProfile, targetProfile, config, conv, client, getInfo, dataFromDb, snapshotMigration) + return args.Get(0).(*writer.BatchWriter), args.Error(1) +} +func (msads *MockDataFromSource) dataFromDump(driver string, config writer.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, processDump ProcessDumpByDialectInterface, populateDataConv PopulateDataConvInterface) (*writer.BatchWriter, error) { + args := msads.Called(driver, config, ioHelper, client, conv, dataOnly, processDump, populateDataConv) + return args.Get(0).(*writer.BatchWriter), args.Error(1) +} +func (msads *MockDataFromSource) dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, pdc PopulateDataConvInterface, csv csv.CsvInterface) (*writer.BatchWriter, error) { + args := msads.Called(ctx, sourceProfile, targetProfile, config, conv, client, pdc, csv) + return args.Get(0).(*writer.BatchWriter), args.Error(1) +} \ No newline at end of file diff --git a/conversion/snapshot_migration.go b/conversion/snapshot_migration.go new file mode 100644 index 0000000000..f0a4d5bd19 --- /dev/null +++ b/conversion/snapshot_migration.go @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) + +package conversion + +import ( + "fmt" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" +) + +type SnapshotMigrationInterface interface { + performSnapshotMigration(config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema, additionalAttributes internal.AdditionalDataAttributes, infoSchemaI common.InfoSchemaInterface, populateDataConv PopulateDataConvInterface) *writer.BatchWriter + snapshotMigrationHandler(sourceProfile profiles.SourceProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema) (*writer.BatchWriter, error) +} +type SnapshotMigrationImpl struct {} + +func (sm *SnapshotMigrationImpl) performSnapshotMigration(config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema, additionalAttributes internal.AdditionalDataAttributes, infoSchemaI common.InfoSchemaInterface, populateDataConv PopulateDataConvInterface) *writer.BatchWriter { + infoSchemaI.SetRowStats(conv, infoSchema) + totalRows := conv.Rows() + if !conv.Audit.DryRun { + conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) + } + batchWriter := populateDataConv.populateDataConv(conv, config, client) + infoSchemaI.ProcessData(conv, infoSchema, additionalAttributes) + batchWriter.Flush() + return batchWriter +} + +func (sm *SnapshotMigrationImpl) snapshotMigrationHandler(sourceProfile profiles.SourceProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema) (*writer.BatchWriter, error) { + switch sourceProfile.Driver { + // Skip snapshot migration via Spanner migration tool for mysql and oracle since dataflow job will job will handle this from backfilled data. + case constants.MYSQL, constants.ORACLE, constants.POSTGRES: + return &writer.BatchWriter{}, nil + case constants.DYNAMODB: + return sm.performSnapshotMigration(config, conv, client, infoSchema, internal.AdditionalDataAttributes{ShardId: ""}, &common.InfoSchemaImpl{}, &PopulateDataConvImpl{}), nil + default: + return &writer.BatchWriter{}, fmt.Errorf("streaming migration not supported for driver %s", sourceProfile.Driver) + } +} \ No newline at end of file diff --git a/conversion/store_files.go b/conversion/store_files.go new file mode 100644 index 0000000000..76d42b0519 --- /dev/null +++ b/conversion/store_files.go @@ -0,0 +1,274 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "strings" + "time" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" +) + +// WriteSchemaFile writes DDL statements in a file. It includes CREATE TABLE +// statements and ALTER TABLE statements to add foreign keys. +// The parameter name should end with a .txt. +func WriteSchemaFile(conv *internal.Conv, now time.Time, name string, out *os.File, driver string) { + f, err := os.Create(name) + if err != nil { + fmt.Fprintf(out, "Can't create schema file %s: %v\n", name, err) + return + } + + // The schema file we write out below is optimized for reading. It includes comments, foreign keys + // and doesn't add backticks around table and column names. This file is + // intended for explanatory and documentation purposes, and is not strictly + // legal Cloud Spanner DDL (Cloud Spanner doesn't currently support comments). + spDDL := conv.SpSchema.GetDDL(ddl.Config{Comments: true, ProtectIds: false, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) + if len(spDDL) == 0 { + spDDL = []string{"\n-- Schema is empty -- no tables found\n"} + } + l := []string{ + fmt.Sprintf("-- Schema generated %s\n", now.Format("2006-01-02 15:04:05")), + strings.Join(spDDL, ";\n\n"), + "\n", + } + if _, err := f.WriteString(strings.Join(l, "")); err != nil { + fmt.Fprintf(out, "Can't write out schema file: %v\n", err) + return + } + fmt.Fprintf(out, "Wrote schema to file '%s'.\n", name) + + // Convert . to .ddl.. + nameSplit := strings.Split(name, ".") + nameSplit = append(nameSplit[:len(nameSplit)-1], "ddl", nameSplit[len(nameSplit)-1]) + name = strings.Join(nameSplit, ".") + f, err = os.Create(name) + if err != nil { + fmt.Fprintf(out, "Can't create legal schema ddl file %s: %v\n", name, err) + return + } + + // We change 'Comments' to false and 'ProtectIds' to true below to write out a + // schema file that is a legal Cloud Spanner DDL. + spDDL = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) + if len(spDDL) == 0 { + spDDL = []string{"\n-- Schema is empty -- no tables found\n"} + } + l = []string{ + strings.Join(spDDL, ";\n\n"), + "\n", + } + if _, err = f.WriteString(strings.Join(l, "")); err != nil { + fmt.Fprintf(out, "Can't write out legal schema ddl file: %v\n", err) + return + } + fmt.Fprintf(out, "Wrote legal schema ddl to file '%s'.\n", name) +} + +// WriteSessionFile writes conv struct to a file in JSON format. +func WriteSessionFile(conv *internal.Conv, name string, out *os.File) { + f, err := os.Create(name) + if err != nil { + fmt.Fprintf(out, "Can't create session file %s: %v\n", name, err) + return + } + // Session file will basically contain 'conv' struct in JSON format. + // It contains all the information for schema and data conversion state. + convJSON, err := json.MarshalIndent(conv, "", " ") + if err != nil { + fmt.Fprintf(out, "Can't encode session state to JSON: %v\n", err) + return + } + if _, err := f.Write(convJSON); err != nil { + fmt.Fprintf(out, "Can't write out session file: %v\n", err) + return + } + fmt.Fprintf(out, "Wrote session to file '%s'.\n", name) +} + +// WriteConvGeneratedFiles creates a directory labeled downloads with the current timestamp +// where it writes the sessionfile, report summary and DDLs then returns the directory where it writes. +func WriteConvGeneratedFiles(conv *internal.Conv, dbName string, driver string, BytesRead int64, out *os.File) (string, error) { + now := time.Now() + dirPath := "spanner_migration_tool_output/" + dbName + "/" + err := os.MkdirAll(dirPath, os.ModePerm) + if err != nil { + fmt.Fprintf(out, "Can't create directory %s: %v\n", dirPath, err) + return "", err + } + schemaFileName := dirPath + dbName + "_schema.txt" + WriteSchemaFile(conv, now, schemaFileName, out, driver) + reportFileName := dirPath + dbName + Report(driver, nil, BytesRead, "", conv, reportFileName, dbName, out) + sessionFileName := dirPath + dbName + ".session.json" + WriteSessionFile(conv, sessionFileName, out) + return dirPath, nil +} + +// ReadSessionFile reads a session JSON file and +// unmarshal it's content into *internal.Conv. +func ReadSessionFile(conv *internal.Conv, sessionJSON string) error { + s, err := ioutil.ReadFile(sessionJSON) + if err != nil { + return err + } + err = json.Unmarshal(s, &conv) + if err != nil { + return err + } + return nil +} + +// WriteBadData prints summary stats about bad rows and writes detailed info +// to file 'name'. +func WriteBadData(bw *writer.BatchWriter, conv *internal.Conv, banner, name string, out *os.File) { + badConversions := conv.BadRows() + badWrites := utils.SumMapValues(bw.DroppedRowsByTable()) + + badDataStreaming := int64(0) + if conv.Audit.StreamingStats.Streaming { + badDataStreaming = getBadStreamingDataCount(conv) + } + + if badConversions == 0 && badWrites == 0 && badDataStreaming == 0 { + os.Remove(name) // Cleanup bad-data file from previous run. + return + } + f, err := os.Create(name) + if err != nil { + fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) + return + } + f.WriteString(banner) + maxRows := 100 + if badConversions > 0 { + l := conv.SampleBadRows(maxRows) + if int64(len(l)) < badConversions { + f.WriteString("A sample of rows that generated conversion errors:\n") + } else { + f.WriteString("Rows that generated conversion errors:\n") + } + for _, r := range l { + _, err := f.WriteString(" " + r + "\n") + if err != nil { + fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) + return + } + } + } + if badWrites > 0 { + l := bw.SampleBadRows(maxRows) + if int64(len(l)) < badWrites { + f.WriteString("A sample of rows that successfully converted but couldn't be written to Spanner:\n") + } else { + f.WriteString("Rows that successfully converted but couldn't be written to Spanner:\n") + } + for _, r := range l { + _, err := f.WriteString(" " + r + "\n") + if err != nil { + fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) + return + } + } + } + if badDataStreaming > 0 { + err = writeBadStreamingData(conv, f) + if err != nil { + fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) + return + } + } + + fmt.Fprintf(out, "See file '%s' for details of bad rows\n", name) +} + + +// writeBadStreamingData writes sample of bad records and dropped records during streaming +// migration process to bad data file. +func writeBadStreamingData(conv *internal.Conv, f *os.File) error { + f.WriteString("\nBad data encountered during streaming migration:\n\n") + + stats := (conv.Audit.StreamingStats) + + badRecords := int64(0) + for _, x := range stats.BadRecords { + badRecords += utils.SumMapValues(x) + } + droppedRecords := int64(0) + for _, x := range stats.DroppedRecords { + droppedRecords += utils.SumMapValues(x) + } + + if badRecords > 0 { + l := stats.SampleBadRecords + if int64(len(l)) < badRecords { + f.WriteString("A sample of records that generated conversion errors:\n") + } else { + f.WriteString("Records that generated conversion errors:\n") + } + for _, r := range l { + _, err := f.WriteString(" " + r + "\n") + if err != nil { + return err + } + } + f.WriteString("\n") + } + if droppedRecords > 0 { + l := stats.SampleBadWrites + if int64(len(l)) < droppedRecords { + f.WriteString("A sample of records that successfully converted but couldn't be written to Spanner:\n") + } else { + f.WriteString("Records that successfully converted but couldn't be written to Spanner:\n") + } + for _, r := range l { + _, err := f.WriteString(" " + r + "\n") + if err != nil { + return err + } + } + } + return nil +} + +// getBadStreamingDataCount returns the total sum of bad and dropped records during +// streaming migration process. +func getBadStreamingDataCount(conv *internal.Conv) int64 { + badDataCount := int64(0) + + for _, x := range conv.Audit.StreamingStats.BadRecords { + badDataCount += utils.SumMapValues(x) + } + for _, x := range conv.Audit.StreamingStats.DroppedRecords { + badDataCount += utils.SumMapValues(x) + } + return badDataCount +} diff --git a/conversion/validations.go b/conversion/validations.go new file mode 100644 index 0000000000..38da6d11fa --- /dev/null +++ b/conversion/validations.go @@ -0,0 +1,108 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) + +package conversion + +import ( + "context" + "fmt" + "strings" + "time" + + sp "cloud.google.com/go/spanner" + database "cloud.google.com/go/spanner/admin/database/apiv1" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/spanner" + adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" +) + +// VerifyDb checks whether the db exists and if it does, verifies if the schema is what we currently support. +func VerifyDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (dbExists bool, err error) { + dbExists, err = CheckExistingDb(ctx, adminClient, dbURI) + if err != nil { + return dbExists, err + } + if dbExists { + err = ValidateDDL(ctx, adminClient, dbURI) + } + return dbExists, err +} + +// CheckExistingDb checks whether the database with dbURI exists or not. +// If API call doesn't respond then user is informed after every 5 minutes on command line. +func CheckExistingDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (bool, error) { + gotResponse := make(chan bool) + var err error + go func() { + _, err = adminClient.GetDatabase(ctx, &adminpb.GetDatabaseRequest{Name: dbURI}) + gotResponse <- true + }() + for { + select { + case <-time.After(5 * time.Minute): + fmt.Println("WARNING! API call not responding: make sure that spanner api endpoint is configured properly") + case <-gotResponse: + if err != nil { + if utils.ContainsAny(strings.ToLower(err.Error()), []string{"database not found"}) { + return false, nil + } + return false, fmt.Errorf("can't get database info: %s", err) + } + return true, nil + } + } +} + +// ValidateTables validates that all the tables in the database are empty. +// It returns the name of the first non-empty table if found, and an empty string otherwise. +func ValidateTables(ctx context.Context, client *sp.Client, spDialect string) (string, error) { + infoSchema := spanner.InfoSchemaImpl{Client: client, Ctx: ctx, SpDialect: spDialect} + tables, err := infoSchema.GetTables() + if err != nil { + return "", err + } + for _, table := range tables { + count, err := infoSchema.GetRowCount(table) + if err != nil { + return "", err + } + if count != 0 { + return table.Name, nil + } + } + return "", nil +} + +// ValidateDDL verifies if an existing DB's ddl follows what is supported by Spanner migration tool. Currently, +// we only support empty schema when db already exists. +func ValidateDDL(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) error { + dbDdl, err := adminClient.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{Database: dbURI}) + if err != nil { + return fmt.Errorf("can't fetch database ddl: %v", err) + } + if len(dbDdl.Statements) != 0 { + return fmt.Errorf("spanner migration tool supports writing to existing databases only if they have an empty schema") + } + return nil +} \ No newline at end of file diff --git a/sources/common/dbdump.go b/sources/common/dbdump.go index 868df90b1f..5e60fd032d 100644 --- a/sources/common/dbdump.go +++ b/sources/common/dbdump.go @@ -34,9 +34,11 @@ func ProcessDbDump(conv *internal.Conv, r *internal.Reader, dbDump DbDump) error return err } if conv.SchemaMode() { - initPrimaryKeyOrder(conv) - initIndexOrder(conv) - SchemaToSpannerDDL(conv, dbDump.GetToDdl()) + utilsOrder := UtilsOrderImpl{} + utilsOrder.initPrimaryKeyOrder(conv) + utilsOrder.initIndexOrder(conv) + schemaToSpanner := SchemaToSpannerImpl{} + schemaToSpanner.SchemaToSpannerDDL(conv, dbDump.GetToDdl()) conv.AddPrimaryKeys() } return nil diff --git a/sources/common/infoschema.go b/sources/common/infoschema.go index dc7cb233c4..8845c7bfe8 100644 --- a/sources/common/infoschema.go +++ b/sources/common/infoschema.go @@ -58,18 +58,33 @@ type FkConstraint struct { Cols []string } +type InfoSchemaInterface interface { + GenerateSrcSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers int) (int, error) + ProcessData(conv *internal.Conv, infoSchema InfoSchema, additionalAttributes internal.AdditionalDataAttributes) + SetRowStats(conv *internal.Conv, infoSchema InfoSchema) + processTable(conv *internal.Conv, table SchemaAndName, infoSchema InfoSchema) (schema.Table, error) + GetIncludedSrcTablesFromConv(conv *internal.Conv) (schemaToTablesMap map[string]internal.SchemaDetails, err error) +} +type InfoSchemaImpl struct {} + +type ProcessSchemaInterface interface { + ProcessSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers int, attributes internal.AdditionalSchemaAttributes, s SchemaToSpannerInterface, uo UtilsOrderInterface, is InfoSchemaInterface) error +} + +type ProcessSchemaImpl struct {} + // ProcessSchema performs schema conversion for source database // 'db'. Information schema tables are a broadly supported ANSI standard, // and we use them to obtain source database's schema information. -func ProcessSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers int, attributes internal.AdditionalSchemaAttributes) error { +func (ps* ProcessSchemaImpl) ProcessSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers int, attributes internal.AdditionalSchemaAttributes, s SchemaToSpannerInterface, uo UtilsOrderInterface, is InfoSchemaInterface) error { - tableCount, err := GenerateSrcSchema(conv, infoSchema, numWorkers) + tableCount, err := is.GenerateSrcSchema(conv, infoSchema, numWorkers) if err != nil { return err } - initPrimaryKeyOrder(conv) - initIndexOrder(conv) - SchemaToSpannerDDL(conv, infoSchema.GetToDdl()) + uo.initPrimaryKeyOrder(conv) + uo.initIndexOrder(conv) + s.SchemaToSpannerDDL(conv, infoSchema.GetToDdl()) if tableCount != len(conv.SpSchema) { fmt.Printf("Failed to load all the source tables, source table count: %v, processed tables:%v. Please retry connecting to the source database to load tables.\n", tableCount, len(conv.SpSchema)) return fmt.Errorf("failed to load all the source tables, source table count: %v, processed tables:%v. Please retry connecting to the source database to load tables.", tableCount, len(conv.SpSchema)) @@ -82,7 +97,7 @@ func ProcessSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers int, a return nil } -func GenerateSrcSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers int) (int, error) { +func (is *InfoSchemaImpl) GenerateSrcSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers int) (int, error) { tables, err := infoSchema.GetTables() fmt.Println("fetched tables", tables) if err != nil { @@ -94,7 +109,7 @@ func GenerateSrcSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers in } asyncProcessTable := func(t SchemaAndName, mutex *sync.Mutex) TaskResult[SchemaAndName] { - table, e := processTable(conv, t, infoSchema) + table, e := is.processTable(conv, t, infoSchema) mutex.Lock() conv.SrcSchema[table.Id] = table mutex.Unlock() @@ -117,7 +132,7 @@ func GenerateSrcSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers in // (based on the source and Spanner schemas), and write it to Spanner. // If we can't get/process data for a table, we skip that table and process // the remaining tables. -func ProcessData(conv *internal.Conv, infoSchema InfoSchema, additionalAttributes internal.AdditionalDataAttributes) { +func (is *InfoSchemaImpl) ProcessData(conv *internal.Conv, infoSchema InfoSchema, additionalAttributes internal.AdditionalDataAttributes) { // Tables are ordered in alphabetical order with one exception: interleaved // tables appear after the population of their parent table. tableIds := ddl.GetSortedTableIdsBySpName(conv.SpSchema) @@ -145,7 +160,7 @@ func ProcessData(conv *internal.Conv, infoSchema InfoSchema, additionalAttribute } // SetRowStats populates conv with the number of rows in each table. -func SetRowStats(conv *internal.Conv, infoSchema InfoSchema) { +func (is *InfoSchemaImpl) SetRowStats(conv *internal.Conv, infoSchema InfoSchema) { tables, err := infoSchema.GetTables() if err != nil { conv.Unexpected(fmt.Sprintf("Couldn't get list of table: %s", err)) @@ -162,7 +177,7 @@ func SetRowStats(conv *internal.Conv, infoSchema InfoSchema) { } } -func processTable(conv *internal.Conv, table SchemaAndName, infoSchema InfoSchema) (schema.Table, error) { +func (is *InfoSchemaImpl) processTable(conv *internal.Conv, table SchemaAndName, infoSchema InfoSchema) (schema.Table, error) { var t schema.Table fmt.Println("processing schema for table", table) tblId := internal.GenerateTableId() @@ -209,7 +224,7 @@ func processTable(conv *internal.Conv, table SchemaAndName, infoSchema InfoSchem // getIncludedSrcTablesFromConv fetches the list of tables // from the source database that need to be migrated. -func GetIncludedSrcTablesFromConv(conv *internal.Conv) (schemaToTablesMap map[string]internal.SchemaDetails, err error) { +func (is *InfoSchemaImpl) GetIncludedSrcTablesFromConv(conv *internal.Conv) (schemaToTablesMap map[string]internal.SchemaDetails, err error) { schemaToTablesMap = make(map[string]internal.SchemaDetails) for spTable := range conv.SpSchema { //lookup the spanner table in the source tables via ID diff --git a/sources/common/mocks.go b/sources/common/mocks.go new file mode 100644 index 0000000000..85a38d3524 --- /dev/null +++ b/sources/common/mocks.go @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" + "github.com/stretchr/testify/mock" +) + +type MockInfoSchema struct { + mock.Mock +} + +func (mis *MockInfoSchema) GenerateSrcSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers int) (int, error) { + args := mis.Called(conv, infoSchema, numWorkers) + return args.Get(0).(int), args.Error(1) +} +func (mis *MockInfoSchema) ProcessData(conv *internal.Conv, infoSchema InfoSchema, additionalAttributes internal.AdditionalDataAttributes) {} +func (mis *MockInfoSchema) SetRowStats(conv *internal.Conv, infoSchema InfoSchema) {} +func (mis *MockInfoSchema) processTable(conv *internal.Conv, table SchemaAndName, infoSchema InfoSchema) (schema.Table, error) { + args := mis.Called(conv, table, infoSchema) + return args.Get(0).(schema.Table), args.Error(1) +} +func (mis *MockInfoSchema) GetIncludedSrcTablesFromConv(conv *internal.Conv) (schemaToTablesMap map[string]internal.SchemaDetails, err error) { + args := mis.Called(conv) + return args.Get(0).(map[string]internal.SchemaDetails), args.Error(1) +} + +type MockUtilsOrder struct { + mock.Mock +} + +func (muo *MockUtilsOrder) initPrimaryKeyOrder(conv *internal.Conv) {} + +func (muo *MockUtilsOrder) initIndexOrder(conv *internal.Conv) {} + +type MockSchemaToSpanner struct { + mock.Mock +} + +func (mss *MockSchemaToSpanner) SchemaToSpannerDDL(conv *internal.Conv, toddl ToDdl) error { + args := mss.Called(conv, toddl) + return args.Error(0) +} + +func (mss *MockSchemaToSpanner) SchemaToSpannerDDLHelper(conv *internal.Conv, toddl ToDdl, srcTable schema.Table, isRestore bool) error { + args := mss.Called(conv, toddl, srcTable, isRestore) + return args.Error(0) +} + +type MockProcessSchema struct { + mock.Mock +} + +func (mps *MockProcessSchema) ProcessSchema(conv *internal.Conv, infoSchema InfoSchema, numWorkers int, attributes internal.AdditionalSchemaAttributes, s SchemaToSpannerInterface, uo UtilsOrderInterface, is InfoSchemaInterface) error { + args := mps.Called(conv, infoSchema, numWorkers, attributes, s, uo, is) + return args.Error(0) +} \ No newline at end of file diff --git a/sources/common/toddl.go b/sources/common/toddl.go index 163336b6d6..86d6509a13 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -49,20 +49,27 @@ type ToDdl interface { ToSpannerType(conv *internal.Conv, spType string, srcType schema.Type) (ddl.Type, []internal.SchemaIssue) } +type SchemaToSpannerInterface interface { + SchemaToSpannerDDL(conv *internal.Conv, toddl ToDdl) error + SchemaToSpannerDDLHelper(conv *internal.Conv, toddl ToDdl, srcTable schema.Table, isRestore bool) error +} + +type SchemaToSpannerImpl struct {} + // SchemaToSpannerDDL performs schema conversion from the source DB schema to // Spanner. It uses the source schema in conv.SrcSchema, and writes // the Spanner schema to conv.SpSchema. -func SchemaToSpannerDDL(conv *internal.Conv, toddl ToDdl) error { +func (ss *SchemaToSpannerImpl) SchemaToSpannerDDL(conv *internal.Conv, toddl ToDdl) error { tableIds := GetSortedTableIdsBySrcName(conv.SrcSchema) for _, tableId := range tableIds { srcTable := conv.SrcSchema[tableId] - SchemaToSpannerDDLHelper(conv, toddl, srcTable, false) + ss.SchemaToSpannerDDLHelper(conv, toddl, srcTable, false) } internal.ResolveRefs(conv) return nil } -func SchemaToSpannerDDLHelper(conv *internal.Conv, toddl ToDdl, srcTable schema.Table, isRestore bool) error { +func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, toddl ToDdl, srcTable schema.Table, isRestore bool) error { spTableName, err := internal.GetSpannerTable(conv, srcTable.Id) if err != nil { conv.Unexpected(fmt.Sprintf("Couldn't map source table %s to Spanner: %s", srcTable.Name, err)) @@ -226,7 +233,8 @@ func cvtIndexes(conv *internal.Conv, tableId string, srcIndexes []schema.Index, } func SrcTableToSpannerDDL(conv *internal.Conv, toddl ToDdl, srcTable schema.Table) error { - err := SchemaToSpannerDDLHelper(conv, toddl, srcTable, true) + schemaToSpanner := SchemaToSpannerImpl{} + err := schemaToSpanner.SchemaToSpannerDDLHelper(conv, toddl, srcTable, true) if err != nil { return err } diff --git a/sources/common/utils.go b/sources/common/utils.go index 8b7a7a50d1..820db08916 100644 --- a/sources/common/utils.go +++ b/sources/common/utils.go @@ -25,6 +25,12 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" ) +type UtilsOrderInterface interface { + initPrimaryKeyOrder(conv *internal.Conv) + initIndexOrder(conv *internal.Conv) +} +type UtilsOrderImpl struct {} + // ToNotNull returns true if a column is not nullable and false if it is. func ToNotNull(conv *internal.Conv, isNullable string) bool { switch isNullable { @@ -68,7 +74,7 @@ func GetSortedTableIdsBySrcName(srcSchema map[string]schema.Table) []string { return sortedTableIds } -func initPrimaryKeyOrder(conv *internal.Conv) { +func (uo *UtilsOrderImpl) initPrimaryKeyOrder(conv *internal.Conv) { for k, table := range conv.SrcSchema { for i := range table.PrimaryKeys { conv.SrcSchema[k].PrimaryKeys[i].Order = i + 1 @@ -76,7 +82,7 @@ func initPrimaryKeyOrder(conv *internal.Conv) { } } -func initIndexOrder(conv *internal.Conv) { +func (uo *UtilsOrderImpl) initIndexOrder(conv *internal.Conv) { for k, table := range conv.SrcSchema { for i, index := range table.Indexes { for j := range index.Keys { diff --git a/sources/csv/data.go b/sources/csv/data.go index 959754c8e8..69363a0906 100644 --- a/sources/csv/data.go +++ b/sources/csv/data.go @@ -35,8 +35,16 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" ) +type CsvInterface interface { + GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tables []utils.ManifestTable, err error) + SetRowStats(conv *internal.Conv, tables []utils.ManifestTable, delimiter rune) error + ProcessCSV(conv *internal.Conv, tables []utils.ManifestTable, nullStr string, delimiter rune) error +} + +type CsvImpl struct {} + // GetCSVFiles finds the appropriate files paths and downloads gcs files in any. -func GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tables []utils.ManifestTable, err error) { +func (c *CsvImpl) GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tables []utils.ManifestTable, err error) { // If manifest file not provided, we assume the csvs exist in the same directory // in table_name.csv format. if sourceProfile.Csv.Manifest == "" { @@ -120,7 +128,7 @@ func VerifyManifest(conv *internal.Conv, tables []utils.ManifestTable) error { } // SetRowStats calculates the number of rows per table. -func SetRowStats(conv *internal.Conv, tables []utils.ManifestTable, delimiter rune) error { +func (c *CsvImpl) SetRowStats(conv *internal.Conv, tables []utils.ManifestTable, delimiter rune) error { for _, table := range tables { for _, filePath := range table.File_patterns { csvFile, err := os.Open(filePath) @@ -184,7 +192,7 @@ func getCSVDataRowCount(r *csvReader.Reader, colNames []string) (int64, error) { // ProcessCSV writes data across the tables provided in the manifest file. Each table's data can be provided // across multiple CSV files hence, the manifest accepts a list of file paths in the input. -func ProcessCSV(conv *internal.Conv, tables []utils.ManifestTable, nullStr string, delimiter rune) error { +func (c *CsvImpl) ProcessCSV(conv *internal.Conv, tables []utils.ManifestTable, nullStr string, delimiter rune) error { tableIds := ddl.GetSortedTableIdsBySpName(conv.SpSchema) nameToFiles := map[string][]string{} for _, table := range tables { diff --git a/sources/csv/data_test.go b/sources/csv/data_test.go index a03e222623..843abc2416 100644 --- a/sources/csv/data_test.go +++ b/sources/csv/data_test.go @@ -109,10 +109,11 @@ func cleanupCSVs() { } func TestSetRowStats(t *testing.T) { + csv := CsvImpl{} conv := buildConv(getCreateTable()) writeCSVs(t) defer cleanupCSVs() - SetRowStats(conv, getManifestTables(), ',') + csv.SetRowStats(conv, getManifestTables(), ',') assert.Equal(t, map[string]int64{ALL_TYPES_TABLE: 1, SINGERS_TABLE: 2}, conv.Stats.Rows) } @@ -128,7 +129,8 @@ func TestProcessCSV(t *testing.T) { func(table string, cols []string, vals []interface{}) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) - err := ProcessCSV(conv, tables, "", ',') + csv := CsvImpl{} + err := csv.ProcessCSV(conv, tables, "", ',') assert.Nil(t, err) assert.Equal(t, []spannerData{ { diff --git a/sources/dynamodb/schema_test.go b/sources/dynamodb/schema_test.go index f3182a763f..749072cf81 100644 --- a/sources/dynamodb/schema_test.go +++ b/sources/dynamodb/schema_test.go @@ -188,7 +188,8 @@ func TestProcessSchema(t *testing.T) { sampleSize := int64(10000) conv := internal.MakeConv() - err := common.ProcessSchema(conv, InfoSchemaImpl{client, nil, sampleSize}, 1, internal.AdditionalSchemaAttributes{}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, InfoSchemaImpl{client, nil, sampleSize}, 1, internal.AdditionalSchemaAttributes{}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) expectedSchema := map[string]ddl.CreateTable{ @@ -288,7 +289,8 @@ func TestProcessSchema_FullDataTypes(t *testing.T) { sampleSize := int64(10000) conv := internal.MakeConv() - err := common.ProcessSchema(conv, InfoSchemaImpl{client, nil, sampleSize}, 1, internal.AdditionalSchemaAttributes{}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, InfoSchemaImpl{client, nil, sampleSize}, 1, internal.AdditionalSchemaAttributes{}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) expectedSchema := map[string]ddl.CreateTable{ @@ -371,7 +373,8 @@ func TestProcessData(t *testing.T) { func(table string, cols []string, vals []interface{}) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) - common.ProcessData(conv, InfoSchemaImpl{client, nil, 10}, internal.AdditionalDataAttributes{}) + commonInfoSchema := common.InfoSchemaImpl{} + commonInfoSchema.ProcessData(conv, InfoSchemaImpl{client, nil, 10}, internal.AdditionalDataAttributes{}) assert.Equal(t, []spannerData{ { @@ -962,7 +965,8 @@ func TestSetRowStats(t *testing.T) { describeTableOutputs: describeTableOutputs, } - common.SetRowStats(conv, InfoSchemaImpl{client, nil, 10}) + commonInfoSchema := common.InfoSchemaImpl{} + commonInfoSchema.SetRowStats(conv, InfoSchemaImpl{client, nil, 10}) assert.Equal(t, tableItemCountA, conv.Stats.Rows[tableNameA]) assert.Equal(t, tableItemCountB, conv.Stats.Rows[tableNameB]) diff --git a/sources/dynamodb/toddl_test.go b/sources/dynamodb/toddl_test.go index 7588c70968..7bf11d4d49 100644 --- a/sources/dynamodb/toddl_test.go +++ b/sources/dynamodb/toddl_test.go @@ -58,7 +58,8 @@ func TestToSpannerType(t *testing.T) { } conv.SrcSchema[name] = srcSchema conv.Audit = audit - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[name] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ @@ -118,7 +119,8 @@ func TestToSpannerPostgreSQLDialectType(t *testing.T) { } conv.SrcSchema["t1"] = srcSchema conv.Audit = audit - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema["t1"] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ diff --git a/sources/mysql/infoschema.go b/sources/mysql/infoschema.go index 48f03f2078..d54dbac3d7 100644 --- a/sources/mysql/infoschema.go +++ b/sources/mysql/infoschema.go @@ -366,7 +366,8 @@ func (isi InfoSchemaImpl) StartChangeDataCapture(ctx context.Context, conv *inte schemaDetails map[string]internal.SchemaDetails err error ) - schemaDetails, err = common.GetIncludedSrcTablesFromConv(conv) + commonInfoSchema := common.InfoSchemaImpl{} + schemaDetails, err = commonInfoSchema.GetIncludedSrcTablesFromConv(conv) streamingCfg, err := streaming.ReadStreamingConfig(isi.SourceProfile.Conn.Mysql.StreamingConfig, isi.TargetProfile.Conn.Sp.Dbname, schemaDetails) if err != nil { return nil, fmt.Errorf("error reading streaming config: %v", err) diff --git a/sources/mysql/infoschema_test.go b/sources/mysql/infoschema_test.go index 4c99c63f28..eca4bfd2bd 100644 --- a/sources/mysql/infoschema_test.go +++ b/sources/mysql/infoschema_test.go @@ -210,7 +210,8 @@ func TestProcessSchemaMYSQL(t *testing.T) { db := mkMockDB(t, ms) conv := internal.MakeConv() isi := InfoSchemaImpl{"test", db, profiles.SourceProfile{}, profiles.TargetProfile{}} - _, err := common.GenerateSrcSchema(conv, isi, 1) + commonInfoSchema := common.InfoSchemaImpl{} + _, err := commonInfoSchema.GenerateSrcSchema(conv, isi, 1) assert.Nil(t, err) expectedSchema := map[string]schema.Table{ "cart": schema.Table{Name: "cart", Schema: "test", ColIds: []string{"productid", "userid", "quantity"}, ColDefs: map[string]schema.Column{ @@ -308,7 +309,8 @@ func TestProcessData(t *testing.T) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) isi := InfoSchemaImpl{"test", db, profiles.SourceProfile{}, profiles.TargetProfile{}} - common.ProcessData(conv, isi, internal.AdditionalDataAttributes{}) + commonInfoSchema := common.InfoSchemaImpl{} + commonInfoSchema.ProcessData(conv, isi, internal.AdditionalDataAttributes{}) assert.Equal(t, []spannerData{ spannerData{table: "te_st", cols: []string{"a_a", "Ab", "Ac_"}, vals: []interface{}{float64(42.3), int64(3), "cat"}}, @@ -366,7 +368,8 @@ func TestProcessData_MultiCol(t *testing.T) { db := mkMockDB(t, ms) conv := internal.MakeConv() isi := InfoSchemaImpl{"test", db, profiles.SourceProfile{}, profiles.TargetProfile{}} - err := common.ProcessSchema(conv, isi, 1, internal.AdditionalSchemaAttributes{}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, isi, 1, internal.AdditionalSchemaAttributes{}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) expectedSchema := map[string]ddl.CreateTable{ "test": ddl.CreateTable{ @@ -382,7 +385,7 @@ func TestProcessData_MultiCol(t *testing.T) { } internal.AssertSpSchema(conv, t, expectedSchema, stripSchemaComments(conv.SpSchema)) columnLevelIssues := map[string][]internal.SchemaIssue{ - "c48": []internal.SchemaIssue { + "c48": []internal.SchemaIssue{ 2, }, } @@ -399,7 +402,8 @@ func TestProcessData_MultiCol(t *testing.T) { func(table string, cols []string, vals []interface{}) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) - common.ProcessData(conv, isi, internal.AdditionalDataAttributes{}) + commonInfoSchema := common.InfoSchemaImpl{} + commonInfoSchema.ProcessData(conv, isi, internal.AdditionalDataAttributes{}) assert.Equal(t, []spannerData{ {table: "test", cols: []string{"a", "b", "synth_id"}, vals: []interface{}{"cat", float64(42.3), "0"}}, {table: "test", cols: []string{"a", "c", "synth_id"}, vals: []interface{}{"dog", int64(22), "-9223372036854775808"}}}, @@ -453,7 +457,8 @@ func TestProcessSchema_Sharded(t *testing.T) { db := mkMockDB(t, ms) conv := internal.MakeConv() isi := InfoSchemaImpl{"test", db, profiles.SourceProfile{}, profiles.TargetProfile{}} - err := common.ProcessSchema(conv, isi, 1, internal.AdditionalSchemaAttributes{IsSharded: true}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, isi, 1, internal.AdditionalSchemaAttributes{IsSharded: true}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) expectedSchema := map[string]ddl.CreateTable{ "test": { @@ -492,7 +497,8 @@ func TestSetRowStats(t *testing.T) { conv := internal.MakeConv() conv.SetDataMode() isi := InfoSchemaImpl{"test", db, profiles.SourceProfile{}, profiles.TargetProfile{}} - common.SetRowStats(conv, isi) + commonInfoSchema := common.InfoSchemaImpl{} + commonInfoSchema.SetRowStats(conv, isi) assert.Equal(t, int64(5), conv.Stats.Rows["test1"]) assert.Equal(t, int64(142), conv.Stats.Rows["test2"]) assert.Equal(t, int64(0), conv.Unexpecteds()) diff --git a/sources/mysql/toddl_test.go b/sources/mysql/toddl_test.go index 8ef626e62d..b7e0fb0534 100644 --- a/sources/mysql/toddl_test.go +++ b/sources/mysql/toddl_test.go @@ -172,7 +172,8 @@ func TestToSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c14"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ @@ -202,7 +203,8 @@ func TestToSpannerType(t *testing.T) { "c2": []internal.SchemaIssue{internal.Widened}, } assert.Equal(t, expectedIssues, conv.SchemaIssues[tableId].ColumnLevelIssues) - tableList, _ := common.GetIncludedSrcTablesFromConv(conv) + commonInfoSchema := common.InfoSchemaImpl{} + tableList, _ := commonInfoSchema.GetIncludedSrcTablesFromConv(conv) keys := make([]string, 0, len(tableList)) for k := range tableList { keys = append(keys, k) @@ -263,7 +265,8 @@ func TestToSpannerPostgreSQLDialectType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c14"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ diff --git a/sources/oracle/infoschema.go b/sources/oracle/infoschema.go index dc6d26b998..3a6cbbb51b 100644 --- a/sources/oracle/infoschema.go +++ b/sources/oracle/infoschema.go @@ -430,7 +430,9 @@ func (isi InfoSchemaImpl) StartChangeDataCapture(ctx context.Context, conv *inte schemaDetails map[string]internal.SchemaDetails err error ) - schemaDetails, err = common.GetIncludedSrcTablesFromConv(conv) + + commonInfoSchema := common.InfoSchemaImpl{} + schemaDetails, err = commonInfoSchema.GetIncludedSrcTablesFromConv(conv) streamingCfg, err := streaming.ReadStreamingConfig(isi.SourceProfile.Conn.Oracle.StreamingConfig, isi.TargetProfile.Conn.Sp.Dbname, schemaDetails) if err != nil { return nil, fmt.Errorf("error reading streaming config: %v", err) diff --git a/sources/oracle/infoschema_test.go b/sources/oracle/infoschema_test.go index b088eef190..4c81307ccf 100644 --- a/sources/oracle/infoschema_test.go +++ b/sources/oracle/infoschema_test.go @@ -158,7 +158,8 @@ func TestProcessSchemaOracle(t *testing.T) { } db := mkMockDB(t, ms) conv := internal.MakeConv() - err := common.ProcessSchema(conv, InfoSchemaImpl{"test", db, profiles.SourceProfile{}, profiles.TargetProfile{}}, 1, internal.AdditionalSchemaAttributes{}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, InfoSchemaImpl{"test", db, profiles.SourceProfile{}, profiles.TargetProfile{}}, 1, internal.AdditionalSchemaAttributes{}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) expectedSchema := map[string]ddl.CreateTable{ "USER": { diff --git a/sources/oracle/toddl_test.go b/sources/oracle/toddl_test.go index e8817ad0f9..fc4bceb527 100644 --- a/sources/oracle/toddl_test.go +++ b/sources/oracle/toddl_test.go @@ -158,7 +158,8 @@ func TestToSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c12"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ @@ -244,7 +245,8 @@ func TestToSpannerPostgreSQLDialectType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c15"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ diff --git a/sources/postgres/infoschema.go b/sources/postgres/infoschema.go index 35bd9dd7de..bd5756d3f2 100644 --- a/sources/postgres/infoschema.go +++ b/sources/postgres/infoschema.go @@ -65,7 +65,8 @@ func (isi InfoSchemaImpl) StartChangeDataCapture(ctx context.Context, conv *inte schemaDetails map[string]internal.SchemaDetails err error ) - schemaDetails, err = common.GetIncludedSrcTablesFromConv(conv) + commonInfoSchema := common.InfoSchemaImpl{} + schemaDetails, err = commonInfoSchema.GetIncludedSrcTablesFromConv(conv) if err != nil { err = fmt.Errorf("error fetching the tableList to setup datastream migration, defaulting to all tables: %v", err) } diff --git a/sources/postgres/infoschema_test.go b/sources/postgres/infoschema_test.go index 46639ea02f..381ceecc49 100644 --- a/sources/postgres/infoschema_test.go +++ b/sources/postgres/infoschema_test.go @@ -229,7 +229,8 @@ func TestProcessSchema(t *testing.T) { } db := mkMockDB(t, ms) conv := internal.MakeConv() - err := common.ProcessSchema(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, 1, internal.AdditionalSchemaAttributes{}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, 1, internal.AdditionalSchemaAttributes{}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) expectedSchema := map[string]ddl.CreateTable{ "user": ddl.CreateTable{ @@ -364,7 +365,8 @@ func TestProcessData(t *testing.T) { func(table string, cols []string, vals []interface{}) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) - common.ProcessData(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, internal.AdditionalDataAttributes{}) + commonInfoSchema := common.InfoSchemaImpl{} + commonInfoSchema.ProcessData(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, internal.AdditionalDataAttributes{}) assert.Equal(t, []spannerData{ @@ -505,7 +507,8 @@ func TestConvertSqlRow_MultiCol(t *testing.T) { } db := mkMockDB(t, ms) conv := internal.MakeConv() - err := common.ProcessSchema(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, 1, internal.AdditionalSchemaAttributes{}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, 1, internal.AdditionalSchemaAttributes{}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) conv.SetDataMode() var rows []spannerData @@ -513,7 +516,8 @@ func TestConvertSqlRow_MultiCol(t *testing.T) { func(table string, cols []string, vals []interface{}) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) - common.ProcessData(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, internal.AdditionalDataAttributes{}) + commonInfoSchema := common.InfoSchemaImpl{} + commonInfoSchema.ProcessData(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}, internal.AdditionalDataAttributes{}) assert.Equal(t, []spannerData{ {table: "test", cols: []string{"a", "b", "synth_id"}, vals: []interface{}{"cat", float64(42.3), "0"}}, {table: "test", cols: []string{"a", "c", "synth_id"}, vals: []interface{}{"dog", int64(22), "-9223372036854775808"}}}, @@ -540,7 +544,8 @@ func TestSetRowStats(t *testing.T) { db := mkMockDB(t, ms) conv := internal.MakeConv() conv.SetDataMode() - common.SetRowStats(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}) + commonInfoSchema := common.InfoSchemaImpl{} + commonInfoSchema.SetRowStats(conv, InfoSchemaImpl{db, profiles.SourceProfile{}, profiles.TargetProfile{}, newFalsePtr()}) assert.Equal(t, int64(5), conv.Stats.Rows["test1"]) assert.Equal(t, int64(142), conv.Stats.Rows["test2"]) assert.Equal(t, int64(0), conv.Unexpecteds()) diff --git a/sources/postgres/toddl_test.go b/sources/postgres/toddl_test.go index b3566efd08..961058a0b6 100644 --- a/sources/postgres/toddl_test.go +++ b/sources/postgres/toddl_test.go @@ -162,7 +162,8 @@ func TestToSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c10"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ @@ -242,7 +243,8 @@ func TestToExperimentalSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c10"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ diff --git a/sources/sqlserver/infoschema_test.go b/sources/sqlserver/infoschema_test.go index c8ec79d84f..051b99935c 100644 --- a/sources/sqlserver/infoschema_test.go +++ b/sources/sqlserver/infoschema_test.go @@ -248,7 +248,8 @@ func TestProcessSchema(t *testing.T) { } db := mkMockDB(t, ms) conv := internal.MakeConv() - err := common.ProcessSchema(conv, InfoSchemaImpl{"test", db}, 1, internal.AdditionalSchemaAttributes{}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, InfoSchemaImpl{"test", db}, 1, internal.AdditionalSchemaAttributes{}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) assert.Nil(t, err) expectedSchema := map[string]ddl.CreateTable{ "user": { diff --git a/sources/sqlserver/toddl_test.go b/sources/sqlserver/toddl_test.go index 4484279957..216fd769d6 100644 --- a/sources/sqlserver/toddl_test.go +++ b/sources/sqlserver/toddl_test.go @@ -164,7 +164,8 @@ func TestToSpannerType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c19"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ @@ -262,7 +263,8 @@ func TestToSpannerPostgreSQLDialectType(t *testing.T) { PrimaryKeys: []schema.Key{{ColId: "c19"}}, } conv.UsedNames = map[string]bool{"ref_table": true, "ref_table2": true} - assert.Nil(t, common.SchemaToSpannerDDL(conv, ToDdlImpl{})) + schemaToSpanner := common.SchemaToSpannerImpl{} + assert.Nil(t, schemaToSpanner.SchemaToSpannerDDL(conv, ToDdlImpl{})) actual := conv.SpSchema[tableId] dropComments(&actual) // Don't test comment. expected := ddl.CreateTable{ diff --git a/webv2/web.go b/webv2/web.go index 70f987a0f5..fe842855f6 100644 --- a/webv2/web.go +++ b/webv2/web.go @@ -253,16 +253,17 @@ func convertSchemaSQL(w http.ResponseWriter, r *http.Request) { additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ IsSharded: sessionState.IsSharded, } + processSchema := common.ProcessSchemaImpl{} switch sessionState.Driver { case constants.MYSQL: - err = common.ProcessSchema(conv, mysql.InfoSchemaImpl{DbName: sessionState.DbName, Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes) + err = processSchema.ProcessSchema(conv, mysql.InfoSchemaImpl{DbName: sessionState.DbName, Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) case constants.POSTGRES: temp := false - err = common.ProcessSchema(conv, postgres.InfoSchemaImpl{Db: sessionState.SourceDB, IsSchemaUnique: &temp}, common.DefaultWorkers, additionalSchemaAttributes) + err = processSchema.ProcessSchema(conv, postgres.InfoSchemaImpl{Db: sessionState.SourceDB, IsSchemaUnique: &temp}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) case constants.SQLSERVER: - err = common.ProcessSchema(conv, sqlserver.InfoSchemaImpl{DbName: sessionState.DbName, Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes) + err = processSchema.ProcessSchema(conv, sqlserver.InfoSchemaImpl{DbName: sessionState.DbName, Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) case constants.ORACLE: - err = common.ProcessSchema(conv, oracle.InfoSchemaImpl{DbName: strings.ToUpper(sessionState.DbName), Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes) + err = processSchema.ProcessSchema(conv, oracle.InfoSchemaImpl{DbName: strings.ToUpper(sessionState.DbName), Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) default: http.Error(w, fmt.Sprintf("Driver : '%s' is not supported", sessionState.Driver), http.StatusBadRequest) return @@ -598,7 +599,8 @@ func convertSchemaDump(w http.ResponseWriter, r *http.Request) { n := profiles.NewSourceProfileImpl{} sourceProfile, _ := profiles.NewSourceProfile("", dc.Config.Driver, &n) sourceProfile.Driver = dc.Config.Driver - conv, err := conversion.SchemaFromDump(sourceProfile.Driver, dc.SpannerDetails.Dialect, &utils.IOStreams{In: f, Out: os.Stdout}) + schemaFromSource := conversion.SchemaFromSourceImpl{} + conv, err := schemaFromSource.SchemaFromDump(sourceProfile.Driver, dc.SpannerDetails.Dialect, &utils.IOStreams{In: f, Out: os.Stdout}, &conversion.ProcessDumpByDialectImpl{}) if err != nil { http.Error(w, fmt.Sprintf("Schema Conversion Error : %v", err), http.StatusNotFound) return From d9ea76286fed92204965f27812ac5efbe816a26f Mon Sep 17 00:00:00 2001 From: Manit Gupta Date: Thu, 8 Feb 2024 17:37:59 +0530 Subject: [PATCH 05/15] Add CODEOWNERS file (#764) --- .github/CODEOWNERS | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..e7546324fd --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,3 @@ +# review when someone opens a pull request + +* @GoogleCloudPlatform/spanner-migrations-team From cd1a94289c618f1ff85bf6da9f89429c2ee3b4c0 Mon Sep 17 00:00:00 2001 From: Astha Mohta <35952883+asthamohta@users.noreply.github.com> Date: Fri, 9 Feb 2024 13:59:02 +0530 Subject: [PATCH 06/15] feat: add support for create, update and createOrUpdate in spanner Accessor (#765) * tests test * [feat] RR Create API 1: Add dataflow accessor (#745) * Add dataflow accessor * Add enable streaming engine struct tag Mofe Unmarshall Method to acc2 due ot storage dependency * Moved dataflow utils to accessor and creates types.go * Create dataflowutils package * Renamed testing package for dataflow util * Added unit tests * Added empty test files for clients * Move test to same package * Add tests for dataflow client * Update fake for client test * Make dataflow accessor interface and struct to make it testable * Remove interface from accessor package * Add dataflow accessor interface * Add comments to dataflow client and comments on unit tests * Move all dataflow dependencies to accessors and remove dataflow utils * Create dataflow client interface for accessor method to make it unit testable * tests tests * common testing * change * change * changes on comments * change * accessor mysql conn * changes * change * change * accessor for spanner * change * tests * change * changes * changes * change * change --------- Co-authored-by: Deep1998 --- accessors/clients/spanner/admin/interface.go | 5 + accessors/clients/spanner/admin/mocks.go | 5 + accessors/spanner/spanner_accessor.go | 233 ++++++++- accessors/spanner/spanner_accessor_test.go | 448 ++++++++++++++++++ cmd/utils.go | 56 ++- conversion/conversion.go | 202 -------- conversion/validations.go | 56 --- .../spanner/spanner_accessor_test.go | 12 +- testing/conversion/conversion_test.go | 36 +- 9 files changed, 763 insertions(+), 290 deletions(-) diff --git a/accessors/clients/spanner/admin/interface.go b/accessors/clients/spanner/admin/interface.go index d0a2ab41b4..47f4d138fa 100644 --- a/accessors/clients/spanner/admin/interface.go +++ b/accessors/clients/spanner/admin/interface.go @@ -26,6 +26,7 @@ type AdminClient interface { GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) CreateDatabase(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) UpdateDatabaseDdl(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) + GetDatabaseDdl(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) } // Use this interface instead of database.CreateDatabaseOperation to support mocking. @@ -88,3 +89,7 @@ type UpdateDatabaseDdlImpl struct { func (c *UpdateDatabaseDdlImpl) Wait(ctx context.Context, opts ...gax.CallOption) error { return c.dbo.Wait(ctx, opts...) } + +func (c *AdminClientImpl) GetDatabaseDdl(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return c.adminClient.GetDatabaseDdl(ctx, req, opts...) +} \ No newline at end of file diff --git a/accessors/clients/spanner/admin/mocks.go b/accessors/clients/spanner/admin/mocks.go index 5d84ddeadd..2cc5f51c2b 100644 --- a/accessors/clients/spanner/admin/mocks.go +++ b/accessors/clients/spanner/admin/mocks.go @@ -26,6 +26,7 @@ type AdminClientMock struct { GetDatabaseMock func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) CreateDatabaseMock func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) UpdateDatabaseDdlMock func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) + GetDatabaseDdlMock func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) } func (acm *AdminClientMock) GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { @@ -40,6 +41,10 @@ func (acm *AdminClientMock) UpdateDatabaseDdl(ctx context.Context, req *database return acm.UpdateDatabaseDdlMock(ctx, req, opts...) } +func (acm *AdminClientMock) GetDatabaseDdl(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return acm.GetDatabaseDdlMock(ctx, req, opts...) +} + // Mock that implements the CreateDatabaseOperation interface. // Pass in unit tests where CreateDatabaseOperation is an input parameter. type CreateDatabaseOperationMock struct { diff --git a/accessors/spanner/spanner_accessor.go b/accessors/spanner/spanner_accessor.go index fd5c2cae22..092a29dd3d 100644 --- a/accessors/spanner/spanner_accessor.go +++ b/accessors/spanner/spanner_accessor.go @@ -17,6 +17,7 @@ import ( "context" "fmt" "strings" + "sync" "time" "cloud.google.com/go/spanner" @@ -25,10 +26,25 @@ import ( spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" spannerclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/client" spinstanceadmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/instanceadmin" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" + "go.uber.org/zap" "google.golang.org/api/iterator" + adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" ) +var ( + // Set the maximum number of concurrent workers during foreign key creation. + // This number should not be too high so as to not hit the AdminQuota limit. + // AdminQuota limits are mentioned here: https://cloud.google.com/spanner/quotas#administrative_limits + // If facing a quota limit error, consider reducing this value. + MaxWorkers = 50 +) + + // The SpannerAccessor provides methods that internally use a spanner client (can be adminClient/databaseclient/instanceclient etc). // Methods should only contain generic logic here that can be used by multiple workflows. type SpannerAccessor interface { @@ -47,6 +63,18 @@ type SpannerAccessor interface { ValidateChangeStreamOptions(ctx context.Context, changeStreamName, dbURI string) error // Create a change stream with default options. CreateChangeStream(ctx context.Context, adminClient spanneradmin.AdminClient, changeStreamName, dbURI string) error + // Create new Database using conv + CreateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) error + // Update Database using conv + UpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string) error + // Updates an existing Spanner database or create a new one if one does not exist using Conv + CreateOrUpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI, driver string, conv *internal.Conv, migrationType string) error + // Check whether the db exists and if it does, verify if the schema is what we currently support. + VerifyDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (dbExists bool, err error) + // Verify if an existing DB's ddl follows what is supported by Spanner migration tool. Currently, we only support empty schema when db already exists. + ValidateDDL(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error + // UpdateDDLForeignKeys updates the Spanner database with foreign key constraints using ALTER TABLE statements. + UpdateDDLForeignKeys(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) } // This implements the SpannerAccessor interface. This is the primary implementation that should be used in all places other than tests. @@ -70,7 +98,7 @@ func (sp *SpannerAccessorImpl) CheckExistingDb(ctx context.Context, adminClient for { select { case <-time.After(5 * time.Minute): - fmt.Println("WARNING! API call not responding: make sure that spanner api endpoint is configured properly") + logger.Log.Debug("WARNING! API call not responding: make sure that spanner api endpoint is configured properly") case <-gotResponse: if err != nil { if utils.ContainsAny(strings.ToLower(err.Error()), []string{"database not found"}) { @@ -197,7 +225,208 @@ func (sp *SpannerAccessorImpl) CreateChangeStream(ctx context.Context, adminClie if err := op.Wait(ctx); err != nil { return fmt.Errorf("could not update database ddl: %v", err) } else { - fmt.Println("Successfully created changestream", changeStreamName) + logger.Log.Debug("Successfully created changestream", zap.String("changeStreamName", changeStreamName)) + } + return nil +} + +// CreateDatabase returns a newly create Spanner DB. +// It automatically determines an appropriate project, selects a +// Spanner instance to use, generates a new Spanner DB name, +// and call into the Spanner admin interface to create the new DB. +func (sp *SpannerAccessorImpl) CreateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) error { + project, instance, dbName := utils.ParseDbURI(dbURI) + // The schema we send to Spanner excludes comments (since Cloud + // Spanner DDL doesn't accept them), and protects table and col names + // using backticks (to avoid any issues with Spanner reserved words). + // Foreign Keys are set to false since we create them post data migration. + req := &adminpb.CreateDatabaseRequest{ + Parent: fmt.Sprintf("projects/%s/instances/%s", project, instance), + } + if conv.SpDialect == constants.DIALECT_POSTGRESQL { + // PostgreSQL dialect doesn't support: + // a) backticks around the database name, and + // b) DDL statements as part of a CreateDatabase operation (so schema + // must be set using a separate UpdateDatabase operation). + req.CreateStatement = "CREATE DATABASE \"" + dbName + "\"" + req.DatabaseDialect = adminpb.DatabaseDialect_POSTGRESQL + } else { + req.CreateStatement = "CREATE DATABASE `" + dbName + "`" + if migrationType == constants.DATAFLOW_MIGRATION { + req.ExtraStatements = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) + } else { + req.ExtraStatements = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: false, SpDialect: conv.SpDialect, Source: driver}) + } + + } + + op, err := adminClient.CreateDatabase(ctx, req) + if err != nil { + return fmt.Errorf("can't build CreateDatabaseRequest: %w", utils.AnalyzeError(err, dbURI)) + } + if _, err := op.Wait(ctx); err != nil { + return fmt.Errorf("createDatabase call failed: %w", utils.AnalyzeError(err, dbURI)) + } + + if conv.SpDialect == constants.DIALECT_POSTGRESQL { + // Update schema separately for PG databases. + return sp.UpdateDatabase(ctx, adminClient, dbURI, conv, driver) + } + return nil +} + +// UpdateDatabase updates an existing spanner database. +func (sp *SpannerAccessorImpl) UpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string) error { + // The schema we send to Spanner excludes comments (since Cloud + // Spanner DDL doesn't accept them), and protects table and col names + // using backticks (to avoid any issues with Spanner reserved words). + // Foreign Keys are set to false since we create them post data migration. + schema := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: false, SpDialect: conv.SpDialect, Source: driver}) + req := &adminpb.UpdateDatabaseDdlRequest{ + Database: dbURI, + Statements: schema, + } + // Update queries for postgres as target db return response after more + // than 1 min for large schemas, therefore, timeout is specified as 5 minutes + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + op, err := adminClient.UpdateDatabaseDdl(ctx, req) + if err != nil { + return fmt.Errorf("can't build UpdateDatabaseDdlRequest: %w", utils.AnalyzeError(err, dbURI)) + } + if err := op.Wait(ctx); err != nil { + return fmt.Errorf("UpdateDatabaseDdl call failed: %w", utils.AnalyzeError(err, dbURI)) } return nil } + +// CreatesOrUpdatesDatabase updates an existing Spanner database or creates a new one if one does not exist. +func (sp *SpannerAccessorImpl) CreateOrUpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI, driver string, conv *internal.Conv, migrationType string) error { + dbExists, err := sp.VerifyDb(ctx, adminClient, dbURI) + if err != nil { + return err + } + if dbExists { + if conv.SpDialect != constants.DIALECT_POSTGRESQL && migrationType == constants.DATAFLOW_MIGRATION { + return fmt.Errorf("spanner migration tool does not support minimal downtime schema/schema-and-data migrations to an existing database") + } + err := sp.UpdateDatabase(ctx, adminClient, dbURI, conv, driver) + if err != nil { + return fmt.Errorf("can't update database schema: %v", err) + } + } else { + err := sp.CreateDatabase(ctx, adminClient, dbURI, conv, driver, migrationType) + if err != nil { + return fmt.Errorf("can't create database: %v", err) + } + } + return nil +} + +// VerifyDb checks whether the db exists and if it does, verifies if the schema is what we currently support. +func (sp *SpannerAccessorImpl) VerifyDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (dbExists bool, err error) { + dbExists, err = sp.CheckExistingDb(ctx, adminClient, dbURI) + if err != nil { + return dbExists, err + } + if dbExists { + err = sp.ValidateDDL(ctx, adminClient, dbURI) + } + return dbExists, err +} + +// ValidateDDL verifies if an existing DB's ddl follows what is supported by Spanner migration tool. Currently, +// we only support empty schema when db already exists. +func (sp *SpannerAccessorImpl) ValidateDDL(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error { + dbDdl, err := adminClient.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{Database: dbURI}) + if err != nil { + return fmt.Errorf("can't fetch database ddl: %v", err) + } + if len(dbDdl.Statements) != 0 { + return fmt.Errorf("spanner migration tool supports writing to existing databases only if they have an empty schema") + } + return nil +} + + +// UpdateDDLForeignKeys updates the Spanner database with foreign key +// constraints using ALTER TABLE statements. +func (sp *SpannerAccessorImpl) UpdateDDLForeignKeys(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) { + + if conv.SpDialect != constants.DIALECT_POSTGRESQL && migrationType == constants.DATAFLOW_MIGRATION { + //foreign keys were applied as part of CreateDatabase + return + } + + // The schema we send to Spanner excludes comments (since Cloud + // Spanner DDL doesn't accept them), and protects table and col names + // using backticks (to avoid any issues with Spanner reserved words). + fkStmts := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: false, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) + if len(fkStmts) == 0 { + return + } + if len(fkStmts) > 50 { + logger.Log.Warn(` + Warning: Large number of foreign keys detected. Spanner can take a long amount of + time to create foreign keys (over 5 mins per batch of Foreign Keys even with no data). + Spanner migration tool does not have control over a single foreign key creation time. The number + of concurrent Foreign Key Creation Requests sent to spanner can be increased by + tweaking the MaxWorkers variable (https://github.com/GoogleCloudPlatform/spanner-migration-tool/blob/master/conversion/conversion.go#L89). + However, setting it to a very high value might lead to exceeding the admin quota limit. Spanner migration tool tries to stay under the + admin quota limit by spreading the FK creation requests over time.`) + } + msg := fmt.Sprintf("Updating schema of database %s with foreign key constraints ...", dbURI) + conv.Audit.Progress = *internal.NewProgress(int64(len(fkStmts)), msg, internal.Verbose(), true, int(internal.ForeignKeyUpdateInProgress)) + + workers := make(chan int, MaxWorkers) + for i := 1; i <= MaxWorkers; i++ { + workers <- i + } + var progressMutex sync.Mutex + progress := int64(0) + + // We dispatch parallel foreign key create requests to ensure the backfill runs in parallel to reduce overall time. + // This cuts down the time taken to a third (approx) compared to Serial and Batched creation. We also do not want to create + // too many requests and get throttled due to network or hitting catalog memory limits. + // Ensure atmost `MaxWorkers` go routines run in parallel that each update the ddl with one foreign key statement. + for _, fkStmt := range fkStmts { + workerID := <-workers + go func(fkStmt string, workerID int) { + defer func() { + // Locking the progress reporting otherwise progress results displayed could be in random order. + progressMutex.Lock() + progress++ + conv.Audit.Progress.MaybeReport(progress) + progressMutex.Unlock() + workers <- workerID + }() + internal.VerbosePrintf("Submitting new FK create request: %s\n", fkStmt) + logger.Log.Debug("Submitting new FK create request", zap.String("fkStmt", fkStmt)) + + op, err := adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: dbURI, + Statements: []string{fkStmt}, + }) + if err != nil { + logger.Log.Debug("Can't add foreign key with statement:"+fkStmt+"\n due to error:"+err.Error()+" Skipping this foreign key...\n") + conv.Unexpected(fmt.Sprintf("Can't add foreign key with statement %s: %s", fkStmt, err)) + return + } + if err := op.Wait(ctx); err != nil { + logger.Log.Debug("Can't add foreign key with statement:"+fkStmt+"\n due to error:"+err.Error()+" Skipping this foreign key...\n") + conv.Unexpected(fmt.Sprintf("Can't add foreign key with statement %s: %s", fkStmt, err)) + return + } + internal.VerbosePrintln("Updated schema with statement: " + fkStmt) + logger.Log.Debug("Updated schema with statement", zap.String("fkStmt", fkStmt)) + }(fkStmt, workerID) + // Send out an FK creation request every second, with total of maxWorkers request being present in a batch. + time.Sleep(time.Second) + } + // Wait for all the goroutines to finish. + for i := 1; i <= MaxWorkers; i++ { + <-workers + } + conv.Audit.Progress.UpdateProgress("Foreign key update complete.", 100, internal.ForeignKeyUpdateComplete) + conv.Audit.Progress.Done() +} diff --git a/accessors/spanner/spanner_accessor_test.go b/accessors/spanner/spanner_accessor_test.go index 9f93050496..ed674a387b 100644 --- a/accessors/spanner/spanner_accessor_test.go +++ b/accessors/spanner/spanner_accessor_test.go @@ -23,7 +23,9 @@ import ( "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" spinstanceadmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/instanceadmin" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/googleapis/gax-go/v2" "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -334,3 +336,449 @@ func TestSpannerAccessorImpl_GetSpannerLeaderLocation(t *testing.T) { assert.Equal(t, tc.want, got, tc.name) } } + +func TestSpannerAccessorImpl_CreateDatabase(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + dialect string + migrationType string + expectError bool + }{ + { + name: "GoogleSql Dataflow", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + }, + expectError: false, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Pg Dataflow", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + expectError: false, + dialect: "postgresql", + migrationType: "dataflow", + }, + { + name: "GoogleSql bulk", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + }, + expectError: false, + dialect: "google_standard_sql", + migrationType: "bulk", + }, + { + name: "Pg bulk", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + expectError: false, + dialect: "postgresql", + migrationType: "bulk", + }, + { + name: "GoogleSql Dataflow create database error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return nil, fmt.Errorf("error") + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Pg Dataflow update error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("error") + }, + }, + expectError: true, + dialect: "postgresql", + migrationType: "dataflow", + }, + { + name: "GoogleSql Dataflow operation error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, fmt.Errorf("error") }, + }, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + dbURI := "projects/project-id/instances/instance-id/databases/database-id" + conv := internal.MakeConv() + conv.SpDialect = tc.dialect + err := spA.CreateDatabase(ctx, &tc.acm, dbURI, conv, "", tc.migrationType) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestSpannerAccessorImpl_CreateOrUpdateDatabase(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + dialect string + migrationType string + expectError bool + }{ + { + name: "GoogleSql Dataflow db does not exist", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("database not found") + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: false, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "GoogleSql Dataflow db exists", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Postgres Dataflow db exists", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Postgres bulk db exists", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: false, + dialect: "google_standard_sql", + migrationType: "bulk", + }, + { + name: "GoogleSql Dataflow db get database error", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("error") + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "GoogleSql Dataflow db ddl statements nto empty", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{Statements: []string{"string"}}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "GoogleSql Dataflow db get database ddl error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return nil, fmt.Errorf("error") + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Postgres bulk db exists update ddl error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("error") + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "bulk", + }, + { + name: "GoogleSql Dataflow db does not exist create error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return nil, fmt.Errorf("error") + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("database not found") + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + dbURI := "projects/project-id/instances/instance-id/databases/database-id" + conv := internal.MakeConv() + conv.SpDialect = tc.dialect + err := spA.CreateOrUpdateDatabase(ctx, &tc.acm, dbURI, "", conv, tc.migrationType) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestSpannerAccessorImpl_UpdateDatabase(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + }{ + { + name: "Update Database successful", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + expectError: false, + }, + { + name: "Update Database request error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("Error") + }, + }, + expectError: true, + }, + { + name: "Update Database operation error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return fmt.Errorf("error") }, + }, nil + }, + }, + expectError: true, + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + dbURI := "projects/project-id/instances/instance-id/databases/database-id" + conv := internal.MakeConv() + err := spA.UpdateDatabase(ctx, &tc.acm, dbURI, conv, "") + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + + +func TestSpannerAccessorImpl_UpdateDDLForeignKey(t *testing.T) { + schemaWithStatements:=map[string]ddl.CreateTable{ + "table_id" : { + Name: "table1", + Id: "table_id", + }, + "table_id2" : { + Name: "table2", + Id: "table_id2", + ParentId: "table1", + ForeignKeys: []ddl.Foreignkey{ + { + Name: "fk", + ColIds: []string{"columns"}, + ReferTableId:"table1", + ReferColumnIds:[]string{"column"}, + Id:"table_id", + }, + }, + }, + } + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + dialect string + migrationType string + spSchema ddl.Schema + }{ + { + name: "Update DDL ForeignKey successful pg dataflow", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + dialect: "postgresql", + spSchema: schemaWithStatements, + migrationType: "dataflow", + }, + { + name: "Update DDL ForeignKey successful pg dataflow no statement", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + dialect: "postgresql", + spSchema: map[string]ddl.CreateTable{}, + migrationType: "dataflow", + }, + { + name: "Update DDL ForeignKey successful google_standard_sql", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + dialect: "google_standard_sql", + spSchema: schemaWithStatements, + migrationType: "dataflow", + }, + { + name: "Update DDL ForeignKey update database error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("error") + }, + }, + dialect: "postgresql", + spSchema: schemaWithStatements, + migrationType: "dataflow", + }, + { + name: "Update DDL ForeignKey operation error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return fmt.Errorf("error") }, + }, nil + }, + }, + dialect: "postgresql", + spSchema: schemaWithStatements, + migrationType: "dataflow", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + dbURI := "projects/project-id/instances/instance-id/databases/database-id" + conv := internal.MakeConv() + conv.SpDialect = tc.dialect + conv.SpSchema = tc.spSchema + spA.UpdateDDLForeignKeys(ctx, &tc.acm, dbURI, conv, "", tc.migrationType) + } +} \ No newline at end of file diff --git a/cmd/utils.go b/cmd/utils.go index d8e4604ae5..16c4add784 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -16,17 +16,24 @@ package cmd import ( "context" + "encoding/base64" "fmt" "time" sp "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/helpers" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" ) var ( @@ -40,6 +47,16 @@ const ( completionPercentage = 100 ) +func metricsPopulation(ctx context.Context, driver string, conv *internal.Conv) { + if !conv.Audit.SkipMetricsPopulation { + // Adding migration metadata to the outgoing context. + migrationData := metrics.GetMigrationData(conv, driver, constants.SchemaConv) + serializedMigrationData, _ := proto.Marshal(migrationData) + migrationMetadataValue := base64.StdEncoding.EncodeToString(serializedMigrationData) + ctx = metadata.AppendToOutgoingContext(ctx, constants.MigrationMetadataKey, migrationMetadataValue) + } +} + // CreateDatabaseClient creates new database client and admin client. func CreateDatabaseClient(ctx context.Context, targetProfile profiles.TargetProfile, driver, dbName string, ioHelper utils.IOStreams) (*database.DatabaseAdminClient, *sp.Client, string, error) { if targetProfile.Conn.Sp.Dbname == "" { @@ -139,13 +156,19 @@ func MigrateDatabase(ctx context.Context, targetProfile profiles.TargetProfile, func migrateSchema(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile, ioHelper *utils.IOStreams, conv *internal.Conv, dbURI string, adminClient *database.DatabaseAdminClient) error { - err := conversion.CreateOrUpdateDatabase(ctx, adminClient, dbURI, sourceProfile.Driver, conv, ioHelper.Out, sourceProfile.Config.ConfigType) - if err != nil { - err = fmt.Errorf("can't create/update database: %v", err) - return err - } - conv.Audit.Progress.UpdateProgress("Schema migration complete.", completionPercentage, internal.SchemaMigrationComplete) - return nil + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + return err + } + err = spA.CreateOrUpdateDatabase(ctx, adminClientImpl, dbURI, sourceProfile.Driver, conv, sourceProfile.Config.ConfigType) + if err != nil { + err = fmt.Errorf("can't create/update database: %v", err) + return err + } + metricsPopulation(ctx, sourceProfile.Driver, conv) + conv.Audit.Progress.UpdateProgress("Schema migration complete.", completionPercentage, internal.SchemaMigrationComplete) + return nil } func migrateData(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile, @@ -170,21 +193,29 @@ func migrateData(ctx context.Context, targetProfile profiles.TargetProfile, sour } conv.Audit.Progress.UpdateProgress("Data migration complete.", completionPercentage, internal.DataMigrationComplete) if !cmd.SkipForeignKeys { - if err = conversion.UpdateDDLForeignKeys(ctx, adminClient, dbURI, conv, ioHelper.Out, sourceProfile.Driver, sourceProfile.Config.ConfigType); err != nil { - err = fmt.Errorf("can't perform update schema on db %s with foreign keys: %v", dbURI, err) + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { return bw, err } + spA.UpdateDDLForeignKeys(ctx, adminClientImpl, dbURI, conv, sourceProfile.Driver, sourceProfile.Config.ConfigType) } return bw, nil } func migrateSchemaAndData(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile, ioHelper *utils.IOStreams, conv *internal.Conv, dbURI string, adminClient *database.DatabaseAdminClient, client *sp.Client, cmd *SchemaAndDataCmd) (*writer.BatchWriter, error) { - err := conversion.CreateOrUpdateDatabase(ctx, adminClient, dbURI, sourceProfile.Driver, conv, ioHelper.Out, sourceProfile.Config.ConfigType) + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + return nil, err + } + err = spA.CreateOrUpdateDatabase(ctx, adminClientImpl, dbURI, sourceProfile.Driver, conv, sourceProfile.Config.ConfigType) if err != nil { err = fmt.Errorf("can't create/update database: %v", err) return nil, err } + metricsPopulation(ctx, sourceProfile.Driver, conv) conv.Audit.Progress.UpdateProgress("Schema migration complete.", completionPercentage, internal.SchemaMigrationComplete) convImpl := &conversion.ConvImpl{} bw, err := convImpl.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, true, cmd.WriteLimit, &conversion.DataFromSourceImpl{}) @@ -195,10 +226,7 @@ func migrateSchemaAndData(ctx context.Context, targetProfile profiles.TargetProf conv.Audit.Progress.UpdateProgress("Data migration complete.", completionPercentage, internal.DataMigrationComplete) if !cmd.SkipForeignKeys { - if err = conversion.UpdateDDLForeignKeys(ctx, adminClient, dbURI, conv, ioHelper.Out, sourceProfile.Driver, sourceProfile.Config.ConfigType); err != nil { - err = fmt.Errorf("can't perform update schema on db %s with foreign keys: %v", dbURI, err) - return bw, err - } + spA.UpdateDDLForeignKeys(ctx, adminClientImpl, dbURI, conv, sourceProfile.Driver, sourceProfile.Config.ConfigType) } return bw, nil } diff --git a/conversion/conversion.go b/conversion/conversion.go index b9ec31f5e3..9a13b44486 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -26,40 +26,25 @@ package conversion import ( "bufio" "context" - "encoding/base64" "encoding/json" "fmt" "os" "strings" "sync" - "time" datastream "cloud.google.com/go/datastream/apiv1" sp "cloud.google.com/go/spanner" - database "cloud.google.com/go/spanner/admin/database/apiv1" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" - "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal/reports" - "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/csv" - "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" - "go.uber.org/zap" - adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" - "google.golang.org/grpc/metadata" - "google.golang.org/protobuf/proto" ) var ( - // Set the maximum number of concurrent workers during foreign key creation. - // This number should not be too high so as to not hit the AdminQuota limit. - // AdminQuota limits are mentioned here: https://cloud.google.com/spanner/quotas#administrative_limits - // If facing a quota limit error, consider reducing this value. - MaxWorkers = 50 once sync.Once datastreamClient *datastream.Client ) @@ -168,190 +153,3 @@ func Report(driver string, badWrites map[string]int64, BytesRead int64, banner s fmt.Fprintf(out, "See file '%s' for details of the schema and data conversions.\n", reportFileName) } } - -// CreatesOrUpdatesDatabase updates an existing Spanner database or creates a new one if one does not exist. -func CreateOrUpdateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI, driver string, conv *internal.Conv, out *os.File, migrationType string) error { - dbExists, err := VerifyDb(ctx, adminClient, dbURI) - if err != nil { - return err - } - if !conv.Audit.SkipMetricsPopulation { - // Adding migration metadata to the outgoing context. - migrationData := metrics.GetMigrationData(conv, driver, constants.SchemaConv) - serializedMigrationData, _ := proto.Marshal(migrationData) - migrationMetadataValue := base64.StdEncoding.EncodeToString(serializedMigrationData) - ctx = metadata.AppendToOutgoingContext(ctx, constants.MigrationMetadataKey, migrationMetadataValue) - } - if dbExists { - if conv.SpDialect != constants.DIALECT_POSTGRESQL && migrationType == constants.DATAFLOW_MIGRATION { - return fmt.Errorf("spanner migration tool does not support minimal downtime schema/schema-and-data migrations to an existing database") - } - err := UpdateDatabase(ctx, adminClient, dbURI, conv, out, driver) - if err != nil { - return fmt.Errorf("can't update database schema: %v", err) - } - } else { - err := CreateDatabase(ctx, adminClient, dbURI, conv, out, driver, migrationType) - if err != nil { - return fmt.Errorf("can't create database: %v", err) - } - } - return nil -} - -// CreateDatabase returns a newly create Spanner DB. -// It automatically determines an appropriate project, selects a -// Spanner instance to use, generates a new Spanner DB name, -// and call into the Spanner admin interface to create the new DB. -func CreateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string, conv *internal.Conv, out *os.File, driver string, migrationType string) error { - project, instance, dbName := utils.ParseDbURI(dbURI) - fmt.Fprintf(out, "Creating new database %s in instance %s with default permissions ... \n", dbName, instance) - // The schema we send to Spanner excludes comments (since Cloud - // Spanner DDL doesn't accept them), and protects table and col names - // using backticks (to avoid any issues with Spanner reserved words). - // Foreign Keys are set to false since we create them post data migration. - req := &adminpb.CreateDatabaseRequest{ - Parent: fmt.Sprintf("projects/%s/instances/%s", project, instance), - } - if conv.SpDialect == constants.DIALECT_POSTGRESQL { - // PostgreSQL dialect doesn't support: - // a) backticks around the database name, and - // b) DDL statements as part of a CreateDatabase operation (so schema - // must be set using a separate UpdateDatabase operation). - req.CreateStatement = "CREATE DATABASE \"" + dbName + "\"" - req.DatabaseDialect = adminpb.DatabaseDialect_POSTGRESQL - } else { - req.CreateStatement = "CREATE DATABASE `" + dbName + "`" - if migrationType == constants.DATAFLOW_MIGRATION { - req.ExtraStatements = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) - } else { - req.ExtraStatements = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: false, SpDialect: conv.SpDialect, Source: driver}) - } - - } - - op, err := adminClient.CreateDatabase(ctx, req) - if err != nil { - return fmt.Errorf("can't build CreateDatabaseRequest: %w", utils.AnalyzeError(err, dbURI)) - } - if _, err := op.Wait(ctx); err != nil { - return fmt.Errorf("createDatabase call failed: %w", utils.AnalyzeError(err, dbURI)) - } - fmt.Fprintf(out, "Created database successfully.\n") - - if conv.SpDialect == constants.DIALECT_POSTGRESQL { - // Update schema separately for PG databases. - return UpdateDatabase(ctx, adminClient, dbURI, conv, out, driver) - } - return nil -} - -// UpdateDatabase updates an existing spanner database. -func UpdateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string, conv *internal.Conv, out *os.File, driver string) error { - fmt.Fprintf(out, "Updating schema for %s with default permissions ... \n", dbURI) - // The schema we send to Spanner excludes comments (since Cloud - // Spanner DDL doesn't accept them), and protects table and col names - // using backticks (to avoid any issues with Spanner reserved words). - // Foreign Keys are set to false since we create them post data migration. - schema := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: false, SpDialect: conv.SpDialect, Source: driver}) - req := &adminpb.UpdateDatabaseDdlRequest{ - Database: dbURI, - Statements: schema, - } - // Update queries for postgres as target db return response after more - // than 1 min for large schemas, therefore, timeout is specified as 5 minutes - ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) - defer cancel() - op, err := adminClient.UpdateDatabaseDdl(ctx, req) - if err != nil { - return fmt.Errorf("can't build UpdateDatabaseDdlRequest: %w", utils.AnalyzeError(err, dbURI)) - } - if err := op.Wait(ctx); err != nil { - return fmt.Errorf("UpdateDatabaseDdl call failed: %w", utils.AnalyzeError(err, dbURI)) - } - fmt.Fprintf(out, "Updated schema successfully.\n") - return nil -} - -// UpdateDDLForeignKeys updates the Spanner database with foreign key -// constraints using ALTER TABLE statements. -func UpdateDDLForeignKeys(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string, conv *internal.Conv, out *os.File, driver string, migrationType string) error { - - if conv.SpDialect != constants.DIALECT_POSTGRESQL && migrationType == constants.DATAFLOW_MIGRATION { - //foreign keys were applied as part of CreateDatabase - return nil - } - - // The schema we send to Spanner excludes comments (since Cloud - // Spanner DDL doesn't accept them), and protects table and col names - // using backticks (to avoid any issues with Spanner reserved words). - fkStmts := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: false, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) - if len(fkStmts) == 0 { - return nil - } - if len(fkStmts) > 50 { - fmt.Println(` -Warning: Large number of foreign keys detected. Spanner can take a long amount of -time to create foreign keys (over 5 mins per batch of Foreign Keys even with no data). -Spanner migration tool does not have control over a single foreign key creation time. The number -of concurrent Foreign Key Creation Requests sent to spanner can be increased by -tweaking the MaxWorkers variable (https://github.com/GoogleCloudPlatform/spanner-migration-tool/blob/master/conversion/conversion.go#L89). -However, setting it to a very high value might lead to exceeding the admin quota limit. Spanner migration tool tries to stay under the -admin quota limit by spreading the FK creation requests over time.`) - } - msg := fmt.Sprintf("Updating schema of database %s with foreign key constraints ...", dbURI) - conv.Audit.Progress = *internal.NewProgress(int64(len(fkStmts)), msg, internal.Verbose(), true, int(internal.ForeignKeyUpdateInProgress)) - - workers := make(chan int, MaxWorkers) - for i := 1; i <= MaxWorkers; i++ { - workers <- i - } - var progressMutex sync.Mutex - progress := int64(0) - - // We dispatch parallel foreign key create requests to ensure the backfill runs in parallel to reduce overall time. - // This cuts down the time taken to a third (approx) compared to Serial and Batched creation. We also do not want to create - // too many requests and get throttled due to network or hitting catalog memory limits. - // Ensure atmost `MaxWorkers` go routines run in parallel that each update the ddl with one foreign key statement. - for _, fkStmt := range fkStmts { - workerID := <-workers - go func(fkStmt string, workerID int) { - defer func() { - // Locking the progress reporting otherwise progress results displayed could be in random order. - progressMutex.Lock() - progress++ - conv.Audit.Progress.MaybeReport(progress) - progressMutex.Unlock() - workers <- workerID - }() - internal.VerbosePrintf("Submitting new FK create request: %s\n", fkStmt) - logger.Log.Debug("Submitting new FK create request", zap.String("fkStmt", fkStmt)) - - op, err := adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ - Database: dbURI, - Statements: []string{fkStmt}, - }) - if err != nil { - fmt.Printf("Cannot submit request for create foreign key with statement: %s\n due to error: %s. Skipping this foreign key...\n", fkStmt, err) - conv.Unexpected(fmt.Sprintf("Can't add foreign key with statement %s: %s", fkStmt, err)) - return - } - if err := op.Wait(ctx); err != nil { - fmt.Printf("Can't add foreign key with statement: %s\n due to error: %s. Skipping this foreign key...\n", fkStmt, err) - conv.Unexpected(fmt.Sprintf("Can't add foreign key with statement %s: %s", fkStmt, err)) - return - } - internal.VerbosePrintln("Updated schema with statement: " + fkStmt) - logger.Log.Debug("Updated schema with statement", zap.String("fkStmt", fkStmt)) - }(fkStmt, workerID) - // Send out an FK creation request every second, with total of maxWorkers request being present in a batch. - time.Sleep(time.Second) - } - // Wait for all the goroutines to finish. - for i := 1; i <= MaxWorkers; i++ { - <-workers - } - conv.Audit.Progress.UpdateProgress("Foreign key update complete.", 100, internal.ForeignKeyUpdateComplete) - conv.Audit.Progress.Done() - return nil -} diff --git a/conversion/validations.go b/conversion/validations.go index 38da6d11fa..01722d1174 100644 --- a/conversion/validations.go +++ b/conversion/validations.go @@ -26,54 +26,11 @@ package conversion import ( "context" - "fmt" - "strings" - "time" sp "cloud.google.com/go/spanner" - database "cloud.google.com/go/spanner/admin/database/apiv1" - "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/spanner" - adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" ) -// VerifyDb checks whether the db exists and if it does, verifies if the schema is what we currently support. -func VerifyDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (dbExists bool, err error) { - dbExists, err = CheckExistingDb(ctx, adminClient, dbURI) - if err != nil { - return dbExists, err - } - if dbExists { - err = ValidateDDL(ctx, adminClient, dbURI) - } - return dbExists, err -} - -// CheckExistingDb checks whether the database with dbURI exists or not. -// If API call doesn't respond then user is informed after every 5 minutes on command line. -func CheckExistingDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (bool, error) { - gotResponse := make(chan bool) - var err error - go func() { - _, err = adminClient.GetDatabase(ctx, &adminpb.GetDatabaseRequest{Name: dbURI}) - gotResponse <- true - }() - for { - select { - case <-time.After(5 * time.Minute): - fmt.Println("WARNING! API call not responding: make sure that spanner api endpoint is configured properly") - case <-gotResponse: - if err != nil { - if utils.ContainsAny(strings.ToLower(err.Error()), []string{"database not found"}) { - return false, nil - } - return false, fmt.Errorf("can't get database info: %s", err) - } - return true, nil - } - } -} - // ValidateTables validates that all the tables in the database are empty. // It returns the name of the first non-empty table if found, and an empty string otherwise. func ValidateTables(ctx context.Context, client *sp.Client, spDialect string) (string, error) { @@ -93,16 +50,3 @@ func ValidateTables(ctx context.Context, client *sp.Client, spDialect string) (s } return "", nil } - -// ValidateDDL verifies if an existing DB's ddl follows what is supported by Spanner migration tool. Currently, -// we only support empty schema when db already exists. -func ValidateDDL(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) error { - dbDdl, err := adminClient.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{Database: dbURI}) - if err != nil { - return fmt.Errorf("can't fetch database ddl: %v", err) - } - if len(dbDdl.Statements) != 0 { - return fmt.Errorf("spanner migration tool supports writing to existing databases only if they have an empty schema") - } - return nil -} \ No newline at end of file diff --git a/testing/accessors/spanner/spanner_accessor_test.go b/testing/accessors/spanner/spanner_accessor_test.go index 311d2de0eb..9d37ca63ca 100644 --- a/testing/accessors/spanner/spanner_accessor_test.go +++ b/testing/accessors/spanner/spanner_accessor_test.go @@ -31,7 +31,6 @@ import ( spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" - "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/stretchr/testify/assert" @@ -103,8 +102,13 @@ func dropDatabase(t *testing.T, dbPath string) { func TestCheckExistingDb(t *testing.T) { onlyRunForEmulatorTest(t) - dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, "check-db-exists") - err := conversion.CreateDatabase(ctx, databaseAdmin, dbURI, internal.MakeConv(), os.Stdout, "", constants.BULK_MIGRATION) + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + t.Fatal(err) + } + dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, "check-db-exists") + err = spA.CreateDatabase(ctx, adminClientImpl, dbURI, internal.MakeConv(), "", constants.BULK_MIGRATION) if err != nil { t.Fatal(err) } @@ -116,11 +120,9 @@ func TestCheckExistingDb(t *testing.T) { {"check-db-exists", true}, {"check-db-does-not-exist", false}, } - adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) if err != nil { t.Fatal(err) } - spA := spanneraccessor.SpannerAccessorImpl{} for _, tc := range testCases { dbExists, err := spA.CheckExistingDb(ctx, adminClientImpl, fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, tc.dbName)) assert.Nil(t, err) diff --git a/testing/conversion/conversion_test.go b/testing/conversion/conversion_test.go index ba08519a71..90ff83881b 100644 --- a/testing/conversion/conversion_test.go +++ b/testing/conversion/conversion_test.go @@ -29,7 +29,6 @@ import ( "time" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" - "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" @@ -38,6 +37,8 @@ import ( database "cloud.google.com/go/spanner/admin/database/apiv1" databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" ) var ( @@ -194,16 +195,19 @@ func TestUpdateDDLForeignKeys(t *testing.T) { } for _, tc := range testCases { + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + t.Fatal(err) + } dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, tc.dbName) conv := BuildConv(t, tc.numCols, tc.numFks, false) - err := conversion.CreateDatabase(ctx, databaseAdmin, dbURI, conv, os.Stdout, "", constants.BULK_MIGRATION) + err = spA.CreateDatabase(ctx, adminClientImpl, dbURI, conv, "", constants.BULK_MIGRATION) if err != nil { t.Fatal(err) } - conversion.MaxWorkers = tc.numWorkers - if err = conversion.UpdateDDLForeignKeys(ctx, databaseAdmin, dbURI, conv, os.Stdout, "", constants.BULK_MIGRATION); err != nil { - t.Fatalf("\nCan't perform update operation on db %s with foreign keys: %v\n", tc.dbName, err) - } + spanneraccessor.MaxWorkers = tc.numWorkers + spA.UpdateDDLForeignKeys(ctx, adminClientImpl, dbURI, conv, "", constants.BULK_MIGRATION) checkResults(t, dbURI, tc.numFks) // Drop the database later. @@ -226,13 +230,18 @@ func TestVerifyDb(t *testing.T) { for _, tc := range testCases { dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, tc.dbName) + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) if tc.dbExists { - err := conversion.CreateDatabase(ctx, databaseAdmin, dbURI, BuildConv(t, 2, 0, tc.emptySchema), os.Stdout, "", constants.BULK_MIGRATION) + if err != nil { + t.Fatal(err) + } + err = spA.CreateDatabase(ctx, adminClientImpl, dbURI, BuildConv(t, 2, 0, tc.emptySchema), "", constants.BULK_MIGRATION) if err != nil { t.Fatal(err) } defer dropDatabase(t, dbURI) - dbExists, err := conversion.VerifyDb(ctx, databaseAdmin, dbURI) + dbExists, err := spA.VerifyDb(ctx, adminClientImpl, dbURI) assert.True(t, dbExists) if tc.emptySchema { assert.Nil(t, err) @@ -240,7 +249,7 @@ func TestVerifyDb(t *testing.T) { assert.NotNil(t, err) } } else { - dbExists, err := conversion.VerifyDb(ctx, databaseAdmin, dbURI) + dbExists, err := spA.VerifyDb(ctx, adminClientImpl, dbURI) assert.Nil(t, err) assert.False(t, dbExists) } @@ -261,12 +270,17 @@ func TestValidateDDL(t *testing.T) { for _, tc := range testCases { dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, tc.dbName) - err := conversion.CreateDatabase(ctx, databaseAdmin, dbURI, BuildConv(t, 2, 0, tc.emptySchema), os.Stdout, "", constants.BULK_MIGRATION) + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + t.Fatal(err) + } + err = spA.CreateDatabase(ctx, adminClientImpl, dbURI, BuildConv(t, 2, 0, tc.emptySchema), "", constants.BULK_MIGRATION) if err != nil { t.Fatal(err) } defer dropDatabase(t, dbURI) - err = conversion.ValidateDDL(ctx, databaseAdmin, dbURI) + err = spA.ValidateDDL(ctx, adminClientImpl, dbURI) if tc.emptySchema { assert.Nil(t, err) } else { From 7f5095dd4c8e6c22075ea4e18fdf34845ce2b93a Mon Sep 17 00:00:00 2001 From: Manit Gupta Date: Mon, 12 Feb 2024 13:33:47 +0530 Subject: [PATCH 07/15] chore: Refactor web.go into 3 files (#763) * Refactor web.go into api/schema.go * Copy init() fn to schema.go --- webv2/api/common.go | 37 + webv2/api/rules.go | 411 ++++ webv2/api/rules_test.go | 780 +++++++ webv2/api/schema.go | 1536 ++++++++++++++ webv2/{web_test.go => api/schema_test.go} | 1950 ++++++------------ webv2/config.json | 4 +- webv2/routes.go | 44 +- webv2/types/types.go | 164 ++ webv2/web.go | 2265 +-------------------- webv2/web_startup_test.go | 4 +- 10 files changed, 3648 insertions(+), 3547 deletions(-) create mode 100644 webv2/api/common.go create mode 100644 webv2/api/rules.go create mode 100644 webv2/api/rules_test.go create mode 100644 webv2/api/schema.go rename webv2/{web_test.go => api/schema_test.go} (75%) create mode 100644 webv2/types/types.go diff --git a/webv2/api/common.go b/webv2/api/common.go new file mode 100644 index 0000000000..0526ba1f2f --- /dev/null +++ b/webv2/api/common.go @@ -0,0 +1,37 @@ +package api + +import ( + "fmt" + "strings" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/index" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/session" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/utilities" +) + +func dropSecondaryIndexHelper(tableId, idxId string) error { + if tableId == "" || idxId == "" { + return fmt.Errorf("Table id or index id is empty") + } + sessionState := session.GetSessionState() + sp := sessionState.Conv.SpSchema[tableId] + position := -1 + for i, index := range sp.Indexes { + if idxId == index.Id { + position = i + break + } + } + if position < 0 || position >= len(sp.Indexes) { + return fmt.Errorf("No secondary index found at position %d", position) + } + + usedNames := sessionState.Conv.UsedNames + delete(usedNames, strings.ToLower(sp.Indexes[position].Name)) + index.RemoveIndexIssues(tableId, sp.Indexes[position]) + + sp.Indexes = utilities.RemoveSecondaryIndex(sp.Indexes, position) + sessionState.Conv.SpSchema[tableId] = sp + session.UpdateSessionFile() + return nil +} \ No newline at end of file diff --git a/webv2/api/rules.go b/webv2/api/rules.go new file mode 100644 index 0000000000..f7962a1b85 --- /dev/null +++ b/webv2/api/rules.go @@ -0,0 +1,411 @@ +package api + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strconv" + "strings" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/index" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/primarykey" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/session" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/table" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/types" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/utilities" +) + +// ApplyRule allows to add rules that changes the schema +// currently it supports two types of operations viz. SetGlobalDataType and AddIndex +func ApplyRule(w http.ResponseWriter, r *http.Request) { + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + return + } + var rule internal.Rule + err = json.Unmarshal(reqBody, &rule) + if err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + + sessionState := session.GetSessionState() + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + if rule.Type == constants.GlobalDataTypeChange { + d, err := json.Marshal(rule.Data) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + typeMap := map[string]string{} + err = json.Unmarshal(d, &typeMap) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + setGlobalDataType(typeMap) + } else if rule.Type == constants.AddIndex { + d, err := json.Marshal(rule.Data) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + newIdx := ddl.CreateIndex{} + err = json.Unmarshal(d, &newIdx) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + addedIndex, err := addIndex(newIdx) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + rule.Data = addedIndex + } else if rule.Type == constants.EditColumnMaxLength { + d, err := json.Marshal(rule.Data) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + var colMaxLength types.ColMaxLength + err = json.Unmarshal(d, &colMaxLength) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + setSpColMaxLength(colMaxLength, rule.AssociatedObjects) + } else if rule.Type == constants.AddShardIdPrimaryKey { + d, err := json.Marshal(rule.Data) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + var shardIdPrimaryKey types.ShardIdPrimaryKey + err = json.Unmarshal(d, &shardIdPrimaryKey) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + tableName := checkInterleaving() + if tableName != "" { + http.Error(w, fmt.Sprintf("Rule cannot be added because some tables, eg: %v are interleaved. Please remove interleaving and try again.", tableName), http.StatusBadRequest) + return + } + setShardIdColumnAsPrimaryKey(shardIdPrimaryKey.AddedAtTheStart) + addShardIdColumnToForeignKeys(shardIdPrimaryKey.AddedAtTheStart) + } else { + http.Error(w, "Invalid rule type", http.StatusInternalServerError) + return + } + + ruleId := internal.GenerateRuleId() + rule.Id = ruleId + + sessionState.Conv.Rules = append(sessionState.Conv.Rules, rule) + session.UpdateSessionFile() + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +func DropRule(w http.ResponseWriter, r *http.Request) { + ruleId := r.FormValue("id") + if ruleId == "" { + http.Error(w, fmt.Sprint("Rule id is empty"), http.StatusBadRequest) + return + } + sessionState := session.GetSessionState() + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + conv := sessionState.Conv + var rule internal.Rule + position := -1 + + for i, r := range conv.Rules { + if r.Id == ruleId { + rule = r + position = i + break + } + } + if position == -1 { + http.Error(w, fmt.Sprint("Rule to be deleted not found"), http.StatusBadRequest) + return + } + + if rule.Type == constants.AddIndex { + if rule.Enabled { + d, err := json.Marshal(rule.Data) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + var index ddl.CreateIndex + err = json.Unmarshal(d, &index) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + tableId := index.TableId + indexId := index.Id + err = dropSecondaryIndexHelper(tableId, indexId) + if err != nil { + http.Error(w, fmt.Sprintf("%v", err), http.StatusBadRequest) + return + } + } + } else if rule.Type == constants.GlobalDataTypeChange { + d, err := json.Marshal(rule.Data) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + typeMap := map[string]string{} + err = json.Unmarshal(d, &typeMap) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + revertGlobalDataType(typeMap) + } else if rule.Type == constants.EditColumnMaxLength { + d, err := json.Marshal(rule.Data) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + var colMaxLength types.ColMaxLength + err = json.Unmarshal(d, &colMaxLength) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + revertSpColMaxLength(colMaxLength, rule.AssociatedObjects) + } else if rule.Type == constants.AddShardIdPrimaryKey { + d, err := json.Marshal(rule.Data) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + var shardIdPrimaryKey types.ShardIdPrimaryKey + err = json.Unmarshal(d, &shardIdPrimaryKey) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + tableName := checkInterleaving() + if tableName != "" { + http.Error(w, fmt.Sprintf("Rule cannot be deleted because some tables, eg: %v are interleaved. Please remove interleaving and try again.", tableName), http.StatusBadRequest) + return + } + revertShardIdColumnAsPrimaryKey(shardIdPrimaryKey.AddedAtTheStart) + removeShardIdColumnFromForeignKeys(shardIdPrimaryKey.AddedAtTheStart) + } else { + http.Error(w, "Invalid rule type", http.StatusInternalServerError) + return + } + + sessionState.Conv.Rules = append(conv.Rules[:position], conv.Rules[position+1:]...) + if len(sessionState.Conv.Rules) == 0 { + sessionState.Conv.Rules = nil + } + session.UpdateSessionFile() + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +// setGlobalDataType allows to change Spanner type globally. +// It takes a map from source type to Spanner type and updates +// the Spanner schema accordingly. +func setGlobalDataType(typeMap map[string]string) { + sessionState := session.GetSessionState() + + // Redo source-to-Spanner typeMap using t (the mapping specified in the http request). + // We drive this process by iterating over the Spanner schema because we want to preserve all + // other customizations that have been performed via the UI (dropping columns, renaming columns + // etc). In particular, note that we can't just blindly redo schema conversion (using an appropriate + // version of 'toDDL' with the new typeMap). + for tableId, spSchema := range sessionState.Conv.SpSchema { + for colId := range spSchema.ColDefs { + srcColDef := sessionState.Conv.SrcSchema[tableId].ColDefs[colId] + // If the srcCol's type is in the map, then recalculate the Spanner type + // for this column using the map. Otherwise, leave the ColDef for this + // column as is. Note that per-column type overrides could be lost in + // this process -- the mapping in typeMap always takes precendence. + if _, found := typeMap[srcColDef.Type.Name]; found { + utilities.UpdateDataType(sessionState.Conv, typeMap[srcColDef.Type.Name], tableId, colId) + } + } + common.ComputeNonKeyColumnSize(sessionState.Conv, tableId) + } +} + +// addIndex checks the new name for spanner name validity, ensures the new name is already not used by existing tables +// secondary indexes or foreign key constraints. If above checks passed then new indexes are added to the schema else appropriate +// error thrown. +func addIndex(newIndex ddl.CreateIndex) (ddl.CreateIndex, error) { + // Check new name for spanner name validity. + newNames := []string{} + newNames = append(newNames, newIndex.Name) + + if ok, invalidNames := utilities.CheckSpannerNamesValidity(newNames); !ok { + return ddl.CreateIndex{}, fmt.Errorf("following names are not valid Spanner identifiers: %s", strings.Join(invalidNames, ",")) + } + // Check that the new names are not already used by existing tables, secondary indexes or foreign key constraints. + if ok, err := utilities.CanRename(newNames, newIndex.TableId); !ok { + return ddl.CreateIndex{}, err + } + + sessionState := session.GetSessionState() + sp := sessionState.Conv.SpSchema[newIndex.TableId] + + newIndexes := []ddl.CreateIndex{newIndex} + index.CheckIndexSuggestion(newIndexes, sp) + for i := 0; i < len(newIndexes); i++ { + newIndexes[i].Id = internal.GenerateIndexesId() + } + + sessionState.Conv.UsedNames[strings.ToLower(newIndex.Name)] = true + sp.Indexes = append(sp.Indexes, newIndexes...) + sessionState.Conv.SpSchema[newIndex.TableId] = sp + return newIndexes[0], nil +} + +func setSpColMaxLength(spColMaxLength types.ColMaxLength, associatedObjects string) { + sessionState := session.GetSessionState() + if associatedObjects == "All table" { + for tId := range sessionState.Conv.SpSchema { + for _, colDef := range sessionState.Conv.SpSchema[tId].ColDefs { + if colDef.T.Name == spColMaxLength.SpDataType { + spColDef := colDef + if spColDef.T.Len == ddl.MaxLength { + spColDef.T.Len, _ = strconv.ParseInt(spColMaxLength.SpColMaxLength, 10, 64) + } + sessionState.Conv.SpSchema[tId].ColDefs[colDef.Id] = spColDef + } + } + common.ComputeNonKeyColumnSize(sessionState.Conv, tId) + } + } else { + for _, colDef := range sessionState.Conv.SpSchema[associatedObjects].ColDefs { + if colDef.T.Name == spColMaxLength.SpDataType { + spColDef := colDef + if spColDef.T.Len == ddl.MaxLength { + table.UpdateColumnSize(spColMaxLength.SpColMaxLength, associatedObjects, colDef.Id, sessionState.Conv) + } + } + } + common.ComputeNonKeyColumnSize(sessionState.Conv, associatedObjects) + } +} + +func revertSpColMaxLength(spColMaxLength types.ColMaxLength, associatedObjects string) { + sessionState := session.GetSessionState() + spColLen, _ := strconv.ParseInt(spColMaxLength.SpColMaxLength, 10, 64) + if associatedObjects == "All tables" { + for tId := range sessionState.Conv.SpSchema { + for colId, colDef := range sessionState.Conv.SpSchema[tId].ColDefs { + if colDef.T.Name == spColMaxLength.SpDataType { + utilities.UpdateMaxColumnLen(sessionState.Conv, spColMaxLength.SpDataType, tId, colId, spColLen) + } + } + common.ComputeNonKeyColumnSize(sessionState.Conv, tId) + } + } else { + for colId, colDef := range sessionState.Conv.SpSchema[associatedObjects].ColDefs { + if colDef.T.Name == spColMaxLength.SpDataType { + utilities.UpdateMaxColumnLen(sessionState.Conv, spColMaxLength.SpDataType, associatedObjects, colId, spColLen) + } + } + common.ComputeNonKeyColumnSize(sessionState.Conv, associatedObjects) + } +} + +// revertGlobalDataType revert back the spanner type to default +// when the rule that is used to apply the data-type change is deleted. +// It takes a map from source type to Spanner type and updates +// the Spanner schema accordingly. +func revertGlobalDataType(typeMap map[string]string) { + sessionState := session.GetSessionState() + + for tableId, spSchema := range sessionState.Conv.SpSchema { + for colId, colDef := range spSchema.ColDefs { + srcColDef, found := sessionState.Conv.SrcSchema[tableId].ColDefs[colId] + if !found { + continue + } + spType, found := typeMap[srcColDef.Type.Name] + + if !found { + continue + } + + if colDef.T.Name == spType { + utilities.UpdateDataType(sessionState.Conv, "", tableId, colId) + } + } + common.ComputeNonKeyColumnSize(sessionState.Conv, tableId) + } +} + +func removeShardIdColumnFromForeignKeys(isAddedAtFirst bool) { + sessionState := session.GetSessionState() + for tableId, table := range sessionState.Conv.SpSchema { + for i, fk := range table.ForeignKeys { + + if isAddedAtFirst { + fk.ColIds = fk.ColIds[1:] + fk.ReferColumnIds = fk.ReferColumnIds[1:] + } else { + fk.ColIds = fk.ColIds[:len(fk.ColIds)-1] + fk.ReferColumnIds = fk.ReferColumnIds[:len(fk.ReferColumnIds)-1] + } + sessionState.Conv.SpSchema[tableId].ForeignKeys[i] = fk + } + } +} + +func revertShardIdColumnAsPrimaryKey(isAddedAtFirst bool) { + sessionState := session.GetSessionState() + for _, table := range sessionState.Conv.SpSchema { + pkRequest := primarykey.PrimaryKeyRequest{ + TableId: table.Id, + Columns: []ddl.IndexKey{}, + } + for index := range table.PrimaryKeys { + pk := table.PrimaryKeys[index] + if pk.ColId != table.ShardIdColumn { + decrement := 0 + if isAddedAtFirst { + decrement = 1 + } + pkRequest.Columns = append(pkRequest.Columns, ddl.IndexKey{ColId: pk.ColId, Order: pk.Order - decrement, Desc: pk.Desc}) + } + } + primarykey.UpdatePrimaryKeyAndSessionFile(pkRequest) + } +} + +func checkInterleaving() string { + sessionState := session.GetSessionState() + for _, spSchema := range sessionState.Conv.SpSchema { + if spSchema.ParentId != "" { + return spSchema.Name + } + } + return "" +} \ No newline at end of file diff --git a/webv2/api/rules_test.go b/webv2/api/rules_test.go new file mode 100644 index 0000000000..83daefcc2d --- /dev/null +++ b/webv2/api/rules_test.go @@ -0,0 +1,780 @@ +package api_test + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/proto/migration" + "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/api" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/session" + "github.com/stretchr/testify/assert" +) + + +func TestApplyRule(t *testing.T) { + tcAddIndex := []struct { + name string + input internal.Rule + statusCode int64 + conv *internal.Conv + expectedConv *internal.Conv + }{ + { + name: "Add Index with unique name", + input: internal.Rule{ + Name: "rule-index1", + ObjectType: "Table", + AssociatedObjects: "t1", + Enabled: true, + Type: constants.AddIndex, + Data: ddl.CreateIndex{ + Name: "idx3", + TableId: "t1", + Unique: false, + Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}, + }, + }, + statusCode: http.StatusOK, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{ + {Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, + {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}}, + }}, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + }, + UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, + }, + expectedConv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{ + {Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, + {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, + {Id: "i1", Name: "idx3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, + }, + }}, + }, + }, + { + name: "New name conflicts with an existing table", + input: internal.Rule{ + Name: "rule-index1", + ObjectType: "Table", + AssociatedObjects: "t1", + Enabled: true, + Type: constants.AddIndex, + Data: map[string]interface{}{ + "Name": "table1", + "TableId": "t1", + "Unique": false, + "Keys": []interface{}{map[string]interface{}{"ColId": "c2", "Desc": false}}, + }, + }, + statusCode: http.StatusInternalServerError, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{{Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, + {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}}, + }}, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + }, + UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, + }, + }, + { + name: "New name conflicts with an existing index", + input: internal.Rule{ + Name: "rule-index1", + ObjectType: "Table", + AssociatedObjects: "t1", + Enabled: true, + Type: constants.AddIndex, + Data: map[string]interface{}{ + "Name": "idx2", + "TableId": "t1", + "Unique": false, + "Keys": []interface{}{map[string]interface{}{"ColId": "c2", "Desc": false}}, + }, + }, + statusCode: http.StatusInternalServerError, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{{Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, + {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}}, + }}, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + }, + UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, + }, + }, + { + name: "Invalid input", + input: internal.Rule{ + Name: "rule-index1", + ObjectType: "Table", + AssociatedObjects: "t1", + Enabled: true, + Type: constants.AddIndex, + Data: []string{"test1"}, + }, + statusCode: http.StatusInternalServerError, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{{Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, + {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}}, + }}, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + }, + UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, + }, + }, + } + for _, tc := range tcAddIndex { + sessionState := session.GetSessionState() + + sessionState.Driver = constants.MYSQL + sessionState.Conv = tc.conv + + inputBytes, err := json.Marshal(tc.input) + if err != nil { + t.Fatal(err) + } + buffer := bytes.NewBuffer(inputBytes) + + req, err := http.NewRequest("POST", "/applyrule", buffer) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + handler := http.HandlerFunc(api.ApplyRule) + handler.ServeHTTP(rr, req) + var res *internal.Conv + json.Unmarshal(rr.Body.Bytes(), &res) + if status := rr.Code; int64(status) != tc.statusCode { + t.Errorf("%s : handler returned wrong status code: got %v want %v", + tc.name, status, tc.statusCode) + } + if tc.statusCode == http.StatusOK { + tc.expectedConv.Rules = internal.MakeConv().Rules + tc.expectedConv.Rules = append(tc.expectedConv.Rules, tc.input) + + // Marshall and unmarshall the data field of rule with its proper type i.e ddl.CreateIndex. + // Else unmarshalling data field of rule as interface convert int to float64. + // In this particular case, order of index-key would be unmarshall to float64 instead of int. + dataBytes, err := json.Marshal(res.Rules[0].Data) + assert.Equal(t, err, nil) + var data ddl.CreateIndex + json.Unmarshal(dataBytes, &data) + + // Removing random ids before comparison. + addedRule := res.Rules[0] + data.Id = "" + addedRule.Data = data + addedRule.Id = "" + res.Rules[0] = addedRule + + assert.Equal(t, tc.expectedConv, res) + } + } + + tcSetGlobalDataTypePostgres := []struct { + name string + payload string + statusCode int64 + expectedSchema ddl.CreateTable + expectedIssues internal.TableIssues + }{ + { + name: "Test type change", + payload: `{ + "Name": "rule1", + "Type": "global_datatype_change", + "ObjectType": "Column", + "AssociatedObjects": "All Columns", + "Enabled": true, + "Data": + { + "bool":"STRING", + "int8":"STRING", + "float4":"STRING", + "varchar":"BYTES", + "numeric":"STRING", + "timestamptz":"STRING", + "bigserial":"STRING", + "bpchar":"BYTES", + "bytea":"STRING", + "date":"STRING", + "float8":"STRING", + "int4":"STRING", + "serial":"STRING", + "text":"BYTES", + "timestamp":"STRING" + } + }`, + statusCode: http.StatusOK, + expectedSchema: ddl.CreateTable{ + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13", "c14", "c15", "c16"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c4": {Name: "d", Id: "c4", T: ddl.Type{Name: ddl.Bytes, Len: 6}}, + "c5": {Name: "e", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c6": {Name: "f", Id: "c6", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c7": {Name: "g", Id: "c7", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c8": {Name: "h", Id: "c8", T: ddl.Type{Name: ddl.Bytes, Len: int64(1)}}, + "c9": {Name: "i", Id: "c9", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c10": {Name: "j", Id: "c10", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c11": {Name: "k", Id: "c11", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c12": {Name: "l", Id: "c12", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c13": {Name: "m", Id: "c13", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c14": {Name: "n", Id: "c14", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, + "c15": {Name: "o", Id: "c15", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c16": {Name: "p", Id: "c16", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + expectedIssues: internal.TableIssues{ + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c2": {internal.Widened}, + "c3": {internal.Widened}, + "c5": {internal.Widened}, + "c6": {internal.Widened}, + "c7": {internal.Widened, internal.Serial}, + "c10": {internal.Widened}, + "c11": {internal.Widened}, + "c12": {internal.Widened}, + "c13": {internal.Widened, internal.Serial}, + "c15": {internal.Widened}, + "c16": {internal.Widened}, + }, + }, + }, + { + name: "Test type change 2", + payload: `{ + "Name": "rule1", + "Type": "global_datatype_change", + "ObjectType": "Column", + "AssociatedObjects": "All Columns", + "Enabled": true, + "Data": + { + "bool":"INT64", + "int8":"STRING", + "float4":"STRING" + } + }`, + statusCode: http.StatusOK, + expectedSchema: ddl.CreateTable{ + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13", "c14", "c15", "c16"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, + "c4": {Name: "d", Id: "c4", T: ddl.Type{Name: ddl.String, Len: int64(6)}}, + "c5": {Name: "e", Id: "c5", T: ddl.Type{Name: ddl.Numeric}}, + "c6": {Name: "f", Id: "c6", T: ddl.Type{Name: ddl.Timestamp}}, + "c7": {Name: "g", Id: "c7", T: ddl.Type{Name: ddl.Int64}}, + "c8": {Name: "h", Id: "c8", T: ddl.Type{Name: ddl.String, Len: int64(1)}}, + "c9": {Name: "i", Id: "c9", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, + "c10": {Name: "j", Id: "c10", T: ddl.Type{Name: ddl.Date}}, + "c11": {Name: "k", Id: "c11", T: ddl.Type{Name: ddl.Float64}}, + "c12": {Name: "l", Id: "c12", T: ddl.Type{Name: ddl.Int64}}, + "c13": {Name: "m", Id: "c13", T: ddl.Type{Name: ddl.Int64}}, + "c14": {Name: "n", Id: "c14", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c15": {Name: "o", Id: "c15", T: ddl.Type{Name: ddl.Timestamp}}, + "c16": {Name: "p", Id: "c16", T: ddl.Type{Name: ddl.Int64}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + expectedIssues: internal.TableIssues{ + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c2": {internal.Widened}, + "c3": {internal.Widened}, + "c7": {internal.Serial}, + "c12": {internal.Widened}, + "c13": {internal.Serial}, + "c15": {internal.Timestamp}, + "c16": {internal.Widened}, + }, + }, + }, + { + name: "Test bad payload data request", + payload: `{ + "Name": "rule1", + "Type": "global_datatype_change", + "ObjectType": "Column", + "AssociatedObjects": "All Columns", + "Enabled": true, + "Data": + { + "bool":"INT64", + "int8":"STRING", + "float4":"STRING", + } + }`, + statusCode: http.StatusBadRequest, + }, + } + for _, tc := range tcSetGlobalDataTypePostgres { + + sessionState := session.GetSessionState() + + sessionState.Driver = constants.POSTGRES + sessionState.Conv = internal.MakeConv() + buildConvPostgres(sessionState.Conv) + payload := tc.payload + req, err := http.NewRequest("POST", "/applyrule", strings.NewReader(payload)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + handler := http.HandlerFunc(api.ApplyRule) + handler.ServeHTTP(rr, req) + var res *internal.Conv + json.Unmarshal(rr.Body.Bytes(), &res) + if status := rr.Code; int64(status) != tc.statusCode { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tc.statusCode) + } + + if tc.statusCode == http.StatusOK { + assert.Equal(t, tc.expectedSchema, res.SpSchema["t1"]) + assert.Equal(t, tc.expectedIssues, res.SchemaIssues["t1"]) + } + } + + tcSetGlobalDataTypeMysql := []struct { + name string + payload string + statusCode int64 + expectedSchema ddl.CreateTable + expectedIssues internal.TableIssues + }{ + { + name: "Test type change", + payload: `{ + "Name": "rule1", + "Type": "global_datatype_change", + "ObjectType": "Column", + "AssociatedObjects": "All Columns", + "Enabled": true, + "Data": + { + "bool":"STRING", + "smallint":"STRING", + "float":"STRING", + "varchar":"BYTES", + "numeric":"STRING", + "timestamp":"STRING", + "decimal":"STRING", + "json":"BYTES", + "binary":"STRING", + "blob":"STRING", + "double":"STRING", + "date":"STRING", + "time":"STRING", + "enum":"STRING", + "text":"BYTES" + } + }`, + statusCode: http.StatusOK, + expectedSchema: ddl.CreateTable{ + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13", "c14", "c15", "c16"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c4": {Name: "d", Id: "c4", T: ddl.Type{Name: ddl.Bytes, Len: 6}}, + "c5": {Name: "e", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c6": {Name: "f", Id: "c6", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c7": {Name: "g", Id: "c7", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, + "c8": {Name: "h", Id: "c8", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c9": {Name: "i", Id: "c9", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c10": {Name: "j", Id: "c10", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c11": {Name: "k", Id: "c11", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c12": {Name: "l", Id: "c12", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c13": {Name: "m", Id: "c13", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c14": {Name: "n", Id: "c14", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c15": {Name: "o", Id: "c15", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c16": {Name: "p", Id: "c16", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + expectedIssues: internal.TableIssues{ + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c3": {internal.Widened}, + "c5": {internal.Widened}, + "c10": {internal.Widened}, + "c11": {internal.Widened}, + "c12": {internal.Widened}, + "c13": {internal.Widened}, + "c14": {internal.Widened}, + "c15": {internal.Widened}, + "c16": {internal.Time}, + }, + }, + }, + { + name: "Test type change 2", + payload: `{ + "Name": "rule1", + "Type": "global_datatype_change", + "ObjectType": "Column", + "AssociatedObjects": "All Columns", + "Enabled": true, + "Data": + { + "bool":"INT64", + "varchar":"BYTES" + } + }`, + statusCode: http.StatusOK, + expectedSchema: ddl.CreateTable{ + Name: "table1", + Id: "t1", + ColIds: []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13", "c14", "c15", "c16"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.Int64}}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, + "c4": {Name: "d", Id: "c4", T: ddl.Type{Name: ddl.Bytes, Len: 6}}, + "c5": {Name: "e", Id: "c5", T: ddl.Type{Name: ddl.Numeric}}, + "c6": {Name: "f", Id: "c6", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c7": {Name: "g", Id: "c7", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + "c8": {Name: "h", Id: "c8", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, + "c9": {Name: "i", Id: "c9", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, + "c10": {Name: "j", Id: "c10", T: ddl.Type{Name: ddl.Int64}}, + "c11": {Name: "k", Id: "c11", T: ddl.Type{Name: ddl.Float64}}, + "c12": {Name: "l", Id: "c12", T: ddl.Type{Name: ddl.Float64}}, + "c13": {Name: "m", Id: "c13", T: ddl.Type{Name: ddl.Numeric}}, + "c14": {Name: "n", Id: "c14", T: ddl.Type{Name: ddl.Date}}, + "c15": {Name: "o", Id: "c15", T: ddl.Type{Name: ddl.Timestamp}}, + "c16": {Name: "p", Id: "c16", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, + }, + expectedIssues: internal.TableIssues{ + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.Widened}, + "c3": {internal.Widened}, + "c10": {internal.Widened}, + "c12": {internal.Widened}, + "c15": {internal.Time}, + }, + }, + }, + { + name: "Test bad request", + payload: `{ + "Name": "rule1", + "Type": "global_datatype_change", + "ObjectType": "Column", + "AssociatedObjects": "All Columns", + "Enabled": true, + "Data": + { + "bool":"INT64", + "smallint":"STRING", + } + }`, + statusCode: http.StatusBadRequest, + }, + } + for _, tc := range tcSetGlobalDataTypeMysql { + sessionState := session.GetSessionState() + + sessionState.Driver = constants.MYSQL + sessionState.Conv = internal.MakeConv() + buildConvMySQL(sessionState.Conv) + payload := tc.payload + req, err := http.NewRequest("POST", "/applyrule", strings.NewReader(payload)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + handler := http.HandlerFunc(api.ApplyRule) + handler.ServeHTTP(rr, req) + var res *internal.Conv + json.Unmarshal(rr.Body.Bytes(), &res) + if status := rr.Code; int64(status) != tc.statusCode { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tc.statusCode) + } + + if tc.statusCode == http.StatusOK { + assert.Equal(t, tc.expectedSchema, res.SpSchema["t1"]) + assert.Equal(t, tc.expectedIssues, res.SchemaIssues["t1"]) + } + } +} + +func TestDropRule(t *testing.T) { + tc := []struct { + name string + ruleId string + statusCode int64 + conv *internal.Conv + expectedConv *internal.Conv + }{ + { + name: "drop a valid add index rule", + ruleId: "r101", + statusCode: http.StatusOK, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{ + {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, + {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, + {Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, + }, + }}, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + }, + UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true, "idx3": true}, + Rules: []internal.Rule{{ + Id: "r101", + Name: "add_index", + Type: constants.AddIndex, + ObjectType: "table", + AssociatedObjects: "t1", + Enabled: true, + Data: ddl.CreateIndex{Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, + }}, + }, + expectedConv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{ + {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, + {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, + }, + }}, + }, + }, + { + name: "drop a vaild add global data type rule", + ruleId: "r101", + statusCode: http.StatusOK, + conv: &internal.Conv{ + SchemaIssues: map[string]internal.TableIssues{ + "t1": {}, + }, + SrcSchema: map[string]schema.Table{ + "t1": { + Name: "table1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]schema.Column{ + "c1": {Name: "a", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c1"}, + "c2": {Name: "b", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c2"}, + "c3": {Name: "c", Type: schema.Type{Name: "varchar"}, NotNull: false, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c3"}, + }, + PrimaryKeys: []schema.Key{{ColId: "c1", Desc: false, Order: 1}}, + Id: "t1", + }, + }, + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}}, + Id: "t1", + }, + }, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), + }, + Rules: []internal.Rule{ + { + Id: "r101", + Name: "bigint to BTYES", + Type: constants.GlobalDataTypeChange, + ObjectType: "Column", + AssociatedObjects: "All Columns", + Enabled: true, + Data: map[string]string{ + "bigint": ddl.String, + }, + }, + }, + }, + expectedConv: &internal.Conv{ + SchemaIssues: map[string]internal.TableIssues{ + "t1": {}, + }, + SrcSchema: map[string]schema.Table{ + "t1": { + Name: "table1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]schema.Column{ + "c1": {Name: "a", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c1"}, + "c2": {Name: "b", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c2"}, + "c3": {Name: "c", Type: schema.Type{Name: "varchar"}, NotNull: false, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c3"}, + }, + PrimaryKeys: []schema.Key{{ColId: "c1", Desc: false, Order: 1}}, + Id: "t1", + }, + }, + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}}, + Id: "t1", + }, + }, + }, + }, + { + name: "drop rule with an invalid rule-id", + ruleId: "ABC", + statusCode: http.StatusBadRequest, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{ + {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, + {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, + {Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, + }, + }}, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + }, + UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true, "idx3": true}, + Rules: []internal.Rule{{ + Id: "r101", + Name: "add_index", + Type: constants.AddIndex, + ObjectType: "table", + AssociatedObjects: "t1", + Enabled: true, + Data: ddl.CreateIndex{Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, + }}, + }, + }, + { + name: "drop a disabled valid add index rule", + ruleId: "r101", + statusCode: http.StatusOK, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{ + {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, + {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, + }, + }}, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + }, + UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, + Rules: []internal.Rule{{ + Id: "r101", + Name: "add_index", + Type: constants.AddIndex, + ObjectType: "table", + AssociatedObjects: "t1", + Enabled: false, + Data: ddl.CreateIndex{Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, + }}, + }, + expectedConv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "table1", + Id: "t1", + Indexes: []ddl.CreateIndex{ + {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, + {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, + }, + }}, + }, + }, + } + for _, tc := range tc { + sessionState := session.GetSessionState() + sessionState.Driver = constants.MYSQL + sessionState.Conv = tc.conv + payload := `{}` + req, err := http.NewRequest("POST", "/dropRule?id="+tc.ruleId, strings.NewReader(payload)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + handler := http.HandlerFunc(api.DropRule) + handler.ServeHTTP(rr, req) + var res *internal.Conv + json.Unmarshal(rr.Body.Bytes(), &res) + if status := rr.Code; int64(status) != tc.statusCode { + t.Errorf("%s : handler returned wrong status code: got %v want %v", + tc.name, status, tc.statusCode) + } + if tc.statusCode == http.StatusOK { + assert.Equal(t, tc.expectedConv, res) + } + } + +} \ No newline at end of file diff --git a/webv2/api/schema.go b/webv2/api/schema.go new file mode 100644 index 0000000000..44eeb0996d --- /dev/null +++ b/webv2/api/schema.go @@ -0,0 +1,1536 @@ +package api + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal/reports" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/oracle" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/postgres" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/sqlserver" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/config" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/helpers" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/index" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/primarykey" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/session" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/types" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/utilities" +) + +var mysqlDefaultTypeMap = make(map[string]ddl.Type) +var postgresDefaultTypeMap = make(map[string]ddl.Type) +var sqlserverDefaultTypeMap = make(map[string]ddl.Type) +var oracleDefaultTypeMap = make(map[string]ddl.Type) + +var mysqlTypeMap = make(map[string][]types.TypeIssue) +var postgresTypeMap = make(map[string][]types.TypeIssue) +var sqlserverTypeMap = make(map[string][]types.TypeIssue) +var oracleTypeMap = make(map[string][]types.TypeIssue) + +func init() { + sessionState := session.GetSessionState() + utilities.InitObjectId() + sessionState.Conv = internal.MakeConv() + config := config.TryInitializeSpannerConfig() + session.SetSessionStorageConnectionState(config.GCPProjectID, config.SpannerInstanceID) +} + +// ConvertSchemaSQL converts source database to Spanner when using +// with postgres and mysql driver. +func ConvertSchemaSQL(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + if sessionState.SourceDB == nil || sessionState.DbName == "" || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Database is not configured or Database connection is lost. Please set configuration and connect to database."), http.StatusNotFound) + return + } + conv := internal.MakeConv() + + conv.SpDialect = sessionState.Dialect + conv.IsSharded = sessionState.IsSharded + var err error + additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ + IsSharded: sessionState.IsSharded, + } + processSchema := common.ProcessSchemaImpl{} + switch sessionState.Driver { + case constants.MYSQL: + err = processSchema.ProcessSchema(conv, mysql.InfoSchemaImpl{DbName: sessionState.DbName, Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) + case constants.POSTGRES: + temp := false + err = processSchema.ProcessSchema(conv, postgres.InfoSchemaImpl{Db: sessionState.SourceDB, IsSchemaUnique: &temp}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) + case constants.SQLSERVER: + err = processSchema.ProcessSchema(conv, sqlserver.InfoSchemaImpl{DbName: sessionState.DbName, Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) + case constants.ORACLE: + err = processSchema.ProcessSchema(conv, oracle.InfoSchemaImpl{DbName: strings.ToUpper(sessionState.DbName), Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) + default: + http.Error(w, fmt.Sprintf("Driver : '%s' is not supported", sessionState.Driver), http.StatusBadRequest) + return + } + if err != nil { + http.Error(w, fmt.Sprintf("Schema Conversion Error : %v", err), http.StatusNotFound) + return + } + + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + sessionState.Conv = conv + + if sessionState.IsSharded { + setShardIdColumnAsPrimaryKey(true) + addShardIdColumnToForeignKeys(true) + ruleId := internal.GenerateRuleId() + rule := internal.Rule{ + Id: ruleId, + Name: ruleId, + Type: constants.AddShardIdPrimaryKey, + AssociatedObjects: "All Tables", + Data: types.ShardIdPrimaryKey{ + AddedAtTheStart: true, + }, + Enabled: true, + } + + sessionState := session.GetSessionState() + sessionState.Conv.Rules = append(sessionState.Conv.Rules, rule) + session.UpdateSessionFile() + } + + primarykey.DetectHotspot() + index.IndexSuggestion() + + sessionMetadata := session.SessionMetadata{ + SessionName: "NewSession", + DatabaseType: sessionState.Driver, + DatabaseName: sessionState.DbName, + Dialect: sessionState.Dialect, + } + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionMetadata, + Conv: *sessionState.Conv, + } + sessionState.SessionMetadata = sessionMetadata + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +// ConvertSchemaDump converts schema from dump file to Spanner schema for +// mysqldump and pg_dump driver. +func ConvertSchemaDump(w http.ResponseWriter, r *http.Request) { + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + return + } + var dc types.ConvertFromDumpRequest + err = json.Unmarshal(reqBody, &dc) + if err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + f, err := os.Open(constants.UPLOAD_FILE_DIR + "/" + dc.Config.FilePath) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to open dump file : %v, no such file or directory", dc.Config.FilePath), http.StatusNotFound) + return + } + // We don't support Dynamodb in web hence no need to pass schema sample size here. + n := profiles.NewSourceProfileImpl{} + sourceProfile, _ := profiles.NewSourceProfile("", dc.Config.Driver, &n) + sourceProfile.Driver = dc.Config.Driver + schemaFromSource := conversion.SchemaFromSourceImpl{} + conv, err := schemaFromSource.SchemaFromDump(sourceProfile.Driver, dc.SpannerDetails.Dialect, &utils.IOStreams{In: f, Out: os.Stdout}, &conversion.ProcessDumpByDialectImpl{}) + if err != nil { + http.Error(w, fmt.Sprintf("Schema Conversion Error : %v", err), http.StatusNotFound) + return + } + + sessionMetadata := session.SessionMetadata{ + SessionName: "NewSession", + DatabaseType: dc.Config.Driver, + DatabaseName: filepath.Base(dc.Config.FilePath), + Dialect: dc.SpannerDetails.Dialect, + } + + sessionState := session.GetSessionState() + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + sessionState.Conv = conv + + primarykey.DetectHotspot() + index.IndexSuggestion() + + sessionState.SessionMetadata = sessionMetadata + sessionState.Driver = dc.Config.Driver + sessionState.DbName = "" + sessionState.SessionFile = "" + sessionState.SourceDB = nil + sessionState.Dialect = dc.SpannerDetails.Dialect + sessionState.SourceDBConnDetails = session.SourceDBConnDetails{ + Path: constants.UPLOAD_FILE_DIR + "/" + dc.Config.FilePath, + ConnectionType: helpers.DUMP_MODE, + } + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionMetadata, + Conv: *conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +// GetDDL returns the Spanner DDL for each table in alphabetical order. +// Unlike internal/convert.go's GetDDL, it does not print tables in a way that +// respects the parent/child ordering of interleaved tables. +// Though foreign keys and secondary indexes are displayed, getDDL cannot be used to +// build DDL to send to Spanner. +func GetDDL(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + sessionState.Conv.ConvLock.RLock() + defer sessionState.Conv.ConvLock.RUnlock() + c := ddl.Config{Comments: true, ProtectIds: false, SpDialect: sessionState.Conv.SpDialect, Source: sessionState.Driver} + var tables []string + for t := range sessionState.Conv.SpSchema { + tables = append(tables, t) + } + sort.Strings(tables) + ddl := make(map[string]string) + for _, t := range tables { + table := sessionState.Conv.SpSchema[t] + tableDdl := table.PrintCreateTable(sessionState.Conv.SpSchema, c) + ";" + if len(table.Indexes) > 0 { + tableDdl = tableDdl + "\n" + } + for _, index := range table.Indexes { + tableDdl = tableDdl + "\n" + index.PrintCreateIndex(table, c) + ";" + } + if len(table.ForeignKeys) > 0 { + tableDdl = tableDdl + "\n" + } + for _, fk := range table.ForeignKeys { + tableDdl = tableDdl + "\n" + fk.PrintForeignKeyAlterTable(sessionState.Conv.SpSchema, c, t) + ";" + } + + ddl[t] = tableDdl + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(ddl) +} + +func GetStandardTypeToPGSQLTypemap(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(ddl.STANDARD_TYPE_TO_PGSQL_TYPEMAP) +} + +func GetPGSQLToStandardTypeTypemap(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(ddl.PGSQL_TO_STANDARD_TYPE_TYPEMAP) +} + +func SpannerDefaultTypeMap(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, "Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner.", http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + initializeTypeMap() + + var typeMap map[string]ddl.Type + switch sessionState.Driver { + case constants.MYSQL, constants.MYSQLDUMP: + typeMap = mysqlDefaultTypeMap + case constants.POSTGRES, constants.PGDUMP: + typeMap = postgresDefaultTypeMap + case constants.SQLSERVER: + typeMap = sqlserverDefaultTypeMap + case constants.ORACLE: + typeMap = oracleDefaultTypeMap + default: + http.Error(w, fmt.Sprintf("Driver : '%s' is not supported", sessionState.Driver), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(typeMap) +} + +// GetTypeMap returns the source to Spanner typemap only for the +// source types used in current conversion. +func GetTypeMap(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + var typeMap map[string][]types.TypeIssue + initializeTypeMap() + switch sessionState.Driver { + case constants.MYSQL, constants.MYSQLDUMP: + typeMap = mysqlTypeMap + case constants.POSTGRES, constants.PGDUMP: + typeMap = postgresTypeMap + case constants.SQLSERVER: + typeMap = sqlserverTypeMap + case constants.ORACLE: + typeMap = oracleTypeMap + default: + http.Error(w, fmt.Sprintf("Driver : '%s' is not supported", sessionState.Driver), http.StatusBadRequest) + return + } + // Filter typeMap so it contains just the types SrcSchema uses. + filteredTypeMap := make(map[string][]types.TypeIssue) + for _, srcTable := range sessionState.Conv.SrcSchema { + for _, colDef := range srcTable.ColDefs { + if _, ok := filteredTypeMap[colDef.Type.Name]; ok { + continue + } + // Timestamp and interval types do not have exact key in typemap. + // Typemap for TIMESTAMP(6), TIMESTAMP(6) WITH LOCAL TIMEZONE,TIMESTAMP(6) WITH TIMEZONE is stored into TIMESTAMP key. + // Same goes with interval types like INTERVAL YEAR(2) TO MONTH, INTERVAL DAY(2) TO SECOND(6) etc. + // If exact key not found then check with regex. + if _, ok := typeMap[colDef.Type.Name]; !ok { + if oracle.TimestampReg.MatchString(colDef.Type.Name) { + filteredTypeMap[colDef.Type.Name] = typeMap["TIMESTAMP"] + } else if oracle.IntervalReg.MatchString(colDef.Type.Name) { + filteredTypeMap[colDef.Type.Name] = typeMap["INTERVAL"] + } + continue + } + filteredTypeMap[colDef.Type.Name] = typeMap[colDef.Type.Name] + } + } + for key, values := range filteredTypeMap { + for i := range values { + if sessionState.Dialect == constants.DIALECT_POSTGRESQL { + spType := ddl.Type{ + Name: filteredTypeMap[key][i].T, + } + filteredTypeMap[key][i].DisplayT = ddl.GetPGType(spType) + } else { + filteredTypeMap[key][i].DisplayT = filteredTypeMap[key][i].T + } + + } + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(filteredTypeMap) +} + +// GetTableWithErrors checks the errors in the spanner schema +// and returns a list of tables with errors +func GetTableWithErrors(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + sessionState.Conv.ConvLock.RLock() + defer sessionState.Conv.ConvLock.RUnlock() + var tableIdName []types.TableIdAndName + for id, issues := range sessionState.Conv.SchemaIssues { + if len(issues.TableLevelIssues) != 0 { + t := types.TableIdAndName{ + Id: id, + Name: sessionState.Conv.SpSchema[id].Name, + } + tableIdName = append(tableIdName, t) + } + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(tableIdName) +} + +func RestoreTables(w http.ResponseWriter, r *http.Request) { + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + return + } + var tables internal.Tables + err = json.Unmarshal(reqBody, &tables) + if err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + var convm session.ConvWithMetadata + for _, tableId := range tables.TableList { + convm = restoreTableHelper(w, tableId) + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +func RestoreTable(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("table") + convm := restoreTableHelper(w, tableId) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +func DropTables(w http.ResponseWriter, r *http.Request) { + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + return + } + var tables internal.Tables + err = json.Unmarshal(reqBody, &tables) + if err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + var convm session.ConvWithMetadata + for _, tableId := range tables.TableList { + convm = dropTableHelper(w, tableId) + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +func DropTable(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("table") + convm := dropTableHelper(w, tableId) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +func RestoreSecondaryIndex(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("tableId") + indexId := r.FormValue("indexId") + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + if tableId == "" { + http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) + return + } + if indexId == "" { + http.Error(w, fmt.Sprintf("Index Id is empty"), http.StatusBadRequest) + return + } + + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + var srcIndex schema.Index + srcIndexFound := false + for _, index := range sessionState.Conv.SrcSchema[tableId].Indexes { + if index.Id == indexId { + srcIndex = index + srcIndexFound = true + break + } + } + if !srcIndexFound { + http.Error(w, fmt.Sprintf("Source index not found"), http.StatusBadRequest) + return + } + + conv := sessionState.Conv + + spIndex := common.CvtIndexHelper(conv, tableId, srcIndex, conv.SpSchema[tableId].ColIds, conv.SpSchema[tableId].ColDefs) + spIndexes := conv.SpSchema[tableId].Indexes + spIndexes = append(spIndexes, spIndex) + spTable := conv.SpSchema[tableId] + spTable.Indexes = spIndexes + conv.SpSchema[tableId] = spTable + + sessionState.Conv = conv + index.AssignInitialOrders() + index.IndexSuggestion() + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) + +} + +// renameForeignKeys checks the new names for spanner name validity, ensures the new names are already not used by existing tables +// secondary indexes or foreign key constraints. If above checks passed then foreignKey renaming reflected in the schema else appropriate +// error thrown. +func UpdateForeignKeys(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("table") + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + } + + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + newFKs := []ddl.Foreignkey{} + if err = json.Unmarshal(reqBody, &newFKs); err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + + // Check new name for spanner name validity. + newNames := []string{} + newNamesMap := map[string]bool{} + for _, newFk := range newFKs { + if len(newFk.Name) == 0 { + continue + } + for _, oldFk := range sessionState.Conv.SpSchema[tableId].ForeignKeys { + if newFk.Id == oldFk.Id && newFk.Name != oldFk.Name && newFk.Name != "" { + newNames = append(newNames, strings.ToLower(newFk.Name)) + } + } + } + + for _, newFk := range newFKs { + if len(newFk.Name) == 0 { + continue + } + if _, ok := newNamesMap[strings.ToLower(newFk.Name)]; ok { + http.Error(w, fmt.Sprintf("Found duplicate names in input : %s", strings.ToLower(newFk.Name)), http.StatusBadRequest) + return + } + newNamesMap[strings.ToLower(newFk.Name)] = true + } + + if ok, invalidNames := utilities.CheckSpannerNamesValidity(newNames); !ok { + http.Error(w, fmt.Sprintf("Following names are not valid Spanner identifiers: %s", strings.Join(invalidNames, ",")), http.StatusBadRequest) + return + } + + // Check that the new names are not already used by existing tables, secondary indexes or foreign key constraints. + if ok, err := utilities.CanRename(newNames, tableId); !ok { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + sp := sessionState.Conv.SpSchema[tableId] + usedNames := sessionState.Conv.UsedNames + + // Update session with renamed foreignkeys. + updatedFKs := []ddl.Foreignkey{} + + for _, foreignKey := range sp.ForeignKeys { + for _, updatedForeignkey := range newFKs { + if foreignKey.Id == updatedForeignkey.Id && len(updatedForeignkey.ColIds) != 0 && updatedForeignkey.ReferTableId != "" { + delete(usedNames, strings.ToLower(foreignKey.Name)) + foreignKey.Name = updatedForeignkey.Name + updatedFKs = append(updatedFKs, foreignKey) + } + } + } + + position := -1 + + for i, fk := range updatedFKs { + // Condition to check whether FK has to be dropped + if len(fk.ReferColumnIds) == 0 && fk.ReferTableId == "" { + position = i + dropFkId := fk.Id + + // To remove the interleavable suggestions if they exist on dropping fk + colId := sp.ForeignKeys[position].ColIds[0] + schemaIssue := []internal.SchemaIssue{} + for _, v := range sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] { + if v != internal.InterleavedAddColumn && v != internal.InterleavedRenameColumn && v != internal.InterleavedNotInOrder && v != internal.InterleavedChangeColumnSize { + schemaIssue = append(schemaIssue, v) + } + } + if _, ok := sessionState.Conv.SchemaIssues[tableId]; ok { + sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] = schemaIssue + } + + sp.ForeignKeys = utilities.RemoveFk(updatedFKs, dropFkId) + } + } + sp.ForeignKeys = updatedFKs + sessionState.Conv.SpSchema[tableId] = sp + session.UpdateSessionFile() + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +// renameIndexes checks the new names for spanner name validity, ensures the new names are already not used by existing tables +// secondary indexes or foreign key constraints. If above checks passed then index renaming reflected in the schema else appropriate +// error thrown. +func RenameIndexes(w http.ResponseWriter, r *http.Request) { + table := r.FormValue("table") + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + } + + renameMap := map[string]string{} + if err = json.Unmarshal(reqBody, &renameMap); err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + + // Check new name for spanner name validity. + newNames := []string{} + newNamesMap := map[string]bool{} + for _, value := range renameMap { + newNames = append(newNames, strings.ToLower(value)) + newNamesMap[strings.ToLower(value)] = true + } + if len(newNames) != len(newNamesMap) { + http.Error(w, fmt.Sprintf("Found duplicate names in input : %s", strings.Join(newNames, ",")), http.StatusBadRequest) + return + } + + if ok, invalidNames := utilities.CheckSpannerNamesValidity(newNames); !ok { + http.Error(w, fmt.Sprintf("Following names are not valid Spanner identifiers: %s", strings.Join(invalidNames, ",")), http.StatusBadRequest) + return + } + + // Check that the new names are not already used by existing tables, secondary indexes or foreign key constraints. + if ok, err := utilities.CanRename(newNames, table); !ok { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + sessionState := session.GetSessionState() + + sp := sessionState.Conv.SpSchema[table] + + // Update session with renamed secondary indexes. + newIndexes := []ddl.CreateIndex{} + for _, index := range sp.Indexes { + if newName, ok := renameMap[index.Id]; ok { + index.Name = newName + } + newIndexes = append(newIndexes, index) + } + sp.Indexes = newIndexes + + sessionState.Conv.SpSchema[table] = sp + session.UpdateSessionFile() + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +// setParentTable checks whether specified table can be interleaved, and updates the schema to convert foreign +// key to interleaved table if 'update' parameter is set to true. If 'update' parameter is set to false, then return +// whether the foreign key can be converted to interleave table without updating the schema. +func SetParentTable(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("table") + update := r.FormValue("update") == "true" + sessionState := session.GetSessionState() + + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + if tableId == "" { + http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) + } + + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + tableInterleaveStatus := parentTableHelper(tableId, update) + + if tableInterleaveStatus.Possible { + + childPks := sessionState.Conv.SpSchema[tableId].PrimaryKeys + childindex := utilities.GetPrimaryKeyIndexFromOrder(childPks, 1) + schemaissue := []internal.SchemaIssue{} + + colId := childPks[childindex].ColId + schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] + if update { + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) + } else { + schemaissue = append(schemaissue, internal.InterleavedOrder) + } + + sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] = schemaissue + } else { + // Remove "Table cart can be converted as Interleaved Table" suggestion from columns + // of the table if interleaving is not possible. + for _, colId := range sessionState.Conv.SpSchema[tableId].ColIds { + schemaIssue := []internal.SchemaIssue{} + for _, v := range sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] { + if v != internal.InterleavedOrder { + schemaIssue = append(schemaIssue, v) + } + } + sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] = schemaIssue + } + } + + index.IndexSuggestion() + session.UpdateSessionFile() + w.WriteHeader(http.StatusOK) + + if update { + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "tableInterleaveStatus": tableInterleaveStatus, + "sessionState": convm}) + } else { + json.NewEncoder(w).Encode(map[string]interface{}{ + "tableInterleaveStatus": tableInterleaveStatus, + }) + } +} + +func RemoveParentTable(w http.ResponseWriter, r *http.Request) { + tableId := r.FormValue("tableId") + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + if tableId == "" { + http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) + return + } + + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + conv := sessionState.Conv + + if conv.SpSchema[tableId].ParentId == "" { + http.Error(w, fmt.Sprintf("Table is not interleaved"), http.StatusBadRequest) + return + } + spTable := conv.SpSchema[tableId] + + var firstOrderPk ddl.IndexKey + order := 1 + + isPresent, isAddedAtFirst := hasShardIdPrimaryKeyRule() + if isAddedAtFirst { + order = 2 + } + + for _, pk := range spTable.PrimaryKeys { + if pk.Order == order { + firstOrderPk = pk + break + } + } + + spColId := conv.SpSchema[tableId].ColDefs[firstOrderPk.ColId].Id + srcCol := conv.SrcSchema[tableId].ColDefs[spColId] + interleavedFk, err := utilities.GetInterleavedFk(conv, tableId, srcCol.Id) + if err != nil { + http.Error(w, fmt.Sprintf("%v", err), http.StatusBadRequest) + return + } + + spFk, err := common.CvtForeignKeysHelper(conv, conv.SpSchema[tableId].Name, tableId, interleavedFk, true) + if err != nil { + http.Error(w, fmt.Sprintf("Foreign key conversion fail"), http.StatusBadRequest) + return + } + + if isPresent { + if isAddedAtFirst { + spFk.ColIds = append([]string{spTable.ShardIdColumn}, spFk.ColIds...) + spFk.ReferColumnIds = append([]string{sessionState.Conv.SpSchema[spTable.ParentId].ShardIdColumn}, spFk.ReferColumnIds...) + } else { + spFk.ColIds = append(spFk.ColIds, spTable.ShardIdColumn) + spFk.ReferColumnIds = append(spFk.ReferColumnIds, sessionState.Conv.SpSchema[spTable.ParentId].ShardIdColumn) + } + } + + spFks := spTable.ForeignKeys + spFks = append(spFks, spFk) + spTable.ForeignKeys = spFks + spTable.ParentId = "" + conv.SpSchema[tableId] = spTable + + sessionState.Conv = conv + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) + +} + +func UpdateIndexes(w http.ResponseWriter, r *http.Request) { + table := r.FormValue("table") + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + } + + newIndexes := []ddl.CreateIndex{} + if err = json.Unmarshal(reqBody, &newIndexes); err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + + list := []int{} + for i := 0; i < len(newIndexes); i++ { + for j := 0; j < len(newIndexes[i].Keys); j++ { + list = append(list, newIndexes[i].Keys[j].Order) + } + } + + if utilities.DuplicateInArray(list) != -1 { + http.Error(w, fmt.Sprintf("Two Index columns can not have same order"), http.StatusBadRequest) + return + } + + sessionState := session.GetSessionState() + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + sp := sessionState.Conv.SpSchema[table] + + st := sessionState.Conv.SrcSchema[table] + + for i, ind := range sp.Indexes { + + if ind.TableId == newIndexes[0].TableId && ind.Id == newIndexes[0].Id { + + index.RemoveIndexIssues(table, sp.Indexes[i]) + + sp.Indexes[i].Keys = newIndexes[0].Keys + sp.Indexes[i].Name = newIndexes[0].Name + sp.Indexes[i].TableId = newIndexes[0].TableId + sp.Indexes[i].Unique = newIndexes[0].Unique + sp.Indexes[i].Id = newIndexes[0].Id + + break + } + } + + for i, spIndex := range sp.Indexes { + + for j, srcIndex := range st.Indexes { + + for k, spIndexKey := range spIndex.Keys { + + for l, srcIndexKey := range srcIndex.Keys { + + if srcIndexKey.ColId == spIndexKey.ColId { + + st.Indexes[j].Keys[l].Order = sp.Indexes[i].Keys[k].Order + } + + } + } + + } + } + + sessionState.Conv.SpSchema[table] = sp + + sessionState.Conv.SrcSchema[table] = st + + session.UpdateSessionFile() + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +func DropSecondaryIndex(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + + table := r.FormValue("table") + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) + } + + var dropDetail struct{ Id string } + if err = json.Unmarshal(reqBody, &dropDetail); err != nil { + http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) + return + } + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return + } + + if table == "" || dropDetail.Id == "" { + http.Error(w, fmt.Sprintf("Table name or position is empty"), http.StatusBadRequest) + } + err = dropSecondaryIndexHelper(table, dropDetail.Id) + if err != nil { + http.Error(w, fmt.Sprintf("%v", err), http.StatusBadRequest) + return + } + + // To set enabled value to false for the rule associated with the dropped index. + indexId := dropDetail.Id + for i, rule := range sessionState.Conv.Rules { + if rule.Type == constants.AddIndex { + d, err := json.Marshal(rule.Data) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + var index ddl.CreateIndex + err = json.Unmarshal(d, &index) + if err != nil { + http.Error(w, "Invalid rule data", http.StatusInternalServerError) + return + } + if index.Id == indexId { + sessionState.Conv.Rules[i].Enabled = false + break + } + } + } + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(convm) +} + +// GetConversionRate returns table wise color coded conversion rate. +func GetConversionRate(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + smt_reports := reports.AnalyzeTables(sessionState.Conv, nil) + rate := make(map[string]string) + for _, t := range smt_reports { + rate[t.SpTable], _ = reports.RateSchema(t.Cols, t.Warnings, t.Errors, t.SyntheticPKey != "", false) + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(rate) +} + +func restoreTableHelper(w http.ResponseWriter, tableId string) session.ConvWithMetadata { + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + } + if tableId == "" { + http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) + } + + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + conv := sessionState.Conv + var toddl common.ToDdl + switch sessionState.Driver { + case constants.MYSQL: + toddl = mysql.InfoSchemaImpl{}.GetToDdl() + case constants.POSTGRES: + toddl = postgres.InfoSchemaImpl{}.GetToDdl() + case constants.SQLSERVER: + toddl = sqlserver.InfoSchemaImpl{}.GetToDdl() + case constants.ORACLE: + toddl = oracle.InfoSchemaImpl{}.GetToDdl() + case constants.MYSQLDUMP: + toddl = mysql.DbDumpImpl{}.GetToDdl() + case constants.PGDUMP: + toddl = postgres.DbDumpImpl{}.GetToDdl() + default: + http.Error(w, fmt.Sprintf("Driver : '%s' is not supported", sessionState.Driver), http.StatusBadRequest) + } + + err := common.SrcTableToSpannerDDL(conv, toddl, sessionState.Conv.SrcSchema[tableId]) + if err != nil { + http.Error(w, fmt.Sprintf("Restoring spanner table fail"), http.StatusBadRequest) + } + conv.AddPrimaryKeys() + if sessionState.IsSharded { + conv.IsSharded = true + conv.AddShardIdColumn() + isPresent, isAddedAtFirst := hasShardIdPrimaryKeyRule() + if isPresent { + table := sessionState.Conv.SpSchema[tableId] + setShardIdColumnAsPrimaryKeyPerTable(isAddedAtFirst, table) + addShardIdToForeignKeyPerTable(isAddedAtFirst, table) + addShardIdToReferencedTableFks(tableId, isAddedAtFirst) + } + } + sessionState.Conv = conv + primarykey.DetectHotspot() + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + return convm +} + +func parentTableHelper(tableId string, update bool) *types.TableInterleaveStatus { + tableInterleaveStatus := &types.TableInterleaveStatus{ + Possible: false, + Comment: "No valid prefix", + } + sessionState := session.GetSessionState() + + if _, found := sessionState.Conv.SyntheticPKeys[tableId]; found { + tableInterleaveStatus.Possible = false + tableInterleaveStatus.Comment = "Has synthetic pk" + } + + childPks := sessionState.Conv.SpSchema[tableId].PrimaryKeys + + // Search this table's foreign keys for a suitable parent table. + // If there are several possible parent tables, we pick the first one. + // TODO: Allow users to pick which parent to use if more than one. + for i, fk := range sessionState.Conv.SpSchema[tableId].ForeignKeys { + refTableId := fk.ReferTableId + + if _, found := sessionState.Conv.SyntheticPKeys[refTableId]; found { + continue + } + + if checkPrimaryKeyPrefix(tableId, refTableId, fk, tableInterleaveStatus) { + sp := sessionState.Conv.SpSchema[tableId] + + colIdNotInOrder := checkPrimaryKeyOrder(tableId, refTableId, fk) + + if update && sp.ParentId == "" && colIdNotInOrder == "" { + usedNames := sessionState.Conv.UsedNames + delete(usedNames, strings.ToLower(sp.ForeignKeys[i].Name)) + sp.ParentId = refTableId + sp.ForeignKeys = utilities.RemoveFk(sp.ForeignKeys, sp.ForeignKeys[i].Id) + } + sessionState.Conv.SpSchema[tableId] = sp + + parentpks := sessionState.Conv.SpSchema[refTableId].PrimaryKeys + if len(parentpks) >= 1 { + if colIdNotInOrder == "" { + + schemaissue := []internal.SchemaIssue{} + for _, column := range childPks { + colId := column.ColId + schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] + + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedNotInOrder) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedAddColumn) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedRenameColumn) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedChangeColumnSize) + + sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] = schemaissue + } + + tableInterleaveStatus.Possible = true + tableInterleaveStatus.Parent = refTableId + tableInterleaveStatus.Comment = "" + + } else { + + schemaissue := []internal.SchemaIssue{} + schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIdNotInOrder] + + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedAddColumn) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedRenameColumn) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedChangeColumnSize) + + schemaissue = append(schemaissue, internal.InterleavedNotInOrder) + + sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIdNotInOrder] = schemaissue + } + } + } + } + + return tableInterleaveStatus +} + +func checkPrimaryKeyOrder(tableId string, refTableId string, fk ddl.Foreignkey) string { + sessionState := session.GetSessionState() + childPks := sessionState.Conv.SpSchema[tableId].PrimaryKeys + parentPks := sessionState.Conv.SpSchema[refTableId].PrimaryKeys + childTable := sessionState.Conv.SpSchema[tableId] + parentTable := sessionState.Conv.SpSchema[refTableId] + for i := 0; i < len(parentPks); i++ { + + for j := 0; j < len(childPks); j++ { + + for k := 0; k < len(fk.ReferColumnIds); k++ { + + if childTable.ColDefs[fk.ColIds[k]].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && + parentTable.ColDefs[parentPks[i].ColId].Name == childTable.ColDefs[childPks[j].ColId].Name && + parentTable.ColDefs[parentPks[i].ColId].T.Name == childTable.ColDefs[childPks[j].ColId].T.Name && + parentTable.ColDefs[parentPks[i].ColId].T.Len == childTable.ColDefs[childPks[j].ColId].T.Len && + parentTable.ColDefs[parentPks[i].ColId].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && + childTable.ColDefs[childPks[j].ColId].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name { + if parentPks[i].Order != childPks[j].Order { + return childPks[j].ColId + } + } + } + + } + + } + return "" +} + +func checkPrimaryKeyPrefix(tableId string, refTableId string, fk ddl.Foreignkey, tableInterleaveStatus *types.TableInterleaveStatus) bool { + + sessionState := session.GetSessionState() + childTable := sessionState.Conv.SpSchema[tableId] + parentTable := sessionState.Conv.SpSchema[refTableId] + childPks := sessionState.Conv.SpSchema[tableId].PrimaryKeys + parentPks := sessionState.Conv.SpSchema[refTableId].PrimaryKeys + possibleInterleave := false + + flag := false + for _, key := range parentPks { + flag = false + for _, colId := range fk.ReferColumnIds { + if key.ColId == colId { + flag = true + } + } + if !flag { + break + } + } + if flag { + possibleInterleave = true + } + + if !possibleInterleave { + removeInterleaveSuggestions(fk.ColIds, tableId) + return false + } + + childPkColIds := []string{} + for _, k := range childPks { + childPkColIds = append(childPkColIds, k.ColId) + } + + interleaved := []ddl.IndexKey{} + + for i := 0; i < len(parentPks); i++ { + + for j := 0; j < len(childPks); j++ { + + for k := 0; k < len(fk.ReferColumnIds); k++ { + + if childTable.ColDefs[fk.ColIds[k]].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && + parentTable.ColDefs[parentPks[i].ColId].Name == childTable.ColDefs[childPks[j].ColId].Name && + parentTable.ColDefs[parentPks[i].ColId].T.Name == childTable.ColDefs[childPks[j].ColId].T.Name && + parentTable.ColDefs[parentPks[i].ColId].T.Len == childTable.ColDefs[childPks[j].ColId].T.Len && + parentTable.ColDefs[parentPks[i].ColId].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && + childTable.ColDefs[childPks[j].ColId].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name { + + interleaved = append(interleaved, parentPks[i]) + } + } + + } + + } + + if len(interleaved) == len(parentPks) { + return true + } + + diff := []ddl.IndexKey{} + + if len(interleaved) == 0 { + + for i := 0; i < len(parentPks); i++ { + + for j := 0; j < len(childPks); j++ { + + if parentTable.ColDefs[parentPks[i].ColId].Name != childTable.ColDefs[childPks[j].ColId].Name || parentTable.ColDefs[parentPks[i].ColId].T.Len != childTable.ColDefs[childPks[j].ColId].T.Len { + diff = append(diff, parentPks[i]) + } + + } + } + + } + + canInterleavedOnAdd := []string{} + canInterleavedOnRename := []string{} + canInterLeaveOnChangeInColumnSize := []string{} + + fkReferColNames := []string{} + childPkColNames := []string{} + for _, colId := range fk.ReferColumnIds { + fkReferColNames = append(fkReferColNames, parentTable.ColDefs[colId].Name) + } + for _, colId := range childPkColIds { + childPkColNames = append(childPkColNames, childTable.ColDefs[colId].Name) + } + + for i := 0; i < len(diff); i++ { + + parentColIndex := utilities.IsColumnPresent(fkReferColNames, parentTable.ColDefs[diff[i].ColId].Name) + if parentColIndex == -1 { + continue + } + childColIndex := utilities.IsColumnPresent(childPkColNames, childTable.ColDefs[fk.ColIds[parentColIndex]].Name) + if childColIndex == -1 { + canInterleavedOnAdd = append(canInterleavedOnAdd, fk.ColIds[parentColIndex]) + } else { + if parentTable.ColDefs[diff[i].ColId].Name == childTable.ColDefs[fk.ColIds[parentColIndex]].Name { + canInterLeaveOnChangeInColumnSize = append(canInterLeaveOnChangeInColumnSize, fk.ColIds[parentColIndex]) + } else { + canInterleavedOnRename = append(canInterleavedOnRename, fk.ColIds[parentColIndex]) + } + + } + } + + if len(canInterLeaveOnChangeInColumnSize) > 0 { + updateInterleaveSuggestion(canInterLeaveOnChangeInColumnSize, tableId, internal.InterleavedChangeColumnSize) + } else if len(canInterleavedOnRename) > 0 { + updateInterleaveSuggestion(canInterleavedOnRename, tableId, internal.InterleavedRenameColumn) + } else if len(canInterleavedOnAdd) > 0 { + updateInterleaveSuggestion(canInterleavedOnAdd, tableId, internal.InterleavedAddColumn) + } + + return false +} + +func updateInterleaveSuggestion(colIds []string, tableId string, issue internal.SchemaIssue) { + sessionState := session.GetSessionState() + + for i := 0; i < len(colIds); i++ { + + schemaissue := []internal.SchemaIssue{} + + schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIds[i]] + + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedNotInOrder) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedAddColumn) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedRenameColumn) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedChangeColumnSize) + + schemaissue = append(schemaissue, issue) + + if sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues == nil { + + s := map[string][]internal.SchemaIssue{ + colIds[i]: schemaissue, + } + sessionState.Conv.SchemaIssues[tableId] = internal.TableIssues{ + ColumnLevelIssues: s, + } + } else { + sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIds[i]] = schemaissue + } + } +} + +func removeInterleaveSuggestions(colIds []string, tableId string) { + sessionState := session.GetSessionState() + + for i := 0; i < len(colIds); i++ { + + schemaissue := []internal.SchemaIssue{} + + schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIds[i]] + + if len(schemaissue) == 0 { + continue + } + + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedNotInOrder) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedAddColumn) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedRenameColumn) + schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedChangeColumnSize) + + if sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues == nil { + + s := map[string][]internal.SchemaIssue{ + colIds[i]: schemaissue, + } + sessionState.Conv.SchemaIssues[tableId] = internal.TableIssues{ + ColumnLevelIssues: s, + } + } else { + sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIds[i]] = schemaissue + } + + } +} + +func hasShardIdPrimaryKeyRule() (bool, bool) { + sessionState := session.GetSessionState() + for _, rule := range sessionState.Conv.Rules { + if rule.Type == constants.AddShardIdPrimaryKey { + v := rule.Data.(types.ShardIdPrimaryKey) + return true, v.AddedAtTheStart + } + } + return false, false +} + +func dropTableHelper(w http.ResponseWriter, tableId string) session.ConvWithMetadata { + sessionState := session.GetSessionState() + if sessionState.Conv == nil || sessionState.Driver == "" { + http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) + return session.ConvWithMetadata{} + } + if tableId == "" { + http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) + } + sessionState.Conv.ConvLock.Lock() + defer sessionState.Conv.ConvLock.Unlock() + spSchema := sessionState.Conv.SpSchema + issues := sessionState.Conv.SchemaIssues + syntheticPkey := sessionState.Conv.SyntheticPKeys + + //remove deleted name from usedName + usedNames := sessionState.Conv.UsedNames + delete(usedNames, strings.ToLower(sessionState.Conv.SpSchema[tableId].Name)) + for _, index := range sessionState.Conv.SpSchema[tableId].Indexes { + delete(usedNames, index.Name) + } + for _, fk := range sessionState.Conv.SpSchema[tableId].ForeignKeys { + delete(usedNames, fk.Name) + } + + delete(spSchema, tableId) + issues[tableId] = internal.TableIssues{ + TableLevelIssues: []internal.SchemaIssue{}, + ColumnLevelIssues: map[string][]internal.SchemaIssue{}, + } + delete(syntheticPkey, tableId) + + //drop reference foreign key + for tableName, spTable := range spSchema { + fks := []ddl.Foreignkey{} + for _, fk := range spTable.ForeignKeys { + if fk.ReferTableId != tableId { + fks = append(fks, fk) + } else { + delete(usedNames, fk.Name) + } + + } + spTable.ForeignKeys = fks + spSchema[tableName] = spTable + } + + //remove interleave that are interleaved on the drop table as parent + for id, spTable := range spSchema { + if spTable.ParentId == tableId { + spTable.ParentId = "" + spSchema[id] = spTable + } + } + + //remove interleavable suggestion on droping the parent table + for tableName, tableIssues := range issues { + for colName, colIssues := range tableIssues.ColumnLevelIssues { + updatedColIssues := []internal.SchemaIssue{} + for _, val := range colIssues { + if val != internal.InterleavedOrder { + updatedColIssues = append(updatedColIssues, val) + } + } + if len(updatedColIssues) == 0 { + delete(issues[tableName].ColumnLevelIssues, colName) + } else { + issues[tableName].ColumnLevelIssues[colName] = updatedColIssues + } + } + } + + sessionState.Conv.SpSchema = spSchema + sessionState.Conv.SchemaIssues = issues + sessionState.Conv.UsedNames = usedNames + + convm := session.ConvWithMetadata{ + SessionMetadata: sessionState.SessionMetadata, + Conv: *sessionState.Conv, + } + return convm +} + +func addShardIdToReferencedTableFks(tableId string, isAddedAtFirst bool) { + sessionState := session.GetSessionState() + for _, table := range sessionState.Conv.SpSchema { + for i, fk := range table.ForeignKeys { + if fk.ReferTableId == tableId { + referredTableShardIdColumn := sessionState.Conv.SpSchema[fk.ReferTableId].ShardIdColumn + if isAddedAtFirst { + fk.ColIds = append([]string{table.ShardIdColumn}, fk.ColIds...) + fk.ReferColumnIds = append([]string{referredTableShardIdColumn}, fk.ReferColumnIds...) + } else { + fk.ColIds = append(fk.ColIds, table.ShardIdColumn) + fk.ReferColumnIds = append(fk.ReferColumnIds, referredTableShardIdColumn) + } + sessionState.Conv.SpSchema[table.Id].ForeignKeys[i] = fk + } + } + } +} + +func initializeTypeMap() { + sessionState := session.GetSessionState() + var toddl common.ToDdl + // Initialize mysqlTypeMap. + toddl = mysql.InfoSchemaImpl{}.GetToDdl() + for _, srcTypeName := range []string{"bool", "boolean", "varchar", "char", "text", "tinytext", "mediumtext", "longtext", "set", "enum", "json", "bit", "binary", "varbinary", "blob", "tinyblob", "mediumblob", "longblob", "tinyint", "smallint", "mediumint", "int", "integer", "bigint", "double", "float", "numeric", "decimal", "date", "datetime", "timestamp", "time", "year", "geometrycollection", "multipoint", "multilinestring", "multipolygon", "point", "linestring", "polygon", "geometry"} { + var l []types.TypeIssue + srcType := schema.MakeType() + srcType.Name = srcTypeName + for _, spType := range []string{ddl.Bool, ddl.Bytes, ddl.Date, ddl.Float64, ddl.Int64, ddl.String, ddl.Timestamp, ddl.Numeric, ddl.JSON} { + ty, issues := toddl.ToSpannerType(sessionState.Conv, spType, srcType) + l = addTypeToList(ty.Name, spType, issues, l) + } + if srcTypeName == "tinyint" { + l = append(l, types.TypeIssue{T: ddl.Bool, Brief: "Only tinyint(1) can be converted to BOOL, for any other mods it will be converted to INT64"}) + } + ty, _ := toddl.ToSpannerType(sessionState.Conv, "", srcType) + mysqlDefaultTypeMap[srcTypeName] = ty + mysqlTypeMap[srcTypeName] = l + } + // Initialize postgresTypeMap. + toddl = postgres.InfoSchemaImpl{}.GetToDdl() + for _, srcTypeName := range []string{"bool", "boolean", "bigserial", "bpchar", "character", "bytea", "date", "float8", "double precision", "float4", "real", "int8", "bigint", "int4", "integer", "int2", "smallint", "numeric", "serial", "text", "timestamptz", "timestamp with time zone", "timestamp", "timestamp without time zone", "varchar", "character varying"} { + var l []types.TypeIssue + srcType := schema.MakeType() + srcType.Name = srcTypeName + for _, spType := range []string{ddl.Bool, ddl.Bytes, ddl.Date, ddl.Float64, ddl.Int64, ddl.String, ddl.Timestamp, ddl.Numeric, ddl.JSON} { + ty, issues := toddl.ToSpannerType(sessionState.Conv, spType, srcType) + l = addTypeToList(ty.Name, spType, issues, l) + } + ty, _ := toddl.ToSpannerType(sessionState.Conv, "", srcType) + postgresDefaultTypeMap[srcTypeName] = ty + postgresTypeMap[srcTypeName] = l + } + + // Initialize sqlserverTypeMap. + toddl = sqlserver.InfoSchemaImpl{}.GetToDdl() + for _, srcTypeName := range []string{"int", "tinyint", "smallint", "bigint", "bit", "float", "real", "numeric", "decimal", "money", "smallmoney", "char", "nchar", "varchar", "nvarchar", "text", "ntext", "date", "datetime", "datetime2", "smalldatetime", "datetimeoffset", "time", "timestamp", "rowversion", "binary", "varbinary", "image", "xml", "geography", "geometry", "uniqueidentifier", "sql_variant", "hierarchyid"} { + var l []types.TypeIssue + srcType := schema.MakeType() + srcType.Name = srcTypeName + for _, spType := range []string{ddl.Bool, ddl.Bytes, ddl.Date, ddl.Float64, ddl.Int64, ddl.String, ddl.Timestamp, ddl.Numeric, ddl.JSON} { + ty, issues := toddl.ToSpannerType(sessionState.Conv, spType, srcType) + l = addTypeToList(ty.Name, spType, issues, l) + } + ty, _ := toddl.ToSpannerType(sessionState.Conv, "", srcType) + sqlserverDefaultTypeMap[srcTypeName] = ty + sqlserverTypeMap[srcTypeName] = l + } + + // Initialize oracleTypeMap. + toddl = oracle.InfoSchemaImpl{}.GetToDdl() + for _, srcTypeName := range []string{"NUMBER", "BFILE", "BLOB", "CHAR", "CLOB", "DATE", "BINARY_DOUBLE", "BINARY_FLOAT", "FLOAT", "LONG", "RAW", "LONG RAW", "NCHAR", "NVARCHAR2", "VARCHAR", "VARCHAR2", "NCLOB", "ROWID", "UROWID", "XMLTYPE", "TIMESTAMP", "INTERVAL", "SDO_GEOMETRY"} { + var l []types.TypeIssue + srcType := schema.MakeType() + srcType.Name = srcTypeName + for _, spType := range []string{ddl.Bool, ddl.Bytes, ddl.Date, ddl.Float64, ddl.Int64, ddl.String, ddl.Timestamp, ddl.Numeric, ddl.JSON} { + ty, issues := toddl.ToSpannerType(sessionState.Conv, spType, srcType) + l = addTypeToList(ty.Name, spType, issues, l) + } + ty, _ := toddl.ToSpannerType(sessionState.Conv, "", srcType) + oracleDefaultTypeMap[srcTypeName] = ty + oracleTypeMap[srcTypeName] = l + } +} + +func addTypeToList(convertedType string, spType string, issues []internal.SchemaIssue, l []types.TypeIssue) []types.TypeIssue { + if convertedType == spType { + if len(issues) > 0 { + var briefs []string + for _, issue := range issues { + briefs = append(briefs, reports.IssueDB[issue].Brief) + } + l = append(l, types.TypeIssue{T: spType, Brief: fmt.Sprintf(strings.Join(briefs, ", "))}) + } else { + l = append(l, types.TypeIssue{T: spType}) + } + } + return l +} + +func setShardIdColumnAsPrimaryKey(isAddedAtFirst bool) { + sessionState := session.GetSessionState() + for _, table := range sessionState.Conv.SpSchema { + setShardIdColumnAsPrimaryKeyPerTable(isAddedAtFirst, table) + } +} + +func setShardIdColumnAsPrimaryKeyPerTable(isAddedAtFirst bool, table ddl.CreateTable) { + pkRequest := primarykey.PrimaryKeyRequest{ + TableId: table.Id, + Columns: []ddl.IndexKey{}, + } + increment := 0 + if isAddedAtFirst { + increment = 1 + pkRequest.Columns = append(pkRequest.Columns, ddl.IndexKey{ColId: table.ShardIdColumn, Order: 1}) + } + for index := range table.PrimaryKeys { + pk := table.PrimaryKeys[index] + pkRequest.Columns = append(pkRequest.Columns, ddl.IndexKey{ColId: pk.ColId, Order: pk.Order + increment, Desc: pk.Desc}) + } + if !isAddedAtFirst { + size := len(table.PrimaryKeys) + pkRequest.Columns = append(pkRequest.Columns, ddl.IndexKey{ColId: table.ShardIdColumn, Order: size + 1}) + } + primarykey.UpdatePrimaryKeyAndSessionFile(pkRequest) +} + +func addShardIdColumnToForeignKeys(isAddedAtFirst bool) { + sessionState := session.GetSessionState() + for _, table := range sessionState.Conv.SpSchema { + addShardIdToForeignKeyPerTable(isAddedAtFirst, table) + } +} + +func addShardIdToForeignKeyPerTable(isAddedAtFirst bool, table ddl.CreateTable) { + sessionState := session.GetSessionState() + for i, fk := range table.ForeignKeys { + referredTableShardIdColumn := sessionState.Conv.SpSchema[fk.ReferTableId].ShardIdColumn + if isAddedAtFirst { + fk.ColIds = append([]string{table.ShardIdColumn}, fk.ColIds...) + fk.ReferColumnIds = append([]string{referredTableShardIdColumn}, fk.ReferColumnIds...) + } else { + fk.ColIds = append(fk.ColIds, table.ShardIdColumn) + fk.ReferColumnIds = append(fk.ReferColumnIds, referredTableShardIdColumn) + } + sessionState.Conv.SpSchema[table.Id].ForeignKeys[i] = fk + } +} diff --git a/webv2/web_test.go b/webv2/api/schema_test.go similarity index 75% rename from webv2/web_test.go rename to webv2/api/schema_test.go index ea1629eaae..1486d86f0c 100644 --- a/webv2/web_test.go +++ b/webv2/api/schema_test.go @@ -1,18 +1,4 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package webv2 +package api_test import ( "bytes" @@ -30,7 +16,9 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/proto/migration" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/api" "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/session" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/types" "github.com/stretchr/testify/assert" "go.uber.org/zap" ) @@ -40,12 +28,15 @@ func init() { } func TestGetTypeMapNoDriver(t *testing.T) { + sessionState := session.GetSessionState() + sessionState.Driver = "" + sessionState.Conv = nil req, err := http.NewRequest("GET", "/typemap", nil) if err != nil { t.Fatal(err) } rr := httptest.NewRecorder() - handler := http.HandlerFunc(getTypeMap) + handler := http.HandlerFunc(api.GetTypeMap) handler.ServeHTTP(rr, req) status := rr.Code @@ -67,15 +58,15 @@ func TestGetTypeMapPostgres(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - handler := http.HandlerFunc(getTypeMap) + handler := http.HandlerFunc(api.GetTypeMap) handler.ServeHTTP(rr, req) - var typemap map[string][]typeIssue + var typemap map[string][]types.TypeIssue json.Unmarshal(rr.Body.Bytes(), &typemap) if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } - expectedTypemap := map[string][]typeIssue{ + expectedTypemap := map[string][]types.TypeIssue{ "bool": { {T: ddl.Bool, DisplayT: ddl.Bool}, {T: ddl.Int64, Brief: reports.IssueDB[internal.Widened].Brief, DisplayT: ddl.Int64}, @@ -138,7 +129,7 @@ func TestGetConversionPostgres(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - handler := http.HandlerFunc(getConversionRate) + handler := http.HandlerFunc(api.GetConversionRate) handler.ServeHTTP(rr, req) var result map[string]string json.Unmarshal(rr.Body.Bytes(), &result) @@ -162,15 +153,15 @@ func TestGetTypeMapMySQL(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - handler := http.HandlerFunc(getTypeMap) + handler := http.HandlerFunc(api.GetTypeMap) handler.ServeHTTP(rr, req) - var typemap map[string][]typeIssue + var typemap map[string][]types.TypeIssue json.Unmarshal(rr.Body.Bytes(), &typemap) if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } - expectedTypemap := map[string][]typeIssue{ + expectedTypemap := map[string][]types.TypeIssue{ "bool": { {T: ddl.Bool, DisplayT: ddl.Bool}, {T: ddl.Int64, Brief: reports.IssueDB[internal.Widened].Brief, DisplayT: ddl.Int64}, @@ -235,7 +226,7 @@ func TestGetConversionMySQL(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - handler := http.HandlerFunc(getConversionRate) + handler := http.HandlerFunc(api.GetConversionRate) handler.ServeHTTP(rr, req) var result map[string]string json.Unmarshal(rr.Body.Bytes(), &result) @@ -292,7 +283,7 @@ func TestGetDDL(t *testing.T) { } req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(getDDL) + handler := http.HandlerFunc(api.GetDDL) handler.ServeHTTP(rr, req) var res map[string]string json.Unmarshal(rr.Body.Bytes(), &res) @@ -304,441 +295,245 @@ func TestGetDDL(t *testing.T) { assert.Equal(t, tc.expectedDDL, res) } } - } -// todo update SetParentTable with case III suggest interleve table column. -func TestSetParentTable(t *testing.T) { - tests := []struct { - name string - ct *internal.Conv - table string - statusCode int64 - expectedResponse *TableInterleaveStatus - expectedFKs []ddl.Foreignkey - parentTable string +func TestDropForeignKey(t *testing.T) { + tc := []struct { + name string + table string + input interface{} + statusCode int64 + conv *internal.Conv + expectedConv *internal.Conv }{ { - name: "no conv provided", - statusCode: http.StatusNotFound, - }, - { - name: "no table name provided", - statusCode: http.StatusBadRequest, - ct: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{"t1": { - Name: "t1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": ddl.ColumnDef{Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": ddl.ColumnDef{Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": ddl.ColumnDef{Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}}, - PrimaryKeys: []ddl.IndexKey{ddl.IndexKey{ColId: "c1", Desc: false}}, - ForeignKeys: []ddl.Foreignkey{ddl.Foreignkey{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t1", ReferColumnIds: []string{"c1"}}, - ddl.Foreignkey{Name: "fk2", ColIds: []string{"c3"}, ReferTableId: "t2", ReferColumnIds: []string{"c2"}}}, - }}, - SchemaIssues: map[string]internal.TableIssues{ + name: "Test drop valid FK success", + table: "t1", + input: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c2"}, ReferTableId: "reft1", ReferColumnIds: []string{"ref_c1"}, Id: "f1"}, + {Name: "", ColIds: []string{}, ReferTableId: "", ReferColumnIds: []string{}, Id: "f2"}}, + statusCode: http.StatusOK, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ "t1": { - ColumnLevelIssues: make(map[string][]internal.SchemaIssue), - }, - }, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c2"}, ReferTableId: "reft1", ReferColumnIds: []string{"ref_c1"}, Id: "f1"}, + {Name: "fk2", ColIds: []string{"c3", "c4"}, ReferTableId: "reft2", ReferColumnIds: []string{"ref_c2", "ref_c3"}, Id: "f2"}}, + }}, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, }, + expectedConv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c2"}, ReferTableId: "reft1", ReferColumnIds: []string{"ref_c1"}, Id: "f1"}}, + }}, + }, }, + } + for _, tc := range tc { + sessionState := session.GetSessionState() + + sessionState.Driver = constants.MYSQL + sessionState.Conv = tc.conv + + inputBytes, err := json.Marshal(tc.input) + if err != nil { + t.Fatal(err) + } + buffer := bytes.NewBuffer(inputBytes) + + req, err := http.NewRequest("POST", "/update/fks?table="+tc.table, buffer) + + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + handler := http.HandlerFunc(api.UpdateForeignKeys) + handler.ServeHTTP(rr, req) + var res *internal.Conv + json.Unmarshal(rr.Body.Bytes(), &res) + if status := rr.Code; int64(status) != tc.statusCode { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tc.statusCode) + } + if tc.statusCode == http.StatusOK { + assert.Equal(t, tc.expectedConv, res) + } + } +} + +func TestUpdateIndexes(t *testing.T) { + tc := []struct { + name string + tableId string + input []ddl.CreateIndex + statusCode int64 + conv *internal.Conv + expectedConv *internal.Conv + }{ { - name: "table with synthetic PK", - ct: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{"t1": { - Name: "t1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - "synth_id": {Name: "synth_id", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + name: "Add a valid index key", + tableId: "t1", + input: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, + statusCode: http.StatusOK, + conv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, + }}, + SrcSchema: map[string]schema.Table{ + "t1": { + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}}}}, }, - PrimaryKeys: []ddl.IndexKey{{ColId: "synth_id", Desc: false}}, - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t1", ReferColumnIds: []string{"c1"}}, - {Name: "fk2", ColIds: []string{"c3"}, ReferTableId: "t2", ReferColumnIds: []string{"c2"}}}, - }}, - SyntheticPKeys: map[string]internal.SyntheticPKey{"t1": internal.SyntheticPKey{ColId: "synth_id"}}, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, - SchemaIssues: map[string]internal.TableIssues{ + UsedNames: map[string]bool{"t1": true, "idx": true}, + }, + expectedConv: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ "t1": { - ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, + }}, + SrcSchema: map[string]schema.Table{ + "t1": { + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}}}}, }, }, }, - table: "t1", - statusCode: http.StatusOK, - expectedResponse: &TableInterleaveStatus{Possible: false, Comment: "Has synthetic pk"}, }, { - name: "no valid prefix 1", - ct: &internal.Conv{ + name: "Change the order of two index keys", + tableId: "t1", + input: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 2}, {ColId: "c3", Desc: true, Order: 1}}}}, + statusCode: http.StatusOK, + conv: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "t1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}}, - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c1"}}}, - }, - "t2": { - Name: "t2", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - "synth_id": {Name: "synth_id", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "synth_id", Desc: false}}, - }, - }, - SyntheticPKeys: map[string]internal.SyntheticPKey{"t2": internal.SyntheticPKey{ColId: "synth_id"}}, - SchemaIssues: map[string]internal.TableIssues{ + Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, + }}, + SrcSchema: map[string]schema.Table{ "t1": { - ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, }, }, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, + UsedNames: map[string]bool{"t1": true, "idx": true}, }, - table: "t1", - statusCode: http.StatusOK, - expectedResponse: &TableInterleaveStatus{Possible: false, Comment: "No valid prefix"}, - expectedFKs: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c1"}}}, - }, - { - name: "no valid prefix 2", - ct: &internal.Conv{ + expectedConv: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "t1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}}, - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c1"}}}, - }, - "t2": { - Name: "t2", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - }, - }, - SchemaIssues: map[string]internal.TableIssues{ + Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 2}, {ColId: "c3", Desc: true, Order: 1}}}}, + }}, + SrcSchema: map[string]schema.Table{ "t1": { - ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 2}, {ColId: "c3", Desc: true, Order: 1}}}}, }, }, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, }, - table: "t1", - statusCode: http.StatusOK, - expectedResponse: &TableInterleaveStatus{Possible: false, Parent: "", Comment: "No valid prefix"}, - expectedFKs: []ddl.Foreignkey{{}}, }, { - name: "no valid prefix 3", - ct: &internal.Conv{ + name: "Delete an index key column", + tableId: "t1", + input: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, + statusCode: http.StatusOK, + conv: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "t1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}, {ColId: "c2", Desc: false}}, - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c3"}, ReferTableId: "t2", ReferColumnIds: []string{"c3"}}}, - }, - "t2": { - Name: "t2", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}}, - }, - }, - SchemaIssues: map[string]internal.TableIssues{ + Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, + }}, + SrcSchema: map[string]schema.Table{ "t1": { - ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, }, }, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, + UsedNames: map[string]bool{"t1": true, "idx": true}, }, - table: "t1", - statusCode: http.StatusOK, - expectedResponse: &TableInterleaveStatus{Possible: false, Comment: "No valid prefix"}, - expectedFKs: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c3"}, ReferTableId: "t2", ReferColumnIds: []string{"c3"}}}, - }, - { - name: "interleave possible on changing primary key order", - ct: &internal.Conv{ + expectedConv: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "t1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 2}, {ColId: "c2", Desc: false, Order: 2}}, - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}}}, - }, - "t2": { - Name: "t2", - ColIds: []string{"cc4", "c5", "c6"}, - ColDefs: map[string]ddl.ColumnDef{"c4": {Name: "d", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c5": {Name: "e", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c6": {Name: "f", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - }, - }, - SchemaIssues: map[string]internal.TableIssues{ + Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, + }}, + SrcSchema: map[string]schema.Table{ "t1": { - ColumnLevelIssues: map[string][]internal.SchemaIssue{ - "c1": {internal.InterleavedNotInOrder}, - }, + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, }, }, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, }, - table: "t1", - statusCode: http.StatusOK, - expectedResponse: &TableInterleaveStatus{Possible: false, Parent: "", Comment: "No valid prefix"}, - expectedFKs: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}}}, - parentTable: "", }, { - name: "successful interleave", - ct: &internal.Conv{ + name: "Test rename indexes name", + tableId: "t1", + input: []ddl.CreateIndex{{Name: "idx_new", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, + statusCode: http.StatusOK, + conv: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "t1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c1"}}}, - }, - "t2": { - Name: "t2", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}}, - }, - }, - SchemaIssues: map[string]internal.TableIssues{ + Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, + }}, + SrcSchema: map[string]schema.Table{ "t1": { - ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}}}}, }, }, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, + UsedNames: map[string]bool{"t1": true, "idx": true}, }, - table: "t1", - statusCode: http.StatusOK, - expectedResponse: &TableInterleaveStatus{Possible: true, Parent: "t2"}, - expectedFKs: []ddl.Foreignkey{}, - parentTable: "t2", - }, - { - name: "successful interleave with same primary key", - ct: &internal.Conv{ + expectedConv: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "t1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1", "c2"}, ReferTableId: "t2", ReferColumnIds: []string{"c1", "c2"}}}, - }, - "t2": { - Name: "t2", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - }, - }, - SchemaIssues: map[string]internal.TableIssues{ + Indexes: []ddl.CreateIndex{{Name: "idx_new", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, + }}, + SrcSchema: map[string]schema.Table{ "t1": { - ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}}}}, }, }, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, }, - table: "t1", - statusCode: http.StatusOK, - expectedResponse: &TableInterleaveStatus{Possible: true, Parent: "t2"}, - expectedFKs: []ddl.Foreignkey{}, - parentTable: "t2", }, { - name: "successful interleave with multiple fks refering multiple tables", - ct: &internal.Conv{ + name: "Two Index key columns can not have same order", + tableId: "t1", + input: []ddl.CreateIndex{{Name: "idx_new", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 1}}}}, + statusCode: http.StatusBadRequest, + conv: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "t1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - ForeignKeys: []ddl.Foreignkey{ - {Name: "fk1", ColIds: []string{"c3"}, ReferTableId: "t3", ReferColumnIds: []string{"c3"}}, - {Name: "fk1", ColIds: []string{"c1", "c2"}, ReferTableId: "t2", ReferColumnIds: []string{"c1", "c2"}}}, - }, - "t2": { - Name: "t2", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - }, - "t3": { - Name: "t3", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c3", Desc: false, Order: 1}}, - }, - }, - SchemaIssues: map[string]internal.TableIssues{ + Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, + }}, + SrcSchema: map[string]schema.Table{ "t1": { - ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, }, }, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, - }, - table: "t1", - statusCode: http.StatusOK, - expectedResponse: &TableInterleaveStatus{Possible: true, Parent: "t2"}, - expectedFKs: []ddl.Foreignkey{ddl.Foreignkey{Name: "fk1", ColIds: []string{"c1", "c2"}, ReferTableId: "t2", ReferColumnIds: []string{"c1", "c2"}, Id: ""}}, - parentTable: "t2", - }, - } - for _, tc := range tests { - sessionState := session.GetSessionState() - - sessionState.Driver = constants.MYSQL - sessionState.Conv = tc.ct - update := true - req, err := http.NewRequest("GET", fmt.Sprintf("/setparent?table=%s&update=%v", tc.table, update), nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - handler := http.HandlerFunc(setParentTable) - handler.ServeHTTP(rr, req) - - type ParentTableSetResponse struct { - TableInterleaveStatus *TableInterleaveStatus `json:"tableInterleaveStatus"` - SessionState *internal.Conv `json:"sessionState"` - } - - var res *TableInterleaveStatus - - if update { - parentTableResponse := &ParentTableSetResponse{} - json.Unmarshal(rr.Body.Bytes(), parentTableResponse) - res = parentTableResponse.TableInterleaveStatus - } else { - res = &TableInterleaveStatus{} - json.Unmarshal(rr.Body.Bytes(), res) - } - - if status := rr.Code; int64(status) != tc.statusCode { - t.Errorf("%s\nhandler returned wrong status code: got %v want %v", - tc.name, status, tc.statusCode) - } - if tc.statusCode == http.StatusOK { - assert.Equal(t, tc.expectedResponse, res, tc.name) - } - if tc.parentTable != "" { - assert.Equal(t, tc.parentTable, sessionState.Conv.SpSchema[tc.table].ParentId, tc.name) - assert.Equal(t, tc.expectedFKs, sessionState.Conv.SpSchema[tc.table].ForeignKeys, tc.name) - } - } -} - -func TestDropForeignKey(t *testing.T) { - tc := []struct { - name string - table string - input interface{} - statusCode int64 - conv *internal.Conv - expectedConv *internal.Conv - }{ - { - name: "Test drop valid FK success", - table: "t1", - input: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c2"}, ReferTableId: "reft1", ReferColumnIds: []string{"ref_c1"}, Id: "f1"}, - {Name: "", ColIds: []string{}, ReferTableId: "", ReferColumnIds: []string{}, Id: "f2"}}, - statusCode: http.StatusOK, - conv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c2"}, ReferTableId: "reft1", ReferColumnIds: []string{"ref_c1"}, Id: "f1"}, - {Name: "fk2", ColIds: []string{"c3", "c4"}, ReferTableId: "reft2", ReferColumnIds: []string{"ref_c2", "ref_c3"}, Id: "f2"}}, - }}, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, + UsedNames: map[string]bool{"t1": true, "idx": true}, }, expectedConv: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c2"}, ReferTableId: "reft1", ReferColumnIds: []string{"ref_c1"}, Id: "f1"}}, + Indexes: []ddl.CreateIndex{{Name: "idx_new", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, }}, + SrcSchema: map[string]schema.Table{ + "t1": { + Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, + }, + }, }, }, } + for _, tc := range tc { sessionState := session.GetSessionState() @@ -751,20 +546,19 @@ func TestDropForeignKey(t *testing.T) { } buffer := bytes.NewBuffer(inputBytes) - req, err := http.NewRequest("POST", "/update/fks?table="+tc.table, buffer) - + req, err := http.NewRequest("POST", "/update/indexes?table="+tc.tableId, buffer) if err != nil { t.Fatal(err) } req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(updateForeignKeys) + handler := http.HandlerFunc(api.UpdateIndexes) handler.ServeHTTP(rr, req) var res *internal.Conv json.Unmarshal(rr.Body.Bytes(), &res) if status := rr.Code; int64(status) != tc.statusCode { - t.Errorf("handler returned wrong status code: got %v want %v", - status, tc.statusCode) + t.Errorf("%s : handler returned wrong status code: got %v want %v", + tc.name, status, tc.statusCode) } if tc.statusCode == http.StatusOK { assert.Equal(t, tc.expectedConv, res) @@ -772,229 +566,26 @@ func TestDropForeignKey(t *testing.T) { } } -func TestUpdateIndexes(t *testing.T) { +func TestRenameIndexes(t *testing.T) { tc := []struct { name string - tableId string - input []ddl.CreateIndex + table string + input interface{} statusCode int64 conv *internal.Conv expectedConv *internal.Conv }{ { - name: "Add a valid index key", - tableId: "t1", - input: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, + name: "Test rename indexes", + table: "t1", + input: map[string]string{ + "i1": "idx_new", + }, statusCode: http.StatusOK, conv: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}}}}, - }, - }, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, - UsedNames: map[string]bool{"t1": true, "idx": true}, - }, - expectedConv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}}}}, - }, - }, - }, - }, - { - name: "Change the order of two index keys", - tableId: "t1", - input: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 2}, {ColId: "c3", Desc: true, Order: 1}}}}, - statusCode: http.StatusOK, - conv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }, - }, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, - UsedNames: map[string]bool{"t1": true, "idx": true}, - }, - expectedConv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 2}, {ColId: "c3", Desc: true, Order: 1}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 2}, {ColId: "c3", Desc: true, Order: 1}}}}, - }, - }, - }, - }, - { - name: "Delete an index key column", - tableId: "t1", - input: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, - statusCode: http.StatusOK, - conv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }, - }, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, - UsedNames: map[string]bool{"t1": true, "idx": true}, - }, - expectedConv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }, - }, - }, - }, - { - name: "Test rename indexes name", - tableId: "t1", - input: []ddl.CreateIndex{{Name: "idx_new", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, - statusCode: http.StatusOK, - conv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}}}}, - }, - }, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, - UsedNames: map[string]bool{"t1": true, "idx": true}, - }, - expectedConv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx_new", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}}}}, - }, - }, - }, - }, - { - name: "Two Index key columns can not have same order", - tableId: "t1", - input: []ddl.CreateIndex{{Name: "idx_new", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 1}}}}, - statusCode: http.StatusBadRequest, - conv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }, - }, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, - UsedNames: map[string]bool{"t1": true, "idx": true}, - }, - expectedConv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx_new", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }}, - SrcSchema: map[string]schema.Table{ - "t1": { - Indexes: []schema.Index{{Name: "idx", Id: "i1", Keys: []schema.Key{{ColId: "c2", Desc: false, Order: 1}, {ColId: "c3", Desc: true, Order: 2}}}}, - }, - }, - }, - }, - } - - for _, tc := range tc { - sessionState := session.GetSessionState() - - sessionState.Driver = constants.MYSQL - sessionState.Conv = tc.conv - - inputBytes, err := json.Marshal(tc.input) - if err != nil { - t.Fatal(err) - } - buffer := bytes.NewBuffer(inputBytes) - - req, err := http.NewRequest("POST", "/update/indexes?table="+tc.tableId, buffer) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - handler := http.HandlerFunc(updateIndexes) - handler.ServeHTTP(rr, req) - var res *internal.Conv - json.Unmarshal(rr.Body.Bytes(), &res) - if status := rr.Code; int64(status) != tc.statusCode { - t.Errorf("%s : handler returned wrong status code: got %v want %v", - tc.name, status, tc.statusCode) - } - if tc.statusCode == http.StatusOK { - assert.Equal(t, tc.expectedConv, res) - } - } -} - -func TestRenameIndexes(t *testing.T) { - tc := []struct { - name string - table string - input interface{} - statusCode int64 - conv *internal.Conv - expectedConv *internal.Conv - }{ - { - name: "Test rename indexes", - table: "t1", - input: map[string]string{ - "i1": "idx_new", - }, - statusCode: http.StatusOK, - conv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}}, + Indexes: []ddl.CreateIndex{{Name: "idx", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}}, }}, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), @@ -1230,7 +821,7 @@ func TestRenameIndexes(t *testing.T) { } req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(renameIndexes) + handler := http.HandlerFunc(api.RenameIndexes) handler.ServeHTTP(rr, req) var res *internal.Conv json.Unmarshal(rr.Body.Bytes(), &res) @@ -1489,7 +1080,7 @@ func TestRenameForeignKeys(t *testing.T) { } req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(updateForeignKeys) + handler := http.HandlerFunc(api.UpdateForeignKeys) handler.ServeHTTP(rr, req) var res *internal.Conv json.Unmarshal(rr.Body.Bytes(), &res) @@ -1617,7 +1208,7 @@ func TestDropSecondaryIndex(t *testing.T) { } req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(dropSecondaryIndex) + handler := http.HandlerFunc(api.DropSecondaryIndex) handler.ServeHTTP(rr, req) var res *internal.Conv json.Unmarshal(rr.Body.Bytes(), &res) @@ -1738,7 +1329,7 @@ func TestRestoreSecondaryIndex(t *testing.T) { } req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(restoreSecondaryIndex) + handler := http.HandlerFunc(api.RestoreSecondaryIndex) handler.ServeHTTP(rr, req) var res *internal.Conv json.Unmarshal(rr.Body.Bytes(), &res) @@ -1824,7 +1415,7 @@ func TestDropTable(t *testing.T) { req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(dropTable) + handler := http.HandlerFunc(api.DropTable) handler.ServeHTTP(rr, req) res := &internal.Conv{} @@ -1918,7 +1509,7 @@ func TestRestoreTable(t *testing.T) { req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(restoreTable) + handler := http.HandlerFunc(api.RestoreTable) handler.ServeHTTP(rr, req) res := &internal.Conv{} @@ -1956,781 +1547,421 @@ func TestRestoreTable(t *testing.T) { } -func TestRemoveParentTable(t *testing.T) { - tc := []struct { +// todo update SetParentTable with case III suggest interleve table column. +func TestSetParentTable(t *testing.T) { + tests := []struct { name string - tableId string + ct *internal.Conv + table string statusCode int64 - conv *internal.Conv - expectedSpSchema ddl.Schema + expectedResponse *types.TableInterleaveStatus + expectedFKs []ddl.Foreignkey + parentTable string }{ { - name: "Remove interleaving with valid table id", - tableId: "t1", - statusCode: http.StatusOK, - conv: &internal.Conv{ + name: "no conv provided", + statusCode: http.StatusNotFound, + }, + { + name: "no table name provided", + statusCode: http.StatusBadRequest, + ct: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{"t1": { + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": ddl.ColumnDef{Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": ddl.ColumnDef{Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": ddl.ColumnDef{Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}}, + PrimaryKeys: []ddl.IndexKey{ddl.IndexKey{ColId: "c1", Desc: false}}, + ForeignKeys: []ddl.Foreignkey{ddl.Foreignkey{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t1", ReferColumnIds: []string{"c1"}}, + ddl.Foreignkey{Name: "fk2", ColIds: []string{"c3"}, ReferTableId: "t2", ReferColumnIds: []string{"c2"}}}, + }}, SchemaIssues: map[string]internal.TableIssues{ - "t1": {}, - "t2": {}, - }, - SrcSchema: map[string]schema.Table{ "t1": { - Name: "table1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]schema.Column{ - "c1": {Name: "a", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c1"}, - "c2": {Name: "b", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c2"}, - "c3": {Name: "c", Type: schema.Type{Name: "varchar"}, NotNull: false, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c3"}, - }, - PrimaryKeys: []schema.Key{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - ForeignKeys: []schema.ForeignKey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}, Id: "f1"}}, - Id: "t1", - }, - - "t2": { - Name: "table2", - ColIds: []string{"c4", "c5"}, - ColDefs: map[string]schema.Column{ - "c4": {Name: "a", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: true, AutoIncrement: false}, Id: "c4"}, - "c5": {Name: "d", Type: schema.Type{Name: "varchar"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c5"}, - }, - Id: "t2", - PrimaryKeys: []schema.Key{{ColId: "c4", Desc: false, Order: 1}}, + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), }, }, - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Name: "table1", - ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{ - "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - Id: "t1", - ParentId: "t2", - }, - "t2": { - Name: "table2", - ColIds: []string{"c4", "c5"}, - ColDefs: map[string]ddl.ColumnDef{ - "c4": {Name: "a", Id: "c4", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c5": {Name: "d", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c4", Desc: false, Order: 1}}, - Id: "t2", - }}, Audit: internal.Audit{ - MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, - UsedNames: map[string]bool{"table1": true, "table2": true}, }, - expectedSpSchema: ddl.Schema{ - "t1": { - Name: "table1", + }, + { + name: "table with synthetic PK", + ct: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{"t1": { + Name: "t1", ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{ - "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + "synth_id": {Name: "synth_id", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}, Id: "f1"}}, - Id: "t1", - ParentId: "", + PrimaryKeys: []ddl.IndexKey{{ColId: "synth_id", Desc: false}}, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t1", ReferColumnIds: []string{"c1"}}, + {Name: "fk2", ColIds: []string{"c3"}, ReferTableId: "t2", ReferColumnIds: []string{"c2"}}}, + }}, + SyntheticPKeys: map[string]internal.SyntheticPKey{"t1": internal.SyntheticPKey{ColId: "synth_id"}}, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, - "t2": { - Name: "table2", - ColIds: []string{"c4", "c5"}, - ColDefs: map[string]ddl.ColumnDef{ - "c4": {Name: "a", Id: "c4", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c5": {Name: "d", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + SchemaIssues: map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c4", Desc: false, Order: 1}}, - Id: "t2", }, }, + table: "t1", + statusCode: http.StatusOK, + expectedResponse: &types.TableInterleaveStatus{Possible: false, Comment: "Has synthetic pk"}, }, - - {name: "Remove interleaving with invalid table id", - tableId: "A", - statusCode: http.StatusBadRequest, - conv: &internal.Conv{ - SchemaIssues: map[string]internal.TableIssues{ - "t1": {}, - "t2": {}, - }, - SrcSchema: map[string]schema.Table{ + { + name: "no valid prefix 1", + ct: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "table1", + Name: "t1", ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]schema.Column{ - "c1": {Name: "a", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c1"}, - "c2": {Name: "b", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c2"}, - "c3": {Name: "c", Type: schema.Type{Name: "varchar"}, NotNull: false, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, }, - PrimaryKeys: []schema.Key{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - ForeignKeys: []schema.ForeignKey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}, Id: "f1"}}, - Id: "t1", + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}}, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c1"}}}, }, - "t2": { - Name: "table2", - ColIds: []string{"c4", "c5"}, - ColDefs: map[string]schema.Column{ - "c4": {Name: "a", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: true, AutoIncrement: false}, Id: "c4"}, - "c5": {Name: "d", Type: schema.Type{Name: "varchar"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c5"}, + Name: "t2", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + "synth_id": {Name: "synth_id", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, }, - Id: "t2", - PrimaryKeys: []schema.Key{{ColId: "c4", Desc: false, Order: 1}}, + PrimaryKeys: []ddl.IndexKey{{ColId: "synth_id", Desc: false}}, + }, + }, + SyntheticPKeys: map[string]internal.SyntheticPKey{"t2": internal.SyntheticPKey{ColId: "synth_id"}}, + SchemaIssues: map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), }, }, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + }, + }, + table: "t1", + statusCode: http.StatusOK, + expectedResponse: &types.TableInterleaveStatus{Possible: false, Comment: "No valid prefix"}, + expectedFKs: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c1"}}}, + }, + { + name: "no valid prefix 2", + ct: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "table1", + Name: "t1", ColIds: []string{"c1", "c2", "c3"}, - ColDefs: map[string]ddl.ColumnDef{ - "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, - Id: "t1", - ParentId: "t2", + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}}, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c1"}}}, }, "t2": { - Name: "table2", - ColIds: []string{"c4", "c5"}, - ColDefs: map[string]ddl.ColumnDef{ - "c4": {Name: "a", Id: "c4", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c5": {Name: "d", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + Name: "t2", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c4", Desc: false, Order: 1}}, - Id: "t2", - }}, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + }, + }, + SchemaIssues: map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + }, + }, Audit: internal.Audit{ - MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, - UsedNames: map[string]bool{"table1": true, "table2": true}, }, - expectedSpSchema: ddl.Schema{}, + table: "t1", + statusCode: http.StatusOK, + expectedResponse: &types.TableInterleaveStatus{Possible: false, Parent: "", Comment: "No valid prefix"}, + expectedFKs: []ddl.Foreignkey{{}}, }, - } - - for _, tc := range tc { - sessionState := session.GetSessionState() - sessionState.Driver = constants.MYSQL - - sessionState.Conv = tc.conv - payload := `{}` - req, err := http.NewRequest("POST", "/drop/removeParent?tableId="+tc.tableId, strings.NewReader(payload)) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - handler := http.HandlerFunc(removeParentTable) - handler.ServeHTTP(rr, req) - var res *internal.Conv - json.Unmarshal(rr.Body.Bytes(), &res) - if status := rr.Code; int64(status) != tc.statusCode { - t.Errorf("handler returned wrong status code: got %v want %v", - status, tc.statusCode) - } - if tc.statusCode == http.StatusOK { - assert.Equal(t, tc.expectedSpSchema, res.SpSchema) - } - } -} - -func TestApplyRule(t *testing.T) { - tcAddIndex := []struct { - name string - input internal.Rule - statusCode int64 - conv *internal.Conv - expectedConv *internal.Conv - }{ { - name: "Add Index with unique name", - input: internal.Rule{ - Name: "rule-index1", - ObjectType: "Table", - AssociatedObjects: "t1", - Enabled: true, - Type: constants.AddIndex, - Data: ddl.CreateIndex{ - Name: "idx3", - TableId: "t1", - Unique: false, - Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}, - }, - }, - statusCode: http.StatusOK, - conv: &internal.Conv{ + name: "no valid prefix 3", + ct: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{ - {Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, - {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}}, - }}, + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}, {ColId: "c2", Desc: false}}, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c3"}, ReferTableId: "t2", ReferColumnIds: []string{"c3"}}}, + }, + "t2": { + Name: "t2", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}}, + }, + }, + SchemaIssues: map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, - UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, - }, - expectedConv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{ - {Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, - {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, - {Id: "i1", Name: "idx3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, - }, - }}, }, + table: "t1", + statusCode: http.StatusOK, + expectedResponse: &types.TableInterleaveStatus{Possible: false, Comment: "No valid prefix"}, + expectedFKs: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c3"}, ReferTableId: "t2", ReferColumnIds: []string{"c3"}}}, }, { - name: "New name conflicts with an existing table", - input: internal.Rule{ - Name: "rule-index1", - ObjectType: "Table", - AssociatedObjects: "t1", - Enabled: true, - Type: constants.AddIndex, - Data: map[string]interface{}{ - "Name": "table1", - "TableId": "t1", - "Unique": false, - "Keys": []interface{}{map[string]interface{}{"ColId": "c2", "Desc": false}}, - }, - }, - statusCode: http.StatusInternalServerError, - conv: &internal.Conv{ + name: "interleave possible on changing primary key order", + ct: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{{Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, - {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}}, - }}, + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 2}, {ColId: "c2", Desc: false, Order: 2}}, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}}}, + }, + "t2": { + Name: "t2", + ColIds: []string{"cc4", "c5", "c6"}, + ColDefs: map[string]ddl.ColumnDef{"c4": {Name: "d", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c5": {Name: "e", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c6": {Name: "f", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + }, + }, + SchemaIssues: map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: map[string][]internal.SchemaIssue{ + "c1": {internal.InterleavedNotInOrder}, + }, + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, - UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, }, + table: "t1", + statusCode: http.StatusOK, + expectedResponse: &types.TableInterleaveStatus{Possible: false, Parent: "", Comment: "No valid prefix"}, + expectedFKs: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}}}, + parentTable: "", }, { - name: "New name conflicts with an existing index", - input: internal.Rule{ - Name: "rule-index1", - ObjectType: "Table", - AssociatedObjects: "t1", - Enabled: true, - Type: constants.AddIndex, - Data: map[string]interface{}{ - "Name": "idx2", - "TableId": "t1", - "Unique": false, - "Keys": []interface{}{map[string]interface{}{"ColId": "c2", "Desc": false}}, - }, - }, - statusCode: http.StatusInternalServerError, - conv: &internal.Conv{ + name: "successful interleave", + ct: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{{Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, - {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}}, - }}, + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c1"}}}, + }, + "t2": { + Name: "t2", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}}, + }, + }, + SchemaIssues: map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, - UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, }, + table: "t1", + statusCode: http.StatusOK, + expectedResponse: &types.TableInterleaveStatus{Possible: true, Parent: "t2"}, + expectedFKs: []ddl.Foreignkey{}, + parentTable: "t2", }, { - name: "Invalid input", - input: internal.Rule{ - Name: "rule-index1", - ObjectType: "Table", - AssociatedObjects: "t1", - Enabled: true, - Type: constants.AddIndex, - Data: []string{"test1"}, - }, - statusCode: http.StatusInternalServerError, - conv: &internal.Conv{ + name: "successful interleave with same primary key", + ct: &internal.Conv{ SpSchema: map[string]ddl.CreateTable{ "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{{Name: "idx1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false}}}, - {Name: "idx2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}}, - }}, + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1", "c2"}, ReferTableId: "t2", ReferColumnIds: []string{"c1", "c2"}}}, + }, + "t2": { + Name: "t2", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + }, + }, + SchemaIssues: map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + }, + }, Audit: internal.Audit{ MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), }, - UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, }, + table: "t1", + statusCode: http.StatusOK, + expectedResponse: &types.TableInterleaveStatus{Possible: true, Parent: "t2"}, + expectedFKs: []ddl.Foreignkey{}, + parentTable: "t2", }, - } - for _, tc := range tcAddIndex { - sessionState := session.GetSessionState() - - sessionState.Driver = constants.MYSQL - sessionState.Conv = tc.conv - - inputBytes, err := json.Marshal(tc.input) - if err != nil { - t.Fatal(err) - } - buffer := bytes.NewBuffer(inputBytes) - - req, err := http.NewRequest("POST", "/applyrule", buffer) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - handler := http.HandlerFunc(applyRule) - handler.ServeHTTP(rr, req) - var res *internal.Conv - json.Unmarshal(rr.Body.Bytes(), &res) - if status := rr.Code; int64(status) != tc.statusCode { - t.Errorf("%s : handler returned wrong status code: got %v want %v", - tc.name, status, tc.statusCode) - } - if tc.statusCode == http.StatusOK { - tc.expectedConv.Rules = internal.MakeConv().Rules - tc.expectedConv.Rules = append(tc.expectedConv.Rules, tc.input) - - // Marshall and unmarshall the data field of rule with its proper type i.e ddl.CreateIndex. - // Else unmarshalling data field of rule as interface convert int to float64. - // In this particular case, order of index-key would be unmarshall to float64 instead of int. - dataBytes, err := json.Marshal(res.Rules[0].Data) - assert.Equal(t, err, nil) - var data ddl.CreateIndex - json.Unmarshal(dataBytes, &data) - - // Removing random ids before comparison. - addedRule := res.Rules[0] - data.Id = "" - addedRule.Data = data - addedRule.Id = "" - res.Rules[0] = addedRule - - assert.Equal(t, tc.expectedConv, res) - } - } - - tcSetGlobalDataTypePostgres := []struct { - name string - payload string - statusCode int64 - expectedSchema ddl.CreateTable - expectedIssues internal.TableIssues - }{ { - name: "Test type change", - payload: `{ - "Name": "rule1", - "Type": "global_datatype_change", - "ObjectType": "Column", - "AssociatedObjects": "All Columns", - "Enabled": true, - "Data": - { - "bool":"STRING", - "int8":"STRING", - "float4":"STRING", - "varchar":"BYTES", - "numeric":"STRING", - "timestamptz":"STRING", - "bigserial":"STRING", - "bpchar":"BYTES", - "bytea":"STRING", - "date":"STRING", - "float8":"STRING", - "int4":"STRING", - "serial":"STRING", - "text":"BYTES", - "timestamp":"STRING" - } - }`, - statusCode: http.StatusOK, - expectedSchema: ddl.CreateTable{ - Name: "table1", - Id: "t1", - ColIds: []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13", "c14", "c15", "c16"}, - ColDefs: map[string]ddl.ColumnDef{ - "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c4": {Name: "d", Id: "c4", T: ddl.Type{Name: ddl.Bytes, Len: 6}}, - "c5": {Name: "e", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c6": {Name: "f", Id: "c6", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c7": {Name: "g", Id: "c7", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c8": {Name: "h", Id: "c8", T: ddl.Type{Name: ddl.Bytes, Len: int64(1)}}, - "c9": {Name: "i", Id: "c9", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c10": {Name: "j", Id: "c10", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c11": {Name: "k", Id: "c11", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c12": {Name: "l", Id: "c12", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c13": {Name: "m", Id: "c13", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c14": {Name: "n", Id: "c14", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, - "c15": {Name: "o", Id: "c15", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c16": {Name: "p", Id: "c16", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }, - expectedIssues: internal.TableIssues{ - ColumnLevelIssues: map[string][]internal.SchemaIssue{ - "c1": {internal.Widened}, - "c2": {internal.Widened}, - "c3": {internal.Widened}, - "c5": {internal.Widened}, - "c6": {internal.Widened}, - "c7": {internal.Widened, internal.Serial}, - "c10": {internal.Widened}, - "c11": {internal.Widened}, - "c12": {internal.Widened}, - "c13": {internal.Widened, internal.Serial}, - "c15": {internal.Widened}, - "c16": {internal.Widened}, + name: "successful interleave with multiple fks refering multiple tables", + ct: &internal.Conv{ + SpSchema: map[string]ddl.CreateTable{ + "t1": { + Name: "t1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + ForeignKeys: []ddl.Foreignkey{ + {Name: "fk1", ColIds: []string{"c3"}, ReferTableId: "t3", ReferColumnIds: []string{"c3"}}, + {Name: "fk1", ColIds: []string{"c1", "c2"}, ReferTableId: "t2", ReferColumnIds: []string{"c1", "c2"}}}, + }, + "t2": { + Name: "t2", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + }, + "t3": { + Name: "t3", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{"c1": {Name: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c3", Desc: false, Order: 1}}, + }, }, - }, - }, - { - name: "Test type change 2", - payload: `{ - "Name": "rule1", - "Type": "global_datatype_change", - "ObjectType": "Column", - "AssociatedObjects": "All Columns", - "Enabled": true, - "Data": - { - "bool":"INT64", - "int8":"STRING", - "float4":"STRING" - } - }`, - statusCode: http.StatusOK, - expectedSchema: ddl.CreateTable{ - Name: "table1", - Id: "t1", - ColIds: []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13", "c14", "c15", "c16"}, - ColDefs: map[string]ddl.ColumnDef{ - "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, - "c4": {Name: "d", Id: "c4", T: ddl.Type{Name: ddl.String, Len: int64(6)}}, - "c5": {Name: "e", Id: "c5", T: ddl.Type{Name: ddl.Numeric}}, - "c6": {Name: "f", Id: "c6", T: ddl.Type{Name: ddl.Timestamp}}, - "c7": {Name: "g", Id: "c7", T: ddl.Type{Name: ddl.Int64}}, - "c8": {Name: "h", Id: "c8", T: ddl.Type{Name: ddl.String, Len: int64(1)}}, - "c9": {Name: "i", Id: "c9", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, - "c10": {Name: "j", Id: "c10", T: ddl.Type{Name: ddl.Date}}, - "c11": {Name: "k", Id: "c11", T: ddl.Type{Name: ddl.Float64}}, - "c12": {Name: "l", Id: "c12", T: ddl.Type{Name: ddl.Int64}}, - "c13": {Name: "m", Id: "c13", T: ddl.Type{Name: ddl.Int64}}, - "c14": {Name: "n", Id: "c14", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c15": {Name: "o", Id: "c15", T: ddl.Type{Name: ddl.Timestamp}}, - "c16": {Name: "p", Id: "c16", T: ddl.Type{Name: ddl.Int64}}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }, - expectedIssues: internal.TableIssues{ - ColumnLevelIssues: map[string][]internal.SchemaIssue{ - "c1": {internal.Widened}, - "c2": {internal.Widened}, - "c3": {internal.Widened}, - "c7": {internal.Serial}, - "c12": {internal.Widened}, - "c13": {internal.Serial}, - "c15": {internal.Timestamp}, - "c16": {internal.Widened}, + SchemaIssues: map[string]internal.TableIssues{ + "t1": { + ColumnLevelIssues: make(map[string][]internal.SchemaIssue), + }, }, - }, - }, - { - name: "Test bad payload data request", - payload: `{ - "Name": "rule1", - "Type": "global_datatype_change", - "ObjectType": "Column", - "AssociatedObjects": "All Columns", - "Enabled": true, - "Data": - { - "bool":"INT64", - "int8":"STRING", - "float4":"STRING", - } - }`, - statusCode: http.StatusBadRequest, + Audit: internal.Audit{ + MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + }, + }, + table: "t1", + statusCode: http.StatusOK, + expectedResponse: &types.TableInterleaveStatus{Possible: true, Parent: "t2"}, + expectedFKs: []ddl.Foreignkey{ddl.Foreignkey{Name: "fk1", ColIds: []string{"c1", "c2"}, ReferTableId: "t2", ReferColumnIds: []string{"c1", "c2"}, Id: ""}}, + parentTable: "t2", }, } - for _, tc := range tcSetGlobalDataTypePostgres { - + for _, tc := range tests { sessionState := session.GetSessionState() - sessionState.Driver = constants.POSTGRES - sessionState.Conv = internal.MakeConv() - buildConvPostgres(sessionState.Conv) - payload := tc.payload - req, err := http.NewRequest("POST", "/applyrule", strings.NewReader(payload)) + sessionState.Driver = constants.MYSQL + sessionState.Conv = tc.ct + update := true + req, err := http.NewRequest("GET", fmt.Sprintf("/setparent?table=%s&update=%v", tc.table, update), nil) if err != nil { t.Fatal(err) } req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(applyRule) + handler := http.HandlerFunc(api.SetParentTable) handler.ServeHTTP(rr, req) - var res *internal.Conv - json.Unmarshal(rr.Body.Bytes(), &res) - if status := rr.Code; int64(status) != tc.statusCode { - t.Errorf("handler returned wrong status code: got %v want %v", - status, tc.statusCode) - } - if tc.statusCode == http.StatusOK { - assert.Equal(t, tc.expectedSchema, res.SpSchema["t1"]) - assert.Equal(t, tc.expectedIssues, res.SchemaIssues["t1"]) + type ParentTableSetResponse struct { + TableInterleaveStatus *types.TableInterleaveStatus `json:"tableInterleaveStatus"` + SessionState *internal.Conv `json:"sessionState"` } - } - tcSetGlobalDataTypeMysql := []struct { - name string - payload string - statusCode int64 - expectedSchema ddl.CreateTable - expectedIssues internal.TableIssues - }{ - { - name: "Test type change", - payload: `{ - "Name": "rule1", - "Type": "global_datatype_change", - "ObjectType": "Column", - "AssociatedObjects": "All Columns", - "Enabled": true, - "Data": - { - "bool":"STRING", - "smallint":"STRING", - "float":"STRING", - "varchar":"BYTES", - "numeric":"STRING", - "timestamp":"STRING", - "decimal":"STRING", - "json":"BYTES", - "binary":"STRING", - "blob":"STRING", - "double":"STRING", - "date":"STRING", - "time":"STRING", - "enum":"STRING", - "text":"BYTES" - } - }`, - statusCode: http.StatusOK, - expectedSchema: ddl.CreateTable{ - Name: "table1", - Id: "t1", - ColIds: []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13", "c14", "c15", "c16"}, - ColDefs: map[string]ddl.ColumnDef{ - "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, - "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c4": {Name: "d", Id: "c4", T: ddl.Type{Name: ddl.Bytes, Len: 6}}, - "c5": {Name: "e", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c6": {Name: "f", Id: "c6", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c7": {Name: "g", Id: "c7", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, - "c8": {Name: "h", Id: "c8", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c9": {Name: "i", Id: "c9", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c10": {Name: "j", Id: "c10", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c11": {Name: "k", Id: "c11", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c12": {Name: "l", Id: "c12", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c13": {Name: "m", Id: "c13", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c14": {Name: "n", Id: "c14", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c15": {Name: "o", Id: "c15", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c16": {Name: "p", Id: "c16", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }, - expectedIssues: internal.TableIssues{ - ColumnLevelIssues: map[string][]internal.SchemaIssue{ - "c1": {internal.Widened}, - "c3": {internal.Widened}, - "c5": {internal.Widened}, - "c10": {internal.Widened}, - "c11": {internal.Widened}, - "c12": {internal.Widened}, - "c13": {internal.Widened}, - "c14": {internal.Widened}, - "c15": {internal.Widened}, - "c16": {internal.Time}, - }, - }, - }, - { - name: "Test type change 2", - payload: `{ - "Name": "rule1", - "Type": "global_datatype_change", - "ObjectType": "Column", - "AssociatedObjects": "All Columns", - "Enabled": true, - "Data": - { - "bool":"INT64", - "varchar":"BYTES" - } - }`, - statusCode: http.StatusOK, - expectedSchema: ddl.CreateTable{ - Name: "table1", - Id: "t1", - ColIds: []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13", "c14", "c15", "c16"}, - ColDefs: map[string]ddl.ColumnDef{ - "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.Int64}}, - "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.Int64}}, - "c4": {Name: "d", Id: "c4", T: ddl.Type{Name: ddl.Bytes, Len: 6}}, - "c5": {Name: "e", Id: "c5", T: ddl.Type{Name: ddl.Numeric}}, - "c6": {Name: "f", Id: "c6", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c7": {Name: "g", Id: "c7", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - "c8": {Name: "h", Id: "c8", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, - "c9": {Name: "i", Id: "c9", T: ddl.Type{Name: ddl.Bytes, Len: ddl.MaxLength}}, - "c10": {Name: "j", Id: "c10", T: ddl.Type{Name: ddl.Int64}}, - "c11": {Name: "k", Id: "c11", T: ddl.Type{Name: ddl.Float64}}, - "c12": {Name: "l", Id: "c12", T: ddl.Type{Name: ddl.Float64}}, - "c13": {Name: "m", Id: "c13", T: ddl.Type{Name: ddl.Numeric}}, - "c14": {Name: "n", Id: "c14", T: ddl.Type{Name: ddl.Date}}, - "c15": {Name: "o", Id: "c15", T: ddl.Type{Name: ddl.Timestamp}}, - "c16": {Name: "p", Id: "c16", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}}, - }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1"}}, - }, - expectedIssues: internal.TableIssues{ - ColumnLevelIssues: map[string][]internal.SchemaIssue{ - "c1": {internal.Widened}, - "c3": {internal.Widened}, - "c10": {internal.Widened}, - "c12": {internal.Widened}, - "c15": {internal.Time}, - }, - }, - }, - { - name: "Test bad request", - payload: `{ - "Name": "rule1", - "Type": "global_datatype_change", - "ObjectType": "Column", - "AssociatedObjects": "All Columns", - "Enabled": true, - "Data": - { - "bool":"INT64", - "smallint":"STRING", - } - }`, - statusCode: http.StatusBadRequest, - }, - } - for _, tc := range tcSetGlobalDataTypeMysql { - sessionState := session.GetSessionState() + var res *types.TableInterleaveStatus - sessionState.Driver = constants.MYSQL - sessionState.Conv = internal.MakeConv() - buildConvMySQL(sessionState.Conv) - payload := tc.payload - req, err := http.NewRequest("POST", "/applyrule", strings.NewReader(payload)) - if err != nil { - t.Fatal(err) + if update { + parentTableResponse := &ParentTableSetResponse{} + json.Unmarshal(rr.Body.Bytes(), parentTableResponse) + res = parentTableResponse.TableInterleaveStatus + } else { + res = &types.TableInterleaveStatus{} + json.Unmarshal(rr.Body.Bytes(), res) } - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - handler := http.HandlerFunc(applyRule) - handler.ServeHTTP(rr, req) - var res *internal.Conv - json.Unmarshal(rr.Body.Bytes(), &res) + if status := rr.Code; int64(status) != tc.statusCode { - t.Errorf("handler returned wrong status code: got %v want %v", - status, tc.statusCode) + t.Errorf("%s\nhandler returned wrong status code: got %v want %v", + tc.name, status, tc.statusCode) } - if tc.statusCode == http.StatusOK { - assert.Equal(t, tc.expectedSchema, res.SpSchema["t1"]) - assert.Equal(t, tc.expectedIssues, res.SchemaIssues["t1"]) + assert.Equal(t, tc.expectedResponse, res, tc.name) + } + if tc.parentTable != "" { + assert.Equal(t, tc.parentTable, sessionState.Conv.SpSchema[tc.table].ParentId, tc.name) + assert.Equal(t, tc.expectedFKs, sessionState.Conv.SpSchema[tc.table].ForeignKeys, tc.name) } } } -func TestDropRule(t *testing.T) { +func TestRemoveParentTable(t *testing.T) { tc := []struct { - name string - ruleId string - statusCode int64 - conv *internal.Conv - expectedConv *internal.Conv + name string + tableId string + statusCode int64 + conv *internal.Conv + expectedSpSchema ddl.Schema }{ { - name: "drop a valid add index rule", - ruleId: "r101", - statusCode: http.StatusOK, - conv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{ - {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, - {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, - {Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, - }, - }}, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, - UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true, "idx3": true}, - Rules: []internal.Rule{{ - Id: "r101", - Name: "add_index", - Type: constants.AddIndex, - ObjectType: "table", - AssociatedObjects: "t1", - Enabled: true, - Data: ddl.CreateIndex{Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, - }}, - }, - expectedConv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{ - {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, - {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, - }, - }}, - }, - }, - { - name: "drop a vaild add global data type rule", - ruleId: "r101", + name: "Remove interleaving with valid table id", + tableId: "t1", statusCode: http.StatusOK, conv: &internal.Conv{ SchemaIssues: map[string]internal.TableIssues{ "t1": {}, + "t2": {}, }, SrcSchema: map[string]schema.Table{ "t1": { @@ -2741,9 +1972,21 @@ func TestDropRule(t *testing.T) { "c2": {Name: "b", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c2"}, "c3": {Name: "c", Type: schema.Type{Name: "varchar"}, NotNull: false, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c3"}, }, - PrimaryKeys: []schema.Key{{ColId: "c1", Desc: false, Order: 1}}, + PrimaryKeys: []schema.Key{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + ForeignKeys: []schema.ForeignKey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}, Id: "f1"}}, Id: "t1", }, + + "t2": { + Name: "table2", + ColIds: []string{"c4", "c5"}, + ColDefs: map[string]schema.Column{ + "c4": {Name: "a", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: true, AutoIncrement: false}, Id: "c4"}, + "c5": {Name: "d", Type: schema.Type{Name: "varchar"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c5"}, + }, + Id: "t2", + PrimaryKeys: []schema.Key{{ColId: "c4", Desc: false, Order: 1}}, + }, }, SpSchema: map[string]ddl.CreateTable{ "t1": { @@ -2751,33 +1994,62 @@ func TestDropRule(t *testing.T) { ColIds: []string{"c1", "c2", "c3"}, ColDefs: map[string]ddl.ColumnDef{ "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, - "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}}, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, Id: "t1", + ParentId: "t2", }, - }, + "t2": { + Name: "table2", + ColIds: []string{"c4", "c5"}, + ColDefs: map[string]ddl.ColumnDef{ + "c4": {Name: "a", Id: "c4", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c5": {Name: "d", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c4", Desc: false, Order: 1}}, + Id: "t2", + }}, Audit: internal.Audit{ MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), }, - Rules: []internal.Rule{ - { - Id: "r101", - Name: "bigint to BTYES", - Type: constants.GlobalDataTypeChange, - ObjectType: "Column", - AssociatedObjects: "All Columns", - Enabled: true, - Data: map[string]string{ - "bigint": ddl.String, - }, + UsedNames: map[string]bool{"table1": true, "table2": true}, + }, + expectedSpSchema: ddl.Schema{ + "t1": { + Name: "table1", + ColIds: []string{"c1", "c2", "c3"}, + ColDefs: map[string]ddl.ColumnDef{ + "c1": {Name: "a", Id: "c1", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, + }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + ForeignKeys: []ddl.Foreignkey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}, Id: "f1"}}, + Id: "t1", + ParentId: "", + }, + "t2": { + Name: "table2", + ColIds: []string{"c4", "c5"}, + ColDefs: map[string]ddl.ColumnDef{ + "c4": {Name: "a", Id: "c4", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c5": {Name: "d", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c4", Desc: false, Order: 1}}, + Id: "t2", }, }, - expectedConv: &internal.Conv{ + }, + + {name: "Remove interleaving with invalid table id", + tableId: "A", + statusCode: http.StatusBadRequest, + conv: &internal.Conv{ SchemaIssues: map[string]internal.TableIssues{ "t1": {}, + "t2": {}, }, SrcSchema: map[string]schema.Table{ "t1": { @@ -2788,9 +2060,21 @@ func TestDropRule(t *testing.T) { "c2": {Name: "b", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c2"}, "c3": {Name: "c", Type: schema.Type{Name: "varchar"}, NotNull: false, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c3"}, }, - PrimaryKeys: []schema.Key{{ColId: "c1", Desc: false, Order: 1}}, + PrimaryKeys: []schema.Key{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, + ForeignKeys: []schema.ForeignKey{{Name: "fk1", ColIds: []string{"c1"}, ReferTableId: "t2", ReferColumnIds: []string{"c4"}, Id: "f1"}}, Id: "t1", }, + + "t2": { + Name: "table2", + ColIds: []string{"c4", "c5"}, + ColDefs: map[string]schema.Column{ + "c4": {Name: "a", Type: schema.Type{Name: "bigint"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: true, AutoIncrement: false}, Id: "c4"}, + "c5": {Name: "d", Type: schema.Type{Name: "varchar"}, NotNull: true, Ignored: schema.Ignored{Check: false, Identity: false, Default: false, Exclusion: false, ForeignKey: false, AutoIncrement: false}, Id: "c5"}, + }, + Id: "t2", + PrimaryKeys: []schema.Key{{ColId: "c4", Desc: false, Order: 1}}, + }, }, SpSchema: map[string]ddl.CreateTable{ "t1": { @@ -2801,107 +2085,53 @@ func TestDropRule(t *testing.T) { "c2": {Name: "b", Id: "c2", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, "c3": {Name: "c", Id: "c3", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, }, - PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false}}, + PrimaryKeys: []ddl.IndexKey{{ColId: "c1", Desc: false, Order: 1}, {ColId: "c2", Desc: false, Order: 2}}, Id: "t1", + ParentId: "t2", }, - }, - }, - }, - { - name: "drop rule with an invalid rule-id", - ruleId: "ABC", - statusCode: http.StatusBadRequest, - conv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{ - {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, - {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, - {Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, - }, - }}, - Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), - }, - UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true, "idx3": true}, - Rules: []internal.Rule{{ - Id: "r101", - Name: "add_index", - Type: constants.AddIndex, - ObjectType: "table", - AssociatedObjects: "t1", - Enabled: true, - Data: ddl.CreateIndex{Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, - }}, - }, - }, - { - name: "drop a disabled valid add index rule", - ruleId: "r101", - statusCode: http.StatusOK, - conv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{ - {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, - {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, + "t2": { + Name: "table2", + ColIds: []string{"c4", "c5"}, + ColDefs: map[string]ddl.ColumnDef{ + "c4": {Name: "a", Id: "c4", T: ddl.Type{Name: ddl.Int64}, NotNull: true}, + "c5": {Name: "d", Id: "c5", T: ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, NotNull: true}, }, + PrimaryKeys: []ddl.IndexKey{{ColId: "c4", Desc: false, Order: 1}}, + Id: "t2", }}, Audit: internal.Audit{ - MigrationType: migration.MigrationData_SCHEMA_ONLY.Enum(), + MigrationType: migration.MigrationData_MIGRATION_TYPE_UNSPECIFIED.Enum(), }, - UsedNames: map[string]bool{"table1": true, "idx1": true, "idx2": true}, - Rules: []internal.Rule{{ - Id: "r101", - Name: "add_index", - Type: constants.AddIndex, - ObjectType: "table", - AssociatedObjects: "t1", - Enabled: false, - Data: ddl.CreateIndex{Name: "idx3", Id: "i3", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c2", Desc: false, Order: 1}}}, - }}, - }, - expectedConv: &internal.Conv{ - SpSchema: map[string]ddl.CreateTable{ - "t1": { - Name: "table1", - Id: "t1", - Indexes: []ddl.CreateIndex{ - {Name: "idx1", Id: "i1", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c1", Desc: false}}}, - {Name: "idx2", Id: "i2", TableId: "t1", Unique: false, Keys: []ddl.IndexKey{{ColId: "c3", Desc: false}, {ColId: "c4", Desc: false}}}, - }, - }}, + UsedNames: map[string]bool{"table1": true, "table2": true}, }, + expectedSpSchema: ddl.Schema{}, }, } + for _, tc := range tc { sessionState := session.GetSessionState() sessionState.Driver = constants.MYSQL + sessionState.Conv = tc.conv payload := `{}` - req, err := http.NewRequest("POST", "/dropRule?id="+tc.ruleId, strings.NewReader(payload)) + req, err := http.NewRequest("POST", "/drop/removeParent?tableId="+tc.tableId, strings.NewReader(payload)) if err != nil { t.Fatal(err) } req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() - handler := http.HandlerFunc(dropRule) + handler := http.HandlerFunc(api.RemoveParentTable) handler.ServeHTTP(rr, req) var res *internal.Conv json.Unmarshal(rr.Body.Bytes(), &res) if status := rr.Code; int64(status) != tc.statusCode { - t.Errorf("%s : handler returned wrong status code: got %v want %v", - tc.name, status, tc.statusCode) + t.Errorf("handler returned wrong status code: got %v want %v", + status, tc.statusCode) } if tc.statusCode == http.StatusOK { - assert.Equal(t, tc.expectedConv, res) + assert.Equal(t, tc.expectedSpSchema, res.SpSchema) } } - } func buildConvMySQL(conv *internal.Conv) { diff --git a/webv2/config.json b/webv2/config.json index 8c5d5e5443..d3bdeb50b9 100644 --- a/webv2/config.json +++ b/webv2/config.json @@ -1,4 +1,4 @@ { - "GCPProjectID": "", - "SpannerInstanceID": "" + "GCPProjectID": "span-cloud-testing", + "SpannerInstanceID": "deep-heavy-100gb" } \ No newline at end of file diff --git a/webv2/routes.go b/webv2/routes.go index f45d34db5f..67e3ab3217 100644 --- a/webv2/routes.go +++ b/webv2/routes.go @@ -24,7 +24,7 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/session" "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/summary" "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/table" - + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/api" "github.com/gorilla/mux" ) @@ -33,39 +33,39 @@ func getRoutes() *mux.Router { frontendRoot, _ := fs.Sub(FrontendDir, "ui/dist/ui") frontendStatic := http.FileServer(http.FS(frontendRoot)) router.HandleFunc("/connect", databaseConnection).Methods("POST") - router.HandleFunc("/convert/infoschema", convertSchemaSQL).Methods("GET") - router.HandleFunc("/convert/dump", convertSchemaDump).Methods("POST") + router.HandleFunc("/convert/infoschema", api.ConvertSchemaSQL).Methods("GET") + router.HandleFunc("/convert/dump", api.ConvertSchemaDump).Methods("POST") router.HandleFunc("/convert/session", loadSession).Methods("POST") - router.HandleFunc("/ddl", getDDL).Methods("GET") - router.HandleFunc("/conversion", getConversionRate).Methods("GET") - router.HandleFunc("/typemap", getTypeMap).Methods("GET") + router.HandleFunc("/ddl", api.GetDDL).Methods("GET") + router.HandleFunc("/conversion", api.GetConversionRate).Methods("GET") + router.HandleFunc("/typemap", api.GetTypeMap).Methods("GET") router.HandleFunc("/report", getReportFile).Methods("GET") router.HandleFunc("/downloadStructuredReport", getDStructuredReport).Methods("GET") router.HandleFunc("/downloadTextReport", getDTextReport).Methods("GET") router.HandleFunc("/downloadDDL", getDSpannerDDL).Methods("GET") router.HandleFunc("/schema", getSchemaFile).Methods("GET") - router.HandleFunc("/applyrule", applyRule).Methods("POST") - router.HandleFunc("/dropRule", dropRule).Methods("POST") + router.HandleFunc("/applyrule", api.ApplyRule).Methods("POST") + router.HandleFunc("/dropRule", api.DropRule).Methods("POST") router.HandleFunc("/typemap/table", table.UpdateTableSchema).Methods("POST") router.HandleFunc("/typemap/reviewTableSchema", table.ReviewTableSchema).Methods("POST") - router.HandleFunc("/typemap/GetStandardTypeToPGSQLTypemap", getStandardTypeToPGSQLTypemap).Methods("GET") - router.HandleFunc("/typemap/GetPGSQLToStandardTypeTypemap", getPGSQLToStandardTypeTypemap).Methods("GET") - router.HandleFunc("/spannerDefaultTypeMap", spannerDefaultTypeMap).Methods("GET") + router.HandleFunc("/typemap/GetStandardTypeToPGSQLTypemap", api.GetStandardTypeToPGSQLTypemap).Methods("GET") + router.HandleFunc("/typemap/GetPGSQLToStandardTypeTypemap", api.GetPGSQLToStandardTypeTypemap).Methods("GET") + router.HandleFunc("/spannerDefaultTypeMap", api.SpannerDefaultTypeMap).Methods("GET") - router.HandleFunc("/setparent", setParentTable).Methods("GET") - router.HandleFunc("/removeParent", removeParentTable).Methods("POST") + router.HandleFunc("/setparent", api.SetParentTable).Methods("GET") + router.HandleFunc("/removeParent", api.RemoveParentTable).Methods("POST") // TODO:(searce) take constraint names themselves which are guaranteed to be unique for Spanner. - router.HandleFunc("/drop/secondaryindex", dropSecondaryIndex).Methods("POST") - router.HandleFunc("/restore/secondaryIndex", restoreSecondaryIndex).Methods("POST") + router.HandleFunc("/drop/secondaryindex", api.DropSecondaryIndex).Methods("POST") + router.HandleFunc("/restore/secondaryIndex", api.RestoreSecondaryIndex).Methods("POST") - router.HandleFunc("/restore/table", restoreTable).Methods("POST") - router.HandleFunc("/restore/tables", restoreTables).Methods("POST") - router.HandleFunc("/drop/table", dropTable).Methods("POST") - router.HandleFunc("/drop/tables", dropTables).Methods("POST") + router.HandleFunc("/restore/table", api.RestoreTable).Methods("POST") + router.HandleFunc("/restore/tables", api.RestoreTables).Methods("POST") + router.HandleFunc("/drop/table", api.DropTable).Methods("POST") + router.HandleFunc("/drop/tables", api.DropTables).Methods("POST") - router.HandleFunc("/update/fks", updateForeignKeys).Methods("POST") - router.HandleFunc("/update/indexes", updateIndexes).Methods("POST") + router.HandleFunc("/update/fks", api.UpdateForeignKeys).Methods("POST") + router.HandleFunc("/update/indexes", api.UpdateIndexes).Methods("POST") // Session Management router.HandleFunc("/IsOffline", session.IsOfflineSession).Methods("GET") @@ -115,7 +115,7 @@ func getRoutes() *mux.Router { router.HandleFunc("/GetSourceProfileConfig", getSourceProfileConfig).Methods("GET") router.HandleFunc("/uploadFile", uploadFile).Methods("POST") - router.HandleFunc("/GetTableWithErrors", getTableWithErrors).Methods("GET") + router.HandleFunc("/GetTableWithErrors", api.GetTableWithErrors).Methods("GET") router.HandleFunc("/ping", getBackendHealth).Methods("GET") router.PathPrefix("/").Handler(frontendStatic) diff --git a/webv2/types/types.go b/webv2/types/types.go new file mode 100644 index 0000000000..1184fff9aa --- /dev/null +++ b/webv2/types/types.go @@ -0,0 +1,164 @@ +package types + +import ( + "database/sql" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" +) + +// TODO:(searce) organize this file according to go style guidelines: generally +// have public constants and public type definitions first, then public +// functions, and finally helper functions (usually in order of importance). + +// driverConfig contains the parameters needed to make a direct database connection. It is +// used to communicate via HTTP with the frontend. +type DriverConfig struct { + Driver string `json:"Driver"` + IsSharded bool `json:"IsSharded"` + Host string `json:"Host"` + Port string `json:"Port"` + Database string `json:"Database"` + User string `json:"User"` + Password string `json:"Password"` + Dialect string `json:"Dialect"` + DataShardId string `json:"DataShardId"` +} + +type DriverConfigs struct { + DbConfigs []DriverConfig `json:"DbConfigs"` + IsRestoredSession string `json:"IsRestoredSession"` +} + +type ShardedDataflowConfig struct { + MigrationProfile profiles.SourceProfileConfig +} + +type SessionSummary struct { + DatabaseType string + ConnectionDetail string + SourceTableCount int + SpannerTableCount int + SourceIndexCount int + SpannerIndexCount int + ConnectionType string + SourceDatabaseName string + Region string + NodeCount int + ProcessingUnits int + Instance string + Dialect string + IsSharded bool +} + +type ProgressDetails struct { + Progress int + ErrorMessage string + ProgressStatus int +} + +type MigrationDetails struct { + TargetDetails TargetDetails `json:"TargetDetails"` + DatastreamConfig profiles.DatastreamConfig `json:"DatastreamConfig"` + GcsConfig profiles.GcsConfig `json:"GcsConfig"` + DataflowConfig profiles.DataflowConfig `json:"DataflowConfig"` + MigrationMode string `json:"MigrationMode"` + MigrationType string `json:"MigrationType"` + IsSharded bool `json:"IsSharded"` + SkipForeignKeys bool `json:"skipForeignKeys"` +} + +type TargetDetails struct { + TargetDB string `json:"TargetDB"` + SourceConnectionProfileName string `json:"SourceConnProfile"` + TargetConnectionProfileName string `json:"TargetConnProfile"` + ReplicationSlot string `json:"ReplicationSlot"` + Publication string `json:"Publication"` +} + +type ColMaxLength struct { + SpDataType string `json:"spDataType"` + SpColMaxLength string `json:"spColMaxLength"` +} + +type TableIdAndName struct { + Id string `json:"Id"` + Name string `json:"Name"` +} + +type ShardIdPrimaryKey struct { + AddedAtTheStart bool `json:"AddedAtTheStart"` +} + +// dumpConfig contains the parameters needed to run the tool using dump approach. It is +// used to communicate via HTTP with the frontend. +type DumpConfig struct { + Driver string `json:"Driver"` + FilePath string `json:"Path"` +} + +type SpannerDetails struct { + Dialect string `json:"Dialect"` +} + +type ConvertFromDumpRequest struct { + Config DumpConfig `json:"Config"` + SpannerDetails SpannerDetails `json:"SpannerDetails"` +} + +// SessionState stores information for the current migration session. +type SessionState struct { + sourceDB *sql.DB // Connection to source database in case of direct connection + dbName string // Name of source database + driver string // Name of Spanner migration tool driver in use + conv *internal.Conv // Current conversion state + sessionFile string // Path to session file +} + +// Type and issue. +type TypeIssue struct { + T string + Brief string + DisplayT string +} + +type ResourceDetails struct { + ResourceType string `json:"ResourceType"` + ResourceName string `json:"ResourceName"` + ResourceUrl string `json:"ResourceUrl"` + GcloudCmd string `json:"GcloudCmd"` +} +type GeneratedResources struct { + MigrationJobId string `json:"MigrationJobId"` + DatabaseName string `json:"DatabaseName"` + DatabaseUrl string `json:"DatabaseUrl"` + BucketName string `json:"BucketName"` + BucketUrl string `json:"BucketUrl"` + //Used for single instance migration flow + DataStreamJobName string `json:"DataStreamJobName"` + DataStreamJobUrl string `json:"DataStreamJobUrl"` + DataflowJobName string `json:"DataflowJobName"` + DataflowJobUrl string `json:"DataflowJobUrl"` + DataflowGcloudCmd string `json:"DataflowGcloudCmd"` + PubsubTopicName string `json:"PubsubTopicName"` + PubsubTopicUrl string `json:"PubsubTopicUrl"` + PubsubSubscriptionName string `json:"PubsubSubscriptionName"` + PubsubSubscriptionUrl string `json:"PubsubSubscriptionUrl"` + MonitoringDashboardName string `json:"MonitoringDashboardName"` + MonitoringDashboardUrl string `json:"MonitoringDashboardUrl"` + AggMonitoringDashboardName string `json:"AggMonitoringDashboardName"` + AggMonitoringDashboardUrl string `json:"AggMonitoringDashboardUrl"` + //Used for sharded migration flow + ShardToShardResourcesMap map[string][]ResourceDetails `json:"ShardToShardResourcesMap"` +} + +type DropDetail struct { + Name string `json:"Name"` +} + +// TableInterleaveStatus stores data regarding interleave status. +type TableInterleaveStatus struct { + Possible bool + Parent string + Comment string +} \ No newline at end of file diff --git a/webv2/web.go b/webv2/web.go index fe842855f6..f7fdde3391 100644 --- a/webv2/web.go +++ b/webv2/web.go @@ -30,7 +30,6 @@ import ( "net/http" "os" "path/filepath" - "sort" "strconv" "strings" "time" @@ -47,18 +46,12 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/proto/migration" - "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/oracle" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/postgres" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/sqlserver" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/config" helpers "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/helpers" "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/profile" - "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/table" + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/types" utilities "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/utilities" "github.com/pkg/browser" instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" @@ -83,98 +76,6 @@ import ( // 6) Update schema conv after setting global datatypes and return conv. (setTypeMap) // 7) Add rateConversion() in schema conversion, ddl and report APIs. // 8) Add an overview in summary report API -var mysqlTypeMap = make(map[string][]typeIssue) -var postgresTypeMap = make(map[string][]typeIssue) -var sqlserverTypeMap = make(map[string][]typeIssue) -var oracleTypeMap = make(map[string][]typeIssue) - -var mysqlDefaultTypeMap = make(map[string]ddl.Type) -var postgresDefaultTypeMap = make(map[string]ddl.Type) -var sqlserverDefaultTypeMap = make(map[string]ddl.Type) -var oracleDefaultTypeMap = make(map[string]ddl.Type) - -// TODO:(searce) organize this file according to go style guidelines: generally -// have public constants and public type definitions first, then public -// functions, and finally helper functions (usually in order of importance). - -// driverConfig contains the parameters needed to make a direct database connection. It is -// used to communicate via HTTP with the frontend. -type driverConfig struct { - Driver string `json:"Driver"` - IsSharded bool `json:"IsSharded"` - Host string `json:"Host"` - Port string `json:"Port"` - Database string `json:"Database"` - User string `json:"User"` - Password string `json:"Password"` - Dialect string `json:"Dialect"` - DataShardId string `json:"DataShardId"` -} - -type driverConfigs struct { - DbConfigs []driverConfig `json:"DbConfigs"` - IsRestoredSession string `json:"IsRestoredSession"` -} - -type shardedDataflowConfig struct { - MigrationProfile profiles.SourceProfileConfig -} - -type sessionSummary struct { - DatabaseType string - ConnectionDetail string - SourceTableCount int - SpannerTableCount int - SourceIndexCount int - SpannerIndexCount int - ConnectionType string - SourceDatabaseName string - Region string - NodeCount int - ProcessingUnits int - Instance string - Dialect string - IsSharded bool -} - -type progressDetails struct { - Progress int - ErrorMessage string - ProgressStatus int -} - -type migrationDetails struct { - TargetDetails targetDetails `json:"TargetDetails"` - DatastreamConfig profiles.DatastreamConfig `json:"DatastreamConfig"` - GcsConfig profiles.GcsConfig `json:"GcsConfig"` - DataflowConfig profiles.DataflowConfig `json:"DataflowConfig"` - MigrationMode string `json:"MigrationMode"` - MigrationType string `json:"MigrationType"` - IsSharded bool `json:"IsSharded"` - SkipForeignKeys bool `json:"skipForeignKeys"` -} - -type targetDetails struct { - TargetDB string `json:"TargetDB"` - SourceConnectionProfileName string `json:"SourceConnProfile"` - TargetConnectionProfileName string `json:"TargetConnProfile"` - ReplicationSlot string `json:"ReplicationSlot"` - Publication string `json:"Publication"` -} - -type ColMaxLength struct { - SpDataType string `json:"spDataType"` - SpColMaxLength string `json:"spColMaxLength"` -} - -type TableIdAndName struct { - Id string `json:"Id"` - Name string `json:"Name"` -} - -type ShardIdPrimaryKey struct { - AddedAtTheStart bool `json:"AddedAtTheStart"` -} // databaseConnection creates connection with database func databaseConnection(w http.ResponseWriter, r *http.Request) { @@ -183,7 +84,7 @@ func databaseConnection(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) return } - var config driverConfig + var config types.DriverConfig err = json.Unmarshal(reqBody, &config) if err != nil { http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) @@ -237,86 +138,6 @@ func databaseConnection(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -// convertSchemaSQL converts source database to Spanner when using -// with postgres and mysql driver. -func convertSchemaSQL(w http.ResponseWriter, r *http.Request) { - sessionState := session.GetSessionState() - if sessionState.SourceDB == nil || sessionState.DbName == "" || sessionState.Driver == "" { - http.Error(w, fmt.Sprintf("Database is not configured or Database connection is lost. Please set configuration and connect to database."), http.StatusNotFound) - return - } - conv := internal.MakeConv() - - conv.SpDialect = sessionState.Dialect - conv.IsSharded = sessionState.IsSharded - var err error - additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ - IsSharded: sessionState.IsSharded, - } - processSchema := common.ProcessSchemaImpl{} - switch sessionState.Driver { - case constants.MYSQL: - err = processSchema.ProcessSchema(conv, mysql.InfoSchemaImpl{DbName: sessionState.DbName, Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) - case constants.POSTGRES: - temp := false - err = processSchema.ProcessSchema(conv, postgres.InfoSchemaImpl{Db: sessionState.SourceDB, IsSchemaUnique: &temp}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) - case constants.SQLSERVER: - err = processSchema.ProcessSchema(conv, sqlserver.InfoSchemaImpl{DbName: sessionState.DbName, Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) - case constants.ORACLE: - err = processSchema.ProcessSchema(conv, oracle.InfoSchemaImpl{DbName: strings.ToUpper(sessionState.DbName), Db: sessionState.SourceDB}, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) - default: - http.Error(w, fmt.Sprintf("Driver : '%s' is not supported", sessionState.Driver), http.StatusBadRequest) - return - } - if err != nil { - http.Error(w, fmt.Sprintf("Schema Conversion Error : %v", err), http.StatusNotFound) - return - } - - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - - sessionState.Conv = conv - - if sessionState.IsSharded { - setShardIdColumnAsPrimaryKey(true) - addShardIdColumnToForeignKeys(true) - ruleId := internal.GenerateRuleId() - rule := internal.Rule{ - Id: ruleId, - Name: ruleId, - Type: constants.AddShardIdPrimaryKey, - AssociatedObjects: "All Tables", - Data: ShardIdPrimaryKey{ - AddedAtTheStart: true, - }, - Enabled: true, - } - - sessionState := session.GetSessionState() - sessionState.Conv.Rules = append(sessionState.Conv.Rules, rule) - session.UpdateSessionFile() - } - - primarykey.DetectHotspot() - index.IndexSuggestion() - - sessionMetadata := session.SessionMetadata{ - SessionName: "NewSession", - DatabaseType: sessionState.Driver, - DatabaseName: sessionState.DbName, - Dialect: sessionState.Dialect, - } - - convm := session.ConvWithMetadata{ - SessionMetadata: sessionMetadata, - Conv: *sessionState.Conv, - } - sessionState.SessionMetadata = sessionMetadata - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - // dumpConfig contains the parameters needed to run the tool using dump approach. It is // used to communicate via HTTP with the frontend. type dumpConfig struct { @@ -340,7 +161,7 @@ func setSourceDBDetailsForDump(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) return } - var dc dumpConfig + var dc types.DumpConfig err = json.Unmarshal(reqBody, &dc) if err != nil { http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) @@ -439,7 +260,7 @@ func setShardsSourceDBDetailsForDataflow(w http.ResponseWriter, r *http.Request) http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) return } - var srcConfig shardedDataflowConfig + var srcConfig types.ShardedDataflowConfig err = json.Unmarshal(reqBody, &srcConfig) if err != nil { http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) @@ -471,7 +292,7 @@ func setShardsSourceDBDetailsForBulk(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) return } - var shardConfigs driverConfigs + var shardConfigs types.DriverConfigs err = json.Unmarshal(reqBody, &shardConfigs) if err != nil { http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) @@ -526,7 +347,7 @@ func setSourceDBDetailsForDirectConnect(w http.ResponseWriter, r *http.Request) http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) return } - var config driverConfig + var config types.DriverConfig err = json.Unmarshal(reqBody, &config) if err != nil { http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) @@ -576,70 +397,6 @@ func setSourceDBDetailsForDirectConnect(w http.ResponseWriter, r *http.Request) w.WriteHeader(http.StatusOK) } -// convertSchemaDump converts schema from dump file to Spanner schema for -// mysqldump and pg_dump driver. -func convertSchemaDump(w http.ResponseWriter, r *http.Request) { - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) - return - } - var dc convertFromDumpRequest - err = json.Unmarshal(reqBody, &dc) - if err != nil { - http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) - return - } - f, err := os.Open(constants.UPLOAD_FILE_DIR + "/" + dc.Config.FilePath) - if err != nil { - http.Error(w, fmt.Sprintf("Failed to open dump file : %v, no such file or directory", dc.Config.FilePath), http.StatusNotFound) - return - } - // We don't support Dynamodb in web hence no need to pass schema sample size here. - n := profiles.NewSourceProfileImpl{} - sourceProfile, _ := profiles.NewSourceProfile("", dc.Config.Driver, &n) - sourceProfile.Driver = dc.Config.Driver - schemaFromSource := conversion.SchemaFromSourceImpl{} - conv, err := schemaFromSource.SchemaFromDump(sourceProfile.Driver, dc.SpannerDetails.Dialect, &utils.IOStreams{In: f, Out: os.Stdout}, &conversion.ProcessDumpByDialectImpl{}) - if err != nil { - http.Error(w, fmt.Sprintf("Schema Conversion Error : %v", err), http.StatusNotFound) - return - } - - sessionMetadata := session.SessionMetadata{ - SessionName: "NewSession", - DatabaseType: dc.Config.Driver, - DatabaseName: filepath.Base(dc.Config.FilePath), - Dialect: dc.SpannerDetails.Dialect, - } - - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - sessionState.Conv = conv - - primarykey.DetectHotspot() - index.IndexSuggestion() - - sessionState.SessionMetadata = sessionMetadata - sessionState.Driver = dc.Config.Driver - sessionState.DbName = "" - sessionState.SessionFile = "" - sessionState.SourceDB = nil - sessionState.Dialect = dc.SpannerDetails.Dialect - sessionState.SourceDBConnDetails = session.SourceDBConnDetails{ - Path: constants.UPLOAD_FILE_DIR + "/" + dc.Config.FilePath, - ConnectionType: helpers.DUMP_MODE, - } - - convm := session.ConvWithMetadata{ - SessionMetadata: sessionMetadata, - Conv: *conv, - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - // loadSession load seesion file to Spanner migration tool. func loadSession(w http.ResponseWriter, r *http.Request) { sessionState := session.GetSessionState() @@ -742,1425 +499,122 @@ func fetchLastLoadedSessionDetails(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(convm) } -// getDDL returns the Spanner DDL for each table in alphabetical order. -// Unlike internal/convert.go's GetDDL, it does not print tables in a way that -// respects the parent/child ordering of interleaved tables. -// Though foreign keys and secondary indexes are displayed, getDDL cannot be used to -// build DDL to send to Spanner. -func getDDL(w http.ResponseWriter, r *http.Request) { +// getSchemaFile generates schema file and returns file path. +func getSchemaFile(w http.ResponseWriter, r *http.Request) { + ioHelper := &utils.IOStreams{In: os.Stdin, Out: os.Stdout} + var err error + now := time.Now() + filePrefix, err := utilities.GetFilePrefix(now) + if err != nil { + http.Error(w, fmt.Sprintf("Can not get file prefix : %v", err), http.StatusInternalServerError) + } + schemaFileName := "frontend/" + filePrefix + "schema.txt" + sessionState := session.GetSessionState() sessionState.Conv.ConvLock.RLock() defer sessionState.Conv.ConvLock.RUnlock() - c := ddl.Config{Comments: true, ProtectIds: false, SpDialect: sessionState.Conv.SpDialect, Source: sessionState.Driver} - var tables []string - for t := range sessionState.Conv.SpSchema { - tables = append(tables, t) - } - sort.Strings(tables) - ddl := make(map[string]string) - for _, t := range tables { - table := sessionState.Conv.SpSchema[t] - tableDdl := table.PrintCreateTable(sessionState.Conv.SpSchema, c) + ";" - if len(table.Indexes) > 0 { - tableDdl = tableDdl + "\n" - } - for _, index := range table.Indexes { - tableDdl = tableDdl + "\n" + index.PrintCreateIndex(table, c) + ";" - } - if len(table.ForeignKeys) > 0 { - tableDdl = tableDdl + "\n" - } - for _, fk := range table.ForeignKeys { - tableDdl = tableDdl + "\n" + fk.PrintForeignKeyAlterTable(sessionState.Conv.SpSchema, c, t) + ";" - } - - ddl[t] = tableDdl + conversion.WriteSchemaFile(sessionState.Conv, now, schemaFileName, ioHelper.Out, sessionState.Driver) + schemaAbsPath, err := filepath.Abs(schemaFileName) + if err != nil { + http.Error(w, fmt.Sprintf("Can not create absolute path : %v", err), http.StatusInternalServerError) } w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(ddl) -} - -func getStandardTypeToPGSQLTypemap(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(ddl.STANDARD_TYPE_TO_PGSQL_TYPEMAP) -} - -func getPGSQLToStandardTypeTypemap(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(ddl.PGSQL_TO_STANDARD_TYPE_TYPEMAP) + w.Write([]byte(schemaAbsPath)) } -func spannerDefaultTypeMap(w http.ResponseWriter, r *http.Request) { - sessionState := session.GetSessionState() - - if sessionState.Conv == nil || sessionState.Driver == "" { - http.Error(w, "Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner.", http.StatusNotFound) - return +// getReportFile generates report file and returns file path. +func getReportFile(w http.ResponseWriter, r *http.Request) { + ioHelper := &utils.IOStreams{In: os.Stdin, Out: os.Stdout} + var err error + now := time.Now() + filePrefix, err := utilities.GetFilePrefix(now) + if err != nil { + http.Error(w, fmt.Sprintf("Can not get file prefix : %v", err), http.StatusInternalServerError) } + reportFileName := "frontend/" + filePrefix + sessionState := session.GetSessionState() sessionState.Conv.ConvLock.Lock() defer sessionState.Conv.ConvLock.Unlock() - initializeTypeMap() - - var typeMap map[string]ddl.Type - switch sessionState.Driver { - case constants.MYSQL, constants.MYSQLDUMP: - typeMap = mysqlDefaultTypeMap - case constants.POSTGRES, constants.PGDUMP: - typeMap = postgresDefaultTypeMap - case constants.SQLSERVER: - typeMap = sqlserverDefaultTypeMap - case constants.ORACLE: - typeMap = oracleDefaultTypeMap - default: - http.Error(w, fmt.Sprintf("Driver : '%s' is not supported", sessionState.Driver), http.StatusBadRequest) - return + conversion.Report(sessionState.Driver, nil, ioHelper.BytesRead, "", sessionState.Conv, reportFileName, sessionState.DbName, ioHelper.Out) + reportAbsPath, err := filepath.Abs(reportFileName) + if err != nil { + http.Error(w, fmt.Sprintf("Can not create absolute path : %v", err), http.StatusInternalServerError) } w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(typeMap) + w.Write([]byte(reportAbsPath)) } -// getTypeMap returns the source to Spanner typemap only for the -// source types used in current conversion. -func getTypeMap(w http.ResponseWriter, r *http.Request) { +// generates a downloadable structured report and send it as a JSON response +func getDStructuredReport(w http.ResponseWriter, r *http.Request) { sessionState := session.GetSessionState() - - if sessionState.Conv == nil || sessionState.Driver == "" { - http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) - return - } sessionState.Conv.ConvLock.Lock() defer sessionState.Conv.ConvLock.Unlock() - var typeMap map[string][]typeIssue - initializeTypeMap() - switch sessionState.Driver { - case constants.MYSQL, constants.MYSQLDUMP: - typeMap = mysqlTypeMap - case constants.POSTGRES, constants.PGDUMP: - typeMap = postgresTypeMap - case constants.SQLSERVER: - typeMap = sqlserverTypeMap - case constants.ORACLE: - typeMap = oracleTypeMap - default: - http.Error(w, fmt.Sprintf("Driver : '%s' is not supported", sessionState.Driver), http.StatusBadRequest) - return - } - // Filter typeMap so it contains just the types SrcSchema uses. - filteredTypeMap := make(map[string][]typeIssue) - for _, srcTable := range sessionState.Conv.SrcSchema { - for _, colDef := range srcTable.ColDefs { - if _, ok := filteredTypeMap[colDef.Type.Name]; ok { - continue - } - // Timestamp and interval types do not have exact key in typemap. - // Typemap for TIMESTAMP(6), TIMESTAMP(6) WITH LOCAL TIMEZONE,TIMESTAMP(6) WITH TIMEZONE is stored into TIMESTAMP key. - // Same goes with interval types like INTERVAL YEAR(2) TO MONTH, INTERVAL DAY(2) TO SECOND(6) etc. - // If exact key not found then check with regex. - if _, ok := typeMap[colDef.Type.Name]; !ok { - if oracle.TimestampReg.MatchString(colDef.Type.Name) { - filteredTypeMap[colDef.Type.Name] = typeMap["TIMESTAMP"] - } else if oracle.IntervalReg.MatchString(colDef.Type.Name) { - filteredTypeMap[colDef.Type.Name] = typeMap["INTERVAL"] - } - continue - } - filteredTypeMap[colDef.Type.Name] = typeMap[colDef.Type.Name] - } - } - for key, values := range filteredTypeMap { - for i := range values { - if sessionState.Dialect == constants.DIALECT_POSTGRESQL { - spType := ddl.Type{ - Name: filteredTypeMap[key][i].T, - } - filteredTypeMap[key][i].DisplayT = ddl.GetPGType(spType) - } else { - filteredTypeMap[key][i].DisplayT = filteredTypeMap[key][i].T - } - - } - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(filteredTypeMap) -} - -// getTableWithErrors checks the errors in the spanner schema -// and returns a list of tables with errors -func getTableWithErrors(w http.ResponseWriter, r *http.Request) { - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.RLock() - defer sessionState.Conv.ConvLock.RUnlock() - var tableIdName []TableIdAndName - for id, issues := range sessionState.Conv.SchemaIssues { - if len(issues.TableLevelIssues) != 0 { - t := TableIdAndName{ - Id: id, - Name: sessionState.Conv.SpSchema[id].Name, - } - tableIdName = append(tableIdName, t) - } - } + structuredReport := reports.GenerateStructuredReport(sessionState.Driver, sessionState.DbName, sessionState.Conv, nil, true, true) w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(tableIdName) + json.NewEncoder(w).Encode(structuredReport) } -// applyRule allows to add rules that changes the schema -// currently it supports two types of operations viz. SetGlobalDataType and AddIndex -func applyRule(w http.ResponseWriter, r *http.Request) { - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) - return - } - var rule internal.Rule - err = json.Unmarshal(reqBody, &rule) - if err != nil { - http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) - return - } - +// generates a downloadable text report and send it as a JSON response +func getDTextReport(w http.ResponseWriter, r *http.Request) { sessionState := session.GetSessionState() sessionState.Conv.ConvLock.Lock() defer sessionState.Conv.ConvLock.Unlock() - if rule.Type == constants.GlobalDataTypeChange { - d, err := json.Marshal(rule.Data) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - typeMap := map[string]string{} - err = json.Unmarshal(d, &typeMap) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - setGlobalDataType(typeMap) - } else if rule.Type == constants.AddIndex { - d, err := json.Marshal(rule.Data) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - newIdx := ddl.CreateIndex{} - err = json.Unmarshal(d, &newIdx) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - addedIndex, err := addIndex(newIdx) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - rule.Data = addedIndex - } else if rule.Type == constants.EditColumnMaxLength { - d, err := json.Marshal(rule.Data) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - var colMaxLength ColMaxLength - err = json.Unmarshal(d, &colMaxLength) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - setSpColMaxLength(colMaxLength, rule.AssociatedObjects) - } else if rule.Type == constants.AddShardIdPrimaryKey { - d, err := json.Marshal(rule.Data) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - var shardIdPrimaryKey ShardIdPrimaryKey - err = json.Unmarshal(d, &shardIdPrimaryKey) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - tableName := checkInterleaving() - if tableName != "" { - http.Error(w, fmt.Sprintf("Rule cannot be added because some tables, eg: %v are interleaved. Please remove interleaving and try again.", tableName), http.StatusBadRequest) - return - } - setShardIdColumnAsPrimaryKey(shardIdPrimaryKey.AddedAtTheStart) - addShardIdColumnToForeignKeys(shardIdPrimaryKey.AddedAtTheStart) - } else { - http.Error(w, "Invalid rule type", http.StatusInternalServerError) - return - } - - ruleId := internal.GenerateRuleId() - rule.Id = ruleId - - sessionState.Conv.Rules = append(sessionState.Conv.Rules, rule) - session.UpdateSessionFile() - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } + structuredReport := reports.GenerateStructuredReport(sessionState.Driver, sessionState.DbName, sessionState.Conv, nil, true, true) + // creates a new buffer + buffer := bytes.NewBuffer([]byte{}) + // initializes buffered writer that writes data to buffer + wb := bufio.NewWriter(buffer) + reports.GenerateTextReport(structuredReport, wb) + // flushes buffered data to writer + wb.Flush() + // introduces a byte slice to represent the content of buffer + data := buffer.Bytes() + // converts byte slice to corressponding string representation + decodedString := string(data) w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - -func dropRule(w http.ResponseWriter, r *http.Request) { - ruleId := r.FormValue("id") - if ruleId == "" { - http.Error(w, fmt.Sprint("Rule id is empty"), http.StatusBadRequest) - return - } - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - conv := sessionState.Conv - var rule internal.Rule - position := -1 - - for i, r := range conv.Rules { - if r.Id == ruleId { - rule = r - position = i - break - } - } - if position == -1 { - http.Error(w, fmt.Sprint("Rule to be deleted not found"), http.StatusBadRequest) - return - } - - if rule.Type == constants.AddIndex { - if rule.Enabled { - d, err := json.Marshal(rule.Data) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - var index ddl.CreateIndex - err = json.Unmarshal(d, &index) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - tableId := index.TableId - indexId := index.Id - err = dropSecondaryIndexHelper(tableId, indexId) - if err != nil { - http.Error(w, fmt.Sprintf("%v", err), http.StatusBadRequest) - return - } - } - } else if rule.Type == constants.GlobalDataTypeChange { - d, err := json.Marshal(rule.Data) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - typeMap := map[string]string{} - err = json.Unmarshal(d, &typeMap) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - revertGlobalDataType(typeMap) - } else if rule.Type == constants.EditColumnMaxLength { - d, err := json.Marshal(rule.Data) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - var colMaxLength ColMaxLength - err = json.Unmarshal(d, &colMaxLength) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - revertSpColMaxLength(colMaxLength, rule.AssociatedObjects) - } else if rule.Type == constants.AddShardIdPrimaryKey { - d, err := json.Marshal(rule.Data) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - var shardIdPrimaryKey ShardIdPrimaryKey - err = json.Unmarshal(d, &shardIdPrimaryKey) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - tableName := checkInterleaving() - if tableName != "" { - http.Error(w, fmt.Sprintf("Rule cannot be deleted because some tables, eg: %v are interleaved. Please remove interleaving and try again.", tableName), http.StatusBadRequest) - return - } - revertShardIdColumnAsPrimaryKey(shardIdPrimaryKey.AddedAtTheStart) - removeShardIdColumnFromForeignKeys(shardIdPrimaryKey.AddedAtTheStart) - } else { - http.Error(w, "Invalid rule type", http.StatusInternalServerError) - return - } - - sessionState.Conv.Rules = append(conv.Rules[:position], conv.Rules[position+1:]...) - if len(sessionState.Conv.Rules) == 0 { - sessionState.Conv.Rules = nil - } - session.UpdateSessionFile() - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) - -} - -func checkInterleaving() string { - sessionState := session.GetSessionState() - for _, spSchema := range sessionState.Conv.SpSchema { - if spSchema.ParentId != "" { - return spSchema.Name - } - } - return "" -} - -func addShardIdToForeignKeyPerTable(isAddedAtFirst bool, table ddl.CreateTable) { - sessionState := session.GetSessionState() - for i, fk := range table.ForeignKeys { - referredTableShardIdColumn := sessionState.Conv.SpSchema[fk.ReferTableId].ShardIdColumn - if isAddedAtFirst { - fk.ColIds = append([]string{table.ShardIdColumn}, fk.ColIds...) - fk.ReferColumnIds = append([]string{referredTableShardIdColumn}, fk.ReferColumnIds...) - } else { - fk.ColIds = append(fk.ColIds, table.ShardIdColumn) - fk.ReferColumnIds = append(fk.ReferColumnIds, referredTableShardIdColumn) - } - sessionState.Conv.SpSchema[table.Id].ForeignKeys[i] = fk - } -} - -func addShardIdColumnToForeignKeys(isAddedAtFirst bool) { - sessionState := session.GetSessionState() - for _, table := range sessionState.Conv.SpSchema { - addShardIdToForeignKeyPerTable(isAddedAtFirst, table) - } -} - -func removeShardIdColumnFromForeignKeys(isAddedAtFirst bool) { - sessionState := session.GetSessionState() - for tableId, table := range sessionState.Conv.SpSchema { - for i, fk := range table.ForeignKeys { - - if isAddedAtFirst { - fk.ColIds = fk.ColIds[1:] - fk.ReferColumnIds = fk.ReferColumnIds[1:] - } else { - fk.ColIds = fk.ColIds[:len(fk.ColIds)-1] - fk.ReferColumnIds = fk.ReferColumnIds[:len(fk.ReferColumnIds)-1] - } - sessionState.Conv.SpSchema[tableId].ForeignKeys[i] = fk - } - } -} - -func setShardIdColumnAsPrimaryKeyPerTable(isAddedAtFirst bool, table ddl.CreateTable) { - pkRequest := primarykey.PrimaryKeyRequest{ - TableId: table.Id, - Columns: []ddl.IndexKey{}, - } - increment := 0 - if isAddedAtFirst { - increment = 1 - pkRequest.Columns = append(pkRequest.Columns, ddl.IndexKey{ColId: table.ShardIdColumn, Order: 1}) - } - for index := range table.PrimaryKeys { - pk := table.PrimaryKeys[index] - pkRequest.Columns = append(pkRequest.Columns, ddl.IndexKey{ColId: pk.ColId, Order: pk.Order + increment, Desc: pk.Desc}) - } - if !isAddedAtFirst { - size := len(table.PrimaryKeys) - pkRequest.Columns = append(pkRequest.Columns, ddl.IndexKey{ColId: table.ShardIdColumn, Order: size + 1}) - } - primarykey.UpdatePrimaryKeyAndSessionFile(pkRequest) -} - -func setShardIdColumnAsPrimaryKey(isAddedAtFirst bool) { - sessionState := session.GetSessionState() - for _, table := range sessionState.Conv.SpSchema { - setShardIdColumnAsPrimaryKeyPerTable(isAddedAtFirst, table) - } -} - -func revertShardIdColumnAsPrimaryKey(isAddedAtFirst bool) { - sessionState := session.GetSessionState() - for _, table := range sessionState.Conv.SpSchema { - pkRequest := primarykey.PrimaryKeyRequest{ - TableId: table.Id, - Columns: []ddl.IndexKey{}, - } - for index := range table.PrimaryKeys { - pk := table.PrimaryKeys[index] - if pk.ColId != table.ShardIdColumn { - decrement := 0 - if isAddedAtFirst { - decrement = 1 - } - pkRequest.Columns = append(pkRequest.Columns, ddl.IndexKey{ColId: pk.ColId, Order: pk.Order - decrement, Desc: pk.Desc}) - } - } - primarykey.UpdatePrimaryKeyAndSessionFile(pkRequest) - } -} - -// setGlobalDataType allows to change Spanner type globally. -// It takes a map from source type to Spanner type and updates -// the Spanner schema accordingly. -func setGlobalDataType(typeMap map[string]string) { - sessionState := session.GetSessionState() - - // Redo source-to-Spanner typeMap using t (the mapping specified in the http request). - // We drive this process by iterating over the Spanner schema because we want to preserve all - // other customizations that have been performed via the UI (dropping columns, renaming columns - // etc). In particular, note that we can't just blindly redo schema conversion (using an appropriate - // version of 'toDDL' with the new typeMap). - for tableId, spSchema := range sessionState.Conv.SpSchema { - for colId := range spSchema.ColDefs { - srcColDef := sessionState.Conv.SrcSchema[tableId].ColDefs[colId] - // If the srcCol's type is in the map, then recalculate the Spanner type - // for this column using the map. Otherwise, leave the ColDef for this - // column as is. Note that per-column type overrides could be lost in - // this process -- the mapping in typeMap always takes precendence. - if _, found := typeMap[srcColDef.Type.Name]; found { - utilities.UpdateDataType(sessionState.Conv, typeMap[srcColDef.Type.Name], tableId, colId) - } - } - common.ComputeNonKeyColumnSize(sessionState.Conv, tableId) - } -} - -func setSpColMaxLength(spColMaxLength ColMaxLength, associatedObjects string) { - sessionState := session.GetSessionState() - if associatedObjects == "All table" { - for tId := range sessionState.Conv.SpSchema { - for _, colDef := range sessionState.Conv.SpSchema[tId].ColDefs { - if colDef.T.Name == spColMaxLength.SpDataType { - spColDef := colDef - if spColDef.T.Len == ddl.MaxLength { - spColDef.T.Len, _ = strconv.ParseInt(spColMaxLength.SpColMaxLength, 10, 64) - } - sessionState.Conv.SpSchema[tId].ColDefs[colDef.Id] = spColDef - } - } - common.ComputeNonKeyColumnSize(sessionState.Conv, tId) - } - } else { - for _, colDef := range sessionState.Conv.SpSchema[associatedObjects].ColDefs { - if colDef.T.Name == spColMaxLength.SpDataType { - spColDef := colDef - if spColDef.T.Len == ddl.MaxLength { - table.UpdateColumnSize(spColMaxLength.SpColMaxLength, associatedObjects, colDef.Id, sessionState.Conv) - } - } - } - common.ComputeNonKeyColumnSize(sessionState.Conv, associatedObjects) - } -} - -func revertSpColMaxLength(spColMaxLength ColMaxLength, associatedObjects string) { - sessionState := session.GetSessionState() - spColLen, _ := strconv.ParseInt(spColMaxLength.SpColMaxLength, 10, 64) - if associatedObjects == "All tables" { - for tId := range sessionState.Conv.SpSchema { - for colId, colDef := range sessionState.Conv.SpSchema[tId].ColDefs { - if colDef.T.Name == spColMaxLength.SpDataType { - utilities.UpdateMaxColumnLen(sessionState.Conv, spColMaxLength.SpDataType, tId, colId, spColLen) - } - } - common.ComputeNonKeyColumnSize(sessionState.Conv, tId) - } - } else { - for colId, colDef := range sessionState.Conv.SpSchema[associatedObjects].ColDefs { - if colDef.T.Name == spColMaxLength.SpDataType { - utilities.UpdateMaxColumnLen(sessionState.Conv, spColMaxLength.SpDataType, associatedObjects, colId, spColLen) - } - } - common.ComputeNonKeyColumnSize(sessionState.Conv, associatedObjects) - } -} - -// revertGlobalDataType revert back the spanner type to default -// when the rule that is used to apply the data-type change is deleted. -// It takes a map from source type to Spanner type and updates -// the Spanner schema accordingly. -func revertGlobalDataType(typeMap map[string]string) { - sessionState := session.GetSessionState() - - for tableId, spSchema := range sessionState.Conv.SpSchema { - for colId, colDef := range spSchema.ColDefs { - srcColDef, found := sessionState.Conv.SrcSchema[tableId].ColDefs[colId] - if !found { - continue - } - spType, found := typeMap[srcColDef.Type.Name] - - if !found { - continue - } - - if colDef.T.Name == spType { - utilities.UpdateDataType(sessionState.Conv, "", tableId, colId) - } - } - common.ComputeNonKeyColumnSize(sessionState.Conv, tableId) - } -} - -// addIndex checks the new name for spanner name validity, ensures the new name is already not used by existing tables -// secondary indexes or foreign key constraints. If above checks passed then new indexes are added to the schema else appropriate -// error thrown. -func addIndex(newIndex ddl.CreateIndex) (ddl.CreateIndex, error) { - // Check new name for spanner name validity. - newNames := []string{} - newNames = append(newNames, newIndex.Name) - - if ok, invalidNames := utilities.CheckSpannerNamesValidity(newNames); !ok { - return ddl.CreateIndex{}, fmt.Errorf("following names are not valid Spanner identifiers: %s", strings.Join(invalidNames, ",")) - } - // Check that the new names are not already used by existing tables, secondary indexes or foreign key constraints. - if ok, err := utilities.CanRename(newNames, newIndex.TableId); !ok { - return ddl.CreateIndex{}, err - } - - sessionState := session.GetSessionState() - sp := sessionState.Conv.SpSchema[newIndex.TableId] - - newIndexes := []ddl.CreateIndex{newIndex} - index.CheckIndexSuggestion(newIndexes, sp) - for i := 0; i < len(newIndexes); i++ { - newIndexes[i].Id = internal.GenerateIndexesId() - } - - sessionState.Conv.UsedNames[strings.ToLower(newIndex.Name)] = true - sp.Indexes = append(sp.Indexes, newIndexes...) - sessionState.Conv.SpSchema[newIndex.TableId] = sp - return newIndexes[0], nil -} - -// getConversionRate returns table wise color coded conversion rate. -func getConversionRate(w http.ResponseWriter, r *http.Request) { - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - smt_reports := reports.AnalyzeTables(sessionState.Conv, nil) - rate := make(map[string]string) - for _, t := range smt_reports { - rate[t.SpTable], _ = reports.RateSchema(t.Cols, t.Warnings, t.Errors, t.SyntheticPKey != "", false) - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(rate) -} - -// getSchemaFile generates schema file and returns file path. -func getSchemaFile(w http.ResponseWriter, r *http.Request) { - ioHelper := &utils.IOStreams{In: os.Stdin, Out: os.Stdout} - var err error - now := time.Now() - filePrefix, err := utilities.GetFilePrefix(now) - if err != nil { - http.Error(w, fmt.Sprintf("Can not get file prefix : %v", err), http.StatusInternalServerError) - } - schemaFileName := "frontend/" + filePrefix + "schema.txt" - - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.RLock() - defer sessionState.Conv.ConvLock.RUnlock() - conversion.WriteSchemaFile(sessionState.Conv, now, schemaFileName, ioHelper.Out, sessionState.Driver) - schemaAbsPath, err := filepath.Abs(schemaFileName) - if err != nil { - http.Error(w, fmt.Sprintf("Can not create absolute path : %v", err), http.StatusInternalServerError) - } - w.WriteHeader(http.StatusOK) - w.Write([]byte(schemaAbsPath)) -} - -// getReportFile generates report file and returns file path. -func getReportFile(w http.ResponseWriter, r *http.Request) { - ioHelper := &utils.IOStreams{In: os.Stdin, Out: os.Stdout} - var err error - now := time.Now() - filePrefix, err := utilities.GetFilePrefix(now) - if err != nil { - http.Error(w, fmt.Sprintf("Can not get file prefix : %v", err), http.StatusInternalServerError) - } - reportFileName := "frontend/" + filePrefix - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - conversion.Report(sessionState.Driver, nil, ioHelper.BytesRead, "", sessionState.Conv, reportFileName, sessionState.DbName, ioHelper.Out) - reportAbsPath, err := filepath.Abs(reportFileName) - if err != nil { - http.Error(w, fmt.Sprintf("Can not create absolute path : %v", err), http.StatusInternalServerError) - } - w.WriteHeader(http.StatusOK) - w.Write([]byte(reportAbsPath)) -} - -// generates a downloadable structured report and send it as a JSON response -func getDStructuredReport(w http.ResponseWriter, r *http.Request) { - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - structuredReport := reports.GenerateStructuredReport(sessionState.Driver, sessionState.DbName, sessionState.Conv, nil, true, true) - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(structuredReport) -} - -// generates a downloadable text report and send it as a JSON response -func getDTextReport(w http.ResponseWriter, r *http.Request) { - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - structuredReport := reports.GenerateStructuredReport(sessionState.Driver, sessionState.DbName, sessionState.Conv, nil, true, true) - // creates a new buffer - buffer := bytes.NewBuffer([]byte{}) - // initializes buffered writer that writes data to buffer - wb := bufio.NewWriter(buffer) - reports.GenerateTextReport(structuredReport, wb) - // flushes buffered data to writer - wb.Flush() - // introduces a byte slice to represent the content of buffer - data := buffer.Bytes() - // converts byte slice to corressponding string representation - decodedString := string(data) - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "text/plain") - json.NewEncoder(w).Encode(decodedString) + w.Header().Set("Content-Type", "text/plain") + json.NewEncoder(w).Encode(decodedString) } // generates a downloadable DDL(spanner) and send it as a JSON response -func getDSpannerDDL(w http.ResponseWriter, r *http.Request) { - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.RLock() - defer sessionState.Conv.ConvLock.RUnlock() - conv := sessionState.Conv - now := time.Now() - spDDL := conv.SpSchema.GetDDL(ddl.Config{Comments: true, ProtectIds: false, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: sessionState.Driver}) - if len(spDDL) == 0 { - spDDL = []string{"\n-- Schema is empty -- no tables found\n"} - } - l := []string{ - fmt.Sprintf("-- Schema generated %s\n", now.Format("2006-01-02 15:04:05")), - strings.Join(spDDL, ";\n\n"), - "\n", - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(strings.Join(l, "")) -} - -// getIssueDescription maps IssueDB's Category to corresponding CategoryDescription(if present), -// or to the Brief if not present and pass the map to frontend to be used in assessment report UI -func getIssueDescription(w http.ResponseWriter, r *http.Request) { - var issuesMap = make(map[string]string) - for _, issue := range reports.IssueDB { - if issue.CategoryDescription == "" { - issuesMap[issue.Category] = issue.Brief - } else { - issuesMap[issue.Category] = issue.CategoryDescription - } - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(issuesMap) -} - -// TableInterleaveStatus stores data regarding interleave status. -type TableInterleaveStatus struct { - Possible bool - Parent string - Comment string -} - -func getBackendHealth(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -} - -// setParentTable checks whether specified table can be interleaved, and updates the schema to convert foreign -// key to interleaved table if 'update' parameter is set to true. If 'update' parameter is set to false, then return -// whether the foreign key can be converted to interleave table without updating the schema. -func setParentTable(w http.ResponseWriter, r *http.Request) { - tableId := r.FormValue("table") - update := r.FormValue("update") == "true" - sessionState := session.GetSessionState() - - if sessionState.Conv == nil || sessionState.Driver == "" { - http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) - return - } - if tableId == "" { - http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) - } - - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - tableInterleaveStatus := parentTableHelper(tableId, update) - - if tableInterleaveStatus.Possible { - - childPks := sessionState.Conv.SpSchema[tableId].PrimaryKeys - childindex := utilities.GetPrimaryKeyIndexFromOrder(childPks, 1) - schemaissue := []internal.SchemaIssue{} - - colId := childPks[childindex].ColId - schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] - if update { - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) - } else { - schemaissue = append(schemaissue, internal.InterleavedOrder) - } - - sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] = schemaissue - } else { - // Remove "Table cart can be converted as Interleaved Table" suggestion from columns - // of the table if interleaving is not possible. - for _, colId := range sessionState.Conv.SpSchema[tableId].ColIds { - schemaIssue := []internal.SchemaIssue{} - for _, v := range sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] { - if v != internal.InterleavedOrder { - schemaIssue = append(schemaIssue, v) - } - } - sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] = schemaIssue - } - } - - index.IndexSuggestion() - session.UpdateSessionFile() - w.WriteHeader(http.StatusOK) - - if update { - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } - json.NewEncoder(w).Encode(map[string]interface{}{ - "tableInterleaveStatus": tableInterleaveStatus, - "sessionState": convm}) - } else { - json.NewEncoder(w).Encode(map[string]interface{}{ - "tableInterleaveStatus": tableInterleaveStatus, - }) - } -} - -func parentTableHelper(tableId string, update bool) *TableInterleaveStatus { - tableInterleaveStatus := &TableInterleaveStatus{ - Possible: false, - Comment: "No valid prefix", - } - sessionState := session.GetSessionState() - - if _, found := sessionState.Conv.SyntheticPKeys[tableId]; found { - tableInterleaveStatus.Possible = false - tableInterleaveStatus.Comment = "Has synthetic pk" - } - - childPks := sessionState.Conv.SpSchema[tableId].PrimaryKeys - - // Search this table's foreign keys for a suitable parent table. - // If there are several possible parent tables, we pick the first one. - // TODO: Allow users to pick which parent to use if more than one. - for i, fk := range sessionState.Conv.SpSchema[tableId].ForeignKeys { - refTableId := fk.ReferTableId - - if _, found := sessionState.Conv.SyntheticPKeys[refTableId]; found { - continue - } - - if checkPrimaryKeyPrefix(tableId, refTableId, fk, tableInterleaveStatus) { - sp := sessionState.Conv.SpSchema[tableId] - - colIdNotInOrder := checkPrimaryKeyOrder(tableId, refTableId, fk) - - if update && sp.ParentId == "" && colIdNotInOrder == "" { - usedNames := sessionState.Conv.UsedNames - delete(usedNames, strings.ToLower(sp.ForeignKeys[i].Name)) - sp.ParentId = refTableId - sp.ForeignKeys = utilities.RemoveFk(sp.ForeignKeys, sp.ForeignKeys[i].Id) - } - sessionState.Conv.SpSchema[tableId] = sp - - parentpks := sessionState.Conv.SpSchema[refTableId].PrimaryKeys - if len(parentpks) >= 1 { - if colIdNotInOrder == "" { - - schemaissue := []internal.SchemaIssue{} - for _, column := range childPks { - colId := column.ColId - schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] - - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedNotInOrder) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedAddColumn) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedRenameColumn) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedChangeColumnSize) - - sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] = schemaissue - } - - tableInterleaveStatus.Possible = true - tableInterleaveStatus.Parent = refTableId - tableInterleaveStatus.Comment = "" - - } else { - - schemaissue := []internal.SchemaIssue{} - schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIdNotInOrder] - - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedAddColumn) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedRenameColumn) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedChangeColumnSize) - - schemaissue = append(schemaissue, internal.InterleavedNotInOrder) - - sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIdNotInOrder] = schemaissue - } - } - } - } - - return tableInterleaveStatus -} - -func hasShardIdPrimaryKeyRule() (bool, bool) { - sessionState := session.GetSessionState() - for _, rule := range sessionState.Conv.Rules { - if rule.Type == constants.AddShardIdPrimaryKey { - v := rule.Data.(ShardIdPrimaryKey) - return true, v.AddedAtTheStart - } - } - return false, false -} - -func removeParentTable(w http.ResponseWriter, r *http.Request) { - tableId := r.FormValue("tableId") - sessionState := session.GetSessionState() - if sessionState.Conv == nil || sessionState.Driver == "" { - http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) - return - } - if tableId == "" { - http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) - return - } - - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - conv := sessionState.Conv - - if conv.SpSchema[tableId].ParentId == "" { - http.Error(w, fmt.Sprintf("Table is not interleaved"), http.StatusBadRequest) - return - } - spTable := conv.SpSchema[tableId] - - var firstOrderPk ddl.IndexKey - order := 1 - - isPresent, isAddedAtFirst := hasShardIdPrimaryKeyRule() - if isAddedAtFirst { - order = 2 - } - - for _, pk := range spTable.PrimaryKeys { - if pk.Order == order { - firstOrderPk = pk - break - } - } - - spColId := conv.SpSchema[tableId].ColDefs[firstOrderPk.ColId].Id - srcCol := conv.SrcSchema[tableId].ColDefs[spColId] - interleavedFk, err := utilities.GetInterleavedFk(conv, tableId, srcCol.Id) - if err != nil { - http.Error(w, fmt.Sprintf("%v", err), http.StatusBadRequest) - return - } - - spFk, err := common.CvtForeignKeysHelper(conv, conv.SpSchema[tableId].Name, tableId, interleavedFk, true) - if err != nil { - http.Error(w, fmt.Sprintf("Foreign key conversion fail"), http.StatusBadRequest) - return - } - - if isPresent { - if isAddedAtFirst { - spFk.ColIds = append([]string{spTable.ShardIdColumn}, spFk.ColIds...) - spFk.ReferColumnIds = append([]string{sessionState.Conv.SpSchema[spTable.ParentId].ShardIdColumn}, spFk.ReferColumnIds...) - } else { - spFk.ColIds = append(spFk.ColIds, spTable.ShardIdColumn) - spFk.ReferColumnIds = append(spFk.ReferColumnIds, sessionState.Conv.SpSchema[spTable.ParentId].ShardIdColumn) - } - } - - spFks := spTable.ForeignKeys - spFks = append(spFks, spFk) - spTable.ForeignKeys = spFks - spTable.ParentId = "" - conv.SpSchema[tableId] = spTable - - sessionState.Conv = conv - - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) - -} - -type DropDetail struct { - Name string `json:"Name"` -} - -func restoreTables(w http.ResponseWriter, r *http.Request) { - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) - return - } - var tables internal.Tables - err = json.Unmarshal(reqBody, &tables) - if err != nil { - http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) - return - } - var convm session.ConvWithMetadata - for _, tableId := range tables.TableList { - convm = restoreTableHelper(w, tableId) - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - -func restoreTableHelper(w http.ResponseWriter, tableId string) session.ConvWithMetadata { - sessionState := session.GetSessionState() - if sessionState.Conv == nil || sessionState.Driver == "" { - http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) - } - if tableId == "" { - http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) - } - - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - conv := sessionState.Conv - var toddl common.ToDdl - switch sessionState.Driver { - case constants.MYSQL: - toddl = mysql.InfoSchemaImpl{}.GetToDdl() - case constants.POSTGRES: - toddl = postgres.InfoSchemaImpl{}.GetToDdl() - case constants.SQLSERVER: - toddl = sqlserver.InfoSchemaImpl{}.GetToDdl() - case constants.ORACLE: - toddl = oracle.InfoSchemaImpl{}.GetToDdl() - case constants.MYSQLDUMP: - toddl = mysql.DbDumpImpl{}.GetToDdl() - case constants.PGDUMP: - toddl = postgres.DbDumpImpl{}.GetToDdl() - default: - http.Error(w, fmt.Sprintf("Driver : '%s' is not supported", sessionState.Driver), http.StatusBadRequest) - } - - err := common.SrcTableToSpannerDDL(conv, toddl, sessionState.Conv.SrcSchema[tableId]) - if err != nil { - http.Error(w, fmt.Sprintf("Restoring spanner table fail"), http.StatusBadRequest) - } - conv.AddPrimaryKeys() - if sessionState.IsSharded { - conv.IsSharded = true - conv.AddShardIdColumn() - isPresent, isAddedAtFirst := hasShardIdPrimaryKeyRule() - if isPresent { - table := sessionState.Conv.SpSchema[tableId] - setShardIdColumnAsPrimaryKeyPerTable(isAddedAtFirst, table) - addShardIdToForeignKeyPerTable(isAddedAtFirst, table) - addShardIdToReferencedTableFks(tableId, isAddedAtFirst) - } - } - sessionState.Conv = conv - primarykey.DetectHotspot() - - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } - return convm -} - -func addShardIdToReferencedTableFks(tableId string, isAddedAtFirst bool) { - sessionState := session.GetSessionState() - for _, table := range sessionState.Conv.SpSchema { - for i, fk := range table.ForeignKeys { - if fk.ReferTableId == tableId { - referredTableShardIdColumn := sessionState.Conv.SpSchema[fk.ReferTableId].ShardIdColumn - if isAddedAtFirst { - fk.ColIds = append([]string{table.ShardIdColumn}, fk.ColIds...) - fk.ReferColumnIds = append([]string{referredTableShardIdColumn}, fk.ReferColumnIds...) - } else { - fk.ColIds = append(fk.ColIds, table.ShardIdColumn) - fk.ReferColumnIds = append(fk.ReferColumnIds, referredTableShardIdColumn) - } - sessionState.Conv.SpSchema[table.Id].ForeignKeys[i] = fk - } - } - } -} - -func restoreTable(w http.ResponseWriter, r *http.Request) { - tableId := r.FormValue("table") - convm := restoreTableHelper(w, tableId) - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - -func dropTables(w http.ResponseWriter, r *http.Request) { - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) - return - } - var tables internal.Tables - err = json.Unmarshal(reqBody, &tables) - if err != nil { - http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) - return - } - var convm session.ConvWithMetadata - for _, tableId := range tables.TableList { - convm = dropTableHelper(w, tableId) - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - -func dropTableHelper(w http.ResponseWriter, tableId string) session.ConvWithMetadata { - sessionState := session.GetSessionState() - if sessionState.Conv == nil || sessionState.Driver == "" { - http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) - return session.ConvWithMetadata{} - } - if tableId == "" { - http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) - } - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - spSchema := sessionState.Conv.SpSchema - issues := sessionState.Conv.SchemaIssues - syntheticPkey := sessionState.Conv.SyntheticPKeys - - //remove deleted name from usedName - usedNames := sessionState.Conv.UsedNames - delete(usedNames, strings.ToLower(sessionState.Conv.SpSchema[tableId].Name)) - for _, index := range sessionState.Conv.SpSchema[tableId].Indexes { - delete(usedNames, index.Name) - } - for _, fk := range sessionState.Conv.SpSchema[tableId].ForeignKeys { - delete(usedNames, fk.Name) - } - - delete(spSchema, tableId) - issues[tableId] = internal.TableIssues{ - TableLevelIssues: []internal.SchemaIssue{}, - ColumnLevelIssues: map[string][]internal.SchemaIssue{}, - } - delete(syntheticPkey, tableId) - - //drop reference foreign key - for tableName, spTable := range spSchema { - fks := []ddl.Foreignkey{} - for _, fk := range spTable.ForeignKeys { - if fk.ReferTableId != tableId { - fks = append(fks, fk) - } else { - delete(usedNames, fk.Name) - } - - } - spTable.ForeignKeys = fks - spSchema[tableName] = spTable - } - - //remove interleave that are interleaved on the drop table as parent - for id, spTable := range spSchema { - if spTable.ParentId == tableId { - spTable.ParentId = "" - spSchema[id] = spTable - } - } - - //remove interleavable suggestion on droping the parent table - for tableName, tableIssues := range issues { - for colName, colIssues := range tableIssues.ColumnLevelIssues { - updatedColIssues := []internal.SchemaIssue{} - for _, val := range colIssues { - if val != internal.InterleavedOrder { - updatedColIssues = append(updatedColIssues, val) - } - } - if len(updatedColIssues) == 0 { - delete(issues[tableName].ColumnLevelIssues, colName) - } else { - issues[tableName].ColumnLevelIssues[colName] = updatedColIssues - } - } - } - - sessionState.Conv.SpSchema = spSchema - sessionState.Conv.SchemaIssues = issues - sessionState.Conv.UsedNames = usedNames - - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } - return convm -} - -func dropTable(w http.ResponseWriter, r *http.Request) { - tableId := r.FormValue("table") - convm := dropTableHelper(w, tableId) - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - -func restoreSecondaryIndex(w http.ResponseWriter, r *http.Request) { - tableId := r.FormValue("tableId") - indexId := r.FormValue("indexId") - sessionState := session.GetSessionState() - if sessionState.Conv == nil || sessionState.Driver == "" { - http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) - return - } - if tableId == "" { - http.Error(w, fmt.Sprintf("Table Id is empty"), http.StatusBadRequest) - return - } - if indexId == "" { - http.Error(w, fmt.Sprintf("Index Id is empty"), http.StatusBadRequest) - return - } - - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - var srcIndex schema.Index - srcIndexFound := false - for _, index := range sessionState.Conv.SrcSchema[tableId].Indexes { - if index.Id == indexId { - srcIndex = index - srcIndexFound = true - break - } - } - if !srcIndexFound { - http.Error(w, fmt.Sprintf("Source index not found"), http.StatusBadRequest) - return - } - - conv := sessionState.Conv - - spIndex := common.CvtIndexHelper(conv, tableId, srcIndex, conv.SpSchema[tableId].ColIds, conv.SpSchema[tableId].ColDefs) - spIndexes := conv.SpSchema[tableId].Indexes - spIndexes = append(spIndexes, spIndex) - spTable := conv.SpSchema[tableId] - spTable.Indexes = spIndexes - conv.SpSchema[tableId] = spTable - - sessionState.Conv = conv - index.AssignInitialOrders() - index.IndexSuggestion() - - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) - -} - -// renameForeignKeys checks the new names for spanner name validity, ensures the new names are already not used by existing tables -// secondary indexes or foreign key constraints. If above checks passed then foreignKey renaming reflected in the schema else appropriate -// error thrown. -func updateForeignKeys(w http.ResponseWriter, r *http.Request) { - tableId := r.FormValue("table") - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) - } - - sessionState := session.GetSessionState() - if sessionState.Conv == nil || sessionState.Driver == "" { - http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) - return - } - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - - newFKs := []ddl.Foreignkey{} - if err = json.Unmarshal(reqBody, &newFKs); err != nil { - http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) - return - } - - // Check new name for spanner name validity. - newNames := []string{} - newNamesMap := map[string]bool{} - for _, newFk := range newFKs { - if len(newFk.Name) == 0 { - continue - } - for _, oldFk := range sessionState.Conv.SpSchema[tableId].ForeignKeys { - if newFk.Id == oldFk.Id && newFk.Name != oldFk.Name && newFk.Name != "" { - newNames = append(newNames, strings.ToLower(newFk.Name)) - } - } - } - - for _, newFk := range newFKs { - if len(newFk.Name) == 0 { - continue - } - if _, ok := newNamesMap[strings.ToLower(newFk.Name)]; ok { - http.Error(w, fmt.Sprintf("Found duplicate names in input : %s", strings.ToLower(newFk.Name)), http.StatusBadRequest) - return - } - newNamesMap[strings.ToLower(newFk.Name)] = true - } - - if ok, invalidNames := utilities.CheckSpannerNamesValidity(newNames); !ok { - http.Error(w, fmt.Sprintf("Following names are not valid Spanner identifiers: %s", strings.Join(invalidNames, ",")), http.StatusBadRequest) - return - } - - // Check that the new names are not already used by existing tables, secondary indexes or foreign key constraints. - if ok, err := utilities.CanRename(newNames, tableId); !ok { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - sp := sessionState.Conv.SpSchema[tableId] - usedNames := sessionState.Conv.UsedNames - - // Update session with renamed foreignkeys. - updatedFKs := []ddl.Foreignkey{} - - for _, foreignKey := range sp.ForeignKeys { - for _, updatedForeignkey := range newFKs { - if foreignKey.Id == updatedForeignkey.Id && len(updatedForeignkey.ColIds) != 0 && updatedForeignkey.ReferTableId != "" { - delete(usedNames, strings.ToLower(foreignKey.Name)) - foreignKey.Name = updatedForeignkey.Name - updatedFKs = append(updatedFKs, foreignKey) - } - } - } - - position := -1 - - for i, fk := range updatedFKs { - // Condition to check whether FK has to be dropped - if len(fk.ReferColumnIds) == 0 && fk.ReferTableId == "" { - position = i - dropFkId := fk.Id - - // To remove the interleavable suggestions if they exist on dropping fk - colId := sp.ForeignKeys[position].ColIds[0] - schemaIssue := []internal.SchemaIssue{} - for _, v := range sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] { - if v != internal.InterleavedAddColumn && v != internal.InterleavedRenameColumn && v != internal.InterleavedNotInOrder && v != internal.InterleavedChangeColumnSize { - schemaIssue = append(schemaIssue, v) - } - } - if _, ok := sessionState.Conv.SchemaIssues[tableId]; ok { - sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colId] = schemaIssue - } - - sp.ForeignKeys = utilities.RemoveFk(updatedFKs, dropFkId) - } - } - sp.ForeignKeys = updatedFKs - sessionState.Conv.SpSchema[tableId] = sp - session.UpdateSessionFile() - - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - -// renameIndexes checks the new names for spanner name validity, ensures the new names are already not used by existing tables -// secondary indexes or foreign key constraints. If above checks passed then index renaming reflected in the schema else appropriate -// error thrown. -func renameIndexes(w http.ResponseWriter, r *http.Request) { - table := r.FormValue("table") - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) - } - - renameMap := map[string]string{} - if err = json.Unmarshal(reqBody, &renameMap); err != nil { - http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) - return - } - - // Check new name for spanner name validity. - newNames := []string{} - newNamesMap := map[string]bool{} - for _, value := range renameMap { - newNames = append(newNames, strings.ToLower(value)) - newNamesMap[strings.ToLower(value)] = true - } - if len(newNames) != len(newNamesMap) { - http.Error(w, fmt.Sprintf("Found duplicate names in input : %s", strings.Join(newNames, ",")), http.StatusBadRequest) - return +func getDSpannerDDL(w http.ResponseWriter, r *http.Request) { + sessionState := session.GetSessionState() + sessionState.Conv.ConvLock.RLock() + defer sessionState.Conv.ConvLock.RUnlock() + conv := sessionState.Conv + now := time.Now() + spDDL := conv.SpSchema.GetDDL(ddl.Config{Comments: true, ProtectIds: false, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: sessionState.Driver}) + if len(spDDL) == 0 { + spDDL = []string{"\n-- Schema is empty -- no tables found\n"} } - - if ok, invalidNames := utilities.CheckSpannerNamesValidity(newNames); !ok { - http.Error(w, fmt.Sprintf("Following names are not valid Spanner identifiers: %s", strings.Join(invalidNames, ",")), http.StatusBadRequest) - return + l := []string{ + fmt.Sprintf("-- Schema generated %s\n", now.Format("2006-01-02 15:04:05")), + strings.Join(spDDL, ";\n\n"), + "\n", } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(strings.Join(l, "")) +} - // Check that the new names are not already used by existing tables, secondary indexes or foreign key constraints. - if ok, err := utilities.CanRename(newNames, table); !ok { - http.Error(w, err.Error(), http.StatusBadRequest) - return +// getIssueDescription maps IssueDB's Category to corresponding CategoryDescription(if present), +// or to the Brief if not present and pass the map to frontend to be used in assessment report UI +func getIssueDescription(w http.ResponseWriter, r *http.Request) { + var issuesMap = make(map[string]string) + for _, issue := range reports.IssueDB { + if issue.CategoryDescription == "" { + issuesMap[issue.Category] = issue.Brief + } else { + issuesMap[issue.Category] = issue.CategoryDescription + } } - sessionState := session.GetSessionState() + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(issuesMap) +} - sp := sessionState.Conv.SpSchema[table] - // Update session with renamed secondary indexes. - newIndexes := []ddl.CreateIndex{} - for _, index := range sp.Indexes { - if newName, ok := renameMap[index.Id]; ok { - index.Name = newName - } - newIndexes = append(newIndexes, index) - } - sp.Indexes = newIndexes - sessionState.Conv.SpSchema[table] = sp - session.UpdateSessionFile() - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } +func getBackendHealth(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) } // ToDo : To Remove once Rules Component updated @@ -2171,7 +625,7 @@ func getSourceDestinationSummary(w http.ResponseWriter, r *http.Request) { sessionState := session.GetSessionState() sessionState.Conv.ConvLock.RLock() defer sessionState.Conv.ConvLock.RUnlock() - var sessionSummary sessionSummary + var sessionSummary types.SessionSummary databaseType, err := helpers.GetSourceDatabaseFromDriver(sessionState.Driver) if err != nil { http.Error(w, fmt.Sprintf("Error while getting source database: %v", err), http.StatusBadRequest) @@ -2228,7 +682,7 @@ func getSourceDestinationSummary(w http.ResponseWriter, r *http.Request) { func updateProgress(w http.ResponseWriter, r *http.Request) { - var detail progressDetails + var detail types.ProgressDetails sessionState := session.GetSessionState() sessionState.Conv.ConvLock.RLock() defer sessionState.Conv.ConvLock.RUnlock() @@ -2251,7 +705,7 @@ func migrate(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) } - details := migrationDetails{} + details := types.MigrationDetails{} err = json.Unmarshal(reqBody, &details) if err != nil { log.Println("request's Body parse error") @@ -2304,7 +758,7 @@ func migrate(w http.ResponseWriter, r *http.Request) { } func getGeneratedResources(w http.ResponseWriter, r *http.Request) { - var generatedResources GeneratedResources + var generatedResources types.GeneratedResources sessionState := session.GetSessionState() sessionState.Conv.ConvLock.RLock() defer sessionState.Conv.ConvLock.RUnlock() @@ -2313,7 +767,7 @@ func getGeneratedResources(w http.ResponseWriter, r *http.Request) { generatedResources.DatabaseUrl = fmt.Sprintf("https://console.cloud.google.com/spanner/instances/%v/databases/%v/details/tables?project=%v", sessionState.SpannerInstanceID, sessionState.SpannerDatabaseName, sessionState.GCPProjectID) generatedResources.BucketName = sessionState.Bucket + sessionState.RootPath generatedResources.BucketUrl = fmt.Sprintf("https://console.cloud.google.com/storage/browser/%v", sessionState.Bucket+sessionState.RootPath) - generatedResources.ShardToShardResourcesMap = map[string][]ResourceDetails{} + generatedResources.ShardToShardResourcesMap = map[string][]types.ResourceDetails{} if sessionState.Conv.Audit.StreamingStats.DatastreamResources.DatastreamName != "" { generatedResources.DataStreamJobName = sessionState.Conv.Audit.StreamingStats.DatastreamResources.DatastreamName generatedResources.DataStreamJobUrl = fmt.Sprintf("https://console.cloud.google.com/datastream/streams/locations/%v/instances/%v?project=%v", sessionState.Region, sessionState.Conv.Audit.StreamingStats.DatastreamResources.DatastreamName, sessionState.GCPProjectID) @@ -2342,34 +796,34 @@ func getGeneratedResources(w http.ResponseWriter, r *http.Request) { for shardId, shardResources := range sessionState.Conv.Audit.StreamingStats.ShardToShardResourcesMap { //Datastream url := fmt.Sprintf("https://console.cloud.google.com/datastream/streams/locations/%v/instances/%v?project=%v", sessionState.Region, shardResources.DatastreamResources.DatastreamName, sessionState.GCPProjectID) - resourceDetails := ResourceDetails{ResourceType: constants.DATASTREAM_RESOURCE, ResourceName: shardResources.DatastreamResources.DatastreamName, ResourceUrl: url} + resourceDetails := types.ResourceDetails{ResourceType: constants.DATASTREAM_RESOURCE, ResourceName: shardResources.DatastreamResources.DatastreamName, ResourceUrl: url} generatedResources.ShardToShardResourcesMap[shardId] = append(generatedResources.ShardToShardResourcesMap[shardId], resourceDetails) //Dataflow dfId := shardResources.DataflowResources.JobId url = fmt.Sprintf("https://console.cloud.google.com/dataflow/jobs/%v/%v?project=%v", sessionState.Conv.Audit.StreamingStats.DataflowResources.Region, dfId, sessionState.GCPProjectID) - resourceDetails = ResourceDetails{ResourceType: constants.DATAFLOW_RESOURCE, ResourceName: dfId, ResourceUrl: url, GcloudCmd: shardResources.DataflowResources.GcloudCmd} + resourceDetails = types.ResourceDetails{ResourceType: constants.DATAFLOW_RESOURCE, ResourceName: dfId, ResourceUrl: url, GcloudCmd: shardResources.DataflowResources.GcloudCmd} generatedResources.ShardToShardResourcesMap[shardId] = append(generatedResources.ShardToShardResourcesMap[shardId], resourceDetails) //monitoring url = fmt.Sprintf("https://console.cloud.google.com/monitoring/dashboards/builder/%v?project=%v", shardResources.MonitoringResources.DashboardName, sessionState.GCPProjectID) - resourceDetails = ResourceDetails{ResourceType: constants.MONITORING_RESOURCE, ResourceName: shardResources.MonitoringResources.DashboardName, ResourceUrl: url} + resourceDetails = types.ResourceDetails{ResourceType: constants.MONITORING_RESOURCE, ResourceName: shardResources.MonitoringResources.DashboardName, ResourceUrl: url} generatedResources.ShardToShardResourcesMap[shardId] = append(generatedResources.ShardToShardResourcesMap[shardId], resourceDetails) //gcs url = fmt.Sprintf("https://console.cloud.google.com/storage/browser/%v?project=%v", shardResources.GcsResources.BucketName, sessionState.GCPProjectID) - resourceDetails = ResourceDetails{ResourceType: constants.GCS_RESOURCE, ResourceName: shardResources.GcsResources.BucketName, ResourceUrl: url} + resourceDetails = types.ResourceDetails{ResourceType: constants.GCS_RESOURCE, ResourceName: shardResources.GcsResources.BucketName, ResourceUrl: url} generatedResources.ShardToShardResourcesMap[shardId] = append(generatedResources.ShardToShardResourcesMap[shardId], resourceDetails) //pubsub topicUrl := fmt.Sprintf("https://console.cloud.google.com/cloudpubsub/topic/detail/%v?project=%v", shardResources.PubsubResources.TopicId, sessionState.GCPProjectID) - topicResourceDetails := ResourceDetails{ResourceType: constants.PUBSUB_TOPIC_RESOURCE, ResourceName: shardResources.PubsubResources.TopicId, ResourceUrl: topicUrl} + topicResourceDetails := types.ResourceDetails{ResourceType: constants.PUBSUB_TOPIC_RESOURCE, ResourceName: shardResources.PubsubResources.TopicId, ResourceUrl: topicUrl} generatedResources.ShardToShardResourcesMap[shardId] = append(generatedResources.ShardToShardResourcesMap[shardId], topicResourceDetails) subscriptionUrl := fmt.Sprintf("https://console.cloud.google.com/cloudpubsub/subscription/detail/%v?project=%v", shardResources.PubsubResources.SubscriptionId, sessionState.GCPProjectID) - subscriptionResourceDetails := ResourceDetails{ResourceType: constants.PUBSUB_SUB_RESOURCE, ResourceName: shardResources.PubsubResources.SubscriptionId, ResourceUrl: subscriptionUrl} + subscriptionResourceDetails := types.ResourceDetails{ResourceType: constants.PUBSUB_SUB_RESOURCE, ResourceName: shardResources.PubsubResources.SubscriptionId, ResourceUrl: subscriptionUrl} generatedResources.ShardToShardResourcesMap[shardId] = append(generatedResources.ShardToShardResourcesMap[shardId], subscriptionResourceDetails) } w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(generatedResources) } -func getSourceAndTargetProfiles(sessionState *session.SessionState, details migrationDetails) (profiles.SourceProfile, profiles.TargetProfile, utils.IOStreams, string, error) { +func getSourceAndTargetProfiles(sessionState *session.SessionState, details types.MigrationDetails) (profiles.SourceProfile, profiles.TargetProfile, utils.IOStreams, string, error) { var ( sourceProfileString string err error @@ -2422,7 +876,7 @@ func getSourceAndTargetProfiles(sessionState *session.SessionState, details migr return sourceProfile, targetProfile, ioHelper, dbName, nil } -func getSourceProfileStringForShardedMigrations(sessionState *session.SessionState, details migrationDetails) (string, error) { +func getSourceProfileStringForShardedMigrations(sessionState *session.SessionState, details types.MigrationDetails) (string, error) { fileName := sessionState.Conv.Audit.MigrationRequestId + "-sharding.cfg" if details.MigrationType != helpers.LOW_DOWNTIME_MIGRATION { err := createConfigFileForShardedBulkMigration(sessionState, details, fileName) @@ -2442,7 +896,7 @@ func getSourceProfileStringForShardedMigrations(sessionState *session.SessionSta } -func createConfigFileForShardedDataflowMigration(sessionState *session.SessionState, details migrationDetails, fileName string) error { +func createConfigFileForShardedDataflowMigration(sessionState *session.SessionState, details types.MigrationDetails, fileName string) error { sourceProfileConfig := sessionState.SourceProfileConfig //Set the TmpDir from the sessionState bucket which is derived from the target connection profile for _, dataShard := range sourceProfileConfig.ShardConfigurationDataflow.DataShards { @@ -2463,7 +917,7 @@ func createConfigFileForShardedDataflowMigration(sessionState *session.SessionSt return nil } -func createConfigFileForShardedBulkMigration(sessionState *session.SessionState, details migrationDetails, fileName string) error { +func createConfigFileForShardedBulkMigration(sessionState *session.SessionState, details types.MigrationDetails, fileName string) error { sourceProfileConfig := profiles.SourceProfileConfig{ ConfigType: constants.BULK_MIGRATION, ShardConfigurationBulk: profiles.ShardConfigurationBulk{ @@ -2517,7 +971,7 @@ func writeSessionFile(ctx context.Context, sessionState *session.SessionState) e return nil } -func createStreamingCfgFile(sessionState *session.SessionState, details migrationDetails, fileName string) error { +func createStreamingCfgFile(sessionState *session.SessionState, details types.MigrationDetails, fileName string) error { targetDetails, datastreamConfig, dataflowConfig := details.TargetDetails, details.DatastreamConfig, details.DataflowConfig dfLocation := sessionState.Region if dataflowConfig.Location != "" { @@ -2577,174 +1031,6 @@ func createStreamingCfgFile(sessionState *session.SessionState, details migratio return nil } -func updateIndexes(w http.ResponseWriter, r *http.Request) { - table := r.FormValue("table") - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) - } - - newIndexes := []ddl.CreateIndex{} - if err = json.Unmarshal(reqBody, &newIndexes); err != nil { - http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) - return - } - - list := []int{} - for i := 0; i < len(newIndexes); i++ { - for j := 0; j < len(newIndexes[i].Keys); j++ { - list = append(list, newIndexes[i].Keys[j].Order) - } - } - - if utilities.DuplicateInArray(list) != -1 { - http.Error(w, fmt.Sprintf("Two Index columns can not have same order"), http.StatusBadRequest) - return - } - - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - sp := sessionState.Conv.SpSchema[table] - - st := sessionState.Conv.SrcSchema[table] - - for i, ind := range sp.Indexes { - - if ind.TableId == newIndexes[0].TableId && ind.Id == newIndexes[0].Id { - - index.RemoveIndexIssues(table, sp.Indexes[i]) - - sp.Indexes[i].Keys = newIndexes[0].Keys - sp.Indexes[i].Name = newIndexes[0].Name - sp.Indexes[i].TableId = newIndexes[0].TableId - sp.Indexes[i].Unique = newIndexes[0].Unique - sp.Indexes[i].Id = newIndexes[0].Id - - break - } - } - - for i, spIndex := range sp.Indexes { - - for j, srcIndex := range st.Indexes { - - for k, spIndexKey := range spIndex.Keys { - - for l, srcIndexKey := range srcIndex.Keys { - - if srcIndexKey.ColId == spIndexKey.ColId { - - st.Indexes[j].Keys[l].Order = sp.Indexes[i].Keys[k].Order - } - - } - } - - } - } - - sessionState.Conv.SpSchema[table] = sp - - sessionState.Conv.SrcSchema[table] = st - - session.UpdateSessionFile() - - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - -func dropSecondaryIndex(w http.ResponseWriter, r *http.Request) { - sessionState := session.GetSessionState() - sessionState.Conv.ConvLock.Lock() - defer sessionState.Conv.ConvLock.Unlock() - - table := r.FormValue("table") - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, fmt.Sprintf("Body Read Error : %v", err), http.StatusInternalServerError) - } - - var dropDetail struct{ Id string } - if err = json.Unmarshal(reqBody, &dropDetail); err != nil { - http.Error(w, fmt.Sprintf("Request Body parse error : %v", err), http.StatusBadRequest) - return - } - if sessionState.Conv == nil || sessionState.Driver == "" { - http.Error(w, fmt.Sprintf("Schema is not converted or Driver is not configured properly. Please retry converting the database to Spanner."), http.StatusNotFound) - return - } - - if table == "" || dropDetail.Id == "" { - http.Error(w, fmt.Sprintf("Table name or position is empty"), http.StatusBadRequest) - } - err = dropSecondaryIndexHelper(table, dropDetail.Id) - if err != nil { - http.Error(w, fmt.Sprintf("%v", err), http.StatusBadRequest) - return - } - - // To set enabled value to false for the rule associated with the dropped index. - indexId := dropDetail.Id - for i, rule := range sessionState.Conv.Rules { - if rule.Type == constants.AddIndex { - d, err := json.Marshal(rule.Data) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - var index ddl.CreateIndex - err = json.Unmarshal(d, &index) - if err != nil { - http.Error(w, "Invalid rule data", http.StatusInternalServerError) - return - } - if index.Id == indexId { - sessionState.Conv.Rules[i].Enabled = false - break - } - } - } - - convm := session.ConvWithMetadata{ - SessionMetadata: sessionState.SessionMetadata, - Conv: *sessionState.Conv, - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(convm) -} - -func dropSecondaryIndexHelper(tableId, idxId string) error { - if tableId == "" || idxId == "" { - return fmt.Errorf("Table id or index id is empty") - } - sessionState := session.GetSessionState() - sp := sessionState.Conv.SpSchema[tableId] - position := -1 - for i, index := range sp.Indexes { - if idxId == index.Id { - position = i - break - } - } - if position < 0 || position >= len(sp.Indexes) { - return fmt.Errorf("No secondary index found at position %d", position) - } - - usedNames := sessionState.Conv.UsedNames - delete(usedNames, strings.ToLower(sp.Indexes[position].Name)) - index.RemoveIndexIssues(tableId, sp.Indexes[position]) - - sp.Indexes = utilities.RemoveSecondaryIndex(sp.Indexes, position) - sessionState.Conv.SpSchema[tableId] = sp - session.UpdateSessionFile() - return nil -} - func uploadFile(w http.ResponseWriter, r *http.Request) { r.ParseMultipartForm(10 << 20) @@ -2808,351 +1094,6 @@ func rollback(err error) error { return err } -func checkPrimaryKeyOrder(tableId string, refTableId string, fk ddl.Foreignkey) string { - sessionState := session.GetSessionState() - childPks := sessionState.Conv.SpSchema[tableId].PrimaryKeys - parentPks := sessionState.Conv.SpSchema[refTableId].PrimaryKeys - childTable := sessionState.Conv.SpSchema[tableId] - parentTable := sessionState.Conv.SpSchema[refTableId] - for i := 0; i < len(parentPks); i++ { - - for j := 0; j < len(childPks); j++ { - - for k := 0; k < len(fk.ReferColumnIds); k++ { - - if childTable.ColDefs[fk.ColIds[k]].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && - parentTable.ColDefs[parentPks[i].ColId].Name == childTable.ColDefs[childPks[j].ColId].Name && - parentTable.ColDefs[parentPks[i].ColId].T.Name == childTable.ColDefs[childPks[j].ColId].T.Name && - parentTable.ColDefs[parentPks[i].ColId].T.Len == childTable.ColDefs[childPks[j].ColId].T.Len && - parentTable.ColDefs[parentPks[i].ColId].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && - childTable.ColDefs[childPks[j].ColId].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name { - if parentPks[i].Order != childPks[j].Order { - return childPks[j].ColId - } - } - } - - } - - } - return "" - -} - -func checkPrimaryKeyPrefix(tableId string, refTableId string, fk ddl.Foreignkey, tableInterleaveStatus *TableInterleaveStatus) bool { - - sessionState := session.GetSessionState() - childTable := sessionState.Conv.SpSchema[tableId] - parentTable := sessionState.Conv.SpSchema[refTableId] - childPks := sessionState.Conv.SpSchema[tableId].PrimaryKeys - parentPks := sessionState.Conv.SpSchema[refTableId].PrimaryKeys - possibleInterleave := false - - flag := false - for _, key := range parentPks { - flag = false - for _, colId := range fk.ReferColumnIds { - if key.ColId == colId { - flag = true - } - } - if !flag { - break - } - } - if flag { - possibleInterleave = true - } - - if !possibleInterleave { - removeInterleaveSuggestions(fk.ColIds, tableId) - return false - } - - childPkColIds := []string{} - for _, k := range childPks { - childPkColIds = append(childPkColIds, k.ColId) - } - - interleaved := []ddl.IndexKey{} - - for i := 0; i < len(parentPks); i++ { - - for j := 0; j < len(childPks); j++ { - - for k := 0; k < len(fk.ReferColumnIds); k++ { - - if childTable.ColDefs[fk.ColIds[k]].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && - parentTable.ColDefs[parentPks[i].ColId].Name == childTable.ColDefs[childPks[j].ColId].Name && - parentTable.ColDefs[parentPks[i].ColId].T.Name == childTable.ColDefs[childPks[j].ColId].T.Name && - parentTable.ColDefs[parentPks[i].ColId].T.Len == childTable.ColDefs[childPks[j].ColId].T.Len && - parentTable.ColDefs[parentPks[i].ColId].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name && - childTable.ColDefs[childPks[j].ColId].Name == parentTable.ColDefs[fk.ReferColumnIds[k]].Name { - - interleaved = append(interleaved, parentPks[i]) - } - } - - } - - } - - if len(interleaved) == len(parentPks) { - return true - } - - diff := []ddl.IndexKey{} - - if len(interleaved) == 0 { - - for i := 0; i < len(parentPks); i++ { - - for j := 0; j < len(childPks); j++ { - - if parentTable.ColDefs[parentPks[i].ColId].Name != childTable.ColDefs[childPks[j].ColId].Name || parentTable.ColDefs[parentPks[i].ColId].T.Len != childTable.ColDefs[childPks[j].ColId].T.Len { - diff = append(diff, parentPks[i]) - } - - } - } - - } - - canInterleavedOnAdd := []string{} - canInterleavedOnRename := []string{} - canInterLeaveOnChangeInColumnSize := []string{} - - fkReferColNames := []string{} - childPkColNames := []string{} - for _, colId := range fk.ReferColumnIds { - fkReferColNames = append(fkReferColNames, parentTable.ColDefs[colId].Name) - } - for _, colId := range childPkColIds { - childPkColNames = append(childPkColNames, childTable.ColDefs[colId].Name) - } - - for i := 0; i < len(diff); i++ { - - parentColIndex := utilities.IsColumnPresent(fkReferColNames, parentTable.ColDefs[diff[i].ColId].Name) - if parentColIndex == -1 { - continue - } - childColIndex := utilities.IsColumnPresent(childPkColNames, childTable.ColDefs[fk.ColIds[parentColIndex]].Name) - if childColIndex == -1 { - canInterleavedOnAdd = append(canInterleavedOnAdd, fk.ColIds[parentColIndex]) - } else { - if parentTable.ColDefs[diff[i].ColId].Name == childTable.ColDefs[fk.ColIds[parentColIndex]].Name { - canInterLeaveOnChangeInColumnSize = append(canInterLeaveOnChangeInColumnSize, fk.ColIds[parentColIndex]) - } else { - canInterleavedOnRename = append(canInterleavedOnRename, fk.ColIds[parentColIndex]) - } - - } - } - - if len(canInterLeaveOnChangeInColumnSize) > 0 { - updateInterleaveSuggestion(canInterLeaveOnChangeInColumnSize, tableId, internal.InterleavedChangeColumnSize) - } else if len(canInterleavedOnRename) > 0 { - updateInterleaveSuggestion(canInterleavedOnRename, tableId, internal.InterleavedRenameColumn) - } else if len(canInterleavedOnAdd) > 0 { - updateInterleaveSuggestion(canInterleavedOnAdd, tableId, internal.InterleavedAddColumn) - } - - return false -} - -func updateInterleaveSuggestion(colIds []string, tableId string, issue internal.SchemaIssue) { - sessionState := session.GetSessionState() - - for i := 0; i < len(colIds); i++ { - - schemaissue := []internal.SchemaIssue{} - - schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIds[i]] - - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedNotInOrder) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedAddColumn) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedRenameColumn) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedChangeColumnSize) - - schemaissue = append(schemaissue, issue) - - if sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues == nil { - - s := map[string][]internal.SchemaIssue{ - colIds[i]: schemaissue, - } - sessionState.Conv.SchemaIssues[tableId] = internal.TableIssues{ - ColumnLevelIssues: s, - } - } else { - sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIds[i]] = schemaissue - } - } -} - -func removeInterleaveSuggestions(colIds []string, tableId string) { - sessionState := session.GetSessionState() - - for i := 0; i < len(colIds); i++ { - - schemaissue := []internal.SchemaIssue{} - - schemaissue = sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIds[i]] - - if len(schemaissue) == 0 { - continue - } - - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedOrder) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedNotInOrder) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedAddColumn) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedRenameColumn) - schemaissue = utilities.RemoveSchemaIssue(schemaissue, internal.InterleavedChangeColumnSize) - - if sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues == nil { - - s := map[string][]internal.SchemaIssue{ - colIds[i]: schemaissue, - } - sessionState.Conv.SchemaIssues[tableId] = internal.TableIssues{ - ColumnLevelIssues: s, - } - } else { - sessionState.Conv.SchemaIssues[tableId].ColumnLevelIssues[colIds[i]] = schemaissue - } - - } -} - -// SessionState stores information for the current migration session. -type SessionState struct { - sourceDB *sql.DB // Connection to source database in case of direct connection - dbName string // Name of source database - driver string // Name of Spanner migration tool driver in use - conv *internal.Conv // Current conversion state - sessionFile string // Path to session file -} - -// Type and issue. -type typeIssue struct { - T string - Brief string - DisplayT string -} - -type ResourceDetails struct { - ResourceType string `json:"ResourceType"` - ResourceName string `json:"ResourceName"` - ResourceUrl string `json:"ResourceUrl"` - GcloudCmd string `json:"GcloudCmd"` -} -type GeneratedResources struct { - MigrationJobId string `json:"MigrationJobId"` - DatabaseName string `json:"DatabaseName"` - DatabaseUrl string `json:"DatabaseUrl"` - BucketName string `json:"BucketName"` - BucketUrl string `json:"BucketUrl"` - //Used for single instance migration flow - DataStreamJobName string `json:"DataStreamJobName"` - DataStreamJobUrl string `json:"DataStreamJobUrl"` - DataflowJobName string `json:"DataflowJobName"` - DataflowJobUrl string `json:"DataflowJobUrl"` - DataflowGcloudCmd string `json:"DataflowGcloudCmd"` - PubsubTopicName string `json:"PubsubTopicName"` - PubsubTopicUrl string `json:"PubsubTopicUrl"` - PubsubSubscriptionName string `json:"PubsubSubscriptionName"` - PubsubSubscriptionUrl string `json:"PubsubSubscriptionUrl"` - MonitoringDashboardName string `json:"MonitoringDashboardName"` - MonitoringDashboardUrl string `json:"MonitoringDashboardUrl"` - AggMonitoringDashboardName string `json:"AggMonitoringDashboardName"` - AggMonitoringDashboardUrl string `json:"AggMonitoringDashboardUrl"` - //Used for sharded migration flow - ShardToShardResourcesMap map[string][]ResourceDetails `json:"ShardToShardResourcesMap"` -} - -func addTypeToList(convertedType string, spType string, issues []internal.SchemaIssue, l []typeIssue) []typeIssue { - if convertedType == spType { - if len(issues) > 0 { - var briefs []string - for _, issue := range issues { - briefs = append(briefs, reports.IssueDB[issue].Brief) - } - l = append(l, typeIssue{T: spType, Brief: fmt.Sprintf(strings.Join(briefs, ", "))}) - } else { - l = append(l, typeIssue{T: spType}) - } - } - return l -} - -func initializeTypeMap() { - sessionState := session.GetSessionState() - var toddl common.ToDdl - // Initialize mysqlTypeMap. - toddl = mysql.InfoSchemaImpl{}.GetToDdl() - for _, srcTypeName := range []string{"bool", "boolean", "varchar", "char", "text", "tinytext", "mediumtext", "longtext", "set", "enum", "json", "bit", "binary", "varbinary", "blob", "tinyblob", "mediumblob", "longblob", "tinyint", "smallint", "mediumint", "int", "integer", "bigint", "double", "float", "numeric", "decimal", "date", "datetime", "timestamp", "time", "year", "geometrycollection", "multipoint", "multilinestring", "multipolygon", "point", "linestring", "polygon", "geometry"} { - var l []typeIssue - srcType := schema.MakeType() - srcType.Name = srcTypeName - for _, spType := range []string{ddl.Bool, ddl.Bytes, ddl.Date, ddl.Float64, ddl.Int64, ddl.String, ddl.Timestamp, ddl.Numeric, ddl.JSON} { - ty, issues := toddl.ToSpannerType(sessionState.Conv, spType, srcType) - l = addTypeToList(ty.Name, spType, issues, l) - } - if srcTypeName == "tinyint" { - l = append(l, typeIssue{T: ddl.Bool, Brief: "Only tinyint(1) can be converted to BOOL, for any other mods it will be converted to INT64"}) - } - ty, _ := toddl.ToSpannerType(sessionState.Conv, "", srcType) - mysqlDefaultTypeMap[srcTypeName] = ty - mysqlTypeMap[srcTypeName] = l - } - // Initialize postgresTypeMap. - toddl = postgres.InfoSchemaImpl{}.GetToDdl() - for _, srcTypeName := range []string{"bool", "boolean", "bigserial", "bpchar", "character", "bytea", "date", "float8", "double precision", "float4", "real", "int8", "bigint", "int4", "integer", "int2", "smallint", "numeric", "serial", "text", "timestamptz", "timestamp with time zone", "timestamp", "timestamp without time zone", "varchar", "character varying"} { - var l []typeIssue - srcType := schema.MakeType() - srcType.Name = srcTypeName - for _, spType := range []string{ddl.Bool, ddl.Bytes, ddl.Date, ddl.Float64, ddl.Int64, ddl.String, ddl.Timestamp, ddl.Numeric, ddl.JSON} { - ty, issues := toddl.ToSpannerType(sessionState.Conv, spType, srcType) - l = addTypeToList(ty.Name, spType, issues, l) - } - ty, _ := toddl.ToSpannerType(sessionState.Conv, "", srcType) - postgresDefaultTypeMap[srcTypeName] = ty - postgresTypeMap[srcTypeName] = l - } - - // Initialize sqlserverTypeMap. - toddl = sqlserver.InfoSchemaImpl{}.GetToDdl() - for _, srcTypeName := range []string{"int", "tinyint", "smallint", "bigint", "bit", "float", "real", "numeric", "decimal", "money", "smallmoney", "char", "nchar", "varchar", "nvarchar", "text", "ntext", "date", "datetime", "datetime2", "smalldatetime", "datetimeoffset", "time", "timestamp", "rowversion", "binary", "varbinary", "image", "xml", "geography", "geometry", "uniqueidentifier", "sql_variant", "hierarchyid"} { - var l []typeIssue - srcType := schema.MakeType() - srcType.Name = srcTypeName - for _, spType := range []string{ddl.Bool, ddl.Bytes, ddl.Date, ddl.Float64, ddl.Int64, ddl.String, ddl.Timestamp, ddl.Numeric, ddl.JSON} { - ty, issues := toddl.ToSpannerType(sessionState.Conv, spType, srcType) - l = addTypeToList(ty.Name, spType, issues, l) - } - ty, _ := toddl.ToSpannerType(sessionState.Conv, "", srcType) - sqlserverDefaultTypeMap[srcTypeName] = ty - sqlserverTypeMap[srcTypeName] = l - } - - // Initialize oracleTypeMap. - toddl = oracle.InfoSchemaImpl{}.GetToDdl() - for _, srcTypeName := range []string{"NUMBER", "BFILE", "BLOB", "CHAR", "CLOB", "DATE", "BINARY_DOUBLE", "BINARY_FLOAT", "FLOAT", "LONG", "RAW", "LONG RAW", "NCHAR", "NVARCHAR2", "VARCHAR", "VARCHAR2", "NCLOB", "ROWID", "UROWID", "XMLTYPE", "TIMESTAMP", "INTERVAL", "SDO_GEOMETRY"} { - var l []typeIssue - srcType := schema.MakeType() - srcType.Name = srcTypeName - for _, spType := range []string{ddl.Bool, ddl.Bytes, ddl.Date, ddl.Float64, ddl.Int64, ddl.String, ddl.Timestamp, ddl.Numeric, ddl.JSON} { - ty, issues := toddl.ToSpannerType(sessionState.Conv, spType, srcType) - l = addTypeToList(ty.Name, spType, issues, l) - } - ty, _ := toddl.ToSpannerType(sessionState.Conv, "", srcType) - oracleDefaultTypeMap[srcTypeName] = ty - oracleTypeMap[srcTypeName] = l - } -} - func init() { sessionState := session.GetSessionState() utilities.InitObjectId() diff --git a/webv2/web_startup_test.go b/webv2/web_startup_test.go index 1e92128aa5..a5be9ea198 100644 --- a/webv2/web_startup_test.go +++ b/webv2/web_startup_test.go @@ -18,6 +18,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/api" ) // Test robustness of API calls on startup. @@ -29,7 +31,7 @@ func TestDdlOnStartup(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - handler := http.HandlerFunc(getDDL) + handler := http.HandlerFunc(api.GetDDL) handler.ServeHTTP(rr, req) if status := rr.Code; status != http.StatusOK { t.Errorf("handler returned wrong status code: got %v want %v", From 5aff5147bd5119a821dbac959829ab41c41d3ad5 Mon Sep 17 00:00:00 2001 From: Manit Gupta Date: Mon, 12 Feb 2024 16:15:16 +0530 Subject: [PATCH 08/15] Revert GCP details (#769) --- webv2/config.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/webv2/config.json b/webv2/config.json index d3bdeb50b9..8c5d5e5443 100644 --- a/webv2/config.json +++ b/webv2/config.json @@ -1,4 +1,4 @@ { - "GCPProjectID": "span-cloud-testing", - "SpannerInstanceID": "deep-heavy-100gb" + "GCPProjectID": "", + "SpannerInstanceID": "" } \ No newline at end of file From 3b95837bbe177f9245249b882ef4d40ce38596fc Mon Sep 17 00:00:00 2001 From: Deep1998 Date: Tue, 9 Jan 2024 14:19:27 +0530 Subject: [PATCH 09/15] Add dao --- dao/dao.go | 320 +++++++++++++++++++++++++++++++++++++++++ dao/dao_client.go | 48 +++++++ dao/dao_client_test.go | 116 +++++++++++++++ 3 files changed, 484 insertions(+) create mode 100644 dao/dao.go create mode 100644 dao/dao_client.go create mode 100644 dao/dao_client_test.go diff --git a/dao/dao.go b/dao/dao.go new file mode 100644 index 0000000000..c515e39539 --- /dev/null +++ b/dao/dao.go @@ -0,0 +1,320 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dao + +import ( + "context" + "fmt" + + "cloud.google.com/go/spanner" + "google.golang.org/api/iterator" +) + +type StateData struct { + State string `json:"state"` +} + +type DAO interface { + InsertSMTJobEntry(ctx context.Context, jobId, jobName, jobType, dialect, dbName string, jobData spanner.NullJSON) error + UpdateSMTJobState(ctx context.Context, jobId, state string) error + InsertSMTResourceEntry(ctx context.Context, resourceId, jobId, externalId, resourceName, resourceType string, resourceData spanner.NullJSON) error + UpdateSMTResourceState(ctx context.Context, resourceId, state string) error + UpdateSMTResourceExternalId(ctx context.Context, resourceId, externalId string) error +} + +type DAOImpl struct{} + +// Insert a job entry into the SMT_JOB table. +func (dao *DAOImpl) InsertSMTJobEntry(ctx context.Context, jobId, jobName, jobType, dialect, dbName string, jobData spanner.NullJSON) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + jobStmt := spanner.Statement{ + SQL: `INSERT INTO SMT_JOB + (JobId, JobName, JobType, JobStateData, JobData, Dialect, SpannerDatabaseName, CreatedAt, UpdatedAt) + VALUES( + @jobId, @jobName, @jobType, @jobStateData, @jobData, @dialect, @dbName, PENDING_COMMIT_TIMESTAMP(), PENDING_COMMIT_TIMESTAMP() + );`, + Params: map[string]interface{}{ + "jobId": jobId, + "jobName": jobName, + "jobType": jobType, + "jobStateData": spanner.NullJSON{Valid: true, Value: StateData{State: "CREATING"}}, + "jobData": jobData, + "dialect": dialect, + "dbName": dbName, + }, + } + _, err := txn.Update(ctx, jobStmt) + if err != nil { + return err + } + // Update job history table within the same txn. + _, err = updateJobHistoryWithinTxn(ctx, txn, jobId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("could not insert SMT job entry: %v", err) + } + return nil +} + +// Update the state of the SMT job. +func (dao *DAOImpl) UpdateSMTJobState(ctx context.Context, jobId, state string) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + jobStmt := spanner.Statement{ + SQL: `UPDATE SMT_JOB SET JobStateData = @jobStateData, UpdatedAt = PENDING_COMMIT_TIMESTAMP() + WHERE JobId = @jobId;`, + Params: map[string]interface{}{ + "jobId": jobId, + "jobStateData": spanner.NullJSON{Valid: true, Value: StateData{State: state}}, + }, + } + _, err := txn.Update(ctx, jobStmt) + if err != nil { + return err + } + _, err = updateJobHistoryWithinTxn(ctx, txn, jobId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("error updating smt job state: %v", err) + } + return nil +} + +// Insert an entry into the SMT_RESOURCE table. +func (dao *DAOImpl) InsertSMTResourceEntry(ctx context.Context, resourceId, jobId, externalId, resourceName, resourceType string, resourceData spanner.NullJSON) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + resourceStmt := spanner.Statement{ + SQL: `INSERT INTO SMT_RESOURCE + (ResourceId, JobId, ExternalId, ResourceName, ResourceType, ResourceStateData, ResourceData, CreatedAt, UpdatedAt) + VALUES( + @resourceId, @jobId, @externalId, @resourceName, @resourceType, @resourceStateData, @resourceData, PENDING_COMMIT_TIMESTAMP(), PENDING_COMMIT_TIMESTAMP() + );`, + Params: map[string]interface{}{ + "resourceId": resourceId, + "jobId": jobId, + "externalId": externalId, + "resourceName": resourceName, + "resourceType": resourceType, + "resourceStateData": spanner.NullJSON{Valid: true, Value: StateData{State: "CREATING"}}, + "resourceData": resourceData, + }, + } + _, err := txn.Update(ctx, resourceStmt) + if err != nil { + return err + } + // Update the resource history table in the same transaction. + _, err = updateResourceHistoryWithinTxn(ctx, txn, resourceId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("error inserting smt resource entry: %v", err) + } + return nil +} + +// Update the state of the SMT resource. +func (dao *DAOImpl) UpdateSMTResourceState(ctx context.Context, resourceId, state string) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + jobStmt := spanner.Statement{ + SQL: `UPDATE SMT_RESOURCE SET ResourceStateData = @resourceStateData, UpdatedAt = PENDING_COMMIT_TIMESTAMP() + WHERE ResourceId = @resourceId;`, + Params: map[string]interface{}{ + "resourceId": resourceId, + "resourceStateData": spanner.NullJSON{Valid: true, Value: StateData{State: state}}, + }, + } + _, err := txn.Update(ctx, jobStmt) + if err != nil { + return err + } + _, err = updateResourceHistoryWithinTxn(ctx, txn, resourceId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("error updating smt resource state: %v", err) + } + return nil +} + +// Update the external of the SMT resource. +func (dao *DAOImpl) UpdateSMTResourceExternalId(ctx context.Context, resourceId, externalId string) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + jobStmt := spanner.Statement{ + SQL: `UPDATE SMT_RESOURCE SET ExternalId = @externalId, UpdatedAt = PENDING_COMMIT_TIMESTAMP() + WHERE ResourceId = @resourceId;`, + Params: map[string]interface{}{ + "resourceId": resourceId, + "externalId": externalId, + }, + } + _, err := txn.Update(ctx, jobStmt) + if err != nil { + return err + } + _, err = updateResourceHistoryWithinTxn(ctx, txn, resourceId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("error updating smt resource external id: %v", err) + } + return nil +} + +func updateJobHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, jobId string) (int64, error) { + version, err := getLatestJobVersionWithinTxn(ctx, txn, jobId) + if err != nil { + return 0, fmt.Errorf("error fetching latest job version: %v", err) + } + // Fetch the newly updated row from SMT_JOB table. + stmt := spanner.Statement{SQL: ` + SELECT + JobName, JobType, JobStateData, JobData, Dialect, SpannerDatabaseName + FROM SMT_JOB WHERE JobId = @jobId;`, + Params: map[string]interface{}{"jobId": jobId}, + } + iter := txn.Query(ctx, stmt) + defer iter.Stop() + var jobName, jobType, dialect, spannerDatabaseName spanner.NullString + var jobStateData, jobData spanner.NullJSON + row, err := iter.Next() + if err == iterator.Done || err != nil { + return 0, err + } + if err := row.Columns(&jobName, &jobType, &jobStateData, &jobData, &dialect, &spannerDatabaseName); err != nil { + return 0, fmt.Errorf("error reading smt job row: %v", err) + } + + // Insert entry to SMT_JOB_HISTORY table. + jobStmt := spanner.Statement{ + SQL: `INSERT INTO SMT_JOB_HISTORY + (JobId, Version, JobName, JobType, JobStateData, JobData, Dialect, SpannerDatabaseName, CreatedAt) + VALUES( + @jobId, @version, @jobName, @jobType, @jobStateData, @jobData, @dialect, @spannerDatabaseName, PENDING_COMMIT_TIMESTAMP() + );`, + Params: map[string]interface{}{ + "jobId": jobId, + "version": version + 1, + "jobName": jobName, + "jobType": jobType, + "jobStateData": jobStateData, + "jobData": jobData, + "dialect": dialect, + "spannerDatabaseName": spannerDatabaseName, + }, + } + return txn.Update(ctx, jobStmt) +} + +func getLatestJobVersionWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, jobId string) (int64, error) { + // Fetch latest version for the job from history table. + stmt := spanner.Statement{SQL: `SELECT MAX(Version) FROM SMT_JOB_HISTORY WHERE JobId = @jobId;`, + Params: map[string]interface{}{"jobId": jobId}, + } + iter := txn.Query(ctx, stmt) + defer iter.Stop() + version := spanner.NullInt64{} + row, err := iter.Next() + if err == iterator.Done || err != nil { + return 0, err + } + if err := row.Columns(&version); err != nil { + return 0, err + } + if version.Valid { + return version.Int64, nil + } + return 0, nil +} + +func updateResourceHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, resourceId string) (int64, error) { + version, err := getLatestResourceVersionWithinTxn(ctx, txn, resourceId) + if err != nil { + return 0, fmt.Errorf("error fetching latest resource version: %v", err) + } + // Fetch the newly updated row from SMT_RESOURCE table. + stmt := spanner.Statement{SQL: ` + SELECT + JobId, ExternalId, ResourceName, ResourceType, ResourceStateData, ResourceData + FROM SMT_RESOURCE WHERE ResourceId = @resourceId;`, + Params: map[string]interface{}{"resourceId": resourceId}, + } + iter := txn.Query(ctx, stmt) + defer iter.Stop() + var jobId, externalId, resourceName, resourceType spanner.NullString + var resourceStateData, resourceData spanner.NullJSON + row, err := iter.Next() + if err == iterator.Done || err != nil { + return 0, err + } + if err := row.Columns(&jobId, &externalId, &resourceName, &resourceType, &resourceStateData, &resourceData); err != nil { + return 0, fmt.Errorf("error reading smt resource row: %v", err) + } + // Create new entry into the SMT_RESOURCE_HISTORY table. + jobStmt := spanner.Statement{ + SQL: `INSERT INTO SMT_RESOURCE_HISTORY + (ResourceId, Version, JobId, ExternalId, ResourceName, ResourceType, ResourceStateData, ResourceData, CreatedAt) + VALUES( + @resourceId, @version, @jobId, @externalId, @resourceName, @resourceType, @resourceStateData, @resourceData, PENDING_COMMIT_TIMESTAMP() + );`, + Params: map[string]interface{}{ + "resourceId": resourceId, + "version": version + 1, + "jobId": jobId, + "externalId": externalId, + "resourceName": resourceName, + "resourceType": resourceType, + "resourceStateData": resourceStateData, + "resourceData": resourceData, + }, + } + return txn.Update(ctx, jobStmt) +} + +func getLatestResourceVersionWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, resourceId string) (int64, error) { + // Fetch latest version for the resource from history table. + stmt := spanner.Statement{SQL: `SELECT MAX(Version) FROM SMT_RESOURCE_HISTORY WHERE ResourceId = @resourceId;`, + Params: map[string]interface{}{"resourceId": resourceId}, + } + iter := txn.Query(ctx, stmt) + defer iter.Stop() + version := spanner.NullInt64{} + row, err := iter.Next() + if err == iterator.Done || err != nil { + return 0, err + } + if err := row.Columns(&version); err != nil { + return 0, err + } + if version.Valid { + return version.Int64, nil + } + return 0, nil +} diff --git a/dao/dao_client.go b/dao/dao_client.go new file mode 100644 index 0000000000..46964888de --- /dev/null +++ b/dao/dao_client.go @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dao + +import ( + "context" + "fmt" + "sync" + + sp "cloud.google.com/go/spanner" +) + +var once sync.Once +var spClient *sp.Client + +// This function is declared as a global variable to make it testable. The unit +// tests edit this function, acting like a double. +var newClient = sp.NewClient + +func GetOrCreateClient(ctx context.Context, dbURI string) (*sp.Client, error) { + var err error + if spClient == nil { + once.Do(func() { + spClient, err = newClient(ctx, dbURI) + }) + if err != nil { + return nil, fmt.Errorf("failed to create spanner database client: %v", err) + } + return spClient, nil + } + return spClient, nil +} + +// The DAO client must be initiated via GetOrCreateClient() once before using GetClient(). +func GetClient() *sp.Client { + return spClient +} diff --git a/dao/dao_client_test.go b/dao/dao_client_test.go new file mode 100644 index 0000000000..05c0352a9a --- /dev/null +++ b/dao/dao_client_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dao + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + spClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + spClient = nil + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx, "testURI") + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + oldC, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx, "testURI") + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.Nil(t, c) + assert.NotNil(t, err) +} From 380a7fff7022c24283e265a153972db1e833ee69 Mon Sep 17 00:00:00 2001 From: Deep1998 Date: Mon, 12 Feb 2024 22:18:15 +0530 Subject: [PATCH 10/15] Remove version column from metadata history tables --- dao/dao.go | 80 +++++++--------------------------------- webv2/helpers/helpers.go | 6 +-- 2 files changed, 16 insertions(+), 70 deletions(-) diff --git a/dao/dao.go b/dao/dao.go index c515e39539..21bea4137f 100644 --- a/dao/dao.go +++ b/dao/dao.go @@ -26,17 +26,17 @@ type StateData struct { } type DAO interface { - InsertSMTJobEntry(ctx context.Context, jobId, jobName, jobType, dialect, dbName string, jobData spanner.NullJSON) error - UpdateSMTJobState(ctx context.Context, jobId, state string) error - InsertSMTResourceEntry(ctx context.Context, resourceId, jobId, externalId, resourceName, resourceType string, resourceData spanner.NullJSON) error - UpdateSMTResourceState(ctx context.Context, resourceId, state string) error - UpdateSMTResourceExternalId(ctx context.Context, resourceId, externalId string) error + InsertJobEntry(ctx context.Context, jobId, jobName, jobType, dialect, dbName string, jobData spanner.NullJSON) error + UpdateJobState(ctx context.Context, jobId, state string) error + InsertResourceEntry(ctx context.Context, resourceId, jobId, externalId, resourceName, resourceType string, resourceData spanner.NullJSON) error + UpdateResourceState(ctx context.Context, resourceId, state string) error + UpdateResourceExternalId(ctx context.Context, resourceId, externalId string) error } type DAOImpl struct{} // Insert a job entry into the SMT_JOB table. -func (dao *DAOImpl) InsertSMTJobEntry(ctx context.Context, jobId, jobName, jobType, dialect, dbName string, jobData spanner.NullJSON) error { +func (dao *DAOImpl) InsertJobEntry(ctx context.Context, jobId, jobName, jobType, dialect, dbName string, jobData spanner.NullJSON) error { _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { jobStmt := spanner.Statement{ SQL: `INSERT INTO SMT_JOB @@ -72,7 +72,7 @@ func (dao *DAOImpl) InsertSMTJobEntry(ctx context.Context, jobId, jobName, jobTy } // Update the state of the SMT job. -func (dao *DAOImpl) UpdateSMTJobState(ctx context.Context, jobId, state string) error { +func (dao *DAOImpl) UpdateJobState(ctx context.Context, jobId, state string) error { _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { jobStmt := spanner.Statement{ SQL: `UPDATE SMT_JOB SET JobStateData = @jobStateData, UpdatedAt = PENDING_COMMIT_TIMESTAMP() @@ -99,7 +99,7 @@ func (dao *DAOImpl) UpdateSMTJobState(ctx context.Context, jobId, state string) } // Insert an entry into the SMT_RESOURCE table. -func (dao *DAOImpl) InsertSMTResourceEntry(ctx context.Context, resourceId, jobId, externalId, resourceName, resourceType string, resourceData spanner.NullJSON) error { +func (dao *DAOImpl) InsertResourceEntry(ctx context.Context, resourceId, jobId, externalId, resourceName, resourceType string, resourceData spanner.NullJSON) error { _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { resourceStmt := spanner.Statement{ SQL: `INSERT INTO SMT_RESOURCE @@ -135,7 +135,7 @@ func (dao *DAOImpl) InsertSMTResourceEntry(ctx context.Context, resourceId, jobI } // Update the state of the SMT resource. -func (dao *DAOImpl) UpdateSMTResourceState(ctx context.Context, resourceId, state string) error { +func (dao *DAOImpl) UpdateResourceState(ctx context.Context, resourceId, state string) error { _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { jobStmt := spanner.Statement{ SQL: `UPDATE SMT_RESOURCE SET ResourceStateData = @resourceStateData, UpdatedAt = PENDING_COMMIT_TIMESTAMP() @@ -162,7 +162,7 @@ func (dao *DAOImpl) UpdateSMTResourceState(ctx context.Context, resourceId, stat } // Update the external of the SMT resource. -func (dao *DAOImpl) UpdateSMTResourceExternalId(ctx context.Context, resourceId, externalId string) error { +func (dao *DAOImpl) UpdateResourceExternalId(ctx context.Context, resourceId, externalId string) error { _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { jobStmt := spanner.Statement{ SQL: `UPDATE SMT_RESOURCE SET ExternalId = @externalId, UpdatedAt = PENDING_COMMIT_TIMESTAMP() @@ -189,10 +189,6 @@ func (dao *DAOImpl) UpdateSMTResourceExternalId(ctx context.Context, resourceId, } func updateJobHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, jobId string) (int64, error) { - version, err := getLatestJobVersionWithinTxn(ctx, txn, jobId) - if err != nil { - return 0, fmt.Errorf("error fetching latest job version: %v", err) - } // Fetch the newly updated row from SMT_JOB table. stmt := spanner.Statement{SQL: ` SELECT @@ -215,13 +211,12 @@ func updateJobHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransa // Insert entry to SMT_JOB_HISTORY table. jobStmt := spanner.Statement{ SQL: `INSERT INTO SMT_JOB_HISTORY - (JobId, Version, JobName, JobType, JobStateData, JobData, Dialect, SpannerDatabaseName, CreatedAt) + (JobId, JobName, JobType, JobStateData, JobData, Dialect, SpannerDatabaseName, CreatedAt) VALUES( - @jobId, @version, @jobName, @jobType, @jobStateData, @jobData, @dialect, @spannerDatabaseName, PENDING_COMMIT_TIMESTAMP() + @jobId, @jobName, @jobType, @jobStateData, @jobData, @dialect, @spannerDatabaseName, PENDING_COMMIT_TIMESTAMP() );`, Params: map[string]interface{}{ "jobId": jobId, - "version": version + 1, "jobName": jobName, "jobType": jobType, "jobStateData": jobStateData, @@ -233,32 +228,7 @@ func updateJobHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransa return txn.Update(ctx, jobStmt) } -func getLatestJobVersionWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, jobId string) (int64, error) { - // Fetch latest version for the job from history table. - stmt := spanner.Statement{SQL: `SELECT MAX(Version) FROM SMT_JOB_HISTORY WHERE JobId = @jobId;`, - Params: map[string]interface{}{"jobId": jobId}, - } - iter := txn.Query(ctx, stmt) - defer iter.Stop() - version := spanner.NullInt64{} - row, err := iter.Next() - if err == iterator.Done || err != nil { - return 0, err - } - if err := row.Columns(&version); err != nil { - return 0, err - } - if version.Valid { - return version.Int64, nil - } - return 0, nil -} - func updateResourceHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, resourceId string) (int64, error) { - version, err := getLatestResourceVersionWithinTxn(ctx, txn, resourceId) - if err != nil { - return 0, fmt.Errorf("error fetching latest resource version: %v", err) - } // Fetch the newly updated row from SMT_RESOURCE table. stmt := spanner.Statement{SQL: ` SELECT @@ -280,13 +250,12 @@ func updateResourceHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteT // Create new entry into the SMT_RESOURCE_HISTORY table. jobStmt := spanner.Statement{ SQL: `INSERT INTO SMT_RESOURCE_HISTORY - (ResourceId, Version, JobId, ExternalId, ResourceName, ResourceType, ResourceStateData, ResourceData, CreatedAt) + (ResourceId, JobId, ExternalId, ResourceName, ResourceType, ResourceStateData, ResourceData, CreatedAt) VALUES( - @resourceId, @version, @jobId, @externalId, @resourceName, @resourceType, @resourceStateData, @resourceData, PENDING_COMMIT_TIMESTAMP() + @resourceId, @jobId, @externalId, @resourceName, @resourceType, @resourceStateData, @resourceData, PENDING_COMMIT_TIMESTAMP() );`, Params: map[string]interface{}{ "resourceId": resourceId, - "version": version + 1, "jobId": jobId, "externalId": externalId, "resourceName": resourceName, @@ -297,24 +266,3 @@ func updateResourceHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteT } return txn.Update(ctx, jobStmt) } - -func getLatestResourceVersionWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, resourceId string) (int64, error) { - // Fetch latest version for the resource from history table. - stmt := spanner.Statement{SQL: `SELECT MAX(Version) FROM SMT_RESOURCE_HISTORY WHERE ResourceId = @resourceId;`, - Params: map[string]interface{}{"resourceId": resourceId}, - } - iter := txn.Query(ctx, stmt) - defer iter.Stop() - version := spanner.NullInt64{} - row, err := iter.Next() - if err == iterator.Done || err != nil { - return 0, err - } - if err := row.Columns(&version); err != nil { - return 0, err - } - if version.Valid { - return version.Int64, nil - } - return 0, nil -} diff --git a/webv2/helpers/helpers.go b/webv2/helpers/helpers.go index 980bc7eac2..d3b250e5fb 100644 --- a/webv2/helpers/helpers.go +++ b/webv2/helpers/helpers.go @@ -66,7 +66,6 @@ var TABLE_STATEMENTS = []string{ ) PRIMARY KEY(JobId)`, `CREATE TABLE IF NOT EXISTS SMT_JOB_HISTORY ( JobId STRING(100) NOT NULL, - Version INT64 NOT NULL, JobName STRING(100) NOT NULL, JobType STRING(100) NOT NULL, JobStateData JSON, @@ -74,7 +73,7 @@ var TABLE_STATEMENTS = []string{ Dialect STRING(50) NOT NULL, SpannerDatabaseName STRING(100) NOT NULL, CreatedAt TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), - ) PRIMARY KEY(JobId, Version)`, + ) PRIMARY KEY(JobId, CreatedAt)`, `CREATE TABLE IF NOT EXISTS SMT_RESOURCE ( ResourceId STRING(100) NOT NULL, JobId STRING(100) NOT NULL, @@ -88,7 +87,6 @@ var TABLE_STATEMENTS = []string{ ) PRIMARY KEY(ResourceId)`, `CREATE TABLE IF NOT EXISTS SMT_RESOURCE_HISTORY ( ResourceId STRING(100) NOT NULL, - Version INT64 NOT NULL, JobId STRING(100) NOT NULL, ExternalId STRING(100), ResourceName STRING(100) NOT NULL, @@ -96,7 +94,7 @@ var TABLE_STATEMENTS = []string{ ResourceStateData JSON, ResourceData JSON, CreatedAt TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), - ) PRIMARY KEY(ResourceId, Version)`, + ) PRIMARY KEY(ResourceId, CreatedAt)`, } func GetSpannerUri(projectId string, instanceId string) string { From 5416bfbc67ad2adcbc316b48aee35b696d5061da Mon Sep 17 00:00:00 2001 From: aksharauke <126752897+aksharauke@users.noreply.github.com> Date: Wed, 14 Feb 2024 11:06:01 +0530 Subject: [PATCH 11/15] ability to pass custom parameters to custom sharding logic (#758) * ability to pass custom parameters to custom sharding logic * updated the template location for reader * corrected typo --- .../ReverseReplicationUserGuide.md | 1 + .../RunnigReverseReplication.md | 5 +- .../reverse-replication-runner.go | 75 ++++++++++--------- 3 files changed, 43 insertions(+), 38 deletions(-) diff --git a/docs/reverse-replication/ReverseReplicationUserGuide.md b/docs/reverse-replication/ReverseReplicationUserGuide.md index df9d0927dc..0660deefe4 100644 --- a/docs/reverse-replication/ReverseReplicationUserGuide.md +++ b/docs/reverse-replication/ReverseReplicationUserGuide.md @@ -311,6 +311,7 @@ Steps to perfrom customization: 1. Write custom shard id fetcher logic [CustomShardIdFetcher.java](https://github.com/GoogleCloudPlatform/DataflowTemplates/blob/main/v2/spanner-custom-shard/src/main/java/com/custom/CustomShardIdFetcher.java). Details of the ShardIdRequest class can be found [here](https://github.com/GoogleCloudPlatform/DataflowTemplates/blob/main/v2/spanner-migrations-sdk/src/main/java/com/google/cloud/teleport/v2/spanner/utils/ShardIdRequest.java). 2. Build the [JAR](https://github.com/GoogleCloudPlatform/DataflowTemplates/tree/main/v2/spanner-custom-shard) and upload the jar to GCS 3. Invoke the reverse replication flow by passing the [custom jar path and custom class path](RunnigReverseReplication.md#custom-jar). +4. If any custom parameters are needed in the custom shard identification logic, they can be passed via the *readerShardingCustomParameters* input to the runner. These parameters will be passed to the *init* method of the custom class. The *init* method is invoked once per worker setup. diff --git a/docs/reverse-replication/RunnigReverseReplication.md b/docs/reverse-replication/RunnigReverseReplication.md index e1b12bd3ad..d4e5b81424 100644 --- a/docs/reverse-replication/RunnigReverseReplication.md +++ b/docs/reverse-replication/RunnigReverseReplication.md @@ -57,6 +57,7 @@ The script takes in multiple arguments to orchestrate the pipeline. They are: - `startTimestamp`: Timestamp from which the changestream should start reading changes in RFC 3339 format, defaults to empty string which is equivalent to the current timestamp. - `readerShardingCustomClassName`: the fully qualified custom class name for sharding logic. - `readerShardingCustomJarPath` : the GCS path to custom jar for sharding logic. +- `readerShardingCustomParameters`: the custom parameters to be passed to the custom sharding logic implementation. - `readerSkipDirectoryName`: Records skipped from reverse replication are written to this directory. Defaults to: skip. - `readerRunMode`: whether the reader from Spanner job runs in regular or resume mode. Default is regular. - `readerWorkers`: number of workers for ordering job. Defaults to 5. @@ -141,13 +142,13 @@ Launched writer job: id:"<>" project_id:"<>" name:"<>" current_state_time:{} cr In order to specify custom shard identification function, custom jar and class names need to give. The command to do that is below: ``` -go run reverse-replication-runner.go -projectId= -dataflowRegion= -instanceId= -dbName= -sourceShardsFilePath=gs://bucket-name/shards.json -sessionFilePath=gs://bucket-name/session.json -gcsPath=gs://bucket-name/ -readerShardingCustomJarPath=gs://bucket-name/custom.jar -readerShardingCustomClassName=com.custom.classname +go run reverse-replication-runner.go -projectId= -dataflowRegion= -instanceId= -dbName= -sourceShardsFilePath=gs://bucket-name/shards.json -sessionFilePath=gs://bucket-name/session.json -gcsPath=gs://bucket-name/ -readerShardingCustomJarPath=gs://bucket-name/custom.jar -readerShardingCustomClassName=com.custom.classname -readerShardingCustomParameters='a=b,c=d' ``` The sample reader job gcloud command for the same ``` -gcloud dataflow flex-template run smt-reverse-replication-reader-2024-01-05t10-33-56z --project= --region= --template-file-gcs-location=