diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..e7546324fd --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,3 @@ +# review when someone opens a pull request + +* @GoogleCloudPlatform/spanner-migrations-team diff --git a/accessors/clients/dataflow/dataflow_client.go b/accessors/clients/dataflow/dataflow_client.go index cc76c34fda..5f70698c8d 100644 --- a/accessors/clients/dataflow/dataflow_client.go +++ b/accessors/clients/dataflow/dataflow_client.go @@ -19,19 +19,13 @@ import ( "sync" dataflow "cloud.google.com/go/dataflow/apiv1beta3" - "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb" - "github.com/googleapis/gax-go/v2" ) -type DataflowClient interface { - LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) -} - var once sync.Once var dfClient *dataflow.FlexTemplatesClient // This function is declared as a global variable to make it testable. The unit -// tests edit this function, acting like a double. +// tests update this function, acting like a double. var newFlexTemplatesClient = dataflow.NewFlexTemplatesClient func GetOrCreateClient(ctx context.Context) (*dataflow.FlexTemplatesClient, error) { diff --git a/accessors/clients/dataflow/interface.go b/accessors/clients/dataflow/interface.go new file mode 100644 index 0000000000..9918b48a02 --- /dev/null +++ b/accessors/clients/dataflow/interface.go @@ -0,0 +1,44 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dataflowclient + +import ( + "context" + + dataflow "cloud.google.com/go/dataflow/apiv1beta3" + "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb" + "github.com/googleapis/gax-go/v2" +) + +// Use this interface instead of dataflow.FlexTemplatesClient to support mocking. +type DataflowClient interface { + LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) +} + +// This implements the DataflowClient interface. This is the primary implementation that should be used in all places other than tests. +type DataflowClientImpl struct { + client *dataflow.FlexTemplatesClient +} + +func NewDataflowClientImpl(ctx context.Context) (*DataflowClientImpl, error) { + c, err := GetOrCreateClient(ctx) + if err != nil { + return nil, err + } + return &DataflowClientImpl{client: c}, nil +} + +func (c *DataflowClientImpl) LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { + return c.client.LaunchFlexTemplate(ctx, req, opts...) +} diff --git a/accessors/clients/dataflow/mocks.go b/accessors/clients/dataflow/mocks.go new file mode 100644 index 0000000000..ac899a7345 --- /dev/null +++ b/accessors/clients/dataflow/mocks.go @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dataflowclient + +import ( + "context" + + "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb" + "github.com/googleapis/gax-go/v2" +) + +// Mock that implements the DataflowClient interface. +// Pass in unit tests where DataflowClient is an input parameter. +type DataflowClientMock struct { + LaunchFlexTemplateMock func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) +} + +func (dcm *DataflowClientMock) LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { + return dcm.LaunchFlexTemplateMock(ctx, req, opts...) +} diff --git a/accessors/clients/datastream/datastream_client.go b/accessors/clients/datastream/datastream_client.go new file mode 100644 index 0000000000..a4b06fad1f --- /dev/null +++ b/accessors/clients/datastream/datastream_client.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package datastreamclient + +import ( + "context" + "fmt" + "sync" + + datastream "cloud.google.com/go/datastream/apiv1" +) + +var once sync.Once +var dsClient *datastream.Client + +// This function is declared as a global variable to make it testable. The unit +// tests update this function, acting like a double. +var newClient = datastream.NewClient + +func GetOrCreateClient(ctx context.Context) (*datastream.Client, error) { + var err error + if dsClient == nil { + once.Do(func() { + dsClient, err = newClient(ctx) + }) + if err != nil { + return nil, fmt.Errorf("failed to create datastream client: %v", err) + } + return dsClient, nil + } + return dsClient, nil +} diff --git a/accessors/clients/datastream/datastream_client_test.go b/accessors/clients/datastream/datastream_client_test.go new file mode 100644 index 0000000000..e90a68c946 --- /dev/null +++ b/accessors/clients/datastream/datastream_client_test.go @@ -0,0 +1,119 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package datastreamclient + +// TODO: Currently this test is intrusive and not using any accessors to mutate the code under test. +// Freeze on the right pattern and fork this into the test package. + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + datastream "cloud.google.com/go/datastream/apiv1" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + dsClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + newClient = func(ctx context.Context, opts ...option.ClientOption) (*datastream.Client, error) { + return &datastream.Client{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*datastream.Client, error) { + return &datastream.Client{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + dsClient = nil + newClient = func(ctx context.Context, opts ...option.ClientOption) (*datastream.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*datastream.Client, error) { + return &datastream.Client{}, nil + } + oldC, err := GetOrCreateClient(ctx) + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newClient = func(ctx context.Context, opts ...option.ClientOption) (*datastream.Client, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx) + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*datastream.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/accessors/clients/datastream/datastream_test/mocks.go b/accessors/clients/datastream/datastream_test/mocks.go new file mode 100644 index 0000000000..339b3991de --- /dev/null +++ b/accessors/clients/datastream/datastream_test/mocks.go @@ -0,0 +1,46 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package datastreamclient_test + +import ( + "context" + + datastreampb "cloud.google.com/go/datastream/apiv1/datastreampb" + "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/operation" + "github.com/googleapis/gax-go/v2" + "github.com/stretchr/testify/mock" +) + +// Mock that implements the DatastreamClient interface. +// Pass in unit tests where DatastreamClient is an input parameter. +type DatastreamClientMock struct { + mock.Mock +} + +func (m *DatastreamClientMock) CreateStream(ctx context.Context, req *datastreampb.CreateStreamRequest, opts ...gax.CallOption) (*operation.OperationWrapper[datastreampb.Stream], error) { + args := m.Called(ctx, req, opts) + // Avoid panic for typeassertion due to null pointer. + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*operation.OperationWrapper[datastreampb.Stream]), args.Error(1) +} +func (m *DatastreamClientMock) UpdateStream(ctx context.Context, req *datastreampb.UpdateStreamRequest, opts ...gax.CallOption) (*operation.OperationWrapper[datastreampb.Stream], error) { + args := m.Called(ctx, req, opts) + // Avoid panic for typeassertion due to null pointer. + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*operation.OperationWrapper[datastreampb.Stream]), args.Error(1) +} diff --git a/accessors/clients/datastream/interface.go b/accessors/clients/datastream/interface.go new file mode 100644 index 0000000000..8c0f0faeea --- /dev/null +++ b/accessors/clients/datastream/interface.go @@ -0,0 +1,92 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package datastreamclient + +import ( + "context" + + datastream "cloud.google.com/go/datastream/apiv1" + datastreampb "cloud.google.com/go/datastream/apiv1/datastreampb" + "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/operation" + "github.com/googleapis/gax-go/v2" +) + +// Use this interface instead of datastream.FlexTemplatesClient to support mocking. +type DatastreamClient interface { + CreateStream(ctx context.Context, req *datastreampb.CreateStreamRequest, opts ...gax.CallOption) (*operation.OperationWrapper[datastreampb.Stream], error) + UpdateStream(ctx context.Context, req *datastreampb.UpdateStreamRequest, opts ...gax.CallOption) (*operation.OperationWrapper[datastreampb.Stream], error) + GetConnectionProfile(ctx context.Context, connectionName string) (*datastreampb.ConnectionProfile, error) + ListConnectionProfiles(ctx context.Context, listRequest *datastreampb.ListConnectionProfilesRequest, opts ...gax.CallOption) *datastream.ConnectionProfileIterator + DeleteConnectionProfile(ctx context.Context, deleteRequest *datastreampb.DeleteConnectionProfileRequest) (*operation.NilOperationWrapper, error) + CreateConnectionProfile(ctx context.Context, createRequest *datastreampb.CreateConnectionProfileRequest) (*operation.OperationWrapper[datastreampb.ConnectionProfile], error) +} + +// This implements the DatastreamClient interface. This is the primary implementation that should be used in all places other than tests. +type DatastreamClientImpl struct { + client *datastream.Client +} + +func NewDatastreamClientImpl(ctx context.Context) (*DatastreamClientImpl, error) { + c, err := GetOrCreateClient(ctx) + if err != nil { + return nil, err + } + return &DatastreamClientImpl{client: c}, nil +} + +func (c *DatastreamClientImpl) CreateStream(ctx context.Context, req *datastreampb.CreateStreamRequest, opts ...gax.CallOption) (*operation.OperationWrapper[datastreampb.Stream], error) { + o, e := c.client.CreateStream(ctx, req, opts...) + if o == nil { + return nil, e + } else { + ret := operation.NewOperationWrapper[datastreampb.Stream](o) + return &ret, nil + } +} +func (c *DatastreamClientImpl) UpdateStream(ctx context.Context, req *datastreampb.UpdateStreamRequest, opts ...gax.CallOption) (*operation.OperationWrapper[datastreampb.Stream], error) { + o, e := c.client.UpdateStream(ctx, req, opts...) + if o == nil { + return nil, e + } else { + ret := operation.NewOperationWrapper[datastreampb.Stream](o) + return &ret, nil + } + +} + +func (c *DatastreamClientImpl) CreateConnectionProfile(ctx context.Context, createRequest *datastreampb.CreateConnectionProfileRequest) (*operation.OperationWrapper[datastreampb.ConnectionProfile], error) { + op, err := c.client.CreateConnectionProfile(ctx, createRequest) + if err != nil { + return nil, err + } + ret := operation.NewOperationWrapper[datastreampb.ConnectionProfile](op) + return &ret, nil +} + +func (c *DatastreamClientImpl) DeleteConnectionProfile(ctx context.Context, deleteRequest *datastreampb.DeleteConnectionProfileRequest) (*operation.NilOperationWrapper, error) { + op, err := c.client.DeleteConnectionProfile(ctx, deleteRequest) + if err != nil { + return nil, err + } + ret := operation.NewNilOperationWrapper(op) + return &ret, nil +} + +func (c *DatastreamClientImpl) GetConnectionProfile(ctx context.Context, connectionName string) (*datastreampb.ConnectionProfile, error) { + return c.client.GetConnectionProfile(ctx, &datastreampb.GetConnectionProfileRequest{Name: connectionName}) +} + +func (c *DatastreamClientImpl) ListConnectionProfiles(ctx context.Context, listRequest *datastreampb.ListConnectionProfilesRequest, opts ...gax.CallOption) *datastream.ConnectionProfileIterator { + return c.client.ListConnectionProfiles(ctx, listRequest, opts...) +} diff --git a/accessors/clients/datastream/mocks.go b/accessors/clients/datastream/mocks.go new file mode 100644 index 0000000000..8ee2635658 --- /dev/null +++ b/accessors/clients/datastream/mocks.go @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package datastreamclient + +import ( + "context" + + datastream "cloud.google.com/go/datastream/apiv1" + datastreampb "cloud.google.com/go/datastream/apiv1/datastreampb" + "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/operation" + "github.com/googleapis/gax-go/v2" + "github.com/stretchr/testify/mock" +) + +// Mock that implements the DatastreamClient interface. +// Pass in unit tests where DatastreamClient is an input parameter. +type DatastreamClientMock struct { + mock.Mock +} + +func (m *DatastreamClientMock) CreateStream(ctx context.Context, req *datastreampb.CreateStreamRequest, opts ...gax.CallOption) (*operation.OperationWrapper[datastreampb.Stream], error) { + args := m.Called(ctx, req, opts) + // Avoid panic for typeassertion due to null pointer. + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*operation.OperationWrapper[datastreampb.Stream]), args.Error(1) +} +func (m *DatastreamClientMock) UpdateStream(ctx context.Context, req *datastreampb.UpdateStreamRequest, opts ...gax.CallOption) (*operation.OperationWrapper[datastreampb.Stream], error) { + args := m.Called(ctx, req, opts) + // Avoid panic for typeassertion due to null pointer. + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*operation.OperationWrapper[datastreampb.Stream]), args.Error(1) +} + +func (m *DatastreamClientMock) GetConnectionProfile(ctx context.Context, connectionName string) (*datastreampb.ConnectionProfile, error) { + args := m.Called(ctx, connectionName) + return args.Get(0).(*datastreampb.ConnectionProfile), args.Error(1) +} + +func (m *DatastreamClientMock) ListConnectionProfiles(ctx context.Context, listRequest *datastreampb.ListConnectionProfilesRequest, opts ...gax.CallOption) *datastream.ConnectionProfileIterator{ + args := m.Called(ctx, listRequest, opts) + return args.Get(0).(*datastream.ConnectionProfileIterator) +} + +func (m *DatastreamClientMock) DeleteConnectionProfile(ctx context.Context, deleteRequest *datastreampb.DeleteConnectionProfileRequest) (*operation.NilOperationWrapper, error){ + args := m.Called(ctx, deleteRequest) + return args.Get(0).(*operation.NilOperationWrapper), args.Error(1) +} + +func (m *DatastreamClientMock) CreateConnectionProfile(ctx context.Context, createRequest *datastreampb.CreateConnectionProfileRequest) (*operation.OperationWrapper[datastreampb.ConnectionProfile], error){ + args := m.Called(ctx, createRequest) + return args.Get(0).(*operation.OperationWrapper[datastreampb.ConnectionProfile]), args.Error(1) +} \ No newline at end of file diff --git a/accessors/clients/operation/mock_operation.go b/accessors/clients/operation/mock_operation.go new file mode 100644 index 0000000000..eaceb31795 --- /dev/null +++ b/accessors/clients/operation/mock_operation.go @@ -0,0 +1,44 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package operation + +import ( + "context" + "time" + + "github.com/googleapis/gax-go/v2" +) + +type MockOperation[T any] struct { + RetVal *T + RetErr error + Delay time.Duration +} + +func (m MockOperation[T]) Wait(ctx context.Context, opts ...gax.CallOption) (*T, error) { + // As per golang docs, a 0 or -ve delay makes sleep return immediately. + time.Sleep(m.Delay) + return m.RetVal, m.RetErr +} + +type MockNilOperation struct { + RetErr error + Delay time.Duration +} + +func (m *MockNilOperation) Wait(ctx context.Context, opts ...gax.CallOption) error { + // As per golang docs, a 0 or -ve delay makes sleep return immediately. + time.Sleep(m.Delay) + return m.RetErr +} diff --git a/accessors/clients/operation/operation.go b/accessors/clients/operation/operation.go new file mode 100644 index 0000000000..eda83fe022 --- /dev/null +++ b/accessors/clients/operation/operation.go @@ -0,0 +1,56 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package operation + +import ( + "context" + + "github.com/googleapis/gax-go/v2" +) + +// Generic interface for mocking long running operations like CreateStreamOperation, UpdateStreamOperation etc. +type Operation[T any] interface { + Wait(ctx context.Context, opts ...gax.CallOption) (*T, error) +} + +// Wrapping the operation interace in a struct helps us stick to the golang idiom of not returning an interface. +type OperationWrapper[T any] struct { + elem Operation[T] +} + +func (o *OperationWrapper[T]) Wait(ctx context.Context, opts ...gax.CallOption) (*T, error) { + return o.elem.Wait(ctx, opts...) +} + +func NewOperationWrapper[T any](elem Operation[T]) OperationWrapper[T] { + return OperationWrapper[T]{elem} +} + +// Generic interface for mocking long running operations like DeleteConnectionProfileOperation etc. +type NilOperation interface { + Wait(ctx context.Context, opts ...gax.CallOption) error +} + +// Wrapping the operation interace in a struct helps us stick to the golang idiom of not returning an interface. +type NilOperationWrapper struct { + elem NilOperation +} + +func (o *NilOperationWrapper) Wait(ctx context.Context, opts ...gax.CallOption) error { + return o.elem.Wait(ctx, opts...) +} + +func NewNilOperationWrapper(elem NilOperation) NilOperationWrapper { + return NilOperationWrapper{elem} +} diff --git a/accessors/clients/operation/operation_test/operation_test.go b/accessors/clients/operation/operation_test/operation_test.go new file mode 100644 index 0000000000..6b516fd96c --- /dev/null +++ b/accessors/clients/operation/operation_test/operation_test.go @@ -0,0 +1,74 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package operation_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/operation" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/googleapis/gax-go/v2" + "go.uber.org/zap" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +type intOperationValue struct { + val int64 + e error +} + +func (i intOperationValue) Wait(ctx context.Context, opts ...gax.CallOption) (*int64, error) { + return &i.val, i.e +} + +func TestWait(t *testing.T) { + ctx := context.Background() + var testVal int64 = 42 + testError := errors.New("testError") + i := intOperationValue{testVal, testError} + o := operation.NewOperationWrapper[int64](i) + v, e := o.Wait(ctx) + assert.Equal(t, *v, testVal, "operationWrapper.Wait must return correct value") + assert.Equal(t, e, testError, "operationWrapper.Wait must return correct error") +} + +type NilOperationValue struct { + e error +} + +func (i NilOperationValue) Wait(ctx context.Context, opts ...gax.CallOption) (error) { + return i.e +} + +func TestWaitNilOp(t *testing.T) { + ctx := context.Background() + testError := errors.New("testError") + i := NilOperationValue{testError} + o := operation.NewNilOperationWrapper(i) + e := o.Wait(ctx) + assert.Equal(t, e, testError, "operationWrapper.Wait must return correct error") +} \ No newline at end of file diff --git a/accessors/clients/spanner/admin/admin_client.go b/accessors/clients/spanner/admin/admin_client.go new file mode 100644 index 0000000000..6d60d8f194 --- /dev/null +++ b/accessors/clients/spanner/admin/admin_client.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneradmin + +import ( + "context" + "fmt" + "sync" + + database "cloud.google.com/go/spanner/admin/database/apiv1" +) + +var once sync.Once +var spannerAdminClient *database.DatabaseAdminClient + +// This function is declared as a global variable to make it testable. The unit +// tests update this function, acting like a double. +var newDatabaseAdminClient = database.NewDatabaseAdminClient + +func GetOrCreateClient(ctx context.Context) (*database.DatabaseAdminClient, error) { + var err error + if spannerAdminClient == nil { + once.Do(func() { + spannerAdminClient, err = newDatabaseAdminClient(ctx) + }) + if err != nil { + return nil, fmt.Errorf("failed to create spanner admin client: %v", err) + } + return spannerAdminClient, nil + } + return spannerAdminClient, nil +} diff --git a/accessors/clients/spanner/admin/admin_client_test.go b/accessors/clients/spanner/admin/admin_client_test.go new file mode 100644 index 0000000000..7c911ee096 --- /dev/null +++ b/accessors/clients/spanner/admin/admin_client_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneradmin + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + database "cloud.google.com/go/spanner/admin/database/apiv1" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + spannerAdminClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newDatabaseAdminClient + defer func() { newDatabaseAdminClient = oldFunc }() + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return &database.DatabaseAdminClient{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newDatabaseAdminClient + defer func() { newDatabaseAdminClient = oldFunc }() + + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return &database.DatabaseAdminClient{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + spannerAdminClient = nil + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newDatabaseAdminClient + defer func() { newDatabaseAdminClient = oldFunc }() + + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return &database.DatabaseAdminClient{}, nil + } + oldC, err := GetOrCreateClient(ctx) + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx) + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newDatabaseAdminClient + defer func() { newDatabaseAdminClient = oldFunc }() + + newDatabaseAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*database.DatabaseAdminClient, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/accessors/clients/spanner/admin/interface.go b/accessors/clients/spanner/admin/interface.go new file mode 100644 index 0000000000..47f4d138fa --- /dev/null +++ b/accessors/clients/spanner/admin/interface.go @@ -0,0 +1,95 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneradmin + +import ( + "context" + + database "cloud.google.com/go/spanner/admin/database/apiv1" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "github.com/googleapis/gax-go/v2" +) + +// Use this interface instead of database.DatabaseAdminClient to support mocking. +type AdminClient interface { + GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) + CreateDatabase(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) + UpdateDatabaseDdl(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) + GetDatabaseDdl(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) +} + +// Use this interface instead of database.CreateDatabaseOperation to support mocking. +type CreateDatabaseOperation interface { + Wait(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) +} + +// Use this interface instead of database.UpdateDatabaseDdlOperation to support mocking. +type UpdateDatabaseDdlOperation interface { + Wait(ctx context.Context, opts ...gax.CallOption) error +} + +// This implements the AdminClient interface. This is the primary implementation that should be used in all places other than tests. +type AdminClientImpl struct { + adminClient *database.DatabaseAdminClient +} + +func NewAdminClientImpl(ctx context.Context) (*AdminClientImpl, error) { + c, err := GetOrCreateClient(ctx) + if err != nil { + return nil, err + } + return &AdminClientImpl{adminClient: c}, nil +} + +func (c *AdminClientImpl) GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return c.adminClient.GetDatabase(ctx, req, opts...) +} + +func (c *AdminClientImpl) CreateDatabase(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) { + op, err := c.adminClient.CreateDatabase(ctx, req, opts...) + if err != nil { + return nil, err + } + return &CreateDatabaseOperationImpl{dbo: op}, nil +} + +func (c *AdminClientImpl) UpdateDatabaseDdl(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) { + op, err := c.adminClient.UpdateDatabaseDdl(ctx, req, opts...) + if err != nil { + return nil, err + } + return &UpdateDatabaseDdlImpl{dbo: op}, nil +} + +// This implements the CreateDatabaseOperation interface. This is the primary implementation that should be used in all places other than tests. +type CreateDatabaseOperationImpl struct { + dbo *database.CreateDatabaseOperation +} + +func (c *CreateDatabaseOperationImpl) Wait(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { + return c.dbo.Wait(ctx, opts...) +} + +// This implements the UpdateDatabaseDdl interface. This is the primary implementation that should be used in all places other than tests. +type UpdateDatabaseDdlImpl struct { + dbo *database.UpdateDatabaseDdlOperation +} + +func (c *UpdateDatabaseDdlImpl) Wait(ctx context.Context, opts ...gax.CallOption) error { + return c.dbo.Wait(ctx, opts...) +} + +func (c *AdminClientImpl) GetDatabaseDdl(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return c.adminClient.GetDatabaseDdl(ctx, req, opts...) +} \ No newline at end of file diff --git a/accessors/clients/spanner/admin/mocks.go b/accessors/clients/spanner/admin/mocks.go new file mode 100644 index 0000000000..2cc5f51c2b --- /dev/null +++ b/accessors/clients/spanner/admin/mocks.go @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneradmin + +import ( + "context" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "github.com/googleapis/gax-go/v2" +) + +// Mock that implements the AdminClient interface. +// Pass in unit tests where AdminClient is an input parameter. +type AdminClientMock struct { + GetDatabaseMock func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) + CreateDatabaseMock func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) + UpdateDatabaseDdlMock func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) + GetDatabaseDdlMock func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) +} + +func (acm *AdminClientMock) GetDatabase(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return acm.GetDatabaseMock(ctx, req, opts...) +} + +func (acm *AdminClientMock) CreateDatabase(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (CreateDatabaseOperation, error) { + return acm.CreateDatabaseMock(ctx, req, opts...) +} + +func (acm *AdminClientMock) UpdateDatabaseDdl(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (UpdateDatabaseDdlOperation, error) { + return acm.UpdateDatabaseDdlMock(ctx, req, opts...) +} + +func (acm *AdminClientMock) GetDatabaseDdl(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return acm.GetDatabaseDdlMock(ctx, req, opts...) +} + +// Mock that implements the CreateDatabaseOperation interface. +// Pass in unit tests where CreateDatabaseOperation is an input parameter. +type CreateDatabaseOperationMock struct { + WaitMock func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) +} + +func (dbo *CreateDatabaseOperationMock) Wait(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { + return dbo.WaitMock(ctx, opts...) +} + +// Mock that implements the UpdateDatabaseDdlOperation interface. +// Pass in unit tests where UpdateDatabaseDdlOperation is an input parameter. +type UpdateDatabaseDdlOperationMock struct { + WaitMock func(ctx context.Context, opts ...gax.CallOption) error +} + +func (dbo *UpdateDatabaseDdlOperationMock) Wait(ctx context.Context, opts ...gax.CallOption) error { + return dbo.WaitMock(ctx, opts...) +} diff --git a/accessors/clients/spanner/client/spanner_client.go b/accessors/clients/spanner/client/spanner_client.go new file mode 100644 index 0000000000..d6edd81ece --- /dev/null +++ b/accessors/clients/spanner/client/spanner_client.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spannerclient + +import ( + "context" + "fmt" + "sync" + + sp "cloud.google.com/go/spanner" +) + +var once sync.Once +var spannerClient *sp.Client + +// This function is declared as a global variable to make it testable. The unit +// tests edit this function, acting like a double. +var newClient = sp.NewClient + +func GetOrCreateClient(ctx context.Context, dbURI string) (*sp.Client, error) { + var err error + if spannerClient == nil { + once.Do(func() { + spannerClient, err = newClient(ctx, dbURI) + }) + if err != nil { + return nil, fmt.Errorf("failed to create spanner database client: %v", err) + } + return spannerClient, nil + } + return spannerClient, nil +} diff --git a/accessors/clients/spanner/client/spanner_client_test.go b/accessors/clients/spanner/client/spanner_client_test.go new file mode 100644 index 0000000000..66f5059591 --- /dev/null +++ b/accessors/clients/spanner/client/spanner_client_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spannerclient + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + spannerClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + spannerClient = nil + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx, "testURI") + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + oldC, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx, "testURI") + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/accessors/clients/spanner/instanceadmin/interface.go b/accessors/clients/spanner/instanceadmin/interface.go new file mode 100644 index 0000000000..5e195ebf1c --- /dev/null +++ b/accessors/clients/spanner/instanceadmin/interface.go @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spinstanceadmin + +import ( + "context" + + instance "cloud.google.com/go/spanner/admin/instance/apiv1" + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + "github.com/googleapis/gax-go/v2" +) + +// Use this interface instead of instance.InstanceAdminClient to support mocking. +type InstanceAdminClient interface { + GetInstance(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) + GetInstanceConfig(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) +} + +// This implements the InstanceAdminClient interface. This is the primary implementation that should be used in all places other than tests. +type InstanceAdminClientImpl struct { + client *instance.InstanceAdminClient +} + +func NewInstanceAdminClientImpl(ctx context.Context) (*InstanceAdminClientImpl, error) { + c, err := GetOrCreateClient(ctx) + if err != nil { + return nil, err + } + return &InstanceAdminClientImpl{client: c}, nil +} + +func (c *InstanceAdminClientImpl) GetInstance(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return c.client.GetInstance(ctx, req, opts...) +} + +func (c *InstanceAdminClientImpl) GetInstanceConfig(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return c.client.GetInstanceConfig(ctx, req, opts...) +} diff --git a/accessors/clients/spanner/instanceadmin/mocks.go b/accessors/clients/spanner/instanceadmin/mocks.go new file mode 100644 index 0000000000..bb88cca00a --- /dev/null +++ b/accessors/clients/spanner/instanceadmin/mocks.go @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spinstanceadmin + +import ( + "context" + + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + "github.com/googleapis/gax-go/v2" +) + +// Mock that implements the InstanceAdminClient interface. +// Pass in unit tests where InstanceAdminClient is an input parameter. +type InstanceAdminClientMock struct { + GetInstanceMock func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) + GetInstanceConfigMock func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) +} + +func (iac *InstanceAdminClientMock) GetInstance(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return iac.GetInstanceMock(ctx, req, opts...) +} + +func (iac *InstanceAdminClientMock) GetInstanceConfig(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return iac.GetInstanceConfigMock(ctx, req, opts...) +} diff --git a/accessors/clients/spanner/instanceadmin/spanner_instance_admin.go b/accessors/clients/spanner/instanceadmin/spanner_instance_admin.go new file mode 100644 index 0000000000..a3d597e173 --- /dev/null +++ b/accessors/clients/spanner/instanceadmin/spanner_instance_admin.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spinstanceadmin + +import ( + "context" + "fmt" + "sync" + + instance "cloud.google.com/go/spanner/admin/instance/apiv1" +) + +var once sync.Once +var instanceAdminClient *instance.InstanceAdminClient + +// This function is declared as a global variable to make it testable. The unit +// tests update this function, acting like a double. +var newInstanceAdminClient = instance.NewInstanceAdminClient + +func GetOrCreateClient(ctx context.Context) (*instance.InstanceAdminClient, error) { + var err error + if instanceAdminClient == nil { + once.Do(func() { + instanceAdminClient, err = newInstanceAdminClient(ctx) + }) + if err != nil { + return nil, fmt.Errorf("failed to create spanner instance admin client: %v", err) + } + return instanceAdminClient, nil + } + return instanceAdminClient, nil +} diff --git a/accessors/clients/spanner/instanceadmin/spanner_instance_admin_test.go b/accessors/clients/spanner/instanceadmin/spanner_instance_admin_test.go new file mode 100644 index 0000000000..2792d03e82 --- /dev/null +++ b/accessors/clients/spanner/instanceadmin/spanner_instance_admin_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spinstanceadmin + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + instance "cloud.google.com/go/spanner/admin/instance/apiv1" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + instanceAdminClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newInstanceAdminClient + defer func() { newInstanceAdminClient = oldFunc }() + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return &instance.InstanceAdminClient{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newInstanceAdminClient + defer func() { newInstanceAdminClient = oldFunc }() + + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return &instance.InstanceAdminClient{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + instanceAdminClient = nil + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newInstanceAdminClient + defer func() { newInstanceAdminClient = oldFunc }() + + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return &instance.InstanceAdminClient{}, nil + } + oldC, err := GetOrCreateClient(ctx) + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx) + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newInstanceAdminClient + defer func() { newInstanceAdminClient = oldFunc }() + + newInstanceAdminClient = func(ctx context.Context, opts ...option.ClientOption) (*instance.InstanceAdminClient, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/accessors/clients/storage/interface.go b/accessors/clients/storage/interface.go new file mode 100644 index 0000000000..71dcfce4e1 --- /dev/null +++ b/accessors/clients/storage/interface.go @@ -0,0 +1,91 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageclient + +import ( + "context" + "io" + + "cloud.google.com/go/storage" +) + +// Use this interface instead of storage.Client to support mocking. +type StorageClient interface { + Bucket(name string) BucketHandle +} + +// Use this interface instead of storage.BucketHandle to support mocking. +type BucketHandle interface { + Create(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) + Update(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) + Object(name string) ObjectHandle + Delete(ctx context.Context) error +} + +// Use this interface instead of storage.ObjectHandle to support mocking. +type ObjectHandle interface { + NewWriter(ctx context.Context) io.WriteCloser + NewReader(ctx context.Context) (io.ReadCloser, error) +} + +// This implements the StorageClient interface. This is the primary implementation that should be used in all places other than tests. +type StorageClientImpl struct { + client *storage.Client +} + +func NewStorageClientImpl(ctx context.Context) (*StorageClientImpl, error) { + c, err := GetOrCreateClient(ctx) + if err != nil { + return nil, err + } + return &StorageClientImpl{client: c}, nil +} + +func (c *StorageClientImpl) Bucket(name string) BucketHandle { + return &BucketHandleImpl{bucketHandle: c.client.Bucket(name)} +} + +// This implements the BucketHandle interface. This is the primary implementation that should be used in all places other than tests. +type BucketHandleImpl struct { + bucketHandle *storage.BucketHandle +} + +func (b *BucketHandleImpl) Create(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return b.bucketHandle.Create(ctx, projectID, attrs) +} + +func (b *BucketHandleImpl) Update(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + return b.bucketHandle.Update(ctx, uattrs) +} + +func (b *BucketHandleImpl) Object(name string) ObjectHandle { + return &ObjectHandleImpl{objectHandle: b.bucketHandle.Object(name)} +} + +func (b *BucketHandleImpl) Delete(ctx context.Context) error { + return b.bucketHandle.Delete(ctx) +} + +// This implements the ObjectHandle interface. This is the primary implementation that should be used in all places other than tests. +type ObjectHandleImpl struct { + objectHandle *storage.ObjectHandle +} + +func (o *ObjectHandleImpl) NewWriter(ctx context.Context) io.WriteCloser { + return o.objectHandle.NewWriter(ctx) +} + +func (o *ObjectHandleImpl) NewReader(ctx context.Context) (io.ReadCloser, error) { + return o.objectHandle.NewReader(ctx) +} diff --git a/accessors/clients/storage/mocks.go b/accessors/clients/storage/mocks.go new file mode 100644 index 0000000000..7a9386172d --- /dev/null +++ b/accessors/clients/storage/mocks.go @@ -0,0 +1,100 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageclient + +import ( + "context" + "io" + + "cloud.google.com/go/storage" +) + +// Mock that implements the StorageClient interface. +// Pass in unit tests where StorageClient is an input parameter. +type StorageClientMock struct { + BucketMock func(name string) BucketHandle +} + +func (scm *StorageClientMock) Bucket(name string) BucketHandle { + return scm.BucketMock(name) +} + +// Mock that implements the BucketHandle interface. +// Pass in unit tests where BucketHandle is an input parameter. +type BucketHandleMock struct { + CreateMock func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) + UpdateMock func(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) + ObjectMock func(name string) ObjectHandle + DeleteMock func(ctx context.Context) error +} + +func (b *BucketHandleMock) Create(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return b.CreateMock(ctx, projectID, attrs) +} + +func (b *BucketHandleMock) Update(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + return b.UpdateMock(ctx, uattrs) +} + +func (b *BucketHandleMock) Object(name string) ObjectHandle { + return b.ObjectMock(name) +} +func (b *BucketHandleMock) Delete(ctx context.Context) error { + return b.DeleteMock(ctx) +} + +// Mock that implements the ObjectHandle interface. +// Pass in unit tests where ObjectHandle is an input parameter. +type ObjectHandleMock struct { + NewWriterMock func(ctx context.Context) io.WriteCloser + NewReaderMock func(ctx context.Context) (io.ReadCloser, error) +} + +func (o *ObjectHandleMock) NewWriter(ctx context.Context) io.WriteCloser { + return o.NewWriterMock(ctx) +} + +func (o *ObjectHandleMock) NewReader(ctx context.Context) (io.ReadCloser, error) { + return o.NewReaderMock(ctx) +} + +// Mock that implements the io.WriteCloser interface. +// Pass in unit tests where io.WriteCloser is an input parameter. +type WriterMock struct { + WriteMock func(p []byte) (n int, err error) + CloseMock func() error +} + +func (w *WriterMock) Write(p []byte) (n int, err error) { + return w.WriteMock(p) +} + +func (w *WriterMock) Close() error { + return w.CloseMock() +} + +// Mock that implements the io.ReadCloser interface. +// Pass in unit tests where io.ReadCloser is an input parameter. +type ReaderMock struct { + ReadMock func(p []byte) (n int, err error) + CloseMock func() error +} + +func (r *ReaderMock) Read(p []byte) (n int, err error) { + return r.ReadMock(p) +} + +func (r *ReaderMock) Close() error { + return r.CloseMock() +} diff --git a/accessors/clients/storage/storage_client.go b/accessors/clients/storage/storage_client.go new file mode 100644 index 0000000000..09a66c401a --- /dev/null +++ b/accessors/clients/storage/storage_client.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageclient + +import ( + "context" + "fmt" + "sync" + + "cloud.google.com/go/storage" +) + +var once sync.Once +var gcsClient *storage.Client + +// This function is declared as a global variable to make it testable. The unit +// tests update this function, acting like a double. +var newClient = storage.NewClient + +func GetOrCreateClient(ctx context.Context) (*storage.Client, error) { + var err error + if gcsClient == nil { + once.Do(func() { + gcsClient, err = newClient(ctx) + }) + if err != nil { + return nil, fmt.Errorf("failed to create storage client: %v", err) + } + return gcsClient, nil + } + return gcsClient, nil +} diff --git a/accessors/clients/storage/storage_client_test.go b/accessors/clients/storage/storage_client_test.go new file mode 100644 index 0000000000..73bd6b873a --- /dev/null +++ b/accessors/clients/storage/storage_client_test.go @@ -0,0 +1,117 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageclient + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + "cloud.google.com/go/storage" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + gcsClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return &storage.Client{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return &storage.Client{}, nil + } + c, err := GetOrCreateClient(ctx) + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + gcsClient = nil + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return &storage.Client{}, nil + } + oldC, err := GetOrCreateClient(ctx) + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx) + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, opts ...option.ClientOption) (*storage.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx) + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/accessors/dataflow/dataflow_accessor.go b/accessors/dataflow/dataflow_accessor.go index aa7f1bea49..3fed9b3cc9 100644 --- a/accessors/dataflow/dataflow_accessor.go +++ b/accessors/dataflow/dataflow_accessor.go @@ -25,12 +25,14 @@ import ( "golang.org/x/exp/maps" ) +// The DataflowAccessor provides methods that internally use the dataflow client. Methods should only contain generic logic here that can be used by multiple workflows. type DataflowAccessor interface { // This function takes the template parameters (@parameters) and runtime environment config (@cfg) as input, and returns // the generated jobId, equivalentGcloudCommand and error if any. - LaunchFlexTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) + LaunchDataflowTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) } +// This implements the DataflowAccessor interface. This is the primary implementation that should be used in all places other than tests. type DataflowAccessorImpl struct{} func (dfA *DataflowAccessorImpl) LaunchDataflowTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) { diff --git a/accessors/dataflow/dataflow_accessor_test.go b/accessors/dataflow/dataflow_accessor_test.go index ffb8b6537b..76cec60007 100644 --- a/accessors/dataflow/dataflow_accessor_test.go +++ b/accessors/dataflow/dataflow_accessor_test.go @@ -20,6 +20,7 @@ import ( "testing" "cloud.google.com/go/dataflow/apiv1beta3/dataflowpb" + dataflowclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/dataflow" "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/google/go-cmp/cmp" "github.com/googleapis/gax-go/v2" @@ -157,14 +158,6 @@ func getExpectedGcloudCmd2() string { "transformationContextFilePath=gs://transformationContext.json" } -type DataflowClientMock struct { - LaunchFlexTemplateMock func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) -} - -func (dcm *DataflowClientMock) LaunchFlexTemplate(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { - return dcm.LaunchFlexTemplateMock(ctx, req, opts...) -} - func TestLaunchDataflowTemplate(t *testing.T) { ctx := context.Background() da := DataflowAccessorImpl{} @@ -172,7 +165,7 @@ func TestLaunchDataflowTemplate(t *testing.T) { name string params map[string]string cfg DataflowTuningConfig - dcm DataflowClientMock + dcm dataflowclient.DataflowClientMock expectError bool expectedJobId string expectedGcloudCmd string @@ -181,7 +174,7 @@ func TestLaunchDataflowTemplate(t *testing.T) { name: "Basic Correct", params: getParameters(), cfg: getTuningConfig(), - dcm: DataflowClientMock{ + dcm: dataflowclient.DataflowClientMock{ LaunchFlexTemplateMock: func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { return &dataflowpb.LaunchFlexTemplateResponse{Job: &dataflowpb.Job{Id: "1234"}}, nil }, @@ -194,7 +187,7 @@ func TestLaunchDataflowTemplate(t *testing.T) { name: "Request builder error", params: getParameters(), cfg: DataflowTuningConfig{Subnetwork: "test"}, - dcm: DataflowClientMock{ + dcm: dataflowclient.DataflowClientMock{ LaunchFlexTemplateMock: func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { return &dataflowpb.LaunchFlexTemplateResponse{Job: &dataflowpb.Job{Id: "1234"}}, nil }, @@ -207,7 +200,7 @@ func TestLaunchDataflowTemplate(t *testing.T) { name: "Launch flex template throws error", params: getParameters(), cfg: getTuningConfig(), - dcm: DataflowClientMock{ + dcm: dataflowclient.DataflowClientMock{ LaunchFlexTemplateMock: func(ctx context.Context, req *dataflowpb.LaunchFlexTemplateRequest, opts ...gax.CallOption) (*dataflowpb.LaunchFlexTemplateResponse, error) { return nil, fmt.Errorf("test error") }, diff --git a/accessors/dataflow/mocks.go b/accessors/dataflow/mocks.go new file mode 100644 index 0000000000..9a65d9d512 --- /dev/null +++ b/accessors/dataflow/mocks.go @@ -0,0 +1,30 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dataflowaccessor + +import ( + "context" + + dataflowclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/dataflow" +) + +// Mock that implements the DataflowAccessor interface. +// Pass in unit tests where DataflowAccessor is an input parameter. +type DataflowAccessorMock struct { + LaunchFlexTemplateMock func(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) +} + +func (d *DataflowAccessorMock) LaunchDataflowTemplate(ctx context.Context, c dataflowclient.DataflowClient, parameters map[string]string, cfg DataflowTuningConfig) (string, string, error) { + return d.LaunchFlexTemplateMock(ctx, c, parameters, cfg) +} diff --git a/accessors/datastream/datastream_accessor.go b/accessors/datastream/datastream_accessor.go new file mode 100644 index 0000000000..35084d4dc3 --- /dev/null +++ b/accessors/datastream/datastream_accessor.go @@ -0,0 +1,125 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datastream_accessor + +import ( + "context" + "fmt" + "strings" + + "cloud.google.com/go/datastream/apiv1/datastreampb" + "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/datastream" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" + "google.golang.org/api/iterator" +) + +// The DatastreamAccessor provides methods that internally use the datstreamclient. Methods should only contain generic logic here that can be used by multiple workflows. +type DatastreamAccessor interface { + FetchTargetBucketAndPath(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectID string, datastreamDestinationConnCfg streaming.DstConnCfg) (string, string, error) + DeleteConnectionProfile(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, id string, projectId string, region string) error + GetConnProfilesRegion(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectId string, region string) ([]string, error) + CreateConnectionProfile(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, req *datastreampb.CreateConnectionProfileRequest) (*datastreampb.ConnectionProfile, error) + ConnectionProfileExists(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectId string, profileName string, profileLocation string, connectionProfiles map[string][]string) (bool, error) +} +type DatastreamAccessorImpl struct{} + +// FetchTargetBucketAndPath fetches the bucket and path name from a Datastream destination config. +func (da *DatastreamAccessorImpl) FetchTargetBucketAndPath(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectID string, datastreamDestinationConnCfg streaming.DstConnCfg) (string, string, error) { + if datastreamClient == nil { + return "", "", fmt.Errorf("datastream client could not be created") + } + dstProf := fmt.Sprintf("projects/%s/locations/%s/connectionProfiles/%s", projectID, datastreamDestinationConnCfg.Location, datastreamDestinationConnCfg.Name) + // `GetConnectionProfile` has out of box retries. Ref - https://github.com/googleapis/googleapis/blob/master/google/cloud/datastream/v1/datastream_grpc_service_config.json + res, err := datastreamClient.GetConnectionProfile(ctx, dstProf) + if err != nil { + return "", "", fmt.Errorf("could not get connection profiles: %v", err) + } + // Fetch the GCS path from the target connection profile. + // The Get calls for Google Cloud Storage API have out of box retries. + // Reference - https://cloud.google.com/storage/docs/retry-strategy#idempotency-operations + gcsProfile := res.Profile.(*datastreampb.ConnectionProfile_GcsProfile).GcsProfile + bucketName := gcsProfile.Bucket + prefix := gcsProfile.RootPath + datastreamDestinationConnCfg.Prefix + prefix = utils.ConcatDirectoryPath(prefix, "data/") + return bucketName, prefix, nil +} + +// Deletes a connection Profile +func (da *DatastreamAccessorImpl) DeleteConnectionProfile(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, id string, projectId string, region string) error { + op, err := datastreamClient.DeleteConnectionProfile(ctx, &datastreampb.DeleteConnectionProfileRequest{ + Name: fmt.Sprintf("projects/%s/locations/%s/connectionProfiles/%s", projectId, region, id), + }) + if err != nil { + return err + } + + err = op.Wait(ctx) + if err != nil { + return err + } + return nil +} + +// Creates new connection Profile +func (da *DatastreamAccessorImpl) CreateConnectionProfile(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, req *datastreampb.CreateConnectionProfileRequest) (*datastreampb.ConnectionProfile, error) { + op, err := datastreamClient.CreateConnectionProfile(ctx, req) + if err != nil { + return nil, err + } + + return op.Wait(ctx) +} + +// Gets all connection profiles in a region +func (da *DatastreamAccessorImpl) GetConnProfilesRegion(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectId string, region string) ([]string, error) { + profilesIt := datastreamClient.ListConnectionProfiles(ctx, &datastreampb.ListConnectionProfilesRequest{Parent: "projects/" + projectId + "/locations/" + region}) + var profiles []string = []string{} + for { + resp, err := profilesIt.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, err + } else { + profiles = append(profiles, strings.Split(resp.Name, "/")[5]) + } + } + return profiles, nil +} + +// returns true if connection profile exists in a provided region else false +func (da *DatastreamAccessorImpl) ConnectionProfileExists(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectId string, profileName string, profileLocation string, connectionProfiles map[string][]string) (bool, error) { + // Check if connection profiles for the given region are fetched. if not, fetch them + profiles, ok := connectionProfiles[profileLocation] + var err error = nil + if !ok { + profiles, err = da.GetConnProfilesRegion(ctx, datastreamClient, projectId, profileLocation) + if err != nil { + return false, err + } + connectionProfiles[profileLocation] = profiles + } + + // Check if connection profile exists in the provided region + for _, element := range profiles { + if element == profileName { + return true, nil + } + } + + return false, nil +} diff --git a/accessors/datastream/datastream_accessor_test.go b/accessors/datastream/datastream_accessor_test.go new file mode 100644 index 0000000000..2015ca723f --- /dev/null +++ b/accessors/datastream/datastream_accessor_test.go @@ -0,0 +1,164 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datastream_accessor_test + +import ( + "context" + "fmt" + "testing" + + "cloud.google.com/go/datastream/apiv1/datastreampb" + "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/datastream" + "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/operation" + datastream_accessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/datastream" + "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestFetchTargetBucketAndPath(t *testing.T) { + dstConfig := streaming.DstConnCfg{ + Location: "region-X", + Name: "profile-name", + Prefix: "/", + } + ctx := context.Background() + da := datastream_accessor.DatastreamAccessorImpl{} + testCases := []struct { + name string + dsm datastreamclient.DatastreamClientMock + connectionProfile *datastreampb.ConnectionProfile + getProfileErr error + expectedBucketName string + expectedPrefix string + expectError bool + }{ + { + name: "basic correct", + connectionProfile: &datastreampb.ConnectionProfile{Profile: &datastreampb.ConnectionProfile_GcsProfile{GcsProfile: &datastreampb.GcsProfile{Bucket: "bucket", RootPath: "/"}}}, + getProfileErr: nil, + expectedBucketName: "bucket", + expectedPrefix: "/data/", + expectError: false, + }, + { + name: "get connection profile error", + connectionProfile: nil, + getProfileErr: fmt.Errorf("error"), + expectedBucketName: "", + expectedPrefix: "", + expectError: true, + }, + { + name: "empty string", + connectionProfile: &datastreampb.ConnectionProfile{Profile: &datastreampb.ConnectionProfile_GcsProfile{GcsProfile: &datastreampb.GcsProfile{Bucket: "", RootPath: ""}}}, + getProfileErr: nil, + expectedBucketName: "", + expectedPrefix: "data/", + expectError: false, + }, + } + for _, tc := range testCases { + dsm := datastreamclient.DatastreamClientMock{} + dsm.On("GetConnectionProfile", mock.Anything, mock.Anything).Return(tc.connectionProfile, tc.getProfileErr) + bucketName, prefix, err := da.FetchTargetBucketAndPath(ctx, &dsm, "project-id", dstConfig) + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.expectedBucketName, bucketName, tc.name) + assert.Equal(t, tc.expectedPrefix, prefix, tc.name) + } +} + +func TestDeleteConnectionProfile(t *testing.T) { + ctx := context.Background() + da := datastream_accessor.DatastreamAccessorImpl{} + testCases := []struct { + name string + op *operation.MockNilOperation + deleteConnProfileErr error + expectError bool + }{ + { + name: "basic correct", + op: &operation.MockNilOperation{ + RetErr: nil, + }, + expectError: false, + }, + { + name: "delete connection profile error", + op: nil, + deleteConnProfileErr: fmt.Errorf("error"), + expectError: true, + }, + { + name: "operation wait error", + op: &operation.MockNilOperation{ + RetErr: fmt.Errorf("error"), + }, + expectError: true, + }, + } + for _, tc := range testCases { + dsm := datastreamclient.DatastreamClientMock{} + op := operation.NewNilOperationWrapper(tc.op) + dsm.On("DeleteConnectionProfile", mock.Anything, mock.Anything).Return(&op, tc.deleteConnProfileErr) + err := da.DeleteConnectionProfile(ctx, &dsm, "id", "project-id", "region") + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestCreateConnectionProfile(t *testing.T) { + ctx := context.Background() + da := datastream_accessor.DatastreamAccessorImpl{} + testCases := []struct { + name string + op operation.MockOperation[datastreampb.ConnectionProfile] + createProfileError error + expectError bool + }{ + { + name: "basic correct", + op: operation.MockOperation[datastreampb.ConnectionProfile]{ + RetVal: &datastreampb.ConnectionProfile{}, + }, + createProfileError: nil, + expectError: false, + }, + { + name: "create connection profile error", + op: operation.MockOperation[datastreampb.ConnectionProfile]{ + RetVal: &datastreampb.ConnectionProfile{}, + }, + createProfileError: fmt.Errorf("error"), + expectError: true, + }, + { + name: "operation wait error", + op: operation.MockOperation[datastreampb.ConnectionProfile]{ + RetVal: &datastreampb.ConnectionProfile{}, + RetErr: fmt.Errorf("error"), + }, + createProfileError: nil, + expectError: true, + }, + } + for _, tc := range testCases { + dsm := datastreamclient.DatastreamClientMock{} + op := operation.NewOperationWrapper[datastreampb.ConnectionProfile](tc.op) + dsm.On("CreateConnectionProfile", mock.Anything, mock.Anything).Return(&op, tc.createProfileError) + _, err := da.CreateConnectionProfile(ctx, &dsm, &datastreampb.CreateConnectionProfileRequest{}) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} diff --git a/accessors/datastream/mocks.go b/accessors/datastream/mocks.go new file mode 100644 index 0000000000..3e530ae2f2 --- /dev/null +++ b/accessors/datastream/mocks.go @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datastream_accessor + +import ( + "context" + + "cloud.google.com/go/datastream/apiv1/datastreampb" + datastreamclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/datastream" + "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" +) + +type DatastreamAccessorMock struct { + FetchTargetBucketAndPathMock func(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectID string, datastreamDestinationConnCfg streaming.DstConnCfg) (string, string, error) + DeleteConnectionProfileMock func(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, id string, projectId string, region string) error + GetConnProfilesRegionMock func(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectId string, region string) ([]string, error) + CreateConnectionProfileMock func(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, req *datastreampb.CreateConnectionProfileRequest) (*datastreampb.ConnectionProfile, error) + ConnectionProfileExistsMock func(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectId string, profileName string, profileLocation string, connectionProfiles map[string][]string) (bool, error) +} + +func (dam *DatastreamAccessorMock) FetchTargetBucketAndPath(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectID string, datastreamDestinationConnCfg streaming.DstConnCfg) (string, string, error) { + return dam.FetchTargetBucketAndPathMock(ctx, datastreamClient, projectID, datastreamDestinationConnCfg) +} + +func (dam *DatastreamAccessorMock) DeleteConnectionProfile(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, id string, projectId string, region string) error { + return dam.DeleteConnectionProfileMock(ctx, datastreamClient, id, projectId, region) +} + +func (dam *DatastreamAccessorMock) GetConnProfilesRegion(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectId string, region string) ([]string, error) { + return dam.GetConnProfilesRegionMock(ctx, datastreamClient, projectId, region) +} + +func (dam *DatastreamAccessorMock) CreateConnectionProfile(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, req *datastreampb.CreateConnectionProfileRequest) (*datastreampb.ConnectionProfile, error) { + return dam.CreateConnectionProfileMock(ctx, datastreamClient, req) +} + +func (dam *DatastreamAccessorMock) ConnectionProfileExists(ctx context.Context, datastreamClient datastreamclient.DatastreamClient, projectId string, profileName string, profileLocation string, connectionProfiles map[string][]string) (bool, error) { + return dam.ConnectionProfileExistsMock(ctx, datastreamClient, projectId, profileName, profileLocation, connectionProfiles) +} diff --git a/accessors/helpers/dataflow/dataflow_helpers.go b/accessors/helpers/dataflow/dataflow_helpers.go new file mode 100644 index 0000000000..ee117ccced --- /dev/null +++ b/accessors/helpers/dataflow/dataflow_helpers.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This is a package is kept with accessors because some functions import other accessors. +// The common/utils package should not import any SMT dependency. +package dataflowhelpers + +import ( + "context" + "encoding/json" + + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + dataflowaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/dataflow" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" +) + +// This package contains common helper methods using the accessor package. This will be used by multiple flows. +// Do not move to util since util only expects methods that do no import any other internal dependency. + +// Reads any local or gcs file and unmarshals the data into a DataflowTuningConfig struct. +func UnmarshalDataflowTuningConfig(ctx context.Context, sc storageclient.StorageClient, sa storageaccessor.StorageAccessor, filePath string) (dataflowaccessor.DataflowTuningConfig, error) { + jsonStr, err := sa.ReadAnyFile(ctx, sc, filePath) + if err != nil { + return dataflowaccessor.DataflowTuningConfig{}, err + } + tuningCfg := dataflowaccessor.DataflowTuningConfig{} + err = json.Unmarshal([]byte(jsonStr), &tuningCfg) + if err != nil { + return dataflowaccessor.DataflowTuningConfig{}, err + } + return tuningCfg, nil +} diff --git a/accessors/helpers/dataflow/dataflow_helpers_test.go b/accessors/helpers/dataflow/dataflow_helpers_test.go new file mode 100644 index 0000000000..1418dea448 --- /dev/null +++ b/accessors/helpers/dataflow/dataflow_helpers_test.go @@ -0,0 +1,141 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dataflowhelpers + +import ( + "context" + "fmt" + "os" + "testing" + + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + dataflowaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/dataflow" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func TestUnmarshalDataflowTuningConfig(t *testing.T) { + testCases := []struct { + name string + sam storageaccessor.StorageAccessorMock + expectError bool + want dataflowaccessor.DataflowTuningConfig + }{ + { + name: "Basic", + sam: storageaccessor.StorageAccessorMock{ + ReadAnyFileMock: func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return `{ + "projectId": "test-project", + "jobName": "test-job-name", + "location": "us-central1", + "network": "test-network", + "subnetwork": "test-subnetwork", + "hostProjectId": "test-host-project", + "maxWorkers": 3, + "numWorkers": 2, + "serviceAccountEmail": "abc@xyz.com", + "machineType": "n1-standard-8", + "additionalUserLabels": {"my": "label"}, + "kmsKeyName": "test-key", + "gcsTemplatePath": "gs://path", + "additionalExperiments": ["xyz","123"], + "enableStreamingEngine": true + }`, nil + }, + }, + expectError: false, + want: dataflowaccessor.DataflowTuningConfig{ + ProjectId: "test-project", + JobName: "test-job-name", + Location: "us-central1", + Network: "test-network", + Subnetwork: "test-subnetwork", + VpcHostProjectId: "test-host-project", + MaxWorkers: 3, + NumWorkers: 2, + ServiceAccountEmail: "abc@xyz.com", + MachineType: "n1-standard-8", + AdditionalUserLabels: map[string]string{"my": "label"}, + KmsKeyName: "test-key", + GcsTemplatePath: "gs://path", + AdditionalExperiments: []string{"xyz", "123"}, + EnableStreamingEngine: true, + }, + }, + { + name: "Defaults", + sam: storageaccessor.StorageAccessorMock{ + ReadAnyFileMock: func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return `{}`, nil + }, + }, + expectError: false, + want: dataflowaccessor.DataflowTuningConfig{ + ProjectId: "", + JobName: "", + Location: "", + Network: "", + Subnetwork: "", + VpcHostProjectId: "", + MaxWorkers: 0, + NumWorkers: 0, + ServiceAccountEmail: "", + MachineType: "", + AdditionalUserLabels: nil, + KmsKeyName: "", + GcsTemplatePath: "", + AdditionalExperiments: nil, + EnableStreamingEngine: false, + }, + }, + { + name: "ReadAnyFile throws error", + sam: storageaccessor.StorageAccessorMock{ + ReadAnyFileMock: func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return "", fmt.Errorf("test error") + }, + }, + expectError: true, + want: dataflowaccessor.DataflowTuningConfig{}, + }, + { + name: "Json unmarshall throws error", + sam: storageaccessor.StorageAccessorMock{ + ReadAnyFileMock: func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return "{\"abc\"", nil + }, + }, + expectError: true, + want: dataflowaccessor.DataflowTuningConfig{}, + }, + } + ctx := context.Background() + for _, tc := range testCases { + got, err := UnmarshalDataflowTuningConfig(ctx, nil, &tc.sam, "unused/path/due/to/mock") + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} diff --git a/accessors/spanner/mocks.go b/accessors/spanner/mocks.go new file mode 100644 index 0000000000..fc8db94f3e --- /dev/null +++ b/accessors/spanner/mocks.go @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneraccessor + +import ( + "context" + + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spinstanceadmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/instanceadmin" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" +) + +// Mock that implements the SpannerAccessor interface. +// Pass in unit tests where SpannerAccessor is an input parameter. +type SpannerAccessorMock struct { + GetDatabaseDialectMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (string, error) + CheckExistingDbMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (bool, error) + CreateEmptyDatabaseMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error + GetSpannerLeaderLocationMock func(ctx context.Context, instanceClient spinstanceadmin.InstanceAdminClient, instanceURI string) (string, error) + CheckIfChangeStreamExistsMock func(ctx context.Context, changeStreamName, dbURI string) (bool, error) + ValidateChangeStreamOptionsMock func(ctx context.Context, changeStreamName, dbURI string) error + CreateChangeStreamMock func(ctx context.Context, adminClient spanneradmin.AdminClient, changeStreamName, dbURI string) error + CreateDatabaseMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) error + UpdateDatabaseMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string) error + CreateOrUpdateDatabaseMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI, driver string, conv *internal.Conv, migrationType string) error + VerifyDbMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (dbExists bool, err error) + ValidateDDLMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error + UpdateDDLForeignKeysMock func(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) +} + +func (sam *SpannerAccessorMock) GetDatabaseDialect(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (string, error) { + return sam.GetDatabaseDialectMock(ctx, adminClient, dbURI) +} + +func (sam *SpannerAccessorMock) CheckExistingDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (bool, error) { + return sam.CheckExistingDbMock(ctx, adminClient, dbURI) +} + +func (sam *SpannerAccessorMock) CreateEmptyDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error { + return sam.CreateEmptyDatabaseMock(ctx, adminClient, dbURI) +} + +func (sam *SpannerAccessorMock) GetSpannerLeaderLocation(ctx context.Context, instanceClient spinstanceadmin.InstanceAdminClient, instanceURI string) (string, error) { + return sam.GetSpannerLeaderLocationMock(ctx, instanceClient, instanceURI) +} + +func (sam *SpannerAccessorMock) CheckIfChangeStreamExists(ctx context.Context, changeStreamName, dbURI string) (bool, error) { + return sam.CheckIfChangeStreamExistsMock(ctx, changeStreamName, dbURI) +} + +func (sam *SpannerAccessorMock) ValidateChangeStreamOptions(ctx context.Context, changeStreamName, dbURI string) error { + return sam.ValidateChangeStreamOptionsMock(ctx, changeStreamName, dbURI) +} + +func (sam *SpannerAccessorMock) CreateChangeStream(ctx context.Context, adminClient spanneradmin.AdminClient, changeStreamName, dbURI string) error { + return sam.CreateChangeStreamMock(ctx, adminClient, changeStreamName, dbURI) +} + +func (sam *SpannerAccessorMock) CreateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) error { + return sam.CreateDatabaseMock(ctx,adminClient, dbURI, conv, driver, migrationType) +} +func (sam *SpannerAccessorMock) UpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string) error{ + return sam.UpdateDatabaseMock(ctx, adminClient, dbURI, conv, driver) +} +func (sam *SpannerAccessorMock) CreateOrUpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI, driver string, conv *internal.Conv, migrationType string) error { + return sam.CreateOrUpdateDatabaseMock(ctx,adminClient, dbURI, driver, conv, migrationType) +} +func (sam *SpannerAccessorMock) VerifyDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (dbExists bool, err error){ + return sam.VerifyDbMock(ctx, adminClient, dbURI) +} +func (sam *SpannerAccessorMock) ValidateDDL(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error{ + return sam.ValidateDDLMock(ctx, adminClient, dbURI) +} +func (sam *SpannerAccessorMock) UpdateDDLForeignKeys(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) {} \ No newline at end of file diff --git a/accessors/spanner/spanner_accessor.go b/accessors/spanner/spanner_accessor.go new file mode 100644 index 0000000000..092a29dd3d --- /dev/null +++ b/accessors/spanner/spanner_accessor.go @@ -0,0 +1,432 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneraccessor + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "cloud.google.com/go/spanner" + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spannerclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/client" + spinstanceadmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/instanceadmin" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" + "go.uber.org/zap" + "google.golang.org/api/iterator" + adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" +) + +var ( + // Set the maximum number of concurrent workers during foreign key creation. + // This number should not be too high so as to not hit the AdminQuota limit. + // AdminQuota limits are mentioned here: https://cloud.google.com/spanner/quotas#administrative_limits + // If facing a quota limit error, consider reducing this value. + MaxWorkers = 50 +) + + +// The SpannerAccessor provides methods that internally use a spanner client (can be adminClient/databaseclient/instanceclient etc). +// Methods should only contain generic logic here that can be used by multiple workflows. +type SpannerAccessor interface { + // Fetch the dialect of the spanner database. + GetDatabaseDialect(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (string, error) + // CheckExistingDb checks whether the database with dbURI exists or not. + // If API call doesn't respond then user is informed after every 5 minutes on command line. + CheckExistingDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (bool, error) + // Create a database with no schema. + CreateEmptyDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error + // Fetch the leader of the Spanner instance. + GetSpannerLeaderLocation(ctx context.Context, instanceClient spinstanceadmin.InstanceAdminClient, instanceURI string) (string, error) + // Check if a change stream already exists. + CheckIfChangeStreamExists(ctx context.Context, changeStreamName, dbURI string) (bool, error) + // Validate that change stream option 'VALUE_CAPTURE_TYPE' is 'NEW_ROW'. + ValidateChangeStreamOptions(ctx context.Context, changeStreamName, dbURI string) error + // Create a change stream with default options. + CreateChangeStream(ctx context.Context, adminClient spanneradmin.AdminClient, changeStreamName, dbURI string) error + // Create new Database using conv + CreateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) error + // Update Database using conv + UpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string) error + // Updates an existing Spanner database or create a new one if one does not exist using Conv + CreateOrUpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI, driver string, conv *internal.Conv, migrationType string) error + // Check whether the db exists and if it does, verify if the schema is what we currently support. + VerifyDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (dbExists bool, err error) + // Verify if an existing DB's ddl follows what is supported by Spanner migration tool. Currently, we only support empty schema when db already exists. + ValidateDDL(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error + // UpdateDDLForeignKeys updates the Spanner database with foreign key constraints using ALTER TABLE statements. + UpdateDDLForeignKeys(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) +} + +// This implements the SpannerAccessor interface. This is the primary implementation that should be used in all places other than tests. +type SpannerAccessorImpl struct{} + +func (sp *SpannerAccessorImpl) GetDatabaseDialect(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (string, error) { + result, err := adminClient.GetDatabase(ctx, &databasepb.GetDatabaseRequest{Name: dbURI}) + if err != nil { + return "", fmt.Errorf("cannot connect to database: %v", err) + } + return strings.ToLower(result.DatabaseDialect.String()), nil +} + +func (sp *SpannerAccessorImpl) CheckExistingDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (bool, error) { + gotResponse := make(chan bool) + var err error + go func() { + _, err = adminClient.GetDatabase(ctx, &databasepb.GetDatabaseRequest{Name: dbURI}) + gotResponse <- true + }() + for { + select { + case <-time.After(5 * time.Minute): + logger.Log.Debug("WARNING! API call not responding: make sure that spanner api endpoint is configured properly") + case <-gotResponse: + if err != nil { + if utils.ContainsAny(strings.ToLower(err.Error()), []string{"database not found"}) { + return false, nil + } + return false, fmt.Errorf("can't get database info: %s", err) + } + return true, nil + } + } +} + +func (sp *SpannerAccessorImpl) CreateEmptyDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error { + project, instance, dbName := utils.ParseDbURI(dbURI) + req := &databasepb.CreateDatabaseRequest{ + Parent: fmt.Sprintf("projects/%s/instances/%s", project, instance), + CreateStatement: "CREATE DATABASE `" + dbName + "`", + } + op, err := adminClient.CreateDatabase(ctx, req) + if err != nil { + return fmt.Errorf("can't build CreateDatabaseRequest: %w", utils.AnalyzeError(err, dbURI)) + } + if _, err := op.Wait(ctx); err != nil { + return fmt.Errorf("createDatabase call failed: %w", utils.AnalyzeError(err, dbURI)) + } + return nil +} + +func (sp *SpannerAccessorImpl) GetSpannerLeaderLocation(ctx context.Context, instanceClient spinstanceadmin.InstanceAdminClient, instanceURI string) (string, error) { + instanceInfo, err := instanceClient.GetInstance(ctx, &instancepb.GetInstanceRequest{Name: instanceURI}) + if err != nil { + return "", err + } + instanceConfig, err := instanceClient.GetInstanceConfig(ctx, &instancepb.GetInstanceConfigRequest{Name: instanceInfo.Config}) + if err != nil { + return "", err + + } + for _, replica := range instanceConfig.Replicas { + if replica.DefaultLeaderLocation { + return replica.Location, nil + } + } + return "", fmt.Errorf("no leader found for spanner instance %s while trying fetch location", instanceURI) +} + +// Consider using a CreateChangestream operation and check for alreadyExists error. That uses adminClient which can be unit tested. +func (sp *SpannerAccessorImpl) CheckIfChangeStreamExists(ctx context.Context, changeStreamName, dbURI string) (bool, error) { + spClient, err := spannerclient.GetOrCreateClient(ctx, dbURI) + if err != nil { + return false, err + } + stmt := spanner.Statement{ + SQL: `SELECT CHANGE_STREAM_NAME FROM information_schema.change_streams`, + } + iter := spClient.Single().Query(ctx, stmt) + defer iter.Stop() + var cs_name string + csExists := false + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return false, fmt.Errorf("couldn't read row from change_streams table: %w", err) + } + err = row.Columns(&cs_name) + if err != nil { + return false, fmt.Errorf("can't scan row from change_streams table: %v", err) + } + if cs_name == changeStreamName { + csExists = true + break + } + } + return csExists, nil +} + +func (sp *SpannerAccessorImpl) ValidateChangeStreamOptions(ctx context.Context, changeStreamName, dbURI string) error { + spClient, err := spannerclient.GetOrCreateClient(ctx, dbURI) + if err != nil { + return err + } + // Validate if change stream options are set correctly. + stmt := spanner.Statement{ + SQL: `SELECT option_value FROM information_schema.change_stream_options + WHERE change_stream_name = @p1 AND option_name = 'value_capture_type'`, + Params: map[string]interface{}{ + "p1": changeStreamName, + }, + } + iter := spClient.Single().Query(ctx, stmt) + defer iter.Stop() + var option_value string + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return fmt.Errorf("couldn't read row from change_stream_options table: %w", err) + } + err = row.Columns(&option_value) + if err != nil { + return fmt.Errorf("can't scan row from change_stream_options table: %v", err) + } + if option_value != "NEW_ROW" { + return fmt.Errorf("VALUE_CAPTURE_TYPE for changestream %s is not NEW_ROW. Please update the changestream option or create a new one", changeStreamName) + } + } + return nil +} + +func (sp *SpannerAccessorImpl) CreateChangeStream(ctx context.Context, adminClient spanneradmin.AdminClient, changeStreamName, dbURI string) error { + op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ + Database: dbURI, + // TODO: create change stream for only the tables present in Spanner. + Statements: []string{fmt.Sprintf("CREATE CHANGE STREAM %s FOR ALL OPTIONS (value_capture_type = 'NEW_ROW', retention_period = '7d')", changeStreamName)}, + }) + if err != nil { + return fmt.Errorf("cannot submit request create change stream request: %v", err) + } + if err := op.Wait(ctx); err != nil { + return fmt.Errorf("could not update database ddl: %v", err) + } else { + logger.Log.Debug("Successfully created changestream", zap.String("changeStreamName", changeStreamName)) + } + return nil +} + +// CreateDatabase returns a newly create Spanner DB. +// It automatically determines an appropriate project, selects a +// Spanner instance to use, generates a new Spanner DB name, +// and call into the Spanner admin interface to create the new DB. +func (sp *SpannerAccessorImpl) CreateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) error { + project, instance, dbName := utils.ParseDbURI(dbURI) + // The schema we send to Spanner excludes comments (since Cloud + // Spanner DDL doesn't accept them), and protects table and col names + // using backticks (to avoid any issues with Spanner reserved words). + // Foreign Keys are set to false since we create them post data migration. + req := &adminpb.CreateDatabaseRequest{ + Parent: fmt.Sprintf("projects/%s/instances/%s", project, instance), + } + if conv.SpDialect == constants.DIALECT_POSTGRESQL { + // PostgreSQL dialect doesn't support: + // a) backticks around the database name, and + // b) DDL statements as part of a CreateDatabase operation (so schema + // must be set using a separate UpdateDatabase operation). + req.CreateStatement = "CREATE DATABASE \"" + dbName + "\"" + req.DatabaseDialect = adminpb.DatabaseDialect_POSTGRESQL + } else { + req.CreateStatement = "CREATE DATABASE `" + dbName + "`" + if migrationType == constants.DATAFLOW_MIGRATION { + req.ExtraStatements = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) + } else { + req.ExtraStatements = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: false, SpDialect: conv.SpDialect, Source: driver}) + } + + } + + op, err := adminClient.CreateDatabase(ctx, req) + if err != nil { + return fmt.Errorf("can't build CreateDatabaseRequest: %w", utils.AnalyzeError(err, dbURI)) + } + if _, err := op.Wait(ctx); err != nil { + return fmt.Errorf("createDatabase call failed: %w", utils.AnalyzeError(err, dbURI)) + } + + if conv.SpDialect == constants.DIALECT_POSTGRESQL { + // Update schema separately for PG databases. + return sp.UpdateDatabase(ctx, adminClient, dbURI, conv, driver) + } + return nil +} + +// UpdateDatabase updates an existing spanner database. +func (sp *SpannerAccessorImpl) UpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string) error { + // The schema we send to Spanner excludes comments (since Cloud + // Spanner DDL doesn't accept them), and protects table and col names + // using backticks (to avoid any issues with Spanner reserved words). + // Foreign Keys are set to false since we create them post data migration. + schema := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: false, SpDialect: conv.SpDialect, Source: driver}) + req := &adminpb.UpdateDatabaseDdlRequest{ + Database: dbURI, + Statements: schema, + } + // Update queries for postgres as target db return response after more + // than 1 min for large schemas, therefore, timeout is specified as 5 minutes + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + op, err := adminClient.UpdateDatabaseDdl(ctx, req) + if err != nil { + return fmt.Errorf("can't build UpdateDatabaseDdlRequest: %w", utils.AnalyzeError(err, dbURI)) + } + if err := op.Wait(ctx); err != nil { + return fmt.Errorf("UpdateDatabaseDdl call failed: %w", utils.AnalyzeError(err, dbURI)) + } + return nil +} + +// CreatesOrUpdatesDatabase updates an existing Spanner database or creates a new one if one does not exist. +func (sp *SpannerAccessorImpl) CreateOrUpdateDatabase(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI, driver string, conv *internal.Conv, migrationType string) error { + dbExists, err := sp.VerifyDb(ctx, adminClient, dbURI) + if err != nil { + return err + } + if dbExists { + if conv.SpDialect != constants.DIALECT_POSTGRESQL && migrationType == constants.DATAFLOW_MIGRATION { + return fmt.Errorf("spanner migration tool does not support minimal downtime schema/schema-and-data migrations to an existing database") + } + err := sp.UpdateDatabase(ctx, adminClient, dbURI, conv, driver) + if err != nil { + return fmt.Errorf("can't update database schema: %v", err) + } + } else { + err := sp.CreateDatabase(ctx, adminClient, dbURI, conv, driver, migrationType) + if err != nil { + return fmt.Errorf("can't create database: %v", err) + } + } + return nil +} + +// VerifyDb checks whether the db exists and if it does, verifies if the schema is what we currently support. +func (sp *SpannerAccessorImpl) VerifyDb(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) (dbExists bool, err error) { + dbExists, err = sp.CheckExistingDb(ctx, adminClient, dbURI) + if err != nil { + return dbExists, err + } + if dbExists { + err = sp.ValidateDDL(ctx, adminClient, dbURI) + } + return dbExists, err +} + +// ValidateDDL verifies if an existing DB's ddl follows what is supported by Spanner migration tool. Currently, +// we only support empty schema when db already exists. +func (sp *SpannerAccessorImpl) ValidateDDL(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string) error { + dbDdl, err := adminClient.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{Database: dbURI}) + if err != nil { + return fmt.Errorf("can't fetch database ddl: %v", err) + } + if len(dbDdl.Statements) != 0 { + return fmt.Errorf("spanner migration tool supports writing to existing databases only if they have an empty schema") + } + return nil +} + + +// UpdateDDLForeignKeys updates the Spanner database with foreign key +// constraints using ALTER TABLE statements. +func (sp *SpannerAccessorImpl) UpdateDDLForeignKeys(ctx context.Context, adminClient spanneradmin.AdminClient, dbURI string, conv *internal.Conv, driver string, migrationType string) { + + if conv.SpDialect != constants.DIALECT_POSTGRESQL && migrationType == constants.DATAFLOW_MIGRATION { + //foreign keys were applied as part of CreateDatabase + return + } + + // The schema we send to Spanner excludes comments (since Cloud + // Spanner DDL doesn't accept them), and protects table and col names + // using backticks (to avoid any issues with Spanner reserved words). + fkStmts := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: false, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) + if len(fkStmts) == 0 { + return + } + if len(fkStmts) > 50 { + logger.Log.Warn(` + Warning: Large number of foreign keys detected. Spanner can take a long amount of + time to create foreign keys (over 5 mins per batch of Foreign Keys even with no data). + Spanner migration tool does not have control over a single foreign key creation time. The number + of concurrent Foreign Key Creation Requests sent to spanner can be increased by + tweaking the MaxWorkers variable (https://github.com/GoogleCloudPlatform/spanner-migration-tool/blob/master/conversion/conversion.go#L89). + However, setting it to a very high value might lead to exceeding the admin quota limit. Spanner migration tool tries to stay under the + admin quota limit by spreading the FK creation requests over time.`) + } + msg := fmt.Sprintf("Updating schema of database %s with foreign key constraints ...", dbURI) + conv.Audit.Progress = *internal.NewProgress(int64(len(fkStmts)), msg, internal.Verbose(), true, int(internal.ForeignKeyUpdateInProgress)) + + workers := make(chan int, MaxWorkers) + for i := 1; i <= MaxWorkers; i++ { + workers <- i + } + var progressMutex sync.Mutex + progress := int64(0) + + // We dispatch parallel foreign key create requests to ensure the backfill runs in parallel to reduce overall time. + // This cuts down the time taken to a third (approx) compared to Serial and Batched creation. We also do not want to create + // too many requests and get throttled due to network or hitting catalog memory limits. + // Ensure atmost `MaxWorkers` go routines run in parallel that each update the ddl with one foreign key statement. + for _, fkStmt := range fkStmts { + workerID := <-workers + go func(fkStmt string, workerID int) { + defer func() { + // Locking the progress reporting otherwise progress results displayed could be in random order. + progressMutex.Lock() + progress++ + conv.Audit.Progress.MaybeReport(progress) + progressMutex.Unlock() + workers <- workerID + }() + internal.VerbosePrintf("Submitting new FK create request: %s\n", fkStmt) + logger.Log.Debug("Submitting new FK create request", zap.String("fkStmt", fkStmt)) + + op, err := adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ + Database: dbURI, + Statements: []string{fkStmt}, + }) + if err != nil { + logger.Log.Debug("Can't add foreign key with statement:"+fkStmt+"\n due to error:"+err.Error()+" Skipping this foreign key...\n") + conv.Unexpected(fmt.Sprintf("Can't add foreign key with statement %s: %s", fkStmt, err)) + return + } + if err := op.Wait(ctx); err != nil { + logger.Log.Debug("Can't add foreign key with statement:"+fkStmt+"\n due to error:"+err.Error()+" Skipping this foreign key...\n") + conv.Unexpected(fmt.Sprintf("Can't add foreign key with statement %s: %s", fkStmt, err)) + return + } + internal.VerbosePrintln("Updated schema with statement: " + fkStmt) + logger.Log.Debug("Updated schema with statement", zap.String("fkStmt", fkStmt)) + }(fkStmt, workerID) + // Send out an FK creation request every second, with total of maxWorkers request being present in a batch. + time.Sleep(time.Second) + } + // Wait for all the goroutines to finish. + for i := 1; i <= MaxWorkers; i++ { + <-workers + } + conv.Audit.Progress.UpdateProgress("Foreign key update complete.", 100, internal.ForeignKeyUpdateComplete) + conv.Audit.Progress.Done() +} diff --git a/accessors/spanner/spanner_accessor_test.go b/accessors/spanner/spanner_accessor_test.go new file mode 100644 index 0000000000..ed674a387b --- /dev/null +++ b/accessors/spanner/spanner_accessor_test.go @@ -0,0 +1,784 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spanneraccessor + +import ( + "context" + "fmt" + "os" + "testing" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spinstanceadmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/instanceadmin" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" + "github.com/googleapis/gax-go/v2" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func TestSpannerAccessorImpl_GetDatabaseDialect(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + want string + }{ + { + name: "Basic", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + }, + expectError: false, + want: "google_standard_sql", + }, + { + name: "Pg Dialect", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_POSTGRESQL}, nil + }, + }, + expectError: false, + want: "postgresql", + }, + { + name: "Unspecified Dialect", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_DATABASE_DIALECT_UNSPECIFIED}, nil + }, + }, + expectError: false, + want: "database_dialect_unspecified", + }, + { + name: "Error case", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("test-error") + }, + }, + expectError: true, + want: "", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + got, err := spA.GetDatabaseDialect(ctx, &tc.acm, "testUri") + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} + +func TestSpannerAccessorImpl_CheckExistingDb(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + want bool + }{ + { + name: "Basic", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, nil + }, + }, + expectError: false, + want: true, + }, + { + name: "Database not found error", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("database not found") + }, + }, + expectError: false, + want: false, + }, + { + name: "Could not get db info", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("failed to connect") + }, + }, + expectError: true, + want: false, + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + got, err := spA.CheckExistingDb(ctx, &tc.acm, "testUri") + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} + +func TestSpannerAccessorImpl_CreateEmptyDatabase(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + want string + }{ + { + name: "Basic", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + }, + expectError: false, + }, + { + name: "Create database returns error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return nil, fmt.Errorf("test error") + }, + }, + expectError: true, + }, + { + name: "Wait returns error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("test error") + }, + }, nil + }, + }, + expectError: true, + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + err := spA.CreateEmptyDatabase(ctx, &tc.acm, "projects/test-project/instances/test-instance/databases/mydb") + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestSpannerAccessorImpl_CreateChangeStream(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + want string + }{ + { + name: "Basic", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + expectError: false, + }, + { + name: "Update database ddl returns error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("test error") + }, + }, + expectError: true, + }, + { + name: "Wait returns error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { + return fmt.Errorf("test error") + }, + }, nil + }, + }, + expectError: true, + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + err := spA.CreateChangeStream(ctx, &tc.acm, "my-changestream", "projects/test-project/instances/test-instance/databases/mydb") + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestSpannerAccessorImpl_GetSpannerLeaderLocation(t *testing.T) { + testCases := []struct { + name string + iac spinstanceadmin.InstanceAdminClientMock + expectError bool + want string + }{ + { + name: "Basic", + iac: spinstanceadmin.InstanceAdminClientMock{ + GetInstanceMock: func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return &instancepb.Instance{Config: "projects/test-project/instanceConfigs/test-config"}, nil + }, + GetInstanceConfigMock: func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return &instancepb.InstanceConfig{Replicas: []*instancepb.ReplicaInfo{ + &instancepb.ReplicaInfo{ + Location: "us-east1", + DefaultLeaderLocation: false, + }, + &instancepb.ReplicaInfo{ + Location: "india1", + DefaultLeaderLocation: true, + }, + &instancepb.ReplicaInfo{ + Location: "europe2", + DefaultLeaderLocation: false, + }, + }}, nil + }, + }, + expectError: false, + want: "india1", + }, + { + name: "GetInstanceMock returns error", + iac: spinstanceadmin.InstanceAdminClientMock{ + GetInstanceMock: func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return nil, fmt.Errorf("test-error") + }, + GetInstanceConfigMock: func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return nil, nil + }, + }, + expectError: true, + want: "", + }, + { + name: "GetInstanceConfigMock returns error", + iac: spinstanceadmin.InstanceAdminClientMock{ + GetInstanceMock: func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return &instancepb.Instance{Config: "projects/test-project/instanceConfigs/test-config"}, nil + }, + GetInstanceConfigMock: func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return nil, fmt.Errorf("test-error") + }, + }, + expectError: true, + want: "", + }, + { + name: "No leader found returns error", + iac: spinstanceadmin.InstanceAdminClientMock{ + GetInstanceMock: func(ctx context.Context, req *instancepb.GetInstanceRequest, opts ...gax.CallOption) (*instancepb.Instance, error) { + return &instancepb.Instance{Config: "projects/test-project/instanceConfigs/test-config"}, nil + }, + GetInstanceConfigMock: func(ctx context.Context, req *instancepb.GetInstanceConfigRequest, opts ...gax.CallOption) (*instancepb.InstanceConfig, error) { + return &instancepb.InstanceConfig{Replicas: []*instancepb.ReplicaInfo{ + &instancepb.ReplicaInfo{ + Location: "us-east1", + DefaultLeaderLocation: false, + }, + &instancepb.ReplicaInfo{ + Location: "india1", + DefaultLeaderLocation: false, + }, + &instancepb.ReplicaInfo{ + Location: "europe2", + DefaultLeaderLocation: false, + }, + }}, nil + }, + }, + expectError: true, + want: "", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + got, err := spA.GetSpannerLeaderLocation(ctx, &tc.iac, "projects/test-project/instances/test-instance") + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} + +func TestSpannerAccessorImpl_CreateDatabase(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + dialect string + migrationType string + expectError bool + }{ + { + name: "GoogleSql Dataflow", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + }, + expectError: false, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Pg Dataflow", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + expectError: false, + dialect: "postgresql", + migrationType: "dataflow", + }, + { + name: "GoogleSql bulk", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + }, + expectError: false, + dialect: "google_standard_sql", + migrationType: "bulk", + }, + { + name: "Pg bulk", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + expectError: false, + dialect: "postgresql", + migrationType: "bulk", + }, + { + name: "GoogleSql Dataflow create database error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return nil, fmt.Errorf("error") + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Pg Dataflow update error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("error") + }, + }, + expectError: true, + dialect: "postgresql", + migrationType: "dataflow", + }, + { + name: "GoogleSql Dataflow operation error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, fmt.Errorf("error") }, + }, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + dbURI := "projects/project-id/instances/instance-id/databases/database-id" + conv := internal.MakeConv() + conv.SpDialect = tc.dialect + err := spA.CreateDatabase(ctx, &tc.acm, dbURI, conv, "", tc.migrationType) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestSpannerAccessorImpl_CreateOrUpdateDatabase(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + dialect string + migrationType string + expectError bool + }{ + { + name: "GoogleSql Dataflow db does not exist", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("database not found") + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: false, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "GoogleSql Dataflow db exists", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Postgres Dataflow db exists", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Postgres bulk db exists", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: false, + dialect: "google_standard_sql", + migrationType: "bulk", + }, + { + name: "GoogleSql Dataflow db get database error", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("error") + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "GoogleSql Dataflow db ddl statements nto empty", + acm: spanneradmin.AdminClientMock{ + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{Statements: []string{"string"}}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "GoogleSql Dataflow db get database ddl error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return &spanneradmin.CreateDatabaseOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) (*databasepb.Database, error) { return nil, nil }, + }, nil + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return nil, fmt.Errorf("error") + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + { + name: "Postgres bulk db exists update ddl error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("error") + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return &databasepb.Database{DatabaseDialect: databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL}, nil + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "bulk", + }, + { + name: "GoogleSql Dataflow db does not exist create error", + acm: spanneradmin.AdminClientMock{ + CreateDatabaseMock: func(ctx context.Context, req *databasepb.CreateDatabaseRequest, opts ...gax.CallOption) (spanneradmin.CreateDatabaseOperation, error) { + return nil, fmt.Errorf("error") + }, + GetDatabaseMock: func(ctx context.Context, req *databasepb.GetDatabaseRequest, opts ...gax.CallOption) (*databasepb.Database, error) { + return nil, fmt.Errorf("database not found") + }, + GetDatabaseDdlMock: func(ctx context.Context, req *databasepb.GetDatabaseDdlRequest, opts ...gax.CallOption) (*databasepb.GetDatabaseDdlResponse, error) { + return &databasepb.GetDatabaseDdlResponse{}, nil + }, + }, + expectError: true, + dialect: "google_standard_sql", + migrationType: "dataflow", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + dbURI := "projects/project-id/instances/instance-id/databases/database-id" + conv := internal.MakeConv() + conv.SpDialect = tc.dialect + err := spA.CreateOrUpdateDatabase(ctx, &tc.acm, dbURI, "", conv, tc.migrationType) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestSpannerAccessorImpl_UpdateDatabase(t *testing.T) { + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + expectError bool + }{ + { + name: "Update Database successful", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + expectError: false, + }, + { + name: "Update Database request error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("Error") + }, + }, + expectError: true, + }, + { + name: "Update Database operation error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return fmt.Errorf("error") }, + }, nil + }, + }, + expectError: true, + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + dbURI := "projects/project-id/instances/instance-id/databases/database-id" + conv := internal.MakeConv() + err := spA.UpdateDatabase(ctx, &tc.acm, dbURI, conv, "") + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + + +func TestSpannerAccessorImpl_UpdateDDLForeignKey(t *testing.T) { + schemaWithStatements:=map[string]ddl.CreateTable{ + "table_id" : { + Name: "table1", + Id: "table_id", + }, + "table_id2" : { + Name: "table2", + Id: "table_id2", + ParentId: "table1", + ForeignKeys: []ddl.Foreignkey{ + { + Name: "fk", + ColIds: []string{"columns"}, + ReferTableId:"table1", + ReferColumnIds:[]string{"column"}, + Id:"table_id", + }, + }, + }, + } + testCases := []struct { + name string + acm spanneradmin.AdminClientMock + dialect string + migrationType string + spSchema ddl.Schema + }{ + { + name: "Update DDL ForeignKey successful pg dataflow", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + dialect: "postgresql", + spSchema: schemaWithStatements, + migrationType: "dataflow", + }, + { + name: "Update DDL ForeignKey successful pg dataflow no statement", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + dialect: "postgresql", + spSchema: map[string]ddl.CreateTable{}, + migrationType: "dataflow", + }, + { + name: "Update DDL ForeignKey successful google_standard_sql", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return nil }, + }, nil + }, + }, + dialect: "google_standard_sql", + spSchema: schemaWithStatements, + migrationType: "dataflow", + }, + { + name: "Update DDL ForeignKey update database error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return nil, fmt.Errorf("error") + }, + }, + dialect: "postgresql", + spSchema: schemaWithStatements, + migrationType: "dataflow", + }, + { + name: "Update DDL ForeignKey operation error", + acm: spanneradmin.AdminClientMock{ + UpdateDatabaseDdlMock: func(ctx context.Context, req *databasepb.UpdateDatabaseDdlRequest, opts ...gax.CallOption) (spanneradmin.UpdateDatabaseDdlOperation, error) { + return &spanneradmin.UpdateDatabaseDdlOperationMock{ + WaitMock: func(ctx context.Context, opts ...gax.CallOption) error { return fmt.Errorf("error") }, + }, nil + }, + }, + dialect: "postgresql", + spSchema: schemaWithStatements, + migrationType: "dataflow", + }, + } + ctx := context.Background() + spA := SpannerAccessorImpl{} + for _, tc := range testCases { + dbURI := "projects/project-id/instances/instance-id/databases/database-id" + conv := internal.MakeConv() + conv.SpDialect = tc.dialect + conv.SpSchema = tc.spSchema + spA.UpdateDDLForeignKeys(ctx, &tc.acm, dbURI, conv, "", tc.migrationType) + } +} \ No newline at end of file diff --git a/accessors/storage/mocks.go b/accessors/storage/mocks.go new file mode 100644 index 0000000000..966ba4a948 --- /dev/null +++ b/accessors/storage/mocks.go @@ -0,0 +1,60 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageaccessor + +import ( + "context" + + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" +) + +// Mock that implements the StorageAccessor interface. +// Pass in unit tests where StorageAccessor is an input parameter. +type StorageAccessorMock struct { + CreateGCSBucketMock func(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error + ApplyBucketLifecycleDeleteRuleMock func(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error + UploadLocalFileToGCSMock func(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, localFilePath string) error + WriteDataToGCSMock func(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, data string) error + ReadGcsFileMock func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) + ReadAnyFileMock func(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) + DeleteGCSBucketMock func(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error +} + +func (sam *StorageAccessorMock) CreateGCSBucket(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + return sam.CreateGCSBucketMock(ctx, sc, req) +} + +func (sam *StorageAccessorMock) ApplyBucketLifecycleDeleteRule(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + return sam.ApplyBucketLifecycleDeleteRuleMock(ctx, sc, req) +} + +func (sam *StorageAccessorMock) UploadLocalFileToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, localFilePath string) error { + return sam.UploadLocalFileToGCSMock(ctx, sc, filePath, fileName, localFilePath) +} + +func (sam *StorageAccessorMock) WriteDataToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, data string) error { + return sam.WriteDataToGCSMock(ctx, sc, filePath, fileName, data) +} + +func (sam *StorageAccessorMock) ReadGcsFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return sam.ReadGcsFileMock(ctx, sc, filePath) +} + +func (sam *StorageAccessorMock) ReadAnyFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + return sam.ReadAnyFileMock(ctx, sc, filePath) +} + +func (sam *StorageAccessorMock) DeleteGCSBucket(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + return sam.DeleteGCSBucketMock(ctx, sc, req) +} diff --git a/accessors/storage/storage_accessor.go b/accessors/storage/storage_accessor.go new file mode 100644 index 0000000000..8eabcd3633 --- /dev/null +++ b/accessors/storage/storage_accessor.go @@ -0,0 +1,201 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageaccessor + +import ( + "context" + "fmt" + "io" + "os" + "strings" + + "cloud.google.com/go/storage" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "google.golang.org/api/googleapi" +) + +// The StorageAccessor provides methods that internally use a storage client. +// Methods should only contain generic logic here that can be used by multiple workflows. +type StorageAccessor interface { + // Create a GCS bucket with the given name in the input projectId and location. If ttl is > 0, + // also apply a delete lifecycle rule with the input ttl and prefixes. Set @ttl to 0 to skip creating lifecycle rules. + CreateGCSBucket(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error + // Applies the bucket lifecycle with delete rule. Only accepts the Age and prefix rule conditions as it is only used for the Datastream destination + // bucket currently. + ApplyBucketLifecycleDeleteRule(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error + // UploadLocalFileToGCS uploads a local file at @localFilePath to a gcs file path @filePath with name @fileName. + UploadLocalFileToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, localFilePath string) error + // Uploads a gcs object to gs://@filePath/@fileName with @data as content. + WriteDataToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, data string) error + // Read a Gcs file path and returns the contents as a string. + ReadGcsFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) + // Read a local or gcs file path. Files starting with a 'gs://' are treated as GCS files. + ReadAnyFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) + // Delete a given gcs bucket + DeleteGCSBucket(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error +} + +// This implements the StorageAccessor interface. This is the primary implementation that should be used in all places other than tests. +type StorageAccessorImpl struct{} + +func (sa *StorageAccessorImpl) CreateGCSBucket(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + bucket := sc.Bucket(req.BucketName) + attrs := storage.BucketAttrs{ + Location: req.Location, + } + if req.Ttl > 0 { + attrs.Lifecycle = storage.Lifecycle{ + Rules: []storage.LifecycleRule{ + { + Action: storage.LifecycleAction{Type: "Delete"}, + Condition: storage.LifecycleCondition{ + AgeInDays: req.Ttl, + // The prefixes should not contain the bucket names and starting slash. + // For object gs://my_bucket/pictures/paris_2022.jpg, + // you would use a condition such as "matchesPrefix":["pictures/paris_"]. + MatchesPrefix: req.MatchesPrefix, + }, + }, + }, + } + } + + if err := bucket.Create(ctx, req.ProjectID, &attrs); err != nil { + if e, ok := err.(*googleapi.Error); ok { + // Ignoring the bucket already exists error. + if e.Code != 409 { + return fmt.Errorf("failed to create bucket: %v", err) + } else { + fmt.Printf("Using the existing bucket: %v \n", req.BucketName) + } + } else { + return fmt.Errorf("failed to create bucket: %v", err) + } + + } else { + logger.Log.Info(fmt.Sprintf("Created new GCS bucket: %v\n", req.BucketName)) + } + return nil +} + + +func (sa *StorageAccessorImpl) DeleteGCSBucket(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + bucket := sc.Bucket(req.BucketName) + return bucket.Delete(ctx) +} + +func (sa *StorageAccessorImpl) ApplyBucketLifecycleDeleteRule(ctx context.Context, sc storageclient.StorageClient, req StorageBucketMetadata) error { + for i, str := range req.MatchesPrefix { + req.MatchesPrefix[i] = strings.TrimPrefix(str, "/") + } + bucket := sc.Bucket(req.BucketName) + bucketAttrsToUpdate := storage.BucketAttrsToUpdate{ + Lifecycle: &storage.Lifecycle{ + Rules: []storage.LifecycleRule{ + { + Action: storage.LifecycleAction{Type: "Delete"}, + Condition: storage.LifecycleCondition{ + AgeInDays: req.Ttl, + // The prefixes should not contain the bucket names and starting slash. + // For object gs://my_bucket/pictures/paris_2022.jpg, + // you would use a condition such as "matchesPrefix":["pictures/paris_"]. + MatchesPrefix: req.MatchesPrefix, + }, + }, + }, + }, + } + + attrs, err := bucket.Update(ctx, bucketAttrsToUpdate) + if err != nil { + return fmt.Errorf("could not bucket with lifecycle: %w", err) + } + logger.Log.Info(fmt.Sprintf("Added lifecycle rule to bucket %v\n. Rule Action: %v\t Rule Condition: %v\n", + req.BucketName, attrs.Lifecycle.Rules[0].Action, attrs.Lifecycle.Rules[0].Condition)) + return nil +} + +func (sa *StorageAccessorImpl) UploadLocalFileToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, localFilePath string) error { + data, err := os.ReadFile(localFilePath) + if err != nil { + return fmt.Errorf("could not read file %s: %w", localFilePath, err) + } + return sa.WriteDataToGCS(ctx, sc, filePath, fileName, string(data)) +} + +func (sa *StorageAccessorImpl) WriteDataToGCS(ctx context.Context, sc storageclient.StorageClient, filePath, fileName, data string) error { + u, err := utils.ParseGCSFilePath(filePath) + if err != nil { + return fmt.Errorf("parseFilePath: unable to parse file path: %v", err) + } + bucketName := u.Host + bucket := sc.Bucket(bucketName) + fullFilePath := u.Path + fileName + if strings.HasPrefix(fullFilePath, "/") { + fullFilePath = u.Path[1:] + fileName + } + obj := bucket.Object(fullFilePath) + + w := obj.NewWriter(ctx) + logger.Log.Info(fmt.Sprintf("Writing data to %s", filePath)) + n, err := fmt.Fprint(w, data) + if err != nil { + fmt.Printf("Failed to write to Cloud Storage: %s\n", filePath) + return err + } + logger.Log.Info(fmt.Sprintf("Wrote %d bytes to GCS", n)) + + if err := w.Close(); err != nil { + fmt.Printf("Failed to close GCS file: %s\n", filePath) + return err + } + return nil +} + +func (sa *StorageAccessorImpl) ReadGcsFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + u, err := utils.ParseGCSFilePath(filePath) + if err != nil { + return "", fmt.Errorf("unable to parse file path: %v", err) + } + bucketName := u.Host + bucket := sc.Bucket(bucketName) + obj := bucket.Object(u.Path[1:]) + rc, err := obj.NewReader(ctx) + if err != nil { + return "", err + } + defer rc.Close() + buf := new(strings.Builder) + logger.Log.Info(fmt.Sprintf("Reading from %s", filePath)) + n, err := io.Copy(buf, rc) + if err != nil { + return "", err + } + logger.Log.Info(fmt.Sprintf("Read %d bytes", n)) + return buf.String(), nil +} + +func (sa *StorageAccessorImpl) ReadAnyFile(ctx context.Context, sc storageclient.StorageClient, filePath string) (string, error) { + if strings.HasPrefix(filePath, constants.GCS_FILE_PREFIX) { + return sa.ReadGcsFile(ctx, sc, filePath) + } + buf, err := os.ReadFile(filePath) + if err != nil { + return "", err + } + return string(buf), nil +} diff --git a/accessors/storage/storage_accessor_test.go b/accessors/storage/storage_accessor_test.go new file mode 100644 index 0000000000..cd38c737f1 --- /dev/null +++ b/accessors/storage/storage_accessor_test.go @@ -0,0 +1,438 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a +package storageaccessor + +import ( + "context" + "fmt" + "io" + "os" + "strings" + "testing" + + "cloud.google.com/go/storage" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/googleapi" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func TestStorageAccessorImpl_CreateGCSBucket(t *testing.T) { + testCases := []struct { + name string + scm storageclient.StorageClientMock + expectError bool + }{ + { + name: "Basic", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + CreateMock: func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return nil + }, + } + }, + }, + expectError: false, + }, + { + name: "random error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + CreateMock: func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return fmt.Errorf("random error") + }, + } + }, + }, + expectError: true, + }, + { + name: "Bucket already exists", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + CreateMock: func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return &googleapi.Error{Code: 409} + }, + } + }, + }, + expectError: false, + }, + { + name: "Other google api error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + CreateMock: func(ctx context.Context, projectID string, attrs *storage.BucketAttrs) (err error) { + return &googleapi.Error{Code: 100} + }, + } + }, + }, + expectError: true, + }, + } + ctx := context.Background() + sa := StorageAccessorImpl{} + for _, tc := range testCases { + err := sa.CreateGCSBucket(ctx, &tc.scm, StorageBucketMetadata{ + BucketName: "test-bucket", + ProjectID: "test-project", + Location: "india2", + Ttl: 1, + MatchesPrefix: nil, + }) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestStorageAccessorImpl_ApplyBucketLifecycleDeleteRule(t *testing.T) { + testCases := []struct { + name string + scm storageclient.StorageClientMock + ttl int64 + matchesPrefix []string + expectError bool + }{ + { + name: "Basic", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + UpdateMock: func(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + if strings.HasPrefix(uattrs.Lifecycle.Rules[0].Condition.MatchesPrefix[0], "/") { + return nil, fmt.Errorf("test error") + } + return &storage.BucketAttrs{Lifecycle: storage.Lifecycle{Rules: []storage.LifecycleRule{ + { + Action: storage.LifecycleAction{Type: "Delete"}, + Condition: storage.LifecycleCondition{ + AgeInDays: 5, + MatchesPrefix: []string{}, + }, + }, + }}}, nil + }, + } + }, + }, + ttl: 5, + matchesPrefix: []string{"test"}, + expectError: false, + }, + { + name: "Update error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + UpdateMock: func(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + return nil, fmt.Errorf("test error") + }, + } + }, + }, + ttl: 5, + matchesPrefix: []string{"test"}, + expectError: true, + }, + { + name: "Prefix '/' gets removed", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + UpdateMock: func(ctx context.Context, uattrs storage.BucketAttrsToUpdate) (attrs *storage.BucketAttrs, err error) { + if strings.HasPrefix(uattrs.Lifecycle.Rules[0].Condition.MatchesPrefix[0], "/") { + return nil, fmt.Errorf("test error") + } + return &storage.BucketAttrs{Lifecycle: storage.Lifecycle{Rules: []storage.LifecycleRule{ + { + Action: storage.LifecycleAction{Type: "Delete"}, + Condition: storage.LifecycleCondition{ + AgeInDays: 5, + MatchesPrefix: []string{}, + }, + }, + }}}, nil + }, + } + }, + }, + ttl: 5, + matchesPrefix: []string{"/test"}, + expectError: false, + }, + } + ctx := context.Background() + sa := StorageAccessorImpl{} + for _, tc := range testCases { + err := sa.ApplyBucketLifecycleDeleteRule(ctx, &tc.scm, StorageBucketMetadata{ + BucketName: "test-bucket", + Ttl: tc.ttl, + MatchesPrefix: tc.matchesPrefix, + }) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestStorageAccessorImpl_WriteDataToGCS(t *testing.T) { + testCases := []struct { + name string + scm storageclient.StorageClientMock + filePath string + fileName string + data string + expectError bool + }{ + { + name: "Basic", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewWriterMock: func(ctx context.Context) io.WriteCloser { + return &storageclient.WriterMock{ + WriteMock: func(p []byte) (n int, err error) { return len(p), nil }, + CloseMock: func() error { return nil }, + } + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + fileName: "test-file", + data: "abcd", + expectError: false, + }, + { + name: "File parsing error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewWriterMock: func(ctx context.Context) io.WriteCloser { + return &storageclient.WriterMock{ + WriteMock: func(p []byte) (n int, err error) { return len(p), nil }, + CloseMock: func() error { return nil }, + } + }, + } + }, + } + }, + }, + filePath: "://bucket/path", + fileName: "test-file", + data: "abcd", + expectError: true, + }, + { + name: "Write error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewWriterMock: func(ctx context.Context) io.WriteCloser { + return &storageclient.WriterMock{ + WriteMock: func(p []byte) (n int, err error) { return 0, fmt.Errorf("test-error") }, + CloseMock: func() error { return nil }, + } + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + fileName: "test-file", + data: "abcd", + expectError: true, + }, + { + name: "Close error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewWriterMock: func(ctx context.Context) io.WriteCloser { + return &storageclient.WriterMock{ + WriteMock: func(p []byte) (n int, err error) { return len(p), nil }, + CloseMock: func() error { return fmt.Errorf("test error") }, + } + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + fileName: "test-file", + data: "abcd", + expectError: true, + }, + } + ctx := context.Background() + sa := StorageAccessorImpl{} + for _, tc := range testCases { + err := sa.WriteDataToGCS(ctx, &tc.scm, tc.filePath, tc.fileName, tc.data) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} + +func TestStorageAccessorImpl_ReadGcsFile(t *testing.T) { + testCases := []struct { + name string + scm storageclient.StorageClientMock + filePath string + expectError bool + want string + }{ + { + name: "Basic", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewReaderMock: func(ctx context.Context) (io.ReadCloser, error) { + return &storageclient.ReaderMock{ + ReadMock: func(p []byte) (n int, err error) { + copy(p, "hello") + return 5, io.EOF + }, + CloseMock: func() error { return nil }, + }, nil + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + expectError: false, + want: "hello", + }, + { + name: "Parse error", + scm: storageclient.StorageClientMock{}, + filePath: "://bucket/path", + expectError: true, + want: "", + }, + { + name: "New reader error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewReaderMock: func(ctx context.Context) (io.ReadCloser, error) { + return nil, fmt.Errorf("test error") + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + expectError: true, + want: "", + }, + { + name: "Read error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + ObjectMock: func(name string) storageclient.ObjectHandle { + return &storageclient.ObjectHandleMock{ + NewReaderMock: func(ctx context.Context) (io.ReadCloser, error) { + return &storageclient.ReaderMock{ + ReadMock: func(p []byte) (n int, err error) { + return 0, fmt.Errorf("test error") + }, + CloseMock: func() error { return nil }, + }, nil + }, + } + }, + } + }, + }, + filePath: "gs://bucket/path", + expectError: true, + want: "", + }, + } + ctx := context.Background() + sa := StorageAccessorImpl{} + for _, tc := range testCases { + got, err := sa.ReadGcsFile(ctx, &tc.scm, tc.filePath) + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} + +func TestStorageAccessorImpl_DeleteGCSBucket(t *testing.T) { + testCases := []struct { + name string + scm storageclient.StorageClientMock + expectError bool + }{ + { + name: "Basic", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + DeleteMock: func(ctx context.Context) error { + return nil + }, + } + }, + }, + expectError: false, + }, + { + name: "Error", + scm: storageclient.StorageClientMock{ + BucketMock: func(name string) storageclient.BucketHandle { + return &storageclient.BucketHandleMock{ + DeleteMock: func(ctx context.Context) error { + return fmt.Errorf("error") + }, + } + }, + }, + expectError: true, + }, + } + ctx := context.Background() + sa := StorageAccessorImpl{} + for _, tc := range testCases { + err := sa.DeleteGCSBucket(ctx, &tc.scm, StorageBucketMetadata{ + BucketName: "test-bucket", + }) + assert.Equal(t, tc.expectError, err != nil, tc.name) + } +} \ No newline at end of file diff --git a/accessors/storage/types.go b/accessors/storage/types.go new file mode 100644 index 0000000000..5a20e1a002 --- /dev/null +++ b/accessors/storage/types.go @@ -0,0 +1,24 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package storageaccessor + +type StorageBucketMetadata struct { + BucketName string + // Not required for Updates. + ProjectID string + // Not required for Updates. + Location string + Ttl int64 + MatchesPrefix []string +} diff --git a/cmd/data.go b/cmd/data.go index c908f8ad94..a028682583 100644 --- a/cmd/data.go +++ b/cmd/data.go @@ -26,6 +26,8 @@ import ( sp "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" @@ -154,7 +156,8 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface banner = utils.GetBanner(now, dbURI) } else { conv.Audit.DryRun = true - bw, err = conversion.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, nil, conv, true, cmd.WriteLimit) + convImpl := &conversion.ConvImpl{} + bw, err = convImpl.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, nil, conv, true, cmd.WriteLimit, &conversion.DataFromSourceImpl{}) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbName, err) return subcommands.ExitFailure @@ -169,7 +172,8 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface if cmd.filePrefix == "" { cmd.filePrefix = targetProfile.Conn.Sp.Dbname } - conversion.Report(sourceProfile.Driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix, dbName, ioHelper.Out) + reportImpl := conversion.ReportImpl{} + reportImpl.GenerateReport(sourceProfile.Driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix, dbName, ioHelper.Out) conversion.WriteBadData(bw, conv, banner, cmd.filePrefix+badDataFile, ioHelper.Out) // Cleanup smt tmp data directory. os.RemoveAll(filepath.Join(os.TempDir(), constants.SMT_TMP_DIR)) @@ -178,7 +182,12 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface // validateExistingDb validates that the existing spanner schema is in accordance with the one specified in the session file. func validateExistingDb(ctx context.Context, spDialect, dbURI string, adminClient *database.DatabaseAdminClient, client *sp.Client, conv *internal.Conv) error { - dbExists, err := conversion.CheckExistingDb(ctx, adminClient, dbURI) + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + return err + } + spA := spanneraccessor.SpannerAccessorImpl{} + dbExists, err := spA.CheckExistingDb(ctx, adminClientImpl, dbURI) if err != nil { err = fmt.Errorf("can't verify target database: %v", err) return err diff --git a/cmd/schema.go b/cmd/schema.go index ec408bddb1..76025ab518 100644 --- a/cmd/schema.go +++ b/cmd/schema.go @@ -112,7 +112,8 @@ func (cmd *SchemaCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfa schemaConversionStartTime := time.Now() var conv *internal.Conv - conv, err = conversion.SchemaConv(sourceProfile, targetProfile, &ioHelper) + convImpl := &conversion.ConvImpl{} + conv, err = convImpl.SchemaConv(sourceProfile, targetProfile, &ioHelper, &conversion.SchemaFromSourceImpl{}) if err != nil { return subcommands.ExitFailure } @@ -136,7 +137,8 @@ func (cmd *SchemaCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfa schemaCoversionEndTime := time.Now() conv.Audit.SchemaConversionDuration = schemaCoversionEndTime.Sub(schemaConversionStartTime) banner := utils.GetBanner(schemaConversionStartTime, dbName) - conversion.Report(sourceProfile.Driver, nil, ioHelper.BytesRead, banner, conv, cmd.filePrefix, dbName, ioHelper.Out) + reportImpl := conversion.ReportImpl{} + reportImpl.GenerateReport(sourceProfile.Driver, nil, ioHelper.BytesRead, banner, conv, cmd.filePrefix, dbName, ioHelper.Out) // Cleanup smt tmp data directory. os.RemoveAll(filepath.Join(os.TempDir(), constants.SMT_TMP_DIR)) return subcommands.ExitSuccess diff --git a/cmd/schema_and_data.go b/cmd/schema_and_data.go index 1296f6fc55..0dc524db46 100644 --- a/cmd/schema_and_data.go +++ b/cmd/schema_and_data.go @@ -121,7 +121,8 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... banner string dbURI string ) - conv, err = conversion.SchemaConv(sourceProfile, targetProfile, &ioHelper) + convImpl := &conversion.ConvImpl{} + conv, err = convImpl.SchemaConv(sourceProfile, targetProfile, &ioHelper, &conversion.SchemaFromSourceImpl{}) if err != nil { panic(err) } @@ -136,9 +137,9 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... conversion.WriteSchemaFile(conv, schemaConversionStartTime, cmd.filePrefix+schemaFile, ioHelper.Out, sourceProfile.Driver) conversion.WriteSessionFile(conv, cmd.filePrefix+sessionFile, ioHelper.Out) conv.Audit.SkipMetricsPopulation = os.Getenv("SKIP_METRICS_POPULATION") == "true" - + reportImpl := conversion.ReportImpl{} if !cmd.dryRun { - conversion.Report(sourceProfile.Driver, nil, ioHelper.BytesRead, "", conv, cmd.filePrefix, dbName, ioHelper.Out) + reportImpl.GenerateReport(sourceProfile.Driver, nil, ioHelper.BytesRead, "", conv, cmd.filePrefix, dbName, ioHelper.Out) bw, err = MigrateDatabase(ctx, targetProfile, sourceProfile, dbName, &ioHelper, cmd, conv, nil) if err != nil { err = fmt.Errorf("can't finish database migration for db %s: %v", dbName, err) @@ -152,7 +153,7 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... conv.Audit.DryRun = true schemaCoversionEndTime := time.Now() conv.Audit.SchemaConversionDuration = schemaCoversionEndTime.Sub(schemaConversionStartTime) - bw, err = conversion.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, nil, conv, true, cmd.WriteLimit) + bw, err = convImpl.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, nil, conv, true, cmd.WriteLimit, &conversion.DataFromSourceImpl{}) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbName, err) return subcommands.ExitFailure @@ -161,7 +162,7 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... conv.Audit.DataConversionDuration = dataCoversionEndTime.Sub(schemaCoversionEndTime) banner = utils.GetBanner(schemaConversionStartTime, dbName) } - conversion.Report(sourceProfile.Driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix, dbName, ioHelper.Out) + reportImpl.GenerateReport(sourceProfile.Driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix, dbName, ioHelper.Out) conversion.WriteBadData(bw, conv, banner, cmd.filePrefix+badDataFile, ioHelper.Out) // Cleanup smt tmp data directory. diff --git a/cmd/utils.go b/cmd/utils.go index a1cf40523a..16c4add784 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -16,17 +16,24 @@ package cmd import ( "context" + "encoding/base64" "fmt" "time" sp "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" + spanneradmin "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/spanner/admin" + spanneraccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/conversion" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" "github.com/GoogleCloudPlatform/spanner-migration-tool/webv2/helpers" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" ) var ( @@ -40,12 +47,22 @@ const ( completionPercentage = 100 ) +func metricsPopulation(ctx context.Context, driver string, conv *internal.Conv) { + if !conv.Audit.SkipMetricsPopulation { + // Adding migration metadata to the outgoing context. + migrationData := metrics.GetMigrationData(conv, driver, constants.SchemaConv) + serializedMigrationData, _ := proto.Marshal(migrationData) + migrationMetadataValue := base64.StdEncoding.EncodeToString(serializedMigrationData) + ctx = metadata.AppendToOutgoingContext(ctx, constants.MigrationMetadataKey, migrationMetadataValue) + } +} + // CreateDatabaseClient creates new database client and admin client. func CreateDatabaseClient(ctx context.Context, targetProfile profiles.TargetProfile, driver, dbName string, ioHelper utils.IOStreams) (*database.DatabaseAdminClient, *sp.Client, string, error) { if targetProfile.Conn.Sp.Dbname == "" { targetProfile.Conn.Sp.Dbname = dbName } - project, instance, dbName, err := targetProfile.GetResourceIds(ctx, time.Now(), driver, ioHelper.Out) + project, instance, dbName, err := targetProfile.GetResourceIds(ctx, time.Now(), driver, ioHelper.Out, &utils.GetUtilInfoImpl{}) if err != nil { return nil, nil, "", err } @@ -74,7 +91,8 @@ func PrepareMigrationPrerequisites(sourceProfileString, targetProfileString, sou return profiles.SourceProfile{}, profiles.TargetProfile{}, utils.IOStreams{}, "", err } - sourceProfile, err := profiles.NewSourceProfile(sourceProfileString, source) + n := profiles.NewSourceProfileImpl{} + sourceProfile, err := profiles.NewSourceProfile(sourceProfileString, source, &n) if err != nil { return profiles.SourceProfile{}, targetProfile, utils.IOStreams{}, "", err } @@ -92,7 +110,8 @@ func PrepareMigrationPrerequisites(sourceProfileString, targetProfileString, sou defer ioHelper.In.Close() } - dbName, err := utils.GetDatabaseName(sourceProfile.Driver, time.Now()) + getInfo := utils.GetUtilInfoImpl{} + dbName, err := getInfo.GetDatabaseName(sourceProfile.Driver, time.Now()) if err != nil { err = fmt.Errorf("can't generate database name for prefix: %v", err) return sourceProfile, targetProfile, ioHelper, "", err @@ -137,13 +156,19 @@ func MigrateDatabase(ctx context.Context, targetProfile profiles.TargetProfile, func migrateSchema(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile, ioHelper *utils.IOStreams, conv *internal.Conv, dbURI string, adminClient *database.DatabaseAdminClient) error { - err := conversion.CreateOrUpdateDatabase(ctx, adminClient, dbURI, sourceProfile.Driver, conv, ioHelper.Out, sourceProfile.Config.ConfigType) - if err != nil { - err = fmt.Errorf("can't create/update database: %v", err) - return err - } - conv.Audit.Progress.UpdateProgress("Schema migration complete.", completionPercentage, internal.SchemaMigrationComplete) - return nil + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + return err + } + err = spA.CreateOrUpdateDatabase(ctx, adminClientImpl, dbURI, sourceProfile.Driver, conv, sourceProfile.Config.ConfigType) + if err != nil { + err = fmt.Errorf("can't create/update database: %v", err) + return err + } + metricsPopulation(ctx, sourceProfile.Driver, conv) + conv.Audit.Progress.UpdateProgress("Schema migration complete.", completionPercentage, internal.SchemaMigrationComplete) + return nil } func migrateData(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile, @@ -160,30 +185,40 @@ func migrateData(ctx context.Context, targetProfile profiles.TargetProfile, sour } fmt.Printf("Schema validated successfully for data migration for db %s\n", dbURI) } - bw, err = conversion.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, true, cmd.WriteLimit) + c := &conversion.ConvImpl{} + bw, err = c.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, true, cmd.WriteLimit, &conversion.DataFromSourceImpl{}) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) return nil, err } conv.Audit.Progress.UpdateProgress("Data migration complete.", completionPercentage, internal.DataMigrationComplete) if !cmd.SkipForeignKeys { - if err = conversion.UpdateDDLForeignKeys(ctx, adminClient, dbURI, conv, ioHelper.Out, sourceProfile.Driver, sourceProfile.Config.ConfigType); err != nil { - err = fmt.Errorf("can't perform update schema on db %s with foreign keys: %v", dbURI, err) + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { return bw, err } + spA.UpdateDDLForeignKeys(ctx, adminClientImpl, dbURI, conv, sourceProfile.Driver, sourceProfile.Config.ConfigType) } return bw, nil } func migrateSchemaAndData(ctx context.Context, targetProfile profiles.TargetProfile, sourceProfile profiles.SourceProfile, ioHelper *utils.IOStreams, conv *internal.Conv, dbURI string, adminClient *database.DatabaseAdminClient, client *sp.Client, cmd *SchemaAndDataCmd) (*writer.BatchWriter, error) { - err := conversion.CreateOrUpdateDatabase(ctx, adminClient, dbURI, sourceProfile.Driver, conv, ioHelper.Out, sourceProfile.Config.ConfigType) + spA := spanneraccessor.SpannerAccessorImpl{} + adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) + if err != nil { + return nil, err + } + err = spA.CreateOrUpdateDatabase(ctx, adminClientImpl, dbURI, sourceProfile.Driver, conv, sourceProfile.Config.ConfigType) if err != nil { err = fmt.Errorf("can't create/update database: %v", err) return nil, err } + metricsPopulation(ctx, sourceProfile.Driver, conv) conv.Audit.Progress.UpdateProgress("Schema migration complete.", completionPercentage, internal.SchemaMigrationComplete) - bw, err := conversion.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, true, cmd.WriteLimit) + convImpl := &conversion.ConvImpl{} + bw, err := convImpl.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, true, cmd.WriteLimit, &conversion.DataFromSourceImpl{}) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) return nil, err @@ -191,10 +226,7 @@ func migrateSchemaAndData(ctx context.Context, targetProfile profiles.TargetProf conv.Audit.Progress.UpdateProgress("Data migration complete.", completionPercentage, internal.DataMigrationComplete) if !cmd.SkipForeignKeys { - if err = conversion.UpdateDDLForeignKeys(ctx, adminClient, dbURI, conv, ioHelper.Out, sourceProfile.Driver, sourceProfile.Config.ConfigType); err != nil { - err = fmt.Errorf("can't perform update schema on db %s with foreign keys: %v", dbURI, err) - return bw, err - } + spA.UpdateDDLForeignKeys(ctx, adminClientImpl, dbURI, conv, sourceProfile.Driver, sourceProfile.Config.ConfigType) } return bw, nil } diff --git a/common/constants/constants.go b/common/constants/constants.go index e98855a81d..c0eaa87685 100644 --- a/common/constants/constants.go +++ b/common/constants/constants.go @@ -64,7 +64,8 @@ const ( MigrationMetadataKey string = "cloud-spanner-migration-metadata" // Scheme used for GCS paths - GCS_SCHEME string = "gs" + GCS_SCHEME string = "gs" + GCS_FILE_PREFIX string = "gs://" // File upload prefix for dump and session load. UPLOAD_FILE_DIR string = "upload-file" diff --git a/common/utils/storage_utils.go b/common/utils/storage_utils.go new file mode 100644 index 0000000000..3968b1731e --- /dev/null +++ b/common/utils/storage_utils.go @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* +Package utils contains common helper functions used across multiple other packages. +Utils should not import any Spanner migration tool packages. +*/ +package utils + +import ( + "fmt" + "net/url" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" +) + +func ParseGCSFilePath(filePath string) (*url.URL, error) { + if len(filePath) == 0 { + return nil, fmt.Errorf("found empty GCS path") + } + if filePath[len(filePath)-1] != '/' { + filePath = filePath + "/" + } + u, err := url.Parse(filePath) + if err != nil { + return nil, fmt.Errorf("parseFilePath: unable to parse file path %s", filePath) + } + if u.Scheme != constants.GCS_SCHEME { + return nil, fmt.Errorf("not a valid GCS path: %s, should start with 'gs'", filePath) + } + return u, nil +} diff --git a/common/utils/storage_utils_test.go b/common/utils/storage_utils_test.go new file mode 100644 index 0000000000..fa00f0c302 --- /dev/null +++ b/common/utils/storage_utils_test.go @@ -0,0 +1,107 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package utils + +import ( + "net/url" + "os" + "testing" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func TestParseGCSFilePath(t *testing.T) { + testCases := []struct { + name string + filePath string + expectError bool + want *url.URL + }{ + { + name: "Basic", + filePath: "gs://test-bucket/path/to/folder/", + expectError: false, + want: &url.URL{ + Scheme: "gs", + Host: "test-bucket", + Path: "/path/to/folder/", + }, + }, + { + name: "Append Slash", + filePath: "gs://test-bucket/path/to/folder", + expectError: false, + want: &url.URL{ + Scheme: "gs", + Host: "test-bucket", + Path: "/path/to/folder/", + }, + }, + { + name: "Empty path", + filePath: "gs://test-bucket", + expectError: false, + want: &url.URL{ + Scheme: "gs", + Host: "test-bucket", + Path: "/", + }, + }, + { + name: "Empty path with leading slash", + filePath: "gs://test-bucket/", + expectError: false, + want: &url.URL{ + Scheme: "gs", + Host: "test-bucket", + Path: "/", + }, + }, + { + name: "Empty File path", + filePath: "", + expectError: true, + want: nil, + }, + { + name: "Wrong Scheme", + filePath: "ab://testpath", + expectError: true, + want: nil, + }, + { + name: "Malformed Path", + filePath: "://path", + expectError: true, + want: nil, + }, + } + + for _, tc := range testCases { + got, err := ParseGCSFilePath(tc.filePath) + assert.Equal(t, tc.expectError, err != nil, tc.name) + assert.Equal(t, tc.want, got, tc.name) + } +} diff --git a/common/utils/utils.go b/common/utils/utils.go index 70ec53f882..956bdc65c4 100644 --- a/common/utils/utils.go +++ b/common/utils/utils.go @@ -19,11 +19,11 @@ package utils import ( "bufio" "context" + "crypto/rand" "fmt" "io" "io/ioutil" "log" - "math/rand" "net/url" "os" "os/exec" @@ -37,6 +37,7 @@ import ( sp "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" instance "cloud.google.com/go/spanner/admin/instance/apiv1" + "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" "cloud.google.com/go/storage" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" @@ -44,10 +45,8 @@ import ( "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/spanner" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "golang.org/x/crypto/ssh/terminal" - "google.golang.org/api/googleapi" "google.golang.org/api/iterator" "google.golang.org/api/option" - instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" ) // IOStreams is a struct that contains the file descriptor for dumpFile. @@ -62,6 +61,16 @@ type ManifestTable struct { File_patterns []string `json:"file_patterns"` } +// Interface to fetch spanner details +type GetUtilInfoInterface interface { + GetProject() (string, error) + GetInstance(ctx context.Context, project string, out *os.File) (string, error) + GetPassword() string + GetDatabaseName(driver string, now time.Time) (string, error) +} + +type GetUtilInfoImpl struct{} + // NewIOStreams returns a new IOStreams struct such that input stream is set // to open file descriptor for dumpFile if driver is PGDUMP or MYSQLDUMP. // Input stream defaults to stdin. Output stream is always set to stdout. @@ -173,86 +182,10 @@ func PreloadGCSFiles(tables []ManifestTable) ([]ManifestTable, error) { return tables, nil } -func ParseGCSFilePath(filePath string) (*url.URL, error) { - if len(filePath) == 0 { - return nil, fmt.Errorf("found empty GCS path") - } - if filePath[len(filePath)-1] != '/' { - filePath = filePath + "/" - } - u, err := url.Parse(filePath) - if err != nil { - return nil, fmt.Errorf("parseFilePath: unable to parse file path %s", filePath) - } - if u.Scheme != constants.GCS_SCHEME { - return nil, fmt.Errorf("not a valid GCS path: %s, should start with 'gs'", filePath) - } - return u, nil -} - -func WriteToGCS(filePath, fileName, data string) error { - ctx := context.Background() - - client, err := storage.NewClient(ctx) - if err != nil { - fmt.Printf("Failed to create GCS client") - return err - } - defer client.Close() - u, err := ParseGCSFilePath(filePath) - if err != nil { - return fmt.Errorf("parseFilePath: unable to parse file path: %v", err) - } - bucketName := u.Host - bucket := client.Bucket(bucketName) - obj := bucket.Object(u.Path[1:] + fileName) - - w := obj.NewWriter(ctx) - if _, err := fmt.Fprint(w, data); err != nil { - fmt.Printf("Failed to write to Cloud Storage: %s", filePath) - return err - } - if err := w.Close(); err != nil { - fmt.Printf("Failed to close GCS file: %s", filePath) - return err - } - return nil -} - -func CreateGCSBucket(bucketName, projectID, location string) error { - ctx := context.Background() - - client, err := storage.NewClient(ctx) - if err != nil { - return fmt.Errorf("failed to create GCS client: %v", err) - } - defer client.Close() - bucket := client.Bucket(bucketName) - attrs := storage.BucketAttrs{ - Location: location, - } - if err := bucket.Create(ctx, projectID, &attrs); err != nil { - if e, ok := err.(*googleapi.Error); ok { - // Ignoring the bucket already exists error. - if e.Code != 409 { - return fmt.Errorf("failed to create bucket: %v", err) - } else { - fmt.Printf("Using the existing bucket: %v \n", bucketName) - } - } else { - return fmt.Errorf("failed to create bucket: %v", err) - } - - } else { - fmt.Printf("Created new GCS bucket: %v\n", bucketName) - } - return nil -} - // GetProject returns the cloud project we should use for accessing Spanner. // Use environment variable GCLOUD_PROJECT if it is set. // Otherwise, use the default project returned from gcloud. -func GetProject() (string, error) { +func (gui *GetUtilInfoImpl) GetProject() (string, error) { project := os.Getenv("GCLOUD_PROJECT") if project != "" { return project, nil @@ -269,7 +202,7 @@ func GetProject() (string, error) { // GetInstance returns the Spanner instance we should use for creating DBs. // If the user specified instance (via flag 'instance') then use that. // Otherwise try to deduce the instance using gcloud. -func GetInstance(ctx context.Context, project string, out *os.File) (string, error) { +func (gui *GetUtilInfoImpl) GetInstance(ctx context.Context, project string, out *os.File) (string, error) { l, err := getInstances(ctx, project) if err != nil { return "", err @@ -315,7 +248,7 @@ func getInstances(ctx context.Context, project string) ([]string, error) { return l, nil } -func GetPassword() string { +func (gui *GetUtilInfoImpl) GetPassword() string { calledFromGCloud := os.Getenv("GCLOUD_HB_PLUGIN") if strings.EqualFold(calledFromGCloud, "true") { fmt.Println("\n Please specify password in enviroment variables (recommended) or --source-profile " + @@ -333,7 +266,7 @@ func GetPassword() string { } // GetDatabaseName generates database name with driver_date prefix. -func GetDatabaseName(driver string, now time.Time) (string, error) { +func (gui *GetUtilInfoImpl) GetDatabaseName(driver string, now time.Time) (string, error) { return GenerateName(fmt.Sprintf("%s_%s", driver, now.Format("2006-01-02"))) } @@ -347,6 +280,12 @@ func GenerateName(prefix string) (string, error) { return fmt.Sprintf("%s_%x-%x", prefix, b[0:2], b[2:4]), nil } +func GenerateHashStr() string { + b := make([]byte, 4) + rand.Read(b) + return fmt.Sprintf("%x-%x", b[0:2], b[2:4]) +} + // parseURI parses an unknown URI string that could be a database, instance or project URI. func parseURI(URI string) (project, instance, dbName string) { project, instance, dbName = "", "", "" @@ -565,7 +504,8 @@ func GetLegacyModeSupportedDrivers() []string { // ReadSpannerSchema fills conv by querying Spanner infoschema treating Spanner as both the source and dest. func ReadSpannerSchema(ctx context.Context, conv *internal.Conv, client *sp.Client) error { infoSchema := spanner.InfoSchemaImpl{Client: client, Ctx: ctx, SpDialect: conv.SpDialect} - err := common.ProcessSchema(conv, infoSchema, common.DefaultWorkers, internal.AdditionalSchemaAttributes{IsSharded: false}) + processSchema := common.ProcessSchemaImpl{} + err := processSchema.ProcessSchema(conv, infoSchema, common.DefaultWorkers, internal.AdditionalSchemaAttributes{IsSharded: false}, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) if err != nil { return fmt.Errorf("error trying to read and convert spanner schema: %v", err) } @@ -651,7 +591,7 @@ func CompareSchema(sessionFileConv, actualSpannerConv *internal.Conv) error { } else { if sessionColDef.Name != spannerColDef.Name || sessionColDef.T.IsArray != spannerColDef.T.IsArray || sessionColDef.T.Len != spannerColDef.T.Len || sessionColDef.T.Name != spannerColDef.T.Name || sessionColDef.NotNull != spannerColDef.NotNull { - return fmt.Errorf("column detail for table %v don't match: session column: %v, spanner column: %v", sessionTable.Name, sessionColDef, spannerColDef) + return fmt.Errorf("column detail for table %v don't match: session column: %v, spanner column: %v", sessionTable.Name, sessionColDef, spannerColDef) } } } diff --git a/conversion/conversion.go b/conversion/conversion.go index 20216dfea4..21972e36cc 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -26,61 +26,25 @@ package conversion import ( "bufio" "context" - "database/sql" - "encoding/base64" "encoding/json" "fmt" - "io" - "io/ioutil" - "net" "os" "strings" "sync" - "sync/atomic" - "syscall" - "time" - "cloud.google.com/go/cloudsqlconn" datastream "cloud.google.com/go/datastream/apiv1" sp "cloud.google.com/go/spanner" - database "cloud.google.com/go/spanner/admin/database/apiv1" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" - "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal/reports" - "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/csv" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/dynamodb" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/oracle" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/postgres" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/spanner" - "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/sqlserver" - "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" - "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - dydb "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodbstreams" - mysqldriver "github.com/go-sql-driver/mysql" - "go.uber.org/zap" - adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" - "google.golang.org/grpc/metadata" - "google.golang.org/protobuf/proto" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/stdlib" ) var ( - // Set the maximum number of concurrent workers during foreign key creation. - // This number should not be too high so as to not hit the AdminQuota limit. - // AdminQuota limits are mentioned here: https://cloud.google.com/spanner/quotas#administrative_limits - // If facing a quota limit error, consider reducing this value. - MaxWorkers = 50 once sync.Once datastreamClient *datastream.Client ) @@ -95,14 +59,20 @@ func getDatastreamClient(ctx context.Context) *datastream.Client { return datastreamClient } +type ConvInterface interface { + SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, schemaFromSource SchemaFromSourceInterface) (*internal.Conv, error) + DataConv(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, writeLimit int64, dataFromSource DataFromSourceInterface) (*writer.BatchWriter, error) +} +type ConvImpl struct {} + // SchemaConv performs the schema conversion // The SourceProfile param provides the connection details to use the go SQL library. -func SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams) (*internal.Conv, error) { +func (ci *ConvImpl) SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, schemaFromSource SchemaFromSourceInterface) (*internal.Conv, error) { switch sourceProfile.Driver { case constants.POSTGRES, constants.MYSQL, constants.DYNAMODB, constants.SQLSERVER, constants.ORACLE: - return schemaFromDatabase(sourceProfile, targetProfile) + return schemaFromSource.schemaFromDatabase(sourceProfile, targetProfile, &GetInfoImpl{}, &common.ProcessSchemaImpl{}) case constants.PGDUMP, constants.MYSQLDUMP: - return SchemaFromDump(sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper) + return schemaFromSource.SchemaFromDump(sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper, &ProcessDumpByDialectImpl{}) default: return nil, fmt.Errorf("schema conversion for driver %s not supported", sourceProfile.Driver) } @@ -110,7 +80,7 @@ func SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.Tar // DataConv performs the data conversion // The SourceProfile param provides the connection details to use the go SQL library. -func DataConv(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, writeLimit int64) (*writer.BatchWriter, error) { +func (ci *ConvImpl) DataConv(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, writeLimit int64, dataFromSource DataFromSourceInterface) (*writer.BatchWriter, error) { config := writer.BatchWriterConfig{ BytesLimit: 100 * 1000 * 1000, WriteLimit: writeLimit, @@ -119,606 +89,32 @@ func DataConv(ctx context.Context, sourceProfile profiles.SourceProfile, targetP } switch sourceProfile.Driver { case constants.POSTGRES, constants.MYSQL, constants.DYNAMODB, constants.SQLSERVER, constants.ORACLE: - return dataFromDatabase(ctx, sourceProfile, targetProfile, config, conv, client) + return dataFromSource.dataFromDatabase(ctx, sourceProfile, targetProfile, config, conv, client, &GetInfoImpl{}, &DataFromDatabaseImpl{}, &SnapshotMigrationImpl{}) case constants.PGDUMP, constants.MYSQLDUMP: if conv.SpSchema.CheckInterleaved() { return nil, fmt.Errorf("spanner migration tool does not currently support data conversion from dump files\nif the schema contains interleaved tables. Suggest using direct access to source database\ni.e. using drivers postgres and mysql") } - return dataFromDump(sourceProfile.Driver, config, ioHelper, client, conv, dataOnly) + return dataFromSource.dataFromDump(sourceProfile.Driver, config, ioHelper, client, conv, dataOnly, &ProcessDumpByDialectImpl{}, &PopulateDataConvImpl{}) case constants.CSV: - return dataFromCSV(ctx, sourceProfile, targetProfile, config, conv, client) + return dataFromSource.dataFromCSV(ctx, sourceProfile, targetProfile, config, conv, client, &PopulateDataConvImpl{}, &csv.CsvImpl{}) default: return nil, fmt.Errorf("data conversion for driver %s not supported", sourceProfile.Driver) } } -func connectionConfig(sourceProfile profiles.SourceProfile) (interface{}, error) { - switch sourceProfile.Driver { - // For PG and MYSQL, When called as part of the subcommand flow, host/user/db etc will - // never be empty as we error out right during source profile creation. If any of them - // are empty, that means this was called through the legacy cmd flow and we create the - // string using env vars. - case constants.POSTGRES: - pgConn := sourceProfile.Conn.Pg - if !(pgConn.Host != "" && pgConn.User != "" && pgConn.Db != "") { - return profiles.GeneratePGSQLConnectionStr() - } else { - return profiles.GetSQLConnectionStr(sourceProfile), nil - } - case constants.MYSQL: - // If empty, this is called as part of the legacy mode witih global CLI flags. - // When using source-profile mode is used, the sqlConnectionStr is already populated. - mysqlConn := sourceProfile.Conn.Mysql - if !(mysqlConn.Host != "" && mysqlConn.User != "" && mysqlConn.Db != "") { - return profiles.GenerateMYSQLConnectionStr() - } else { - return profiles.GetSQLConnectionStr(sourceProfile), nil - } - // For Dynamodb, both legacy and new flows use env vars. - case constants.DYNAMODB: - return getDynamoDBClientConfig() - case constants.SQLSERVER: - return profiles.GetSQLConnectionStr(sourceProfile), nil - case constants.ORACLE: - return profiles.GetSQLConnectionStr(sourceProfile), nil - default: - return "", fmt.Errorf("driver %s not supported", sourceProfile.Driver) - } -} - -func getDbNameFromSQLConnectionStr(driver, sqlConnectionStr string) string { - switch driver { - case constants.POSTGRES: - dbParam := strings.Split(sqlConnectionStr, " ")[4] - return strings.Split(dbParam, "=")[1] - case constants.MYSQL: - return strings.Split(sqlConnectionStr, ")/")[1] - case constants.SQLSERVER: - splts := strings.Split(sqlConnectionStr, "?database=") - return splts[len(splts)-1] - case constants.ORACLE: - // connection string formate : "oracle://user:password@104.108.154.85:1521/XE" - substr := sqlConnectionStr[9:] - dbName := strings.Split(substr, ":")[0] - return dbName - } - return "" -} - -func getInfoSchemaForShard(shardConnInfo profiles.DirectConnectionConfig, driver string, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { - params := make(map[string]string) - params["host"] = shardConnInfo.Host - params["user"] = shardConnInfo.User - params["dbName"] = shardConnInfo.DbName - params["port"] = shardConnInfo.Port - params["password"] = shardConnInfo.Password - //while adding other sources, a switch-case will be added here on the basis of the driver input param passed. - //pased on the driver name, profiles.NewSourceProfileConnection will need to be called to create - //the source profile information. - sourceProfileConnectionMySQL, err := profiles.NewSourceProfileConnectionMySQL(params) - if err != nil { - return nil, fmt.Errorf("cannot parse connection configuration for the primary shard") - } - sourceProfileConnection := profiles.SourceProfileConnection{Mysql: sourceProfileConnectionMySQL, Ty: profiles.SourceProfileConnectionTypeMySQL} - //create a source profile which contains the sourceProfileConnection object for the primary shard - //this is done because GetSQLConnectionStr() should not be aware of sharding - newSourceProfile := profiles.SourceProfile{Conn: sourceProfileConnection, Ty: profiles.SourceProfileTypeConnection} - newSourceProfile.Driver = driver - infoSchema, err := GetInfoSchema(newSourceProfile, targetProfile) - if err != nil { - return nil, err - } - return infoSchema, nil -} - -func schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (*internal.Conv, error) { - conv := internal.MakeConv() - conv.SpDialect = targetProfile.Conn.Sp.Dialect - //handle fetching schema differently for sharded migrations, we only connect to the primary shard to - //fetch the schema. We reuse the SourceProfileConnection object for this purpose. - var infoSchema common.InfoSchema - var err error - isSharded := false - switch sourceProfile.Ty { - case profiles.SourceProfileTypeConfig: - isSharded = true - //Find Primary Shard Name - if sourceProfile.Config.ConfigType == constants.BULK_MIGRATION { - schemaSource := sourceProfile.Config.ShardConfigurationBulk.SchemaSource - infoSchema, err = getInfoSchemaForShard(schemaSource, sourceProfile.Driver, targetProfile) - if err != nil { - return conv, err - } - } else if sourceProfile.Config.ConfigType == constants.DATAFLOW_MIGRATION { - schemaSource := sourceProfile.Config.ShardConfigurationDataflow.SchemaSource - infoSchema, err = getInfoSchemaForShard(schemaSource, sourceProfile.Driver, targetProfile) - if err != nil { - return conv, err - } - } else if sourceProfile.Config.ConfigType == constants.DMS_MIGRATION { - // TODO: Define the schema processing logic for DMS migrations here. - return conv, fmt.Errorf("dms based migrations are not implemented yet") - } else { - return conv, fmt.Errorf("unknown type of migration, please select one of bulk, dataflow or dms") - } - case profiles.SourceProfileTypeCloudSQL: - infoSchema, err = GetInfoSchemaFromCloudSQL(sourceProfile, targetProfile) - if err != nil { - return conv, err - } - - default: - infoSchema, err = GetInfoSchema(sourceProfile, targetProfile) - if err != nil { - return conv, err - } - } - additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ - IsSharded: isSharded, - } - return conv, common.ProcessSchema(conv, infoSchema, common.DefaultWorkers, additionalSchemaAttributes) -} - -func performSnapshotMigration(config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema, additionalAttributes internal.AdditionalDataAttributes) *writer.BatchWriter { - common.SetRowStats(conv, infoSchema) - totalRows := conv.Rows() - if !conv.Audit.DryRun { - conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) - } - batchWriter := populateDataConv(conv, config, client) - common.ProcessData(conv, infoSchema, additionalAttributes) - batchWriter.Flush() - return batchWriter -} - -func snapshotMigrationHandler(sourceProfile profiles.SourceProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema) (*writer.BatchWriter, error) { - switch sourceProfile.Driver { - // Skip snapshot migration via Spanner migration tool for mysql and oracle since dataflow job will job will handle this from backfilled data. - case constants.MYSQL, constants.ORACLE, constants.POSTGRES: - return &writer.BatchWriter{}, nil - case constants.DYNAMODB: - return performSnapshotMigration(config, conv, client, infoSchema, internal.AdditionalDataAttributes{ShardId: ""}), nil - default: - return &writer.BatchWriter{}, fmt.Errorf("streaming migration not supported for driver %s", sourceProfile.Driver) - } -} - -func updateShardsWithTuningConfigs(shardedTuningConfig profiles.ShardConfigurationDataflow) { - for _, dataShard := range shardedTuningConfig.DataShards { - dataShard.DatastreamConfig = shardedTuningConfig.DatastreamConfig - dataShard.GcsConfig = shardedTuningConfig.GcsConfig - dataShard.DataflowConfig = shardedTuningConfig.DataflowConfig - } -} - -func dataFromDatabase(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client) (*writer.BatchWriter, error) { - //handle migrating data for sharded migrations differently - //sharded migrations are identified via the config= flag, if that flag is not present - //carry on with the existing code path in the else block - switch sourceProfile.Ty { - case profiles.SourceProfileTypeConfig: - ////There are three cases to cover here, bulk migrations and sharded migrations (and later DMS) - //We provide an if-else based handling for each within the sharded code branch - //This will be determined via the configType, which can be "bulk", "dataflow" or "dms" - if sourceProfile.Config.ConfigType == constants.BULK_MIGRATION { - return dataFromDatabaseForBulkMigration(sourceProfile, targetProfile, config, conv, client) - } else if sourceProfile.Config.ConfigType == constants.DATAFLOW_MIGRATION { - return dataFromDatabaseForDataflowMigration(targetProfile, ctx, sourceProfile, conv) - } else if sourceProfile.Config.ConfigType == constants.DMS_MIGRATION { - return dataFromDatabaseForDMSMigration() - } else { - return nil, fmt.Errorf("configType should be one of 'bulk', 'dataflow' or 'dms'") - } - default: - var infoSchema common.InfoSchema - var err error - if sourceProfile.Ty == profiles.SourceProfileTypeCloudSQL { - infoSchema, err = GetInfoSchemaFromCloudSQL(sourceProfile, targetProfile) - if err != nil { - return nil, err - } - } else { - infoSchema, err = GetInfoSchema(sourceProfile, targetProfile) - if err != nil { - return nil, err - } - } - var streamInfo map[string]interface{} - // minimal downtime migration for a single shard - if sourceProfile.Conn.Streaming { - //Generate a job Id - migrationJobId := conv.Audit.MigrationRequestId - logger.Log.Info(fmt.Sprintf("Creating a migration job with id: %v. This jobId can be used in future commmands (such as cleanup) to refer to this job.\n", migrationJobId)) - streamInfo, err = infoSchema.StartChangeDataCapture(ctx, conv) - if err != nil { - return nil, err - } - bw, err := snapshotMigrationHandler(sourceProfile, config, conv, client, infoSchema) - if err != nil { - return nil, err - } - dfOutput, err := infoSchema.StartStreamingMigration(ctx, client, conv, streamInfo) - if err != nil { - return nil, err - } - dfJobId := dfOutput.JobID - gcloudCmd := dfOutput.GCloudCmd - streamingCfg, _ := streamInfo["streamingCfg"].(streaming.StreamingCfg) - // Fetch and store the GCS bucket associated with the datastream - dsClient := getDatastreamClient(ctx) - gcsBucket, gcsDestPrefix, fetchGcsErr := streaming.FetchTargetBucketAndPath(ctx, dsClient, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig) - if fetchGcsErr != nil { - logger.Log.Info("Could not fetch GCS Bucket, hence Monitoring Dashboard will not contain Metrics for the gcs bucket\n") - logger.Log.Debug("Error", zap.Error(fetchGcsErr)) - } - - // Try to apply lifecycle rule to Datastream destination bucket. - gcsConfig := streamingCfg.GcsCfg - if gcsConfig.TtlInDaysSet { - err = streaming.EnableBucketLifecycleDeleteRule(ctx, gcsBucket, []string{gcsDestPrefix}, gcsConfig.TtlInDays) - if err != nil { - logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) - logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") - } - } - - monitoringResources := metrics.MonitoringMetricsResources{ - ProjectId: targetProfile.Conn.Sp.Project, - DataflowJobId: dfOutput.JobID, - DatastreamId: streamingCfg.DatastreamCfg.StreamId, - JobMetadataGcsBucket: gcsBucket, - PubsubSubscriptionId: streamingCfg.PubsubCfg.SubscriptionId, - SpannerInstanceId: targetProfile.Conn.Sp.Instance, - SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, - ShardId: "", - MigrationRequestId: conv.Audit.MigrationRequestId, - } - respDash, dashboardErr := monitoringResources.CreateDataflowShardMonitoringDashboard(ctx) - var dashboardName string - if dashboardErr != nil { - dashboardName = "" - logger.Log.Info("Creation of the monitoring dashboard failed, please create the dashboard manually") - logger.Log.Debug("Error", zap.Error(dashboardErr)) - } else { - dashboardName = strings.Split(respDash.Name, "/")[3] - fmt.Printf("Monitoring Dashboard: %+v\n", dashboardName) - } - // store the generated resources locally in conv, this is used as source of truth for persistence and the UI (should change to persisted values) - streaming.StoreGeneratedResources(conv, streamingCfg, dfJobId, gcloudCmd, targetProfile.Conn.Sp.Project, "", internal.GcsResources{BucketName: gcsBucket}, dashboardName) - //persist job and shard level data in the metadata db - err = streaming.PersistJobDetails(ctx, targetProfile, sourceProfile, conv, migrationJobId, false) - if err != nil { - logger.Log.Info(fmt.Sprintf("Error storing job details in SMT metadata store...the migration job will still continue as intended. %v", err)) - } else { - //only attempt persisting shard level data if the job level data is persisted - err = streaming.PersistResources(ctx, targetProfile, sourceProfile, conv, migrationJobId, constants.DEFAULT_SHARD_ID) - if err != nil { - logger.Log.Info(fmt.Sprintf("Error storing details for migration job: %s, data shard: %s in SMT metadata store...the migration job will still continue as intended. err = %v\n", migrationJobId, constants.DEFAULT_SHARD_ID, err)) - } - } - return bw, nil - } - //bulk migration for a single shard - return performSnapshotMigration(config, conv, client, infoSchema, internal.AdditionalDataAttributes{ShardId: ""}), nil - } +type ReportInterface interface { + GenerateReport(driver string, badWrites map[string]int64, BytesRead int64, banner string, conv *internal.Conv, reportFileName string, dbName string, out *os.File) } -// TODO: Define the data processing logic for DMS migrations here. -func dataFromDatabaseForDMSMigration() (*writer.BatchWriter, error) { - return nil, fmt.Errorf("dms configType is not implemented yet, please use one of 'bulk' or 'dataflow'") -} - -// 1. Create batch for each physical shard -// 2. Create streaming cfg from the config source type. -// 3. Verify the CFG and update it with SMT defaults -// 4. Launch the stream for the physical shard -// 5. Perform streaming migration via dataflow -func dataFromDatabaseForDataflowMigration(targetProfile profiles.TargetProfile, ctx context.Context, sourceProfile profiles.SourceProfile, conv *internal.Conv) (*writer.BatchWriter, error) { - updateShardsWithTuningConfigs(sourceProfile.Config.ShardConfigurationDataflow) - //Generate a job Id - migrationJobId := conv.Audit.MigrationRequestId - fmt.Printf("Creating a migration job with id: %v. This jobId can be used in future commmands (such as cleanup) to refer to this job.\n", migrationJobId) - conv.Audit.StreamingStats.ShardToShardResourcesMap = make(map[string]internal.ShardResources) - schemaDetails, err := common.GetIncludedSrcTablesFromConv(conv) - if err != nil { - fmt.Printf("unable to determine tableList from schema, falling back to full database") - schemaDetails = map[string]internal.SchemaDetails{} - } - err = streaming.PersistJobDetails(ctx, targetProfile, sourceProfile, conv, migrationJobId, true) - if err != nil { - logger.Log.Info(fmt.Sprintf("Error storing job details in SMT metadata store...the migration job will still continue as intended. %v", err)) - } - asyncProcessShards := func(p *profiles.DataShard, mutex *sync.Mutex) common.TaskResult[*profiles.DataShard] { - dbNameToShardIdMap := make(map[string]string) - for _, l := range p.LogicalShards { - dbNameToShardIdMap[l.DbName] = l.LogicalShardId - } - if p.DataShardId == "" { - dataShardId, err := utils.GenerateName("smt-datashard") - dataShardId = strings.Replace(dataShardId, "_", "-", -1) - if err != nil { - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - p.DataShardId = dataShardId - fmt.Printf("Data shard id generated: %v\n", p.DataShardId) - } - streamingCfg := streaming.CreateStreamingConfig(*p) - err := streaming.VerifyAndUpdateCfg(&streamingCfg, targetProfile.Conn.Sp.Dbname, schemaDetails) - if err != nil { - err = fmt.Errorf("failed to process shard: %s, there seems to be an error in the sharding configuration, error: %v", p.DataShardId, err) - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - fmt.Printf("Initiating migration for shard: %v\n", p.DataShardId) - pubsubCfg, err := streaming.CreatePubsubResources(ctx, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig, targetProfile.Conn.Sp.Dbname) - if err != nil { - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - streamingCfg.PubsubCfg = *pubsubCfg - err = streaming.LaunchStream(ctx, sourceProfile, p.LogicalShards, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg) - if err != nil { - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - streamingCfg.DataflowCfg.DbNameToShardIdMap = dbNameToShardIdMap - dfOutput, err := streaming.StartDataflow(ctx, targetProfile, streamingCfg, conv) - if err != nil { - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - // store the generated resources locally in conv, this is used as source of truth for persistence and the UI (should change to persisted values) - - // Fetch and store the GCS bucket associated with the datastream - dsClient := getDatastreamClient(ctx) - gcsBucket, gcsDestPrefix, fetchGcsErr := streaming.FetchTargetBucketAndPath(ctx, dsClient, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig) - if fetchGcsErr != nil { - logger.Log.Info(fmt.Sprintf("Could not fetch GCS Bucket for Shard %s hence Monitoring Dashboard will not contain Metrics for the gcs bucket\n", p.DataShardId)) - logger.Log.Debug("Error", zap.Error(fetchGcsErr)) - } - - // Try to apply lifecycle rule to Datastream destination bucket. - gcsConfig := streamingCfg.GcsCfg - if gcsConfig.TtlInDaysSet { - err = streaming.EnableBucketLifecycleDeleteRule(ctx, gcsBucket, []string{gcsDestPrefix}, gcsConfig.TtlInDays) - if err != nil { - logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) - logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") - } - } - - // create monitoring dashboard for a single shard - monitoringResources := metrics.MonitoringMetricsResources{ - ProjectId: targetProfile.Conn.Sp.Project, - DataflowJobId: dfOutput.JobID, - DatastreamId: streamingCfg.DatastreamCfg.StreamId, - JobMetadataGcsBucket: gcsBucket, - PubsubSubscriptionId: streamingCfg.PubsubCfg.SubscriptionId, - SpannerInstanceId: targetProfile.Conn.Sp.Instance, - SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, - ShardId: p.DataShardId, - MigrationRequestId: conv.Audit.MigrationRequestId, - } - respDash, dashboardErr := monitoringResources.CreateDataflowShardMonitoringDashboard(ctx) - var dashboardName string - if dashboardErr != nil { - dashboardName = "" - logger.Log.Info(fmt.Sprintf("Creation of the monitoring dashboard for shard %s failed, please create the dashboard manually\n", p.DataShardId)) - logger.Log.Debug("Error", zap.Error(dashboardErr)) - } else { - dashboardName = strings.Split(respDash.Name, "/")[3] - fmt.Printf("Monitoring Dashboard for shard %v: %+v\n", p.DataShardId, dashboardName) - } - streaming.StoreGeneratedResources(conv, streamingCfg, dfOutput.JobID, dfOutput.GCloudCmd, targetProfile.Conn.Sp.Project, p.DataShardId, internal.GcsResources{BucketName: gcsBucket}, dashboardName) - //persist the generated resources in a metadata db - err = streaming.PersistResources(ctx, targetProfile, sourceProfile, conv, migrationJobId, p.DataShardId) - if err != nil { - fmt.Printf("Error storing generated resources in SMT metadata store for dataShardId: %s...the migration job will still continue as intended, error: %v\n", p.DataShardId, err) - } - return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} - } - _, err = common.RunParallelTasks(sourceProfile.Config.ShardConfigurationDataflow.DataShards, 20, asyncProcessShards, true) - if err != nil { - return nil, fmt.Errorf("unable to start minimal downtime migrations: %v", err) - } - - // create monitoring aggregated dashboard for sharded migration - aggMonitoringResources := metrics.MonitoringMetricsResources{ - ProjectId: targetProfile.Conn.Sp.Project, - SpannerInstanceId: targetProfile.Conn.Sp.Instance, - SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, - ShardToShardResourcesMap: conv.Audit.StreamingStats.ShardToShardResourcesMap, - MigrationRequestId: conv.Audit.MigrationRequestId, - } - aggRespDash, dashboardErr := aggMonitoringResources.CreateDataflowAggMonitoringDashboard(ctx) - if dashboardErr != nil { - logger.Log.Error(fmt.Sprintf("Creation of the aggregated monitoring dashboard failed, please create the dashboard manually\n error=%v\n", dashboardErr)) - } else { - fmt.Printf("Aggregated Monitoring Dashboard: %+v\n", strings.Split(aggRespDash.Name, "/")[3]) - conv.Audit.StreamingStats.AggMonitoringResources = internal.MonitoringResources{DashboardName: strings.Split(aggRespDash.Name, "/")[3]} - } - err = streaming.PersistAggregateMonitoringResources(ctx, targetProfile, sourceProfile, conv, migrationJobId) - if err != nil { - logger.Log.Info(fmt.Sprintf("Unable to store aggregated monitoring dashboard in metadata database\n error=%v\n", err)) - } else { - logger.Log.Debug("Aggregate monitoring resources stored successfully.\n") - } - return &writer.BatchWriter{}, nil -} - -// 1. Migrate the data from the data shards, the schema shard needs to be specified here again. -// 2. Create a connection profile object for it -// 3. Perform a snapshot migration for the shard -// 4. Once all shard migrations are complete, return the batch writer object -func dataFromDatabaseForBulkMigration(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client) (*writer.BatchWriter, error) { - var bw *writer.BatchWriter - for _, dataShard := range sourceProfile.Config.ShardConfigurationBulk.DataShards { - - fmt.Printf("Initiating migration for shard: %v\n", dataShard.DbName) - infoSchema, err := getInfoSchemaForShard(dataShard, sourceProfile.Driver, targetProfile) - if err != nil { - return nil, err - } - additionalDataAttributes := internal.AdditionalDataAttributes{ - ShardId: dataShard.DataShardId, - } - bw = performSnapshotMigration(config, conv, client, infoSchema, additionalDataAttributes) - } - - return bw, nil -} - -func getDynamoDBClientConfig() (*aws.Config, error) { - cfg := aws.Config{} - endpointOverride := os.Getenv("DYNAMODB_ENDPOINT_OVERRIDE") - if endpointOverride != "" { - cfg.Endpoint = aws.String(endpointOverride) - } - return &cfg, nil -} - -func SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams) (*internal.Conv, error) { - f, n, err := getSeekable(ioHelper.In) - if err != nil { - utils.PrintSeekError(driver, err, ioHelper.Out) - return nil, fmt.Errorf("can't get seekable input file") - } - ioHelper.SeekableIn = f - ioHelper.BytesRead = n - conv := internal.MakeConv() - conv.SpDialect = spDialect - p := internal.NewProgress(n, "Generating schema", internal.Verbose(), false, int(internal.SchemaCreationInProgress)) - r := internal.NewReader(bufio.NewReader(f), p) - conv.SetSchemaMode() // Build schema and ignore data in dump. - conv.SetDataSink(nil) - err = ProcessDump(driver, conv, r) - if err != nil { - fmt.Fprintf(ioHelper.Out, "Failed to parse the data file: %v", err) - return nil, fmt.Errorf("failed to parse the data file") - } - p.Done() - return conv, nil -} - -func dataFromDump(driver string, config writer.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool) (*writer.BatchWriter, error) { - // TODO: refactor of the way we handle getSeekable - // to avoid the code duplication here - if !dataOnly { - _, err := ioHelper.SeekableIn.Seek(0, 0) - if err != nil { - fmt.Printf("\nCan't seek to start of file (preparation for second pass): %v\n", err) - return nil, fmt.Errorf("can't seek to start of file") - } - } else { - // Note: input file is kept seekable to plan for future - // changes in showing progress for data migration. - f, n, err := getSeekable(ioHelper.In) - if err != nil { - utils.PrintSeekError(driver, err, ioHelper.Out) - return nil, fmt.Errorf("can't get seekable input file") - } - ioHelper.SeekableIn = f - ioHelper.BytesRead = n - } - totalRows := conv.Rows() - - conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) - r := internal.NewReader(bufio.NewReader(ioHelper.SeekableIn), nil) - batchWriter := populateDataConv(conv, config, client) - ProcessDump(driver, conv, r) - batchWriter.Flush() - conv.Audit.Progress.Done() - - return batchWriter, nil -} - -func dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client) (*writer.BatchWriter, error) { - if targetProfile.Conn.Sp.Dbname == "" { - return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source") - } - conv.SpDialect = targetProfile.Conn.Sp.Dialect - dialect, err := targetProfile.FetchTargetDialect(ctx) - if err != nil { - return nil, fmt.Errorf("could not fetch dialect: %v", err) - } - if strings.ToLower(dialect) != constants.DIALECT_POSTGRESQL { - dialect = constants.DIALECT_GOOGLESQL - } - - if dialect != conv.SpDialect { - return nil, fmt.Errorf("dialect specified in target profile does not match spanner dialect") - } - - delimiterStr := sourceProfile.Csv.Delimiter - if len(delimiterStr) != 1 { - return nil, fmt.Errorf("delimiter should only be a single character long, found '%s'", delimiterStr) - } - - delimiter := rune(delimiterStr[0]) - - err = utils.ReadSpannerSchema(ctx, conv, client) - if err != nil { - return nil, fmt.Errorf("error trying to read and convert spanner schema: %v", err) - } - - tables, err := csv.GetCSVFiles(conv, sourceProfile) - if err != nil { - return nil, fmt.Errorf("error finding csv files: %v", err) - } - - // Find the number of rows in each csv file for generating stats. - err = csv.SetRowStats(conv, tables, delimiter) - if err != nil { - return nil, err - } - - totalRows := conv.Rows() - conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) - batchWriter := populateDataConv(conv, config, client) - err = csv.ProcessCSV(conv, tables, sourceProfile.Csv.NullStr, delimiter) - if err != nil { - return nil, fmt.Errorf("can't process csv: %v", err) - } - batchWriter.Flush() - conv.Audit.Progress.Done() - return batchWriter, nil -} - -func populateDataConv(conv *internal.Conv, config writer.BatchWriterConfig, client *sp.Client) *writer.BatchWriter { - rows := int64(0) - config.Write = func(m []*sp.Mutation) error { - ctx := context.Background() - if !conv.Audit.SkipMetricsPopulation { - migrationData := metrics.GetMigrationData(conv, "", constants.DataConv) - serializedMigrationData, _ := proto.Marshal(migrationData) - migrationMetadataValue := base64.StdEncoding.EncodeToString(serializedMigrationData) - ctx = metadata.AppendToOutgoingContext(context.Background(), constants.MigrationMetadataKey, migrationMetadataValue) - } - _, err := client.Apply(ctx, m) - if err != nil { - return err - } - atomic.AddInt64(&rows, int64(len(m))) - conv.Audit.Progress.MaybeReport(atomic.LoadInt64(&rows)) - return nil - } - batchWriter := writer.NewBatchWriter(config) - conv.SetDataMode() - if !conv.Audit.DryRun { - conv.SetDataSink( - func(table string, cols []string, vals []interface{}) { - batchWriter.AddRow(table, cols, vals) - }) - conv.DataFlush = func() { - batchWriter.Flush() - } - } - - return batchWriter -} +type ReportImpl struct {} // Report generates a report of schema and data conversion. -func Report(driver string, badWrites map[string]int64, BytesRead int64, banner string, conv *internal.Conv, reportFileName string, dbName string, out *os.File) { +func (r *ReportImpl) GenerateReport(driver string, badWrites map[string]int64, BytesRead int64, banner string, conv *internal.Conv, reportFileName string, dbName string, out *os.File) { //Write the structured report file structuredReportFileName := fmt.Sprintf("%s.%s", reportFileName, "structured_report.json") - structuredReport := reports.GenerateStructuredReport(driver, dbName, conv, badWrites, true, true) + reportGenerator := reports.ReportImpl{} + structuredReport := reportGenerator.GenerateStructuredReport(driver, dbName, conv, badWrites, true, true) fBytes, _ := json.MarshalIndent(structuredReport, "", " ") f, err := os.Create(structuredReportFileName) if err != nil { @@ -742,7 +138,7 @@ func Report(driver string, badWrites map[string]int64, BytesRead int64, banner s } w := bufio.NewWriter(f) w.WriteString(banner) - reports.GenerateTextReport(structuredReport, w) + reportGenerator.GenerateTextReport(structuredReport, w) w.Flush() var isDump bool @@ -763,663 +159,3 @@ func Report(driver string, badWrites map[string]int64, BytesRead int64, banner s fmt.Fprintf(out, "See file '%s' for details of the schema and data conversions.\n", reportFileName) } } - -// getSeekable returns a seekable file (with same content as f) and the size of the content (in bytes). -func getSeekable(f *os.File) (*os.File, int64, error) { - _, err := f.Seek(0, 0) - if err == nil { // Stdin is seekable, let's just use that. This happens when you run 'cmd < file'. - n, err := utils.GetFileSize(f) - return f, n, err - } - internal.VerbosePrintln("Creating a tmp file with a copy of stdin because stdin is not seekable.") - logger.Log.Debug("Creating a tmp file with a copy of stdin because stdin is not seekable.") - - // Create file in os.TempDir. Its not clear this is a good idea e.g. if the - // pg_dump/mysqldump output is large (tens of GBs) and os.TempDir points to a directory - // (such as /tmp) that's configured with a small amount of disk space. - // To workaround such limits on Unix, set $TMPDIR to a directory with lots - // of disk space. - fcopy, err := ioutil.TempFile("", "spanner-migration-tool.data") - if err != nil { - return nil, 0, err - } - syscall.Unlink(fcopy.Name()) // File will be deleted when this process exits. - _, err = io.Copy(fcopy, f) - if err != nil { - return nil, 0, fmt.Errorf("can't write stdin to tmp file: %w", err) - } - _, err = fcopy.Seek(0, 0) - if err != nil { - return nil, 0, fmt.Errorf("can't reset file offset: %w", err) - } - n, _ := utils.GetFileSize(fcopy) - return fcopy, n, nil -} - -// VerifyDb checks whether the db exists and if it does, verifies if the schema is what we currently support. -func VerifyDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (dbExists bool, err error) { - dbExists, err = CheckExistingDb(ctx, adminClient, dbURI) - if err != nil { - return dbExists, err - } - if dbExists { - err = ValidateDDL(ctx, adminClient, dbURI) - } - return dbExists, err -} - -// CheckExistingDb checks whether the database with dbURI exists or not. -// If API call doesn't respond then user is informed after every 5 minutes on command line. -func CheckExistingDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (bool, error) { - gotResponse := make(chan bool) - var err error - go func() { - _, err = adminClient.GetDatabase(ctx, &adminpb.GetDatabaseRequest{Name: dbURI}) - gotResponse <- true - }() - for { - select { - case <-time.After(5 * time.Minute): - fmt.Println("WARNING! API call not responding: make sure that spanner api endpoint is configured properly") - case <-gotResponse: - if err != nil { - if utils.ContainsAny(strings.ToLower(err.Error()), []string{"database not found"}) { - return false, nil - } - return false, fmt.Errorf("can't get database info: %s", err) - } - return true, nil - } - } -} - -// ValidateTables validates that all the tables in the database are empty. -// It returns the name of the first non-empty table if found, and an empty string otherwise. -func ValidateTables(ctx context.Context, client *sp.Client, spDialect string) (string, error) { - infoSchema := spanner.InfoSchemaImpl{Client: client, Ctx: ctx, SpDialect: spDialect} - tables, err := infoSchema.GetTables() - if err != nil { - return "", err - } - for _, table := range tables { - count, err := infoSchema.GetRowCount(table) - if err != nil { - return "", err - } - if count != 0 { - return table.Name, nil - } - } - return "", nil -} - -// ValidateDDL verifies if an existing DB's ddl follows what is supported by Spanner migration tool. Currently, -// we only support empty schema when db already exists. -func ValidateDDL(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) error { - dbDdl, err := adminClient.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{Database: dbURI}) - if err != nil { - return fmt.Errorf("can't fetch database ddl: %v", err) - } - if len(dbDdl.Statements) != 0 { - return fmt.Errorf("spanner migration tool supports writing to existing databases only if they have an empty schema") - } - return nil -} - -// CreatesOrUpdatesDatabase updates an existing Spanner database or creates a new one if one does not exist. -func CreateOrUpdateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI, driver string, conv *internal.Conv, out *os.File, migrationType string) error { - dbExists, err := VerifyDb(ctx, adminClient, dbURI) - if err != nil { - return err - } - if !conv.Audit.SkipMetricsPopulation { - // Adding migration metadata to the outgoing context. - migrationData := metrics.GetMigrationData(conv, driver, constants.SchemaConv) - serializedMigrationData, _ := proto.Marshal(migrationData) - migrationMetadataValue := base64.StdEncoding.EncodeToString(serializedMigrationData) - ctx = metadata.AppendToOutgoingContext(ctx, constants.MigrationMetadataKey, migrationMetadataValue) - } - if dbExists { - if conv.SpDialect != constants.DIALECT_POSTGRESQL && migrationType == constants.DATAFLOW_MIGRATION { - return fmt.Errorf("spanner migration tool does not support minimal downtime schema/schema-and-data migrations to an existing database") - } - err := UpdateDatabase(ctx, adminClient, dbURI, conv, out, driver) - if err != nil { - return fmt.Errorf("can't update database schema: %v", err) - } - } else { - err := CreateDatabase(ctx, adminClient, dbURI, conv, out, driver, migrationType) - if err != nil { - return fmt.Errorf("can't create database: %v", err) - } - } - return nil -} - -// CreateDatabase returns a newly create Spanner DB. -// It automatically determines an appropriate project, selects a -// Spanner instance to use, generates a new Spanner DB name, -// and call into the Spanner admin interface to create the new DB. -func CreateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string, conv *internal.Conv, out *os.File, driver string, migrationType string) error { - project, instance, dbName := utils.ParseDbURI(dbURI) - fmt.Fprintf(out, "Creating new database %s in instance %s with default permissions ... \n", dbName, instance) - // The schema we send to Spanner excludes comments (since Cloud - // Spanner DDL doesn't accept them), and protects table and col names - // using backticks (to avoid any issues with Spanner reserved words). - // Foreign Keys are set to false since we create them post data migration. - req := &adminpb.CreateDatabaseRequest{ - Parent: fmt.Sprintf("projects/%s/instances/%s", project, instance), - } - if conv.SpDialect == constants.DIALECT_POSTGRESQL { - // PostgreSQL dialect doesn't support: - // a) backticks around the database name, and - // b) DDL statements as part of a CreateDatabase operation (so schema - // must be set using a separate UpdateDatabase operation). - req.CreateStatement = "CREATE DATABASE \"" + dbName + "\"" - req.DatabaseDialect = adminpb.DatabaseDialect_POSTGRESQL - } else { - req.CreateStatement = "CREATE DATABASE `" + dbName + "`" - if migrationType == constants.DATAFLOW_MIGRATION { - req.ExtraStatements = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) - } else { - req.ExtraStatements = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: false, SpDialect: conv.SpDialect, Source: driver}) - } - - } - - op, err := adminClient.CreateDatabase(ctx, req) - if err != nil { - return fmt.Errorf("can't build CreateDatabaseRequest: %w", utils.AnalyzeError(err, dbURI)) - } - if _, err := op.Wait(ctx); err != nil { - return fmt.Errorf("createDatabase call failed: %w", utils.AnalyzeError(err, dbURI)) - } - fmt.Fprintf(out, "Created database successfully.\n") - - if conv.SpDialect == constants.DIALECT_POSTGRESQL { - // Update schema separately for PG databases. - return UpdateDatabase(ctx, adminClient, dbURI, conv, out, driver) - } - return nil -} - -// UpdateDatabase updates an existing spanner database. -func UpdateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string, conv *internal.Conv, out *os.File, driver string) error { - fmt.Fprintf(out, "Updating schema for %s with default permissions ... \n", dbURI) - // The schema we send to Spanner excludes comments (since Cloud - // Spanner DDL doesn't accept them), and protects table and col names - // using backticks (to avoid any issues with Spanner reserved words). - // Foreign Keys are set to false since we create them post data migration. - schema := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: false, SpDialect: conv.SpDialect, Source: driver}) - req := &adminpb.UpdateDatabaseDdlRequest{ - Database: dbURI, - Statements: schema, - } - // Update queries for postgres as target db return response after more - // than 1 min for large schemas, therefore, timeout is specified as 5 minutes - ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) - defer cancel() - op, err := adminClient.UpdateDatabaseDdl(ctx, req) - if err != nil { - return fmt.Errorf("can't build UpdateDatabaseDdlRequest: %w", utils.AnalyzeError(err, dbURI)) - } - if err := op.Wait(ctx); err != nil { - return fmt.Errorf("UpdateDatabaseDdl call failed: %w", utils.AnalyzeError(err, dbURI)) - } - fmt.Fprintf(out, "Updated schema successfully.\n") - return nil -} - -// UpdateDDLForeignKeys updates the Spanner database with foreign key -// constraints using ALTER TABLE statements. -func UpdateDDLForeignKeys(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string, conv *internal.Conv, out *os.File, driver string, migrationType string) error { - - if conv.SpDialect != constants.DIALECT_POSTGRESQL && migrationType == constants.DATAFLOW_MIGRATION { - //foreign keys were applied as part of CreateDatabase - return nil - } - - // The schema we send to Spanner excludes comments (since Cloud - // Spanner DDL doesn't accept them), and protects table and col names - // using backticks (to avoid any issues with Spanner reserved words). - fkStmts := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: false, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) - if len(fkStmts) == 0 { - return nil - } - if len(fkStmts) > 50 { - fmt.Println(` -Warning: Large number of foreign keys detected. Spanner can take a long amount of -time to create foreign keys (over 5 mins per batch of Foreign Keys even with no data). -Spanner migration tool does not have control over a single foreign key creation time. The number -of concurrent Foreign Key Creation Requests sent to spanner can be increased by -tweaking the MaxWorkers variable (https://github.com/GoogleCloudPlatform/spanner-migration-tool/blob/master/conversion/conversion.go#L89). -However, setting it to a very high value might lead to exceeding the admin quota limit. Spanner migration tool tries to stay under the -admin quota limit by spreading the FK creation requests over time.`) - } - msg := fmt.Sprintf("Updating schema of database %s with foreign key constraints ...", dbURI) - conv.Audit.Progress = *internal.NewProgress(int64(len(fkStmts)), msg, internal.Verbose(), true, int(internal.ForeignKeyUpdateInProgress)) - - workers := make(chan int, MaxWorkers) - for i := 1; i <= MaxWorkers; i++ { - workers <- i - } - var progressMutex sync.Mutex - progress := int64(0) - - // We dispatch parallel foreign key create requests to ensure the backfill runs in parallel to reduce overall time. - // This cuts down the time taken to a third (approx) compared to Serial and Batched creation. We also do not want to create - // too many requests and get throttled due to network or hitting catalog memory limits. - // Ensure atmost `MaxWorkers` go routines run in parallel that each update the ddl with one foreign key statement. - for _, fkStmt := range fkStmts { - workerID := <-workers - go func(fkStmt string, workerID int) { - defer func() { - // Locking the progress reporting otherwise progress results displayed could be in random order. - progressMutex.Lock() - progress++ - conv.Audit.Progress.MaybeReport(progress) - progressMutex.Unlock() - workers <- workerID - }() - internal.VerbosePrintf("Submitting new FK create request: %s\n", fkStmt) - logger.Log.Debug("Submitting new FK create request", zap.String("fkStmt", fkStmt)) - - op, err := adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ - Database: dbURI, - Statements: []string{fkStmt}, - }) - if err != nil { - fmt.Printf("Cannot submit request for create foreign key with statement: %s\n due to error: %s. Skipping this foreign key...\n", fkStmt, err) - conv.Unexpected(fmt.Sprintf("Can't add foreign key with statement %s: %s", fkStmt, err)) - return - } - if err := op.Wait(ctx); err != nil { - fmt.Printf("Can't add foreign key with statement: %s\n due to error: %s. Skipping this foreign key...\n", fkStmt, err) - conv.Unexpected(fmt.Sprintf("Can't add foreign key with statement %s: %s", fkStmt, err)) - return - } - internal.VerbosePrintln("Updated schema with statement: " + fkStmt) - logger.Log.Debug("Updated schema with statement", zap.String("fkStmt", fkStmt)) - }(fkStmt, workerID) - // Send out an FK creation request every second, with total of maxWorkers request being present in a batch. - time.Sleep(time.Second) - } - // Wait for all the goroutines to finish. - for i := 1; i <= MaxWorkers; i++ { - <-workers - } - conv.Audit.Progress.UpdateProgress("Foreign key update complete.", 100, internal.ForeignKeyUpdateComplete) - conv.Audit.Progress.Done() - return nil -} - -// WriteSchemaFile writes DDL statements in a file. It includes CREATE TABLE -// statements and ALTER TABLE statements to add foreign keys. -// The parameter name should end with a .txt. -func WriteSchemaFile(conv *internal.Conv, now time.Time, name string, out *os.File, driver string) { - f, err := os.Create(name) - if err != nil { - fmt.Fprintf(out, "Can't create schema file %s: %v\n", name, err) - return - } - - // The schema file we write out below is optimized for reading. It includes comments, foreign keys - // and doesn't add backticks around table and column names. This file is - // intended for explanatory and documentation purposes, and is not strictly - // legal Cloud Spanner DDL (Cloud Spanner doesn't currently support comments). - spDDL := conv.SpSchema.GetDDL(ddl.Config{Comments: true, ProtectIds: false, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) - if len(spDDL) == 0 { - spDDL = []string{"\n-- Schema is empty -- no tables found\n"} - } - l := []string{ - fmt.Sprintf("-- Schema generated %s\n", now.Format("2006-01-02 15:04:05")), - strings.Join(spDDL, ";\n\n"), - "\n", - } - if _, err := f.WriteString(strings.Join(l, "")); err != nil { - fmt.Fprintf(out, "Can't write out schema file: %v\n", err) - return - } - fmt.Fprintf(out, "Wrote schema to file '%s'.\n", name) - - // Convert . to .ddl.. - nameSplit := strings.Split(name, ".") - nameSplit = append(nameSplit[:len(nameSplit)-1], "ddl", nameSplit[len(nameSplit)-1]) - name = strings.Join(nameSplit, ".") - f, err = os.Create(name) - if err != nil { - fmt.Fprintf(out, "Can't create legal schema ddl file %s: %v\n", name, err) - return - } - - // We change 'Comments' to false and 'ProtectIds' to true below to write out a - // schema file that is a legal Cloud Spanner DDL. - spDDL = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) - if len(spDDL) == 0 { - spDDL = []string{"\n-- Schema is empty -- no tables found\n"} - } - l = []string{ - strings.Join(spDDL, ";\n\n"), - "\n", - } - if _, err = f.WriteString(strings.Join(l, "")); err != nil { - fmt.Fprintf(out, "Can't write out legal schema ddl file: %v\n", err) - return - } - fmt.Fprintf(out, "Wrote legal schema ddl to file '%s'.\n", name) -} - -// WriteSessionFile writes conv struct to a file in JSON format. -func WriteSessionFile(conv *internal.Conv, name string, out *os.File) { - f, err := os.Create(name) - if err != nil { - fmt.Fprintf(out, "Can't create session file %s: %v\n", name, err) - return - } - // Session file will basically contain 'conv' struct in JSON format. - // It contains all the information for schema and data conversion state. - convJSON, err := json.MarshalIndent(conv, "", " ") - if err != nil { - fmt.Fprintf(out, "Can't encode session state to JSON: %v\n", err) - return - } - if _, err := f.Write(convJSON); err != nil { - fmt.Fprintf(out, "Can't write out session file: %v\n", err) - return - } - fmt.Fprintf(out, "Wrote session to file '%s'.\n", name) -} - -// WriteConvGeneratedFiles creates a directory labeled downloads with the current timestamp -// where it writes the sessionfile, report summary and DDLs then returns the directory where it writes. -func WriteConvGeneratedFiles(conv *internal.Conv, dbName string, driver string, BytesRead int64, out *os.File) (string, error) { - now := time.Now() - dirPath := "spanner_migration_tool_output/" + dbName + "/" - err := os.MkdirAll(dirPath, os.ModePerm) - if err != nil { - fmt.Fprintf(out, "Can't create directory %s: %v\n", dirPath, err) - return "", err - } - schemaFileName := dirPath + dbName + "_schema.txt" - WriteSchemaFile(conv, now, schemaFileName, out, driver) - reportFileName := dirPath + dbName - Report(driver, nil, BytesRead, "", conv, reportFileName, dbName, out) - sessionFileName := dirPath + dbName + ".session.json" - WriteSessionFile(conv, sessionFileName, out) - return dirPath, nil -} - -// ReadSessionFile reads a session JSON file and -// unmarshal it's content into *internal.Conv. -func ReadSessionFile(conv *internal.Conv, sessionJSON string) error { - s, err := ioutil.ReadFile(sessionJSON) - if err != nil { - return err - } - err = json.Unmarshal(s, &conv) - if err != nil { - return err - } - return nil -} - -// WriteBadData prints summary stats about bad rows and writes detailed info -// to file 'name'. -func WriteBadData(bw *writer.BatchWriter, conv *internal.Conv, banner, name string, out *os.File) { - badConversions := conv.BadRows() - badWrites := utils.SumMapValues(bw.DroppedRowsByTable()) - - badDataStreaming := int64(0) - if conv.Audit.StreamingStats.Streaming { - badDataStreaming = getBadStreamingDataCount(conv) - } - - if badConversions == 0 && badWrites == 0 && badDataStreaming == 0 { - os.Remove(name) // Cleanup bad-data file from previous run. - return - } - f, err := os.Create(name) - if err != nil { - fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) - return - } - f.WriteString(banner) - maxRows := 100 - if badConversions > 0 { - l := conv.SampleBadRows(maxRows) - if int64(len(l)) < badConversions { - f.WriteString("A sample of rows that generated conversion errors:\n") - } else { - f.WriteString("Rows that generated conversion errors:\n") - } - for _, r := range l { - _, err := f.WriteString(" " + r + "\n") - if err != nil { - fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) - return - } - } - } - if badWrites > 0 { - l := bw.SampleBadRows(maxRows) - if int64(len(l)) < badWrites { - f.WriteString("A sample of rows that successfully converted but couldn't be written to Spanner:\n") - } else { - f.WriteString("Rows that successfully converted but couldn't be written to Spanner:\n") - } - for _, r := range l { - _, err := f.WriteString(" " + r + "\n") - if err != nil { - fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) - return - } - } - } - if badDataStreaming > 0 { - err = writeBadStreamingData(conv, f) - if err != nil { - fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) - return - } - } - - fmt.Fprintf(out, "See file '%s' for details of bad rows\n", name) -} - -// getBadStreamingDataCount returns the total sum of bad and dropped records during -// streaming migration process. -func getBadStreamingDataCount(conv *internal.Conv) int64 { - badDataCount := int64(0) - - for _, x := range conv.Audit.StreamingStats.BadRecords { - badDataCount += utils.SumMapValues(x) - } - for _, x := range conv.Audit.StreamingStats.DroppedRecords { - badDataCount += utils.SumMapValues(x) - } - return badDataCount -} - -// writeBadStreamingData writes sample of bad records and dropped records during streaming -// migration process to bad data file. -func writeBadStreamingData(conv *internal.Conv, f *os.File) error { - f.WriteString("\nBad data encountered during streaming migration:\n\n") - - stats := (conv.Audit.StreamingStats) - - badRecords := int64(0) - for _, x := range stats.BadRecords { - badRecords += utils.SumMapValues(x) - } - droppedRecords := int64(0) - for _, x := range stats.DroppedRecords { - droppedRecords += utils.SumMapValues(x) - } - - if badRecords > 0 { - l := stats.SampleBadRecords - if int64(len(l)) < badRecords { - f.WriteString("A sample of records that generated conversion errors:\n") - } else { - f.WriteString("Records that generated conversion errors:\n") - } - for _, r := range l { - _, err := f.WriteString(" " + r + "\n") - if err != nil { - return err - } - } - f.WriteString("\n") - } - if droppedRecords > 0 { - l := stats.SampleBadWrites - if int64(len(l)) < droppedRecords { - f.WriteString("A sample of records that successfully converted but couldn't be written to Spanner:\n") - } else { - f.WriteString("Records that successfully converted but couldn't be written to Spanner:\n") - } - for _, r := range l { - _, err := f.WriteString(" " + r + "\n") - if err != nil { - return err - } - } - } - return nil -} - -// ProcessDump invokes process dump function from a sql package based on driver selected. -func ProcessDump(driver string, conv *internal.Conv, r *internal.Reader) error { - switch driver { - case constants.MYSQLDUMP: - return common.ProcessDbDump(conv, r, mysql.DbDumpImpl{}) - case constants.PGDUMP: - return common.ProcessDbDump(conv, r, postgres.DbDumpImpl{}) - default: - return fmt.Errorf("process dump for driver %s not supported", driver) - } -} - -func GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { - driver := sourceProfile.Driver - switch driver { - case constants.MYSQL: - d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) - if err != nil { - return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) - } - var opts []cloudsqlconn.DialOption - instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Mysql.Project, sourceProfile.ConnCloudSQL.Mysql.Region, sourceProfile.ConnCloudSQL.Mysql.InstanceName) - mysqldriver.RegisterDialContext("cloudsqlconn", - func(ctx context.Context, addr string) (net.Conn, error) { - return d.Dial(ctx, instanceName, opts...) - }) - - dbURI := fmt.Sprintf("%s:empty@cloudsqlconn(localhost:3306)/%s?parseTime=true", - sourceProfile.ConnCloudSQL.Mysql.User, sourceProfile.ConnCloudSQL.Mysql.Db) - - db, err := sql.Open("mysql", dbURI) - if err != nil { - return nil, fmt.Errorf("sql.Open: %w", err) - } - return mysql.InfoSchemaImpl{ - DbName: sourceProfile.ConnCloudSQL.Mysql.Db, - Db: db, - SourceProfile: sourceProfile, - TargetProfile: targetProfile, - }, nil - case constants.POSTGRES: - d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) - if err != nil { - return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) - } - var opts []cloudsqlconn.DialOption - - dsn := fmt.Sprintf("user=%s database=%s", sourceProfile.ConnCloudSQL.Pg.User, sourceProfile.ConnCloudSQL.Pg.Db) - config, err := pgx.ParseConfig(dsn) - if err != nil { - return nil, err - } - instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Pg.Project, sourceProfile.ConnCloudSQL.Pg.Region, sourceProfile.ConnCloudSQL.Pg.InstanceName) - config.DialFunc = func(ctx context.Context, network, instance string) (net.Conn, error) { - return d.Dial(ctx, instanceName, opts...) - } - dbURI := stdlib.RegisterConnConfig(config) - db, err := sql.Open("pgx", dbURI) - if err != nil { - return nil, fmt.Errorf("sql.Open: %w", err) - } - temp := false - return postgres.InfoSchemaImpl{ - Db: db, - SourceProfile: sourceProfile, - TargetProfile: targetProfile, - IsSchemaUnique: &temp, //this is a workaround to set a bool pointer - }, nil - default: - return nil, fmt.Errorf("driver %s not supported", driver) - } -} - -func GetInfoSchema(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { - connectionConfig, err := connectionConfig(sourceProfile) - if err != nil { - return nil, err - } - driver := sourceProfile.Driver - switch driver { - case constants.MYSQL: - db, err := sql.Open(driver, connectionConfig.(string)) - dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) - if err != nil { - return nil, err - } - return mysql.InfoSchemaImpl{ - DbName: dbName, - Db: db, - SourceProfile: sourceProfile, - TargetProfile: targetProfile, - }, nil - case constants.POSTGRES: - db, err := sql.Open(driver, connectionConfig.(string)) - if err != nil { - return nil, err - } - temp := false - return postgres.InfoSchemaImpl{ - Db: db, - SourceProfile: sourceProfile, - TargetProfile: targetProfile, - IsSchemaUnique: &temp, //this is a workaround to set a bool pointer - }, nil - case constants.DYNAMODB: - mySession := session.Must(session.NewSession()) - dydbClient := dydb.New(mySession, connectionConfig.(*aws.Config)) - var dydbStreamsClient *dynamodbstreams.DynamoDBStreams - if sourceProfile.Conn.Streaming { - newSession := session.Must(session.NewSession()) - dydbStreamsClient = dynamodbstreams.New(newSession, connectionConfig.(*aws.Config)) - } - return dynamodb.InfoSchemaImpl{ - DynamoClient: dydbClient, - SampleSize: profiles.GetSchemaSampleSize(sourceProfile), - DynamoStreamsClient: dydbStreamsClient, - }, nil - case constants.SQLSERVER: - db, err := sql.Open(driver, connectionConfig.(string)) - dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) - if err != nil { - return nil, err - } - return sqlserver.InfoSchemaImpl{DbName: dbName, Db: db}, nil - case constants.ORACLE: - db, err := sql.Open(driver, connectionConfig.(string)) - dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) - if err != nil { - return nil, err - } - return oracle.InfoSchemaImpl{DbName: strings.ToUpper(dbName), Db: db, SourceProfile: sourceProfile, TargetProfile: targetProfile}, nil - default: - return nil, fmt.Errorf("driver %s not supported", driver) - } -} diff --git a/conversion/conversion_from_source.go b/conversion/conversion_from_source.go new file mode 100644 index 0000000000..8a1b73c87f --- /dev/null +++ b/conversion/conversion_from_source.go @@ -0,0 +1,339 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "bufio" + "context" + "fmt" + "strings" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/csv" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" + "go.uber.org/zap" +) + +type SchemaFromSourceInterface interface { + schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) + SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) + } + +type SchemaFromSourceImpl struct{} + +type DataFromSourceInterface interface { + dataFromDatabase(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, getInfo GetInfoInterface, dataFromDb DataFromDatabaseInterface, snapshotMigration SnapshotMigrationInterface) (*writer.BatchWriter, error) + dataFromDump(driver string, config writer.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, processDump ProcessDumpByDialectInterface, populateDataConv PopulateDataConvInterface) (*writer.BatchWriter, error) + dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, populateDataConv PopulateDataConvInterface, csv csv.CsvInterface) (*writer.BatchWriter, error) +} + +type DataFromSourceImpl struct{} + +func (sads *SchemaFromSourceImpl) schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) { + conv := internal.MakeConv() + conv.SpDialect = targetProfile.Conn.Sp.Dialect + //handle fetching schema differently for sharded migrations, we only connect to the primary shard to + //fetch the schema. We reuse the SourceProfileConnection object for this purpose. + var infoSchema common.InfoSchema + var err error + isSharded := false + switch sourceProfile.Ty { + case profiles.SourceProfileTypeConfig: + isSharded = true + //Find Primary Shard Name + if sourceProfile.Config.ConfigType == constants.BULK_MIGRATION { + schemaSource := sourceProfile.Config.ShardConfigurationBulk.SchemaSource + infoSchema, err = getInfo.getInfoSchemaForShard(schemaSource, sourceProfile.Driver, targetProfile, &profiles.SourceProfileDialectImpl{}, &GetInfoImpl{}) + if err != nil { + return conv, err + } + } else if sourceProfile.Config.ConfigType == constants.DATAFLOW_MIGRATION { + schemaSource := sourceProfile.Config.ShardConfigurationDataflow.SchemaSource + infoSchema, err = getInfo.getInfoSchemaForShard(schemaSource, sourceProfile.Driver, targetProfile, &profiles.SourceProfileDialectImpl{}, &GetInfoImpl{}) + if err != nil { + return conv, err + } + } else if sourceProfile.Config.ConfigType == constants.DMS_MIGRATION { + // TODO: Define the schema processing logic for DMS migrations here. + return conv, fmt.Errorf("dms based migrations are not implemented yet") + } else { + return conv, fmt.Errorf("unknown type of migration, please select one of bulk, dataflow or dms") + } + case profiles.SourceProfileTypeCloudSQL: + infoSchema, err = getInfo.GetInfoSchemaFromCloudSQL(sourceProfile, targetProfile) + if err != nil { + return conv, err + } + + default: + infoSchema, err = getInfo.GetInfoSchema(sourceProfile, targetProfile) + if err != nil { + return conv, err + } + } + additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ + IsSharded: isSharded, + } + return conv, processSchema.ProcessSchema(conv, infoSchema, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) +} + +func (sads *SchemaFromSourceImpl) SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) { + f, n, err := getSeekable(ioHelper.In) + if err != nil { + utils.PrintSeekError(driver, err, ioHelper.Out) + return nil, fmt.Errorf("can't get seekable input file") + } + ioHelper.SeekableIn = f + ioHelper.BytesRead = n + conv := internal.MakeConv() + conv.SpDialect = spDialect + p := internal.NewProgress(n, "Generating schema", internal.Verbose(), false, int(internal.SchemaCreationInProgress)) + r := internal.NewReader(bufio.NewReader(f), p) + conv.SetSchemaMode() // Build schema and ignore data in dump. + conv.SetDataSink(nil) + err = processDump.ProcessDump(driver, conv, r) + if err != nil { + fmt.Fprintf(ioHelper.Out, "Failed to parse the data file: %v", err) + return nil, fmt.Errorf("failed to parse the data file") + } + p.Done() + return conv, nil +} + + +func (sads *DataFromSourceImpl) dataFromDump(driver string, config writer.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, processDump ProcessDumpByDialectInterface, populateDataConv PopulateDataConvInterface) (*writer.BatchWriter, error) { + // TODO: refactor of the way we handle getSeekable + // to avoid the code duplication here + if !dataOnly { + _, err := ioHelper.SeekableIn.Seek(0, 0) + if err != nil { + fmt.Printf("\nCan't seek to start of file (preparation for second pass): %v\n", err) + return nil, fmt.Errorf("can't seek to start of file") + } + } else { + // Note: input file is kept seekable to plan for future + // changes in showing progress for data migration. + f, n, err := getSeekable(ioHelper.In) + if err != nil { + utils.PrintSeekError(driver, err, ioHelper.Out) + return nil, fmt.Errorf("can't get seekable input file") + } + ioHelper.SeekableIn = f + ioHelper.BytesRead = n + } + totalRows := conv.Rows() + + conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) + r := internal.NewReader(bufio.NewReader(ioHelper.SeekableIn), nil) + batchWriter := populateDataConv.populateDataConv(conv, config, client) + processDump.ProcessDump(driver, conv, r) + batchWriter.Flush() + conv.Audit.Progress.Done() + + return batchWriter, nil +} + +func (sads *DataFromSourceImpl) dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, populateDataConv PopulateDataConvInterface, csv csv.CsvInterface) (*writer.BatchWriter, error) { + if targetProfile.Conn.Sp.Dbname == "" { + return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source") + } + conv.SpDialect = targetProfile.Conn.Sp.Dialect + dialect, err := targetProfile.FetchTargetDialect(ctx) + if err != nil { + return nil, fmt.Errorf("could not fetch dialect: %v", err) + } + if strings.ToLower(dialect) != constants.DIALECT_POSTGRESQL { + dialect = constants.DIALECT_GOOGLESQL + } + + if dialect != conv.SpDialect { + return nil, fmt.Errorf("dialect specified in target profile does not match spanner dialect") + } + + delimiterStr := sourceProfile.Csv.Delimiter + if len(delimiterStr) != 1 { + return nil, fmt.Errorf("delimiter should only be a single character long, found '%s'", delimiterStr) + } + + delimiter := rune(delimiterStr[0]) + + err = utils.ReadSpannerSchema(ctx, conv, client) + if err != nil { + return nil, fmt.Errorf("error trying to read and convert spanner schema: %v", err) + } + + tables, err := csv.GetCSVFiles(conv, sourceProfile) + if err != nil { + return nil, fmt.Errorf("error finding csv files: %v", err) + } + + // Find the number of rows in each csv file for generating stats. + err = csv.SetRowStats(conv, tables, delimiter) + if err != nil { + return nil, err + } + + totalRows := conv.Rows() + conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) + batchWriter := populateDataConv.populateDataConv(conv, config, client) + err = csv.ProcessCSV(conv, tables, sourceProfile.Csv.NullStr, delimiter) + if err != nil { + return nil, fmt.Errorf("can't process csv: %v", err) + } + batchWriter.Flush() + conv.Audit.Progress.Done() + return batchWriter, nil +} + + +func (sads *DataFromSourceImpl) dataFromDatabase(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, getInfo GetInfoInterface, dataFromDb DataFromDatabaseInterface, snapshotMigration SnapshotMigrationInterface) (*writer.BatchWriter, error) { + //handle migrating data for sharded migrations differently + //sharded migrations are identified via the config= flag, if that flag is not present + //carry on with the existing code path in the else block + switch sourceProfile.Ty { + case profiles.SourceProfileTypeConfig: + ////There are three cases to cover here, bulk migrations and sharded migrations (and later DMS) + //We provide an if-else based handling for each within the sharded code branch + //This will be determined via the configType, which can be "bulk", "dataflow" or "dms" + if sourceProfile.Config.ConfigType == constants.BULK_MIGRATION { + return dataFromDb.dataFromDatabaseForBulkMigration(sourceProfile, targetProfile, config, conv, client, getInfo, snapshotMigration) + } else if sourceProfile.Config.ConfigType == constants.DATAFLOW_MIGRATION { + return dataFromDb.dataFromDatabaseForDataflowMigration(targetProfile, ctx, sourceProfile, conv, &common.InfoSchemaImpl{}) + } else if sourceProfile.Config.ConfigType == constants.DMS_MIGRATION { + return dataFromDb.dataFromDatabaseForDMSMigration() + } else { + return nil, fmt.Errorf("configType should be one of 'bulk', 'dataflow' or 'dms'") + } + default: + var infoSchema common.InfoSchema + var err error + if sourceProfile.Ty == profiles.SourceProfileTypeCloudSQL { + infoSchema, err = getInfo.GetInfoSchemaFromCloudSQL(sourceProfile, targetProfile) + if err != nil { + return nil, err + } + } else { + infoSchema, err = getInfo.GetInfoSchema(sourceProfile, targetProfile) + if err != nil { + return nil, err + } + } + var streamInfo map[string]interface{} + // minimal downtime migration for a single shard + if sourceProfile.Conn.Streaming { + //Generate a job Id + migrationJobId := conv.Audit.MigrationRequestId + logger.Log.Info(fmt.Sprintf("Creating a migration job with id: %v. This jobId can be used in future commmands (such as cleanup) to refer to this job.\n", migrationJobId)) + streamInfo, err = infoSchema.StartChangeDataCapture(ctx, conv) + if err != nil { + return nil, err + } + bw, err := snapshotMigration.snapshotMigrationHandler(sourceProfile, config, conv, client, infoSchema) + if err != nil { + return nil, err + } + dfOutput, err := infoSchema.StartStreamingMigration(ctx, client, conv, streamInfo) + if err != nil { + return nil, err + } + dfJobId := dfOutput.JobID + gcloudCmd := dfOutput.GCloudCmd + streamingCfg, _ := streamInfo["streamingCfg"].(streaming.StreamingCfg) + // Fetch and store the GCS bucket associated with the datastream + dsClient := getDatastreamClient(ctx) + gcsBucket, gcsDestPrefix, fetchGcsErr := streaming.FetchTargetBucketAndPath(ctx, dsClient, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig) + if fetchGcsErr != nil { + logger.Log.Info("Could not fetch GCS Bucket, hence Monitoring Dashboard will not contain Metrics for the gcs bucket\n") + logger.Log.Debug("Error", zap.Error(fetchGcsErr)) + } + + // Try to apply lifecycle rule to Datastream destination bucket. + gcsConfig := streamingCfg.GcsCfg + sc, err := storageclient.NewStorageClientImpl(ctx) + if err != nil { + return nil, err + } + sa := storageaccessor.StorageAccessorImpl{} + if gcsConfig.TtlInDaysSet { + err = sa.ApplyBucketLifecycleDeleteRule(ctx, sc, storageaccessor.StorageBucketMetadata{ + BucketName: gcsBucket, + Ttl: gcsConfig.TtlInDays, + MatchesPrefix: []string{gcsDestPrefix}, + }) + if err != nil { + logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) + logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") + } + } + + monitoringResources := metrics.MonitoringMetricsResources{ + ProjectId: targetProfile.Conn.Sp.Project, + DataflowJobId: dfOutput.JobID, + DatastreamId: streamingCfg.DatastreamCfg.StreamId, + JobMetadataGcsBucket: gcsBucket, + PubsubSubscriptionId: streamingCfg.PubsubCfg.SubscriptionId, + SpannerInstanceId: targetProfile.Conn.Sp.Instance, + SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, + ShardId: "", + MigrationRequestId: conv.Audit.MigrationRequestId, + } + respDash, dashboardErr := monitoringResources.CreateDataflowShardMonitoringDashboard(ctx) + var dashboardName string + if dashboardErr != nil { + dashboardName = "" + logger.Log.Info("Creation of the monitoring dashboard failed, please create the dashboard manually") + logger.Log.Debug("Error", zap.Error(dashboardErr)) + } else { + dashboardName = strings.Split(respDash.Name, "/")[3] + fmt.Printf("Monitoring Dashboard: %+v\n", dashboardName) + } + // store the generated resources locally in conv, this is used as source of truth for persistence and the UI (should change to persisted values) + streaming.StoreGeneratedResources(conv, streamingCfg, dfJobId, gcloudCmd, targetProfile.Conn.Sp.Project, "", internal.GcsResources{BucketName: gcsBucket}, dashboardName) + //persist job and shard level data in the metadata db + err = streaming.PersistJobDetails(ctx, targetProfile, sourceProfile, conv, migrationJobId, false) + if err != nil { + logger.Log.Info(fmt.Sprintf("Error storing job details in SMT metadata store...the migration job will still continue as intended. %v", err)) + } else { + //only attempt persisting shard level data if the job level data is persisted + err = streaming.PersistResources(ctx, targetProfile, sourceProfile, conv, migrationJobId, constants.DEFAULT_SHARD_ID) + if err != nil { + logger.Log.Info(fmt.Sprintf("Error storing details for migration job: %s, data shard: %s in SMT metadata store...the migration job will still continue as intended. err = %v\n", migrationJobId, constants.DEFAULT_SHARD_ID, err)) + } + } + return bw, nil + } + //bulk migration for a single shard + return snapshotMigration.performSnapshotMigration(config, conv, client, infoSchema, internal.AdditionalDataAttributes{ShardId: ""}, &common.InfoSchemaImpl{}, &PopulateDataConvImpl{}), nil + } +} diff --git a/conversion/conversion_from_source_test.go b/conversion/conversion_from_source_test.go new file mode 100644 index 0000000000..f236807553 --- /dev/null +++ b/conversion/conversion_from_source_test.go @@ -0,0 +1,175 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "fmt" + "testing" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + + +func TestSchemaFromDatabase(t *testing.T) { + targetProfile := profiles.TargetProfile{ + Conn: profiles.TargetProfileConnection{ + Sp: profiles.TargetProfileConnectionSpanner{ + Dialect: "google_standard_sql", + }, + }, + } + + sourceProfileConfigBulk := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(3), + Config: profiles.SourceProfileConfig{ + ConfigType: "bulk", + }, + } + sourceProfileConfigDataflow := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(3), + Config: profiles.SourceProfileConfig{ + ConfigType: "dataflow", + }, + } + sourceProfileConfigDms := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(3), + Config: profiles.SourceProfileConfig{ + ConfigType: "dms", + }, + } + sourceProfileConfigInvalid := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(3), + Config: profiles.SourceProfileConfig{ + ConfigType: "invalid", + }, + } + sourceProfileCloudSql := profiles.SourceProfile{ + Ty: profiles.SourceProfileType(5), + } + sourceProfileCloudDefault := profiles.SourceProfile{} + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + sourceProfile profiles.SourceProfile + getInfoError error + processSchemaError error + errorExpected bool + }{ + { + name: "successful source profile config for bulk migration", + sourceProfile: sourceProfileConfigBulk, + getInfoError: nil, + processSchemaError: nil, + errorExpected: false, + }, + { + name: "source profile config for bulk migration: get info error", + sourceProfile: sourceProfileConfigBulk, + getInfoError: fmt.Errorf("error"), + processSchemaError: nil, + errorExpected: true, + }, + { + name: "source profile config for bulk migration: process schema error", + sourceProfile: sourceProfileConfigBulk, + getInfoError: nil, + processSchemaError: fmt.Errorf("error"), + errorExpected: true, + }, + { + name: "successful source profile config for dataflow migration", + sourceProfile: sourceProfileConfigDataflow, + getInfoError: nil, + processSchemaError: nil, + errorExpected: false, + }, + { + name: "source profile config for dataflow migration: get info error", + sourceProfile: sourceProfileConfigDataflow, + getInfoError: fmt.Errorf("error"), + processSchemaError: nil, + errorExpected: true, + }, + { + name: "source profile config for dms migration", + sourceProfile: sourceProfileConfigDms, + getInfoError: nil, + processSchemaError: nil, + errorExpected: true, + }, + { + name: "invalid source profile config", + sourceProfile: sourceProfileConfigInvalid, + getInfoError: nil, + processSchemaError: nil, + errorExpected: true, + }, + { + name: "successful source profile cloud sql", + sourceProfile: sourceProfileCloudSql, + getInfoError: nil, + processSchemaError: nil, + errorExpected: false, + }, + { + name: "source profile cloud sql: get info error", + sourceProfile: sourceProfileCloudSql, + getInfoError: fmt.Errorf("error"), + processSchemaError: nil, + errorExpected: true, + }, + { + name: "successful source profile default", + sourceProfile: sourceProfileCloudDefault, + getInfoError: nil, + processSchemaError: nil, + errorExpected: false, + }, + { + name: "source profile default: get info error", + sourceProfile: sourceProfileCloudDefault, + getInfoError: fmt.Errorf("error"), + processSchemaError: nil, + errorExpected: true, + }, + } + + for _, tc := range testCases { + gim := MockGetInfo{} + ps := common.MockProcessSchema{} + + gim.On("getInfoSchemaForShard", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(mysql.InfoSchemaImpl{}, tc.getInfoError) + gim.On("GetInfoSchemaFromCloudSQL", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(mysql.InfoSchemaImpl{}, tc.getInfoError) + gim.On("GetInfoSchema", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(mysql.InfoSchemaImpl{}, tc.getInfoError) + ps.On("ProcessSchema", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.processSchemaError) + + s := SchemaFromSourceImpl{} + _, err := s.schemaFromDatabase(tc.sourceProfile, targetProfile, &gim, &ps) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + } +} \ No newline at end of file diff --git a/conversion/conversion_helper.go b/conversion/conversion_helper.go new file mode 100644 index 0000000000..8aa9d30b8e --- /dev/null +++ b/conversion/conversion_helper.go @@ -0,0 +1,212 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) + +package conversion + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "os" + "strings" + "sync/atomic" + "syscall" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/postgres" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/aws/aws-sdk-go/aws" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" +) + +type ProcessDumpByDialectInterface interface{ + ProcessDump(driver string, conv *internal.Conv, r *internal.Reader) error +} + +type ProcessDumpByDialectImpl struct{} + +type PopulateDataConvInterface interface{ + populateDataConv(conv *internal.Conv, config writer.BatchWriterConfig, client *sp.Client) *writer.BatchWriter +} + +type PopulateDataConvImpl struct{} +// getSeekable returns a seekable file (with same content as f) and the size of the content (in bytes). +func getSeekable(f *os.File) (*os.File, int64, error) { + _, err := f.Seek(0, 0) + if err == nil { // Stdin is seekable, let's just use that. This happens when you run 'cmd < file'. + n, err := utils.GetFileSize(f) + return f, n, err + } + internal.VerbosePrintln("Creating a tmp file with a copy of stdin because stdin is not seekable.") + logger.Log.Debug("Creating a tmp file with a copy of stdin because stdin is not seekable.") + + // Create file in os.TempDir. Its not clear this is a good idea e.g. if the + // pg_dump/mysqldump output is large (tens of GBs) and os.TempDir points to a directory + // (such as /tmp) that's configured with a small amount of disk space. + // To workaround such limits on Unix, set $TMPDIR to a directory with lots + // of disk space. + fcopy, err := ioutil.TempFile("", "spanner-migration-tool.data") + if err != nil { + return nil, 0, err + } + syscall.Unlink(fcopy.Name()) // File will be deleted when this process exits. + _, err = io.Copy(fcopy, f) + if err != nil { + return nil, 0, fmt.Errorf("can't write stdin to tmp file: %w", err) + } + _, err = fcopy.Seek(0, 0) + if err != nil { + return nil, 0, fmt.Errorf("can't reset file offset: %w", err) + } + n, _ := utils.GetFileSize(fcopy) + return fcopy, n, nil +} + +// ProcessDump invokes process dump function from a sql package based on driver selected. +func (pdd *ProcessDumpByDialectImpl) ProcessDump(driver string, conv *internal.Conv, r *internal.Reader) error { + switch driver { + case constants.MYSQLDUMP: + return common.ProcessDbDump(conv, r, mysql.DbDumpImpl{}) + case constants.PGDUMP: + return common.ProcessDbDump(conv, r, postgres.DbDumpImpl{}) + default: + return fmt.Errorf("process dump for driver %s not supported", driver) + } +} + + +func (pdc *PopulateDataConvImpl) populateDataConv(conv *internal.Conv, config writer.BatchWriterConfig, client *sp.Client) *writer.BatchWriter { + rows := int64(0) + config.Write = func(m []*sp.Mutation) error { + ctx := context.Background() + if !conv.Audit.SkipMetricsPopulation { + migrationData := metrics.GetMigrationData(conv, "", constants.DataConv) + serializedMigrationData, _ := proto.Marshal(migrationData) + migrationMetadataValue := base64.StdEncoding.EncodeToString(serializedMigrationData) + ctx = metadata.AppendToOutgoingContext(context.Background(), constants.MigrationMetadataKey, migrationMetadataValue) + } + _, err := client.Apply(ctx, m) + if err != nil { + return err + } + atomic.AddInt64(&rows, int64(len(m))) + conv.Audit.Progress.MaybeReport(atomic.LoadInt64(&rows)) + return nil + } + batchWriter := writer.NewBatchWriter(config) + conv.SetDataMode() + if !conv.Audit.DryRun { + conv.SetDataSink( + func(table string, cols []string, vals []interface{}) { + batchWriter.AddRow(table, cols, vals) + }) + conv.DataFlush = func() { + batchWriter.Flush() + } + } + + return batchWriter +} + + +func connectionConfig(sourceProfile profiles.SourceProfile) (interface{}, error) { + switch sourceProfile.Driver { + // For PG and MYSQL, When called as part of the subcommand flow, host/user/db etc will + // never be empty as we error out right during source profile creation. If any of them + // are empty, that means this was called through the legacy cmd flow and we create the + // string using env vars. + case constants.POSTGRES: + pgConn := sourceProfile.Conn.Pg + if !(pgConn.Host != "" && pgConn.User != "" && pgConn.Db != "") { + return profiles.GeneratePGSQLConnectionStr() + } else { + return profiles.GetSQLConnectionStr(sourceProfile), nil + } + case constants.MYSQL: + // If empty, this is called as part of the legacy mode witih global CLI flags. + // When using source-profile mode is used, the sqlConnectionStr is already populated. + mysqlConn := sourceProfile.Conn.Mysql + if !(mysqlConn.Host != "" && mysqlConn.User != "" && mysqlConn.Db != "") { + return profiles.GenerateMYSQLConnectionStr() + } else { + return profiles.GetSQLConnectionStr(sourceProfile), nil + } + // For Dynamodb, both legacy and new flows use env vars. + case constants.DYNAMODB: + return getDynamoDBClientConfig() + case constants.SQLSERVER: + return profiles.GetSQLConnectionStr(sourceProfile), nil + case constants.ORACLE: + return profiles.GetSQLConnectionStr(sourceProfile), nil + default: + return "", fmt.Errorf("driver %s not supported", sourceProfile.Driver) + } +} + +func getDbNameFromSQLConnectionStr(driver, sqlConnectionStr string) string { + switch driver { + case constants.POSTGRES: + dbParam := strings.Split(sqlConnectionStr, " ")[4] + return strings.Split(dbParam, "=")[1] + case constants.MYSQL: + return strings.Split(sqlConnectionStr, ")/")[1] + case constants.SQLSERVER: + splts := strings.Split(sqlConnectionStr, "?database=") + return splts[len(splts)-1] + case constants.ORACLE: + // connection string formate : "oracle://user:password@104.108.154.85:1521/XE" + substr := sqlConnectionStr[9:] + dbName := strings.Split(substr, ":")[0] + return dbName + } + return "" +} + +func updateShardsWithTuningConfigs(shardedTuningConfig profiles.ShardConfigurationDataflow) { + for _, dataShard := range shardedTuningConfig.DataShards { + dataShard.DatastreamConfig = shardedTuningConfig.DatastreamConfig + dataShard.GcsConfig = shardedTuningConfig.GcsConfig + dataShard.DataflowConfig = shardedTuningConfig.DataflowConfig + } +} + +func getDynamoDBClientConfig() (*aws.Config, error) { + cfg := aws.Config{} + endpointOverride := os.Getenv("DYNAMODB_ENDPOINT_OVERRIDE") + if endpointOverride != "" { + cfg.Endpoint = aws.String(endpointOverride) + } + return &cfg, nil +} \ No newline at end of file diff --git a/conversion/conversion_test.go b/conversion/conversion_test.go new file mode 100644 index 0000000000..2eb4ed3808 --- /dev/null +++ b/conversion/conversion_test.go @@ -0,0 +1,203 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "context" + "testing" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestSchemaConv(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + sourceProfileDriver string + output interface{} + function string + errorExpected bool + }{ + { + name: "postgres driver", + sourceProfileDriver: "postgres", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "mysql driver", + sourceProfileDriver: "mysql", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "dynamodb driver", + sourceProfileDriver: "dynamodb", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "sqlserver driver", + sourceProfileDriver: "sqlserver", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "oracle driver", + sourceProfileDriver: "oracle", + output: &internal.Conv{}, + function: "schemaFromDatabase", + errorExpected: false, + }, + { + name: "pg dump driver", + sourceProfileDriver: "pg_dump", + output: &internal.Conv{}, + function: "SchemaFromDump", + errorExpected: false, + }, + { + name: "mysql dump driver", + sourceProfileDriver: "mysqldump", + output: &internal.Conv{}, + function: "SchemaFromDump", + errorExpected: false, + }, + { + name: "invalid driver", + sourceProfileDriver: "invalid", + output: nil, + function: "", + errorExpected: true, + }, + } + + for _, tc := range testCases { + m := MockSchemaFromSource{} + m.On(tc.function, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.output, nil) + c := ConvImpl{} + _, err := c.SchemaConv(profiles.SourceProfile{Driver: tc.sourceProfileDriver}, profiles.TargetProfile{}, &utils.IOStreams{}, &m) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + if err == nil { + m.AssertExpectations(t) + } + } +} + +func TestDataConv(t *testing.T) { + // Avoid getting/setting env variables in the unit tests. + testCases := []struct { + name string + sourceProfileDriver string + output interface{} + function string + errorExpected bool + }{ + { + name: "postgres driver", + sourceProfileDriver: "postgres", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "mysql driver", + sourceProfileDriver: "mysql", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "dynamodb driver", + sourceProfileDriver: "dynamodb", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "sqlserver driver", + sourceProfileDriver: "sqlserver", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "oracle driver", + sourceProfileDriver: "oracle", + output: &writer.BatchWriter{}, + function: "dataFromDatabase", + errorExpected: false, + }, + { + name: "pg dump driver", + sourceProfileDriver: "pg_dump", + output: &writer.BatchWriter{}, + function: "dataFromDump", + errorExpected: false, + }, + { + name: "mysql dump driver", + sourceProfileDriver: "mysqldump", + output: &writer.BatchWriter{}, + function: "dataFromDump", + errorExpected: false, + }, + { + name: "crv driver", + sourceProfileDriver: "csv", + output: &writer.BatchWriter{}, + function: "dataFromCSV", + errorExpected: false, + }, + { + name: "invalid driver", + sourceProfileDriver: "invalid", + output: nil, + function: "", + errorExpected: true, + }, + } + + ctx:= context.Background() + for _, tc := range testCases { + m := MockDataFromSource{} + m.On(tc.function, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.output, nil) + c := ConvImpl{} + _, err := c.DataConv(ctx, profiles.SourceProfile{Driver: tc.sourceProfileDriver}, profiles.TargetProfile{}, &utils.IOStreams{}, &sp.Client{}, &internal.Conv{}, true, int64(5), &m) + assert.Equal(t, tc.errorExpected, err != nil, tc.name) + if err == nil { + m.AssertExpectations(t) + } + } +} \ No newline at end of file diff --git a/conversion/data_from_database.go b/conversion/data_from_database.go new file mode 100644 index 0000000000..706b61dd90 --- /dev/null +++ b/conversion/data_from_database.go @@ -0,0 +1,224 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "context" + "fmt" + "strings" + "sync" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/GoogleCloudPlatform/spanner-migration-tool/streaming" + storageaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/storage" + storageclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/clients/storage" + "go.uber.org/zap" +) + +type DataFromDatabaseInterface interface{ + dataFromDatabaseForDMSMigration() (*writer.BatchWriter, error) + dataFromDatabaseForDataflowMigration(targetProfile profiles.TargetProfile, ctx context.Context, sourceProfile profiles.SourceProfile, conv *internal.Conv, is common.InfoSchemaInterface) (*writer.BatchWriter, error) + dataFromDatabaseForBulkMigration(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, gi GetInfoInterface, sm SnapshotMigrationInterface) (*writer.BatchWriter, error) + +} + +type DataFromDatabaseImpl struct{} + + +// TODO: Define the data processing logic for DMS migrations here. +func (dd *DataFromDatabaseImpl) dataFromDatabaseForDMSMigration() (*writer.BatchWriter, error) { + return nil, fmt.Errorf("dms configType is not implemented yet, please use one of 'bulk' or 'dataflow'") +} + +// 1. Create batch for each physical shard +// 2. Create streaming cfg from the config source type. +// 3. Verify the CFG and update it with SMT defaults +// 4. Launch the stream for the physical shard +// 5. Perform streaming migration via dataflow +func (dd *DataFromDatabaseImpl) dataFromDatabaseForDataflowMigration(targetProfile profiles.TargetProfile, ctx context.Context, sourceProfile profiles.SourceProfile, conv *internal.Conv, is common.InfoSchemaInterface) (*writer.BatchWriter, error) { + updateShardsWithTuningConfigs(sourceProfile.Config.ShardConfigurationDataflow) + //Generate a job Id + migrationJobId := conv.Audit.MigrationRequestId + fmt.Printf("Creating a migration job with id: %v. This jobId can be used in future commmands (such as cleanup) to refer to this job.\n", migrationJobId) + conv.Audit.StreamingStats.ShardToShardResourcesMap = make(map[string]internal.ShardResources) + schemaDetails, err := is.GetIncludedSrcTablesFromConv(conv) + if err != nil { + fmt.Printf("unable to determine tableList from schema, falling back to full database") + schemaDetails = map[string]internal.SchemaDetails{} + } + err = streaming.PersistJobDetails(ctx, targetProfile, sourceProfile, conv, migrationJobId, true) + if err != nil { + logger.Log.Info(fmt.Sprintf("Error storing job details in SMT metadata store...the migration job will still continue as intended. %v", err)) + } + asyncProcessShards := func(p *profiles.DataShard, mutex *sync.Mutex) common.TaskResult[*profiles.DataShard] { + dbNameToShardIdMap := make(map[string]string) + for _, l := range p.LogicalShards { + dbNameToShardIdMap[l.DbName] = l.LogicalShardId + } + if p.DataShardId == "" { + dataShardId, err := utils.GenerateName("smt-datashard") + dataShardId = strings.Replace(dataShardId, "_", "-", -1) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + p.DataShardId = dataShardId + fmt.Printf("Data shard id generated: %v\n", p.DataShardId) + } + streamingCfg := streaming.CreateStreamingConfig(*p) + err := streaming.VerifyAndUpdateCfg(&streamingCfg, targetProfile.Conn.Sp.Dbname, schemaDetails) + if err != nil { + err = fmt.Errorf("failed to process shard: %s, there seems to be an error in the sharding configuration, error: %v", p.DataShardId, err) + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + fmt.Printf("Initiating migration for shard: %v\n", p.DataShardId) + pubsubCfg, err := streaming.CreatePubsubResources(ctx, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig, targetProfile.Conn.Sp.Dbname) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + streamingCfg.PubsubCfg = *pubsubCfg + err = streaming.LaunchStream(ctx, sourceProfile, p.LogicalShards, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + streamingCfg.DataflowCfg.DbNameToShardIdMap = dbNameToShardIdMap + dfOutput, err := streaming.StartDataflow(ctx, targetProfile, streamingCfg, conv) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + // store the generated resources locally in conv, this is used as source of truth for persistence and the UI (should change to persisted values) + + // Fetch and store the GCS bucket associated with the datastream + dsClient := getDatastreamClient(ctx) + gcsBucket, gcsDestPrefix, fetchGcsErr := streaming.FetchTargetBucketAndPath(ctx, dsClient, targetProfile.Conn.Sp.Project, streamingCfg.DatastreamCfg.DestinationConnectionConfig) + if fetchGcsErr != nil { + logger.Log.Info(fmt.Sprintf("Could not fetch GCS Bucket for Shard %s hence Monitoring Dashboard will not contain Metrics for the gcs bucket\n", p.DataShardId)) + logger.Log.Debug("Error", zap.Error(fetchGcsErr)) + } + + // Try to apply lifecycle rule to Datastream destination bucket. + gcsConfig := streamingCfg.GcsCfg + sc, err := storageclient.NewStorageClientImpl(ctx) + if err != nil { + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + sa := storageaccessor.StorageAccessorImpl{} + if gcsConfig.TtlInDaysSet { + err = sa.ApplyBucketLifecycleDeleteRule(ctx, sc, storageaccessor.StorageBucketMetadata{ + BucketName: gcsBucket, + Ttl: gcsConfig.TtlInDays, + MatchesPrefix: []string{gcsDestPrefix}, + }) + if err != nil { + logger.Log.Warn(fmt.Sprintf("\nWARNING: could not update Datastream destination GCS bucket with lifecycle rule, error: %v\n", err)) + logger.Log.Warn("Please apply the lifecycle rule manually. Continuing...\n") + } + } + + // create monitoring dashboard for a single shard + monitoringResources := metrics.MonitoringMetricsResources{ + ProjectId: targetProfile.Conn.Sp.Project, + DataflowJobId: dfOutput.JobID, + DatastreamId: streamingCfg.DatastreamCfg.StreamId, + JobMetadataGcsBucket: gcsBucket, + PubsubSubscriptionId: streamingCfg.PubsubCfg.SubscriptionId, + SpannerInstanceId: targetProfile.Conn.Sp.Instance, + SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, + ShardId: p.DataShardId, + MigrationRequestId: conv.Audit.MigrationRequestId, + } + respDash, dashboardErr := monitoringResources.CreateDataflowShardMonitoringDashboard(ctx) + var dashboardName string + if dashboardErr != nil { + dashboardName = "" + logger.Log.Info(fmt.Sprintf("Creation of the monitoring dashboard for shard %s failed, please create the dashboard manually\n", p.DataShardId)) + logger.Log.Debug("Error", zap.Error(dashboardErr)) + } else { + dashboardName = strings.Split(respDash.Name, "/")[3] + fmt.Printf("Monitoring Dashboard for shard %v: %+v\n", p.DataShardId, dashboardName) + } + streaming.StoreGeneratedResources(conv, streamingCfg, dfOutput.JobID, dfOutput.GCloudCmd, targetProfile.Conn.Sp.Project, p.DataShardId, internal.GcsResources{BucketName: gcsBucket}, dashboardName) + //persist the generated resources in a metadata db + err = streaming.PersistResources(ctx, targetProfile, sourceProfile, conv, migrationJobId, p.DataShardId) + if err != nil { + fmt.Printf("Error storing generated resources in SMT metadata store for dataShardId: %s...the migration job will still continue as intended, error: %v\n", p.DataShardId, err) + } + return common.TaskResult[*profiles.DataShard]{Result: p, Err: err} + } + _, err = common.RunParallelTasks(sourceProfile.Config.ShardConfigurationDataflow.DataShards, 20, asyncProcessShards, true) + if err != nil { + return nil, fmt.Errorf("unable to start minimal downtime migrations: %v", err) + } + + // create monitoring aggregated dashboard for sharded migration + aggMonitoringResources := metrics.MonitoringMetricsResources{ + ProjectId: targetProfile.Conn.Sp.Project, + SpannerInstanceId: targetProfile.Conn.Sp.Instance, + SpannerDatabaseId: targetProfile.Conn.Sp.Dbname, + ShardToShardResourcesMap: conv.Audit.StreamingStats.ShardToShardResourcesMap, + MigrationRequestId: conv.Audit.MigrationRequestId, + } + aggRespDash, dashboardErr := aggMonitoringResources.CreateDataflowAggMonitoringDashboard(ctx) + if dashboardErr != nil { + logger.Log.Error(fmt.Sprintf("Creation of the aggregated monitoring dashboard failed, please create the dashboard manually\n error=%v\n", dashboardErr)) + } else { + fmt.Printf("Aggregated Monitoring Dashboard: %+v\n", strings.Split(aggRespDash.Name, "/")[3]) + conv.Audit.StreamingStats.AggMonitoringResources = internal.MonitoringResources{DashboardName: strings.Split(aggRespDash.Name, "/")[3]} + } + err = streaming.PersistAggregateMonitoringResources(ctx, targetProfile, sourceProfile, conv, migrationJobId) + if err != nil { + logger.Log.Info(fmt.Sprintf("Unable to store aggregated monitoring dashboard in metadata database\n error=%v\n", err)) + } else { + logger.Log.Debug("Aggregate monitoring resources stored successfully.\n") + } + return &writer.BatchWriter{}, nil +} + +// 1. Migrate the data from the data shards, the schema shard needs to be specified here again. +// 2. Create a connection profile object for it +// 3. Perform a snapshot migration for the shard +// 4. Once all shard migrations are complete, return the batch writer object +func (dd *DataFromDatabaseImpl) dataFromDatabaseForBulkMigration(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, gi GetInfoInterface, sm SnapshotMigrationInterface) (*writer.BatchWriter, error) { + var bw *writer.BatchWriter + for _, dataShard := range sourceProfile.Config.ShardConfigurationBulk.DataShards { + + fmt.Printf("Initiating migration for shard: %v\n", dataShard.DbName) + infoSchema, err := gi.getInfoSchemaForShard(dataShard, sourceProfile.Driver, targetProfile, &profiles.SourceProfileDialectImpl{}, &GetInfoImpl{}) + if err != nil { + return nil, err + } + additionalDataAttributes := internal.AdditionalDataAttributes{ + ShardId: dataShard.DataShardId, + } + bw = sm.performSnapshotMigration(config, conv, client, infoSchema, additionalDataAttributes, &common.InfoSchemaImpl{}, &PopulateDataConvImpl{}) + } + + return bw, nil +} \ No newline at end of file diff --git a/conversion/get_info.go b/conversion/get_info.go new file mode 100644 index 0000000000..a43fb8a0f0 --- /dev/null +++ b/conversion/get_info.go @@ -0,0 +1,210 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "context" + "database/sql" + "fmt" + "net" + "strings" + + "cloud.google.com/go/cloudsqlconn" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/dynamodb" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/mysql" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/oracle" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/postgres" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/sqlserver" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + dydb "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodbstreams" + mysqldriver "github.com/go-sql-driver/mysql" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" +) + +type GetInfoInterface interface{ + getInfoSchemaForShard(shardConnInfo profiles.DirectConnectionConfig, driver string, targetProfile profiles.TargetProfile, sourceProfileDialect profiles.SourceProfileDialectInterface, getInfo GetInfoInterface) (common.InfoSchema, error) + GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) + GetInfoSchema(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) +} +type GetInfoImpl struct{} + +func (gi *GetInfoImpl) getInfoSchemaForShard(shardConnInfo profiles.DirectConnectionConfig, driver string, targetProfile profiles.TargetProfile, sourceProfileDialect profiles.SourceProfileDialectInterface, getInfo GetInfoInterface) (common.InfoSchema, error) { + params := make(map[string]string) + params["host"] = shardConnInfo.Host + params["user"] = shardConnInfo.User + params["dbName"] = shardConnInfo.DbName + params["port"] = shardConnInfo.Port + params["password"] = shardConnInfo.Password + //while adding other sources, a switch-case will be added here on the basis of the driver input param passed. + //pased on the driver name, profiles.NewSourceProfileConnection will need to be called to create + //the source profile information. + getUtilsInfo := utils.GetUtilInfoImpl{} + sourceProfileConnectionMySQL, err := sourceProfileDialect.NewSourceProfileConnectionMySQL(params, &getUtilsInfo) + if err != nil { + return nil, fmt.Errorf("cannot parse connection configuration for the primary shard") + } + sourceProfileConnection := profiles.SourceProfileConnection{Mysql: sourceProfileConnectionMySQL, Ty: profiles.SourceProfileConnectionTypeMySQL} + //create a source profile which contains the sourceProfileConnection object for the primary shard + //this is done because GetSQLConnectionStr() should not be aware of sharding + newSourceProfile := profiles.SourceProfile{Conn: sourceProfileConnection, Ty: profiles.SourceProfileTypeConnection} + newSourceProfile.Driver = driver + infoSchema, err := getInfo.GetInfoSchema(newSourceProfile, targetProfile) + if err != nil { + return nil, err + } + return infoSchema, nil +} + + +func (gi *GetInfoImpl) GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { + driver := sourceProfile.Driver + switch driver { + case constants.MYSQL: + d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) + if err != nil { + return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) + } + var opts []cloudsqlconn.DialOption + instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Mysql.Project, sourceProfile.ConnCloudSQL.Mysql.Region, sourceProfile.ConnCloudSQL.Mysql.InstanceName) + mysqldriver.RegisterDialContext("cloudsqlconn", + func(ctx context.Context, addr string) (net.Conn, error) { + return d.Dial(ctx, instanceName, opts...) + }) + + dbURI := fmt.Sprintf("%s:empty@cloudsqlconn(localhost:3306)/%s?parseTime=true", + sourceProfile.ConnCloudSQL.Mysql.User, sourceProfile.ConnCloudSQL.Mysql.Db) + + db, err := sql.Open("mysql", dbURI) + if err != nil { + return nil, fmt.Errorf("sql.Open: %w", err) + } + return mysql.InfoSchemaImpl{ + DbName: sourceProfile.ConnCloudSQL.Mysql.Db, + Db: db, + SourceProfile: sourceProfile, + TargetProfile: targetProfile, + }, nil + case constants.POSTGRES: + d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithIAMAuthN()) + if err != nil { + return nil, fmt.Errorf("cloudsqlconn.NewDialer: %w", err) + } + var opts []cloudsqlconn.DialOption + + dsn := fmt.Sprintf("user=%s database=%s", sourceProfile.ConnCloudSQL.Pg.User, sourceProfile.ConnCloudSQL.Pg.Db) + config, err := pgx.ParseConfig(dsn) + if err != nil { + return nil, err + } + instanceName := fmt.Sprintf("%s:%s:%s", sourceProfile.ConnCloudSQL.Pg.Project, sourceProfile.ConnCloudSQL.Pg.Region, sourceProfile.ConnCloudSQL.Pg.InstanceName) + config.DialFunc = func(ctx context.Context, network, instance string) (net.Conn, error) { + return d.Dial(ctx, instanceName, opts...) + } + dbURI := stdlib.RegisterConnConfig(config) + db, err := sql.Open("pgx", dbURI) + if err != nil { + return nil, fmt.Errorf("sql.Open: %w", err) + } + temp := false + return postgres.InfoSchemaImpl{ + Db: db, + SourceProfile: sourceProfile, + TargetProfile: targetProfile, + IsSchemaUnique: &temp, //this is a workaround to set a bool pointer + }, nil + default: + return nil, fmt.Errorf("driver %s not supported", driver) + } +} + + +func (gi *GetInfoImpl) GetInfoSchema(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { + connectionConfig, err := connectionConfig(sourceProfile) + if err != nil { + return nil, err + } + driver := sourceProfile.Driver + switch driver { + case constants.MYSQL: + db, err := sql.Open(driver, connectionConfig.(string)) + dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) + if err != nil { + return nil, err + } + return mysql.InfoSchemaImpl{ + DbName: dbName, + Db: db, + SourceProfile: sourceProfile, + TargetProfile: targetProfile, + }, nil + case constants.POSTGRES: + db, err := sql.Open(driver, connectionConfig.(string)) + if err != nil { + return nil, err + } + temp := false + return postgres.InfoSchemaImpl{ + Db: db, + SourceProfile: sourceProfile, + TargetProfile: targetProfile, + IsSchemaUnique: &temp, //this is a workaround to set a bool pointer + }, nil + case constants.DYNAMODB: + mySession := session.Must(session.NewSession()) + dydbClient := dydb.New(mySession, connectionConfig.(*aws.Config)) + var dydbStreamsClient *dynamodbstreams.DynamoDBStreams + if sourceProfile.Conn.Streaming { + newSession := session.Must(session.NewSession()) + dydbStreamsClient = dynamodbstreams.New(newSession, connectionConfig.(*aws.Config)) + } + return dynamodb.InfoSchemaImpl{ + DynamoClient: dydbClient, + SampleSize: profiles.GetSchemaSampleSize(sourceProfile), + DynamoStreamsClient: dydbStreamsClient, + }, nil + case constants.SQLSERVER: + db, err := sql.Open(driver, connectionConfig.(string)) + dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) + if err != nil { + return nil, err + } + return sqlserver.InfoSchemaImpl{DbName: dbName, Db: db}, nil + case constants.ORACLE: + db, err := sql.Open(driver, connectionConfig.(string)) + dbName := getDbNameFromSQLConnectionStr(driver, connectionConfig.(string)) + if err != nil { + return nil, err + } + return oracle.InfoSchemaImpl{DbName: strings.ToUpper(dbName), Db: db, SourceProfile: sourceProfile, TargetProfile: targetProfile}, nil + default: + return nil, fmt.Errorf("driver %s not supported", driver) + } +} \ No newline at end of file diff --git a/conversion/mocks.go b/conversion/mocks.go new file mode 100644 index 0000000000..65f53f0181 --- /dev/null +++ b/conversion/mocks.go @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "context" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/csv" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" + "github.com/stretchr/testify/mock" +) + + +type MockGetInfo struct { + mock.Mock +} + +func (mgi *MockGetInfo) getInfoSchemaForShard(shardConnInfo profiles.DirectConnectionConfig, driver string, targetProfile profiles.TargetProfile, s profiles.SourceProfileDialectInterface, g GetInfoInterface) (common.InfoSchema, error) { + args := mgi.Called(shardConnInfo, driver, targetProfile, s, g) + return args.Get(0).(common.InfoSchema), args.Error(1) +} +func (mgi *MockGetInfo) GetInfoSchemaFromCloudSQL(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { + args := mgi.Called(sourceProfile, targetProfile) + return args.Get(0).(common.InfoSchema), args.Error(1) +} +func (mgi *MockGetInfo) GetInfoSchema(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (common.InfoSchema, error) { + args := mgi.Called(sourceProfile, targetProfile) + return args.Get(0).(common.InfoSchema), args.Error(1) +} + +type MockSchemaFromSource struct { + mock.Mock +} +func (msads *MockSchemaFromSource) schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) { + args := msads.Called(sourceProfile, targetProfile, getInfo, processSchema) + return args.Get(0).(*internal.Conv), args.Error(1) +} +func (msads *MockSchemaFromSource) SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) { + args := msads.Called(driver, spDialect, ioHelper, processDump) + return args.Get(0).(*internal.Conv), args.Error(1) +} + +type MockDataFromSource struct { + mock.Mock +} +func (msads *MockDataFromSource) dataFromDatabase(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, getInfo GetInfoInterface, dataFromDb DataFromDatabaseInterface, snapshotMigration SnapshotMigrationInterface) (*writer.BatchWriter, error) { + args := msads.Called(ctx, sourceProfile, targetProfile, config, conv, client, getInfo, dataFromDb, snapshotMigration) + return args.Get(0).(*writer.BatchWriter), args.Error(1) +} +func (msads *MockDataFromSource) dataFromDump(driver string, config writer.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, processDump ProcessDumpByDialectInterface, populateDataConv PopulateDataConvInterface) (*writer.BatchWriter, error) { + args := msads.Called(driver, config, ioHelper, client, conv, dataOnly, processDump, populateDataConv) + return args.Get(0).(*writer.BatchWriter), args.Error(1) +} +func (msads *MockDataFromSource) dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, pdc PopulateDataConvInterface, csv csv.CsvInterface) (*writer.BatchWriter, error) { + args := msads.Called(ctx, sourceProfile, targetProfile, config, conv, client, pdc, csv) + return args.Get(0).(*writer.BatchWriter), args.Error(1) +} \ No newline at end of file diff --git a/conversion/snapshot_migration.go b/conversion/snapshot_migration.go new file mode 100644 index 0000000000..f0a4d5bd19 --- /dev/null +++ b/conversion/snapshot_migration.go @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) + +package conversion + +import ( + "fmt" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/profiles" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" +) + +type SnapshotMigrationInterface interface { + performSnapshotMigration(config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema, additionalAttributes internal.AdditionalDataAttributes, infoSchemaI common.InfoSchemaInterface, populateDataConv PopulateDataConvInterface) *writer.BatchWriter + snapshotMigrationHandler(sourceProfile profiles.SourceProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema) (*writer.BatchWriter, error) +} +type SnapshotMigrationImpl struct {} + +func (sm *SnapshotMigrationImpl) performSnapshotMigration(config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema, additionalAttributes internal.AdditionalDataAttributes, infoSchemaI common.InfoSchemaInterface, populateDataConv PopulateDataConvInterface) *writer.BatchWriter { + infoSchemaI.SetRowStats(conv, infoSchema) + totalRows := conv.Rows() + if !conv.Audit.DryRun { + conv.Audit.Progress = *internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false, int(internal.DataWriteInProgress)) + } + batchWriter := populateDataConv.populateDataConv(conv, config, client) + infoSchemaI.ProcessData(conv, infoSchema, additionalAttributes) + batchWriter.Flush() + return batchWriter +} + +func (sm *SnapshotMigrationImpl) snapshotMigrationHandler(sourceProfile profiles.SourceProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client, infoSchema common.InfoSchema) (*writer.BatchWriter, error) { + switch sourceProfile.Driver { + // Skip snapshot migration via Spanner migration tool for mysql and oracle since dataflow job will job will handle this from backfilled data. + case constants.MYSQL, constants.ORACLE, constants.POSTGRES: + return &writer.BatchWriter{}, nil + case constants.DYNAMODB: + return sm.performSnapshotMigration(config, conv, client, infoSchema, internal.AdditionalDataAttributes{ShardId: ""}, &common.InfoSchemaImpl{}, &PopulateDataConvImpl{}), nil + default: + return &writer.BatchWriter{}, fmt.Errorf("streaming migration not supported for driver %s", sourceProfile.Driver) + } +} \ No newline at end of file diff --git a/conversion/store_files.go b/conversion/store_files.go new file mode 100644 index 0000000000..9b457c88ad --- /dev/null +++ b/conversion/store_files.go @@ -0,0 +1,275 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) +package conversion + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "strings" + "time" + + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/utils" + "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/ddl" + "github.com/GoogleCloudPlatform/spanner-migration-tool/spanner/writer" +) + +// WriteSchemaFile writes DDL statements in a file. It includes CREATE TABLE +// statements and ALTER TABLE statements to add foreign keys. +// The parameter name should end with a .txt. +func WriteSchemaFile(conv *internal.Conv, now time.Time, name string, out *os.File, driver string) { + f, err := os.Create(name) + if err != nil { + fmt.Fprintf(out, "Can't create schema file %s: %v\n", name, err) + return + } + + // The schema file we write out below is optimized for reading. It includes comments, foreign keys + // and doesn't add backticks around table and column names. This file is + // intended for explanatory and documentation purposes, and is not strictly + // legal Cloud Spanner DDL (Cloud Spanner doesn't currently support comments). + spDDL := conv.SpSchema.GetDDL(ddl.Config{Comments: true, ProtectIds: false, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) + if len(spDDL) == 0 { + spDDL = []string{"\n-- Schema is empty -- no tables found\n"} + } + l := []string{ + fmt.Sprintf("-- Schema generated %s\n", now.Format("2006-01-02 15:04:05")), + strings.Join(spDDL, ";\n\n"), + "\n", + } + if _, err := f.WriteString(strings.Join(l, "")); err != nil { + fmt.Fprintf(out, "Can't write out schema file: %v\n", err) + return + } + fmt.Fprintf(out, "Wrote schema to file '%s'.\n", name) + + // Convert . to .ddl.. + nameSplit := strings.Split(name, ".") + nameSplit = append(nameSplit[:len(nameSplit)-1], "ddl", nameSplit[len(nameSplit)-1]) + name = strings.Join(nameSplit, ".") + f, err = os.Create(name) + if err != nil { + fmt.Fprintf(out, "Can't create legal schema ddl file %s: %v\n", name, err) + return + } + + // We change 'Comments' to false and 'ProtectIds' to true below to write out a + // schema file that is a legal Cloud Spanner DDL. + spDDL = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true, SpDialect: conv.SpDialect, Source: driver}) + if len(spDDL) == 0 { + spDDL = []string{"\n-- Schema is empty -- no tables found\n"} + } + l = []string{ + strings.Join(spDDL, ";\n\n"), + "\n", + } + if _, err = f.WriteString(strings.Join(l, "")); err != nil { + fmt.Fprintf(out, "Can't write out legal schema ddl file: %v\n", err) + return + } + fmt.Fprintf(out, "Wrote legal schema ddl to file '%s'.\n", name) +} + +// WriteSessionFile writes conv struct to a file in JSON format. +func WriteSessionFile(conv *internal.Conv, name string, out *os.File) { + f, err := os.Create(name) + if err != nil { + fmt.Fprintf(out, "Can't create session file %s: %v\n", name, err) + return + } + // Session file will basically contain 'conv' struct in JSON format. + // It contains all the information for schema and data conversion state. + convJSON, err := json.MarshalIndent(conv, "", " ") + if err != nil { + fmt.Fprintf(out, "Can't encode session state to JSON: %v\n", err) + return + } + if _, err := f.Write(convJSON); err != nil { + fmt.Fprintf(out, "Can't write out session file: %v\n", err) + return + } + fmt.Fprintf(out, "Wrote session to file '%s'.\n", name) +} + +// WriteConvGeneratedFiles creates a directory labeled downloads with the current timestamp +// where it writes the sessionfile, report summary and DDLs then returns the directory where it writes. +func WriteConvGeneratedFiles(conv *internal.Conv, dbName string, driver string, BytesRead int64, out *os.File) (string, error) { + now := time.Now() + dirPath := "spanner_migration_tool_output/" + dbName + "/" + err := os.MkdirAll(dirPath, os.ModePerm) + if err != nil { + fmt.Fprintf(out, "Can't create directory %s: %v\n", dirPath, err) + return "", err + } + schemaFileName := dirPath + dbName + "_schema.txt" + WriteSchemaFile(conv, now, schemaFileName, out, driver) + reportFileName := dirPath + dbName + reportImpl := ReportImpl{} + reportImpl.GenerateReport(driver, nil, BytesRead, "", conv, reportFileName, dbName, out) + sessionFileName := dirPath + dbName + ".session.json" + WriteSessionFile(conv, sessionFileName, out) + return dirPath, nil +} + +// ReadSessionFile reads a session JSON file and +// unmarshal it's content into *internal.Conv. +func ReadSessionFile(conv *internal.Conv, sessionJSON string) error { + s, err := ioutil.ReadFile(sessionJSON) + if err != nil { + return err + } + err = json.Unmarshal(s, &conv) + if err != nil { + return err + } + return nil +} + +// WriteBadData prints summary stats about bad rows and writes detailed info +// to file 'name'. +func WriteBadData(bw *writer.BatchWriter, conv *internal.Conv, banner, name string, out *os.File) { + badConversions := conv.BadRows() + badWrites := utils.SumMapValues(bw.DroppedRowsByTable()) + + badDataStreaming := int64(0) + if conv.Audit.StreamingStats.Streaming { + badDataStreaming = getBadStreamingDataCount(conv) + } + + if badConversions == 0 && badWrites == 0 && badDataStreaming == 0 { + os.Remove(name) // Cleanup bad-data file from previous run. + return + } + f, err := os.Create(name) + if err != nil { + fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) + return + } + f.WriteString(banner) + maxRows := 100 + if badConversions > 0 { + l := conv.SampleBadRows(maxRows) + if int64(len(l)) < badConversions { + f.WriteString("A sample of rows that generated conversion errors:\n") + } else { + f.WriteString("Rows that generated conversion errors:\n") + } + for _, r := range l { + _, err := f.WriteString(" " + r + "\n") + if err != nil { + fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) + return + } + } + } + if badWrites > 0 { + l := bw.SampleBadRows(maxRows) + if int64(len(l)) < badWrites { + f.WriteString("A sample of rows that successfully converted but couldn't be written to Spanner:\n") + } else { + f.WriteString("Rows that successfully converted but couldn't be written to Spanner:\n") + } + for _, r := range l { + _, err := f.WriteString(" " + r + "\n") + if err != nil { + fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) + return + } + } + } + if badDataStreaming > 0 { + err = writeBadStreamingData(conv, f) + if err != nil { + fmt.Fprintf(out, "Can't write out bad data file: %v\n", err) + return + } + } + + fmt.Fprintf(out, "See file '%s' for details of bad rows\n", name) +} + + +// writeBadStreamingData writes sample of bad records and dropped records during streaming +// migration process to bad data file. +func writeBadStreamingData(conv *internal.Conv, f *os.File) error { + f.WriteString("\nBad data encountered during streaming migration:\n\n") + + stats := (conv.Audit.StreamingStats) + + badRecords := int64(0) + for _, x := range stats.BadRecords { + badRecords += utils.SumMapValues(x) + } + droppedRecords := int64(0) + for _, x := range stats.DroppedRecords { + droppedRecords += utils.SumMapValues(x) + } + + if badRecords > 0 { + l := stats.SampleBadRecords + if int64(len(l)) < badRecords { + f.WriteString("A sample of records that generated conversion errors:\n") + } else { + f.WriteString("Records that generated conversion errors:\n") + } + for _, r := range l { + _, err := f.WriteString(" " + r + "\n") + if err != nil { + return err + } + } + f.WriteString("\n") + } + if droppedRecords > 0 { + l := stats.SampleBadWrites + if int64(len(l)) < droppedRecords { + f.WriteString("A sample of records that successfully converted but couldn't be written to Spanner:\n") + } else { + f.WriteString("Records that successfully converted but couldn't be written to Spanner:\n") + } + for _, r := range l { + _, err := f.WriteString(" " + r + "\n") + if err != nil { + return err + } + } + } + return nil +} + +// getBadStreamingDataCount returns the total sum of bad and dropped records during +// streaming migration process. +func getBadStreamingDataCount(conv *internal.Conv) int64 { + badDataCount := int64(0) + + for _, x := range conv.Audit.StreamingStats.BadRecords { + badDataCount += utils.SumMapValues(x) + } + for _, x := range conv.Audit.StreamingStats.DroppedRecords { + badDataCount += utils.SumMapValues(x) + } + return badDataCount +} diff --git a/conversion/validations.go b/conversion/validations.go new file mode 100644 index 0000000000..01722d1174 --- /dev/null +++ b/conversion/validations.go @@ -0,0 +1,52 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conversion handles initial setup for the command line tool +// and web APIs. + +// TODO:(searce) Organize code in go style format to make this file more readable. +// +// public constants first +// key public type definitions next (although often it makes sense to put them next to public functions that use them) +// then public functions (and relevant type definitions) +// and helper functions and other non-public definitions last (generally in order of importance) + +package conversion + +import ( + "context" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/spanner" +) + +// ValidateTables validates that all the tables in the database are empty. +// It returns the name of the first non-empty table if found, and an empty string otherwise. +func ValidateTables(ctx context.Context, client *sp.Client, spDialect string) (string, error) { + infoSchema := spanner.InfoSchemaImpl{Client: client, Ctx: ctx, SpDialect: spDialect} + tables, err := infoSchema.GetTables() + if err != nil { + return "", err + } + for _, table := range tables { + count, err := infoSchema.GetRowCount(table) + if err != nil { + return "", err + } + if count != 0 { + return table.Name, nil + } + } + return "", nil +} diff --git a/dao/dao.go b/dao/dao.go new file mode 100644 index 0000000000..21bea4137f --- /dev/null +++ b/dao/dao.go @@ -0,0 +1,268 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dao + +import ( + "context" + "fmt" + + "cloud.google.com/go/spanner" + "google.golang.org/api/iterator" +) + +type StateData struct { + State string `json:"state"` +} + +type DAO interface { + InsertJobEntry(ctx context.Context, jobId, jobName, jobType, dialect, dbName string, jobData spanner.NullJSON) error + UpdateJobState(ctx context.Context, jobId, state string) error + InsertResourceEntry(ctx context.Context, resourceId, jobId, externalId, resourceName, resourceType string, resourceData spanner.NullJSON) error + UpdateResourceState(ctx context.Context, resourceId, state string) error + UpdateResourceExternalId(ctx context.Context, resourceId, externalId string) error +} + +type DAOImpl struct{} + +// Insert a job entry into the SMT_JOB table. +func (dao *DAOImpl) InsertJobEntry(ctx context.Context, jobId, jobName, jobType, dialect, dbName string, jobData spanner.NullJSON) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + jobStmt := spanner.Statement{ + SQL: `INSERT INTO SMT_JOB + (JobId, JobName, JobType, JobStateData, JobData, Dialect, SpannerDatabaseName, CreatedAt, UpdatedAt) + VALUES( + @jobId, @jobName, @jobType, @jobStateData, @jobData, @dialect, @dbName, PENDING_COMMIT_TIMESTAMP(), PENDING_COMMIT_TIMESTAMP() + );`, + Params: map[string]interface{}{ + "jobId": jobId, + "jobName": jobName, + "jobType": jobType, + "jobStateData": spanner.NullJSON{Valid: true, Value: StateData{State: "CREATING"}}, + "jobData": jobData, + "dialect": dialect, + "dbName": dbName, + }, + } + _, err := txn.Update(ctx, jobStmt) + if err != nil { + return err + } + // Update job history table within the same txn. + _, err = updateJobHistoryWithinTxn(ctx, txn, jobId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("could not insert SMT job entry: %v", err) + } + return nil +} + +// Update the state of the SMT job. +func (dao *DAOImpl) UpdateJobState(ctx context.Context, jobId, state string) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + jobStmt := spanner.Statement{ + SQL: `UPDATE SMT_JOB SET JobStateData = @jobStateData, UpdatedAt = PENDING_COMMIT_TIMESTAMP() + WHERE JobId = @jobId;`, + Params: map[string]interface{}{ + "jobId": jobId, + "jobStateData": spanner.NullJSON{Valid: true, Value: StateData{State: state}}, + }, + } + _, err := txn.Update(ctx, jobStmt) + if err != nil { + return err + } + _, err = updateJobHistoryWithinTxn(ctx, txn, jobId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("error updating smt job state: %v", err) + } + return nil +} + +// Insert an entry into the SMT_RESOURCE table. +func (dao *DAOImpl) InsertResourceEntry(ctx context.Context, resourceId, jobId, externalId, resourceName, resourceType string, resourceData spanner.NullJSON) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + resourceStmt := spanner.Statement{ + SQL: `INSERT INTO SMT_RESOURCE + (ResourceId, JobId, ExternalId, ResourceName, ResourceType, ResourceStateData, ResourceData, CreatedAt, UpdatedAt) + VALUES( + @resourceId, @jobId, @externalId, @resourceName, @resourceType, @resourceStateData, @resourceData, PENDING_COMMIT_TIMESTAMP(), PENDING_COMMIT_TIMESTAMP() + );`, + Params: map[string]interface{}{ + "resourceId": resourceId, + "jobId": jobId, + "externalId": externalId, + "resourceName": resourceName, + "resourceType": resourceType, + "resourceStateData": spanner.NullJSON{Valid: true, Value: StateData{State: "CREATING"}}, + "resourceData": resourceData, + }, + } + _, err := txn.Update(ctx, resourceStmt) + if err != nil { + return err + } + // Update the resource history table in the same transaction. + _, err = updateResourceHistoryWithinTxn(ctx, txn, resourceId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("error inserting smt resource entry: %v", err) + } + return nil +} + +// Update the state of the SMT resource. +func (dao *DAOImpl) UpdateResourceState(ctx context.Context, resourceId, state string) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + jobStmt := spanner.Statement{ + SQL: `UPDATE SMT_RESOURCE SET ResourceStateData = @resourceStateData, UpdatedAt = PENDING_COMMIT_TIMESTAMP() + WHERE ResourceId = @resourceId;`, + Params: map[string]interface{}{ + "resourceId": resourceId, + "resourceStateData": spanner.NullJSON{Valid: true, Value: StateData{State: state}}, + }, + } + _, err := txn.Update(ctx, jobStmt) + if err != nil { + return err + } + _, err = updateResourceHistoryWithinTxn(ctx, txn, resourceId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("error updating smt resource state: %v", err) + } + return nil +} + +// Update the external of the SMT resource. +func (dao *DAOImpl) UpdateResourceExternalId(ctx context.Context, resourceId, externalId string) error { + _, err := GetClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + jobStmt := spanner.Statement{ + SQL: `UPDATE SMT_RESOURCE SET ExternalId = @externalId, UpdatedAt = PENDING_COMMIT_TIMESTAMP() + WHERE ResourceId = @resourceId;`, + Params: map[string]interface{}{ + "resourceId": resourceId, + "externalId": externalId, + }, + } + _, err := txn.Update(ctx, jobStmt) + if err != nil { + return err + } + _, err = updateResourceHistoryWithinTxn(ctx, txn, resourceId) + if err != nil { + return err + } + return nil + }) + if err != nil { + return fmt.Errorf("error updating smt resource external id: %v", err) + } + return nil +} + +func updateJobHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, jobId string) (int64, error) { + // Fetch the newly updated row from SMT_JOB table. + stmt := spanner.Statement{SQL: ` + SELECT + JobName, JobType, JobStateData, JobData, Dialect, SpannerDatabaseName + FROM SMT_JOB WHERE JobId = @jobId;`, + Params: map[string]interface{}{"jobId": jobId}, + } + iter := txn.Query(ctx, stmt) + defer iter.Stop() + var jobName, jobType, dialect, spannerDatabaseName spanner.NullString + var jobStateData, jobData spanner.NullJSON + row, err := iter.Next() + if err == iterator.Done || err != nil { + return 0, err + } + if err := row.Columns(&jobName, &jobType, &jobStateData, &jobData, &dialect, &spannerDatabaseName); err != nil { + return 0, fmt.Errorf("error reading smt job row: %v", err) + } + + // Insert entry to SMT_JOB_HISTORY table. + jobStmt := spanner.Statement{ + SQL: `INSERT INTO SMT_JOB_HISTORY + (JobId, JobName, JobType, JobStateData, JobData, Dialect, SpannerDatabaseName, CreatedAt) + VALUES( + @jobId, @jobName, @jobType, @jobStateData, @jobData, @dialect, @spannerDatabaseName, PENDING_COMMIT_TIMESTAMP() + );`, + Params: map[string]interface{}{ + "jobId": jobId, + "jobName": jobName, + "jobType": jobType, + "jobStateData": jobStateData, + "jobData": jobData, + "dialect": dialect, + "spannerDatabaseName": spannerDatabaseName, + }, + } + return txn.Update(ctx, jobStmt) +} + +func updateResourceHistoryWithinTxn(ctx context.Context, txn *spanner.ReadWriteTransaction, resourceId string) (int64, error) { + // Fetch the newly updated row from SMT_RESOURCE table. + stmt := spanner.Statement{SQL: ` + SELECT + JobId, ExternalId, ResourceName, ResourceType, ResourceStateData, ResourceData + FROM SMT_RESOURCE WHERE ResourceId = @resourceId;`, + Params: map[string]interface{}{"resourceId": resourceId}, + } + iter := txn.Query(ctx, stmt) + defer iter.Stop() + var jobId, externalId, resourceName, resourceType spanner.NullString + var resourceStateData, resourceData spanner.NullJSON + row, err := iter.Next() + if err == iterator.Done || err != nil { + return 0, err + } + if err := row.Columns(&jobId, &externalId, &resourceName, &resourceType, &resourceStateData, &resourceData); err != nil { + return 0, fmt.Errorf("error reading smt resource row: %v", err) + } + // Create new entry into the SMT_RESOURCE_HISTORY table. + jobStmt := spanner.Statement{ + SQL: `INSERT INTO SMT_RESOURCE_HISTORY + (ResourceId, JobId, ExternalId, ResourceName, ResourceType, ResourceStateData, ResourceData, CreatedAt) + VALUES( + @resourceId, @jobId, @externalId, @resourceName, @resourceType, @resourceStateData, @resourceData, PENDING_COMMIT_TIMESTAMP() + );`, + Params: map[string]interface{}{ + "resourceId": resourceId, + "jobId": jobId, + "externalId": externalId, + "resourceName": resourceName, + "resourceType": resourceType, + "resourceStateData": resourceStateData, + "resourceData": resourceData, + }, + } + return txn.Update(ctx, jobStmt) +} diff --git a/dao/dao_client.go b/dao/dao_client.go new file mode 100644 index 0000000000..46964888de --- /dev/null +++ b/dao/dao_client.go @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dao + +import ( + "context" + "fmt" + "sync" + + sp "cloud.google.com/go/spanner" +) + +var once sync.Once +var spClient *sp.Client + +// This function is declared as a global variable to make it testable. The unit +// tests edit this function, acting like a double. +var newClient = sp.NewClient + +func GetOrCreateClient(ctx context.Context, dbURI string) (*sp.Client, error) { + var err error + if spClient == nil { + once.Do(func() { + spClient, err = newClient(ctx, dbURI) + }) + if err != nil { + return nil, fmt.Errorf("failed to create spanner database client: %v", err) + } + return spClient, nil + } + return spClient, nil +} + +// The DAO client must be initiated via GetOrCreateClient() once before using GetClient(). +func GetClient() *sp.Client { + return spClient +} diff --git a/dao/dao_client_test.go b/dao/dao_client_test.go new file mode 100644 index 0000000000..05c0352a9a --- /dev/null +++ b/dao/dao_client_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package dao + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + spClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, c) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + spClient = nil + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err = GetOrCreateClient(ctx, "testURI") + assert.Nil(t, c) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + oldC, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx, "testURI") + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + c, err := GetOrCreateClient(ctx, "testURI") + assert.Nil(t, c) + assert.NotNil(t, err) +} diff --git a/docs/reverse-replication/ReverseReplicationUserGuide.md b/docs/reverse-replication/ReverseReplicationUserGuide.md index df9d0927dc..0660deefe4 100644 --- a/docs/reverse-replication/ReverseReplicationUserGuide.md +++ b/docs/reverse-replication/ReverseReplicationUserGuide.md @@ -311,6 +311,7 @@ Steps to perfrom customization: 1. Write custom shard id fetcher logic [CustomShardIdFetcher.java](https://github.com/GoogleCloudPlatform/DataflowTemplates/blob/main/v2/spanner-custom-shard/src/main/java/com/custom/CustomShardIdFetcher.java). Details of the ShardIdRequest class can be found [here](https://github.com/GoogleCloudPlatform/DataflowTemplates/blob/main/v2/spanner-migrations-sdk/src/main/java/com/google/cloud/teleport/v2/spanner/utils/ShardIdRequest.java). 2. Build the [JAR](https://github.com/GoogleCloudPlatform/DataflowTemplates/tree/main/v2/spanner-custom-shard) and upload the jar to GCS 3. Invoke the reverse replication flow by passing the [custom jar path and custom class path](RunnigReverseReplication.md#custom-jar). +4. If any custom parameters are needed in the custom shard identification logic, they can be passed via the *readerShardingCustomParameters* input to the runner. These parameters will be passed to the *init* method of the custom class. The *init* method is invoked once per worker setup. diff --git a/docs/reverse-replication/RunnigReverseReplication.md b/docs/reverse-replication/RunnigReverseReplication.md index e1b12bd3ad..d4e5b81424 100644 --- a/docs/reverse-replication/RunnigReverseReplication.md +++ b/docs/reverse-replication/RunnigReverseReplication.md @@ -57,6 +57,7 @@ The script takes in multiple arguments to orchestrate the pipeline. They are: - `startTimestamp`: Timestamp from which the changestream should start reading changes in RFC 3339 format, defaults to empty string which is equivalent to the current timestamp. - `readerShardingCustomClassName`: the fully qualified custom class name for sharding logic. - `readerShardingCustomJarPath` : the GCS path to custom jar for sharding logic. +- `readerShardingCustomParameters`: the custom parameters to be passed to the custom sharding logic implementation. - `readerSkipDirectoryName`: Records skipped from reverse replication are written to this directory. Defaults to: skip. - `readerRunMode`: whether the reader from Spanner job runs in regular or resume mode. Default is regular. - `readerWorkers`: number of workers for ordering job. Defaults to 5. @@ -141,13 +142,13 @@ Launched writer job: id:"<>" project_id:"<>" name:"<>" current_state_time:{} cr In order to specify custom shard identification function, custom jar and class names need to give. The command to do that is below: ``` -go run reverse-replication-runner.go -projectId= -dataflowRegion= -instanceId= -dbName= -sourceShardsFilePath=gs://bucket-name/shards.json -sessionFilePath=gs://bucket-name/session.json -gcsPath=gs://bucket-name/ -readerShardingCustomJarPath=gs://bucket-name/custom.jar -readerShardingCustomClassName=com.custom.classname +go run reverse-replication-runner.go -projectId= -dataflowRegion= -instanceId= -dbName= -sourceShardsFilePath=gs://bucket-name/shards.json -sessionFilePath=gs://bucket-name/session.json -gcsPath=gs://bucket-name/ -readerShardingCustomJarPath=gs://bucket-name/custom.jar -readerShardingCustomClassName=com.custom.classname -readerShardingCustomParameters='a=b,c=d' ``` The sample reader job gcloud command for the same ``` -gcloud dataflow flex-template run smt-reverse-replication-reader-2024-01-05t10-33-56z --project= --region= --template-file-gcs-location=