Skip to content

Refactor init schema logic into builder #3186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package destinationdatabasebuilder

import (
"context"
"fmt"
"log/slog"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/pkg/sqlmanager"
connectionmanager "github.com/nucleuscloud/neosync/internal/connection-manager"
ddbuilder_mssql "github.com/nucleuscloud/neosync/internal/destination-database-builder/mssql"
ddbuilder_mysql "github.com/nucleuscloud/neosync/internal/destination-database-builder/mysql"
ddbuilder_notsupported "github.com/nucleuscloud/neosync/internal/destination-database-builder/not-supported"
ddbuilder_postgres "github.com/nucleuscloud/neosync/internal/destination-database-builder/postgres"
destdb_shared "github.com/nucleuscloud/neosync/internal/destination-database-builder/shared"
"github.com/nucleuscloud/neosync/internal/ee/license"
)

type DestinationDatabaseBuilderService interface {
InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*destdb_shared.InitSchemaError, error)
TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error
CloseConnections()
}

type DestinationDatabaseBuilder interface {
NewDestinationDatabaseBuilderService(
ctx context.Context,
sourceConnection *mgmtv1alpha1.Connection,
destinationConnection *mgmtv1alpha1.Connection,
destination *mgmtv1alpha1.JobDestination,
) (DestinationDatabaseBuilderService, error)
}

type DefaultDestinationDatabaseBuilder struct {
sqlmanagerclient sqlmanager.SqlManagerClient
session connectionmanager.SessionInterface
logger *slog.Logger
eelicense license.EEInterface
}

func NewDestinationDatabaseBuilder(
sqlmanagerclient sqlmanager.SqlManagerClient,
session connectionmanager.SessionInterface,
logger *slog.Logger,
eelicense license.EEInterface,
) DestinationDatabaseBuilder {
return &DefaultDestinationDatabaseBuilder{sqlmanagerclient: sqlmanagerclient, session: session, logger: logger, eelicense: eelicense}
}

func (d *DefaultDestinationDatabaseBuilder) NewDestinationDatabaseBuilderService(
ctx context.Context,
sourceConnection *mgmtv1alpha1.Connection,
destinationConnection *mgmtv1alpha1.Connection,
destination *mgmtv1alpha1.JobDestination,
) (DestinationDatabaseBuilderService, error) {
switch cfg := destination.GetOptions().GetConfig().(type) {
case *mgmtv1alpha1.JobDestinationOptions_PostgresOptions:
opts := cfg.PostgresOptions
return ddbuilder_postgres.NewPostgresDestinationDatabaseBuilderService(ctx, d.logger, d.session, d.sqlmanagerclient, sourceConnection, destinationConnection, opts)
case *mgmtv1alpha1.JobDestinationOptions_MysqlOptions:
opts := cfg.MysqlOptions
return ddbuilder_mysql.NewMysqlDestinationDatabaseBuilderService(ctx, d.logger, d.session, d.sqlmanagerclient, sourceConnection, destinationConnection, opts)
case *mgmtv1alpha1.JobDestinationOptions_MssqlOptions:
opts := cfg.MssqlOptions
return ddbuilder_mssql.NewMssqlDestinationDatabaseBuilderService(ctx, d.logger, d.eelicense, d.session, d.sqlmanagerclient, sourceConnection, destinationConnection, opts)
case *mgmtv1alpha1.JobDestinationOptions_DynamodbOptions, *mgmtv1alpha1.JobDestinationOptions_MongodbOptions, *mgmtv1alpha1.JobDestinationOptions_AwsS3Options, *mgmtv1alpha1.JobDestinationOptions_GcpCloudstorageOptions:
// For destinations like DynamoDB, MongoDB, S3, and GCP Cloud Storage, we use a no-op implementation
// since schema initialization and data truncation don't apply to these data stores
return ddbuilder_notsupported.NewNotSupportedDestinationDatabaseBuilderService()
default:
return nil, fmt.Errorf("unsupported connection type: %T", destination.GetOptions().GetConfig())
}
}
167 changes: 167 additions & 0 deletions internal/destination-database-builder/mssql/mssql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package ddbuilder_mssql

import (
"context"
"fmt"
"log/slog"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/pkg/sqlmanager"
sqlmanager_mssql "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/mssql"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency"
connectionmanager "github.com/nucleuscloud/neosync/internal/connection-manager"
destdb_shared "github.com/nucleuscloud/neosync/internal/destination-database-builder/shared"
"github.com/nucleuscloud/neosync/internal/ee/license"
ee_sqlmanager_mssql "github.com/nucleuscloud/neosync/internal/ee/mssql-manager"
)

type MssqlDestinationDatabaseBuilderService struct {
logger *slog.Logger
eelicense license.EEInterface
sqlmanagerclient sqlmanager.SqlManagerClient
sourceConnection *mgmtv1alpha1.Connection
destinationConnection *mgmtv1alpha1.Connection
destOpts *mgmtv1alpha1.MssqlDestinationConnectionOptions
destdb *sqlmanager.SqlConnection
sourcedb *sqlmanager.SqlConnection
}

func NewMssqlDestinationDatabaseBuilderService(
ctx context.Context,
logger *slog.Logger,
eelicense license.EEInterface,
session connectionmanager.SessionInterface,
sqlmanagerclient sqlmanager.SqlManagerClient,
sourceConnection *mgmtv1alpha1.Connection,
destinationConnection *mgmtv1alpha1.Connection,
destOpts *mgmtv1alpha1.MssqlDestinationConnectionOptions,
) (*MssqlDestinationDatabaseBuilderService, error) {
sourcedb, err := sqlmanagerclient.NewSqlConnection(ctx, session, sourceConnection, logger)
if err != nil {
return nil, fmt.Errorf("unable to create new sql db: %w", err)
}

destdb, err := sqlmanagerclient.NewSqlConnection(ctx, session, destinationConnection, logger)
if err != nil {
return nil, fmt.Errorf("unable to create new sql db: %w", err)
}

return &MssqlDestinationDatabaseBuilderService{
logger: logger,
eelicense: eelicense,
sqlmanagerclient: sqlmanagerclient,
sourceConnection: sourceConnection,
destinationConnection: destinationConnection,
destOpts: destOpts,
destdb: destdb,
sourcedb: sourcedb,
}, nil
}

func (d *MssqlDestinationDatabaseBuilderService) InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*destdb_shared.InitSchemaError, error) {
initErrors := []*destdb_shared.InitSchemaError{}
if !d.destOpts.GetInitTableSchema() {
d.logger.Info("skipping schema init as it is not enabled")
return initErrors, nil
}
if !d.eelicense.IsValid() {
return nil, fmt.Errorf("invalid or non-existent Neosync License. SQL Server schema init requires valid Enterprise license.")
}
tables := []*sqlmanager_shared.SchemaTable{}
for tableKey := range uniqueTables {
schema, table := sqlmanager_shared.SplitTableKey(tableKey)
tables = append(tables, &sqlmanager_shared.SchemaTable{Schema: schema, Table: table})
}

initblocks, err := d.sourcedb.Db().GetSchemaInitStatements(ctx, tables)
if err != nil {
return nil, err
}

for _, block := range initblocks {
d.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements)))
if len(block.Statements) == 0 {
continue
}
for _, stmt := range block.Statements {
err = d.destdb.Db().Exec(ctx, stmt)
if err != nil {
d.logger.Error(fmt.Sprintf("unable to exec mssql %s statements: %s", block.Label, err.Error()))
if block.Label != ee_sqlmanager_mssql.SchemasLabel && block.Label != ee_sqlmanager_mssql.ViewsFunctionsLabel && block.Label != ee_sqlmanager_mssql.TableIndexLabel {
return nil, fmt.Errorf("unable to exec mssql %s statements: %w", block.Label, err)
}
initErrors = append(initErrors, &destdb_shared.InitSchemaError{
Statement: stmt,
Error: err.Error(),
})
}
}
}
return initErrors, nil
}

func (d *MssqlDestinationDatabaseBuilderService) TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error {
if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() {
d.logger.Info("skipping truncate as it is not enabled")
return nil
}
tableDependencies, err := d.sourcedb.Db().GetTableConstraintsBySchema(ctx, uniqueSchemas)
if err != nil {
return fmt.Errorf("unable to retrieve database foreign key constraints: %w", err)
}
d.logger.Info(fmt.Sprintf("found %d foreign key constraints for database", len(tableDependencies.ForeignKeyConstraints)))
tablePrimaryDependencyMap := destdb_shared.GetFilteredForeignToPrimaryTableMap(tableDependencies.ForeignKeyConstraints, uniqueTables)
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(tablePrimaryDependencyMap)
if err != nil {
return err
}

orderedTableDelete := []string{}
for i := len(orderedTablesResp.OrderedTables) - 1; i >= 0; i-- {
st := orderedTablesResp.OrderedTables[i]
stmt, err := sqlmanager_mssql.BuildMssqlDeleteStatement(st.Schema, st.Table)
if err != nil {
return err
}
orderedTableDelete = append(orderedTableDelete, stmt)
}

d.logger.Info(fmt.Sprintf("executing %d sql statements that will delete from tables", len(orderedTableDelete)))
err = d.destdb.Db().BatchExec(ctx, 10, orderedTableDelete, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
return fmt.Errorf("unable to exec ordered delete from statements: %w", err)
}

// reset identity column counts
schemaColMap, err := d.sourcedb.Db().GetSchemaColumnMap(ctx)
if err != nil {
return err
}

identityStmts := []string{}
for table, cols := range schemaColMap {
if _, ok := uniqueTables[table]; !ok {
continue
}
for _, c := range cols {
if c.IdentityGeneration != nil && *c.IdentityGeneration != "" {
schema, table := sqlmanager_shared.SplitTableKey(table)
identityResetStatement := sqlmanager_mssql.BuildMssqlIdentityColumnResetStatement(schema, table, c.IdentitySeed, c.IdentityIncrement)
identityStmts = append(identityStmts, identityResetStatement)
}
}
}
if len(identityStmts) > 0 {
err = d.destdb.Db().BatchExec(ctx, 10, identityStmts, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
return fmt.Errorf("unable to exec identity reset statements: %w", err)
}
}
return nil
}

func (d *MssqlDestinationDatabaseBuilderService) CloseConnections() {
d.destdb.Db().Close()
d.sourcedb.Db().Close()
}
124 changes: 124 additions & 0 deletions internal/destination-database-builder/mysql/mysql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package ddbuilder_mysql

import (
"context"
"fmt"
"log/slog"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/pkg/sqlmanager"
sqlmanager_mysql "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/mysql"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
connectionmanager "github.com/nucleuscloud/neosync/internal/connection-manager"
destdb_shared "github.com/nucleuscloud/neosync/internal/destination-database-builder/shared"
)

type MysqlDestinationDatabaseBuilderService struct {
logger *slog.Logger
sqlmanagerclient sqlmanager.SqlManagerClient
sourceConnection *mgmtv1alpha1.Connection
destinationConnection *mgmtv1alpha1.Connection
destOpts *mgmtv1alpha1.MysqlDestinationConnectionOptions
destdb *sqlmanager.SqlConnection
sourcedb *sqlmanager.SqlConnection
}

func NewMysqlDestinationDatabaseBuilderService(
ctx context.Context,
logger *slog.Logger,
session connectionmanager.SessionInterface,
sqlmanagerclient sqlmanager.SqlManagerClient,
sourceConnection *mgmtv1alpha1.Connection,
destinationConnection *mgmtv1alpha1.Connection,
destOpts *mgmtv1alpha1.MysqlDestinationConnectionOptions,
) (*MysqlDestinationDatabaseBuilderService, error) {
sourcedb, err := sqlmanagerclient.NewSqlConnection(ctx, session, sourceConnection, logger)
if err != nil {
return nil, fmt.Errorf("unable to create new sql db: %w", err)
}
defer sourcedb.Db().Close()

destdb, err := sqlmanagerclient.NewSqlConnection(ctx, session, destinationConnection, logger)
if err != nil {
return nil, fmt.Errorf("unable to create new sql db: %w", err)
}

return &MysqlDestinationDatabaseBuilderService{
logger: logger,
sqlmanagerclient: sqlmanagerclient,
sourceConnection: sourceConnection,
destinationConnection: destinationConnection,
destOpts: destOpts,
destdb: destdb,
sourcedb: sourcedb,
}, nil
}

func (d *MysqlDestinationDatabaseBuilderService) InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*destdb_shared.InitSchemaError, error) {
initErrors := []*destdb_shared.InitSchemaError{}
if !d.destOpts.GetInitTableSchema() {
d.logger.Info("skipping schema init as it is not enabled")
return initErrors, nil
}
tables := []*sqlmanager_shared.SchemaTable{}
for tableKey := range uniqueTables {
schema, table := sqlmanager_shared.SplitTableKey(tableKey)
tables = append(tables, &sqlmanager_shared.SchemaTable{Schema: schema, Table: table})
}

initblocks, err := d.sourcedb.Db().GetSchemaInitStatements(ctx, tables)
if err != nil {
return nil, err
}

for _, block := range initblocks {
d.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements)))
if len(block.Statements) == 0 {
continue
}
err = d.destdb.Db().BatchExec(ctx, destdb_shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
d.logger.Error(fmt.Sprintf("unable to exec mysql %s statements: %s", block.Label, err.Error()))
if block.Label != sqlmanager_mysql.SchemasLabel {
return nil, fmt.Errorf("unable to exec mysql %s statements: %w", block.Label, err)
}
for _, stmt := range block.Statements {
err = d.destdb.Db().BatchExec(ctx, 1, []string{stmt}, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
initErrors = append(initErrors, &destdb_shared.InitSchemaError{
Statement: stmt,
Error: err.Error(),
})
}
}
}
}
return initErrors, nil
}

func (d *MysqlDestinationDatabaseBuilderService) TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error {
if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() {
d.logger.Info("skipping truncate as it is not enabled")
return nil
}
tableTruncate := []string{}
for table := range uniqueTables {
schema, table := sqlmanager_shared.SplitTableKey(table)
stmt, err := sqlmanager_mysql.BuildMysqlTruncateStatement(schema, table)
if err != nil {
return err
}
tableTruncate = append(tableTruncate, stmt)
}
d.logger.Info(fmt.Sprintf("executing %d sql statements that will truncate tables", len(tableTruncate)))
disableFkChecks := sqlmanager_shared.DisableForeignKeyChecks
err := d.destdb.Db().BatchExec(ctx, destdb_shared.BatchSizeConst, tableTruncate, &sqlmanager_shared.BatchExecOpts{Prefix: &disableFkChecks})
if err != nil {
return err
}
return nil
}

func (d *MysqlDestinationDatabaseBuilderService) CloseConnections() {
d.destdb.Db().Close()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package ddbuilder_notsupported

import (
"context"

destdb_shared "github.com/nucleuscloud/neosync/internal/destination-database-builder/shared"
)

type NotSupportedDestinationDatabaseBuilderService struct {
}

func NewNotSupportedDestinationDatabaseBuilderService() (*NotSupportedDestinationDatabaseBuilderService, error) {
return &NotSupportedDestinationDatabaseBuilderService{}, nil
}

func (d *NotSupportedDestinationDatabaseBuilderService) InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*destdb_shared.InitSchemaError, error) {
return []*destdb_shared.InitSchemaError{}, nil
}

func (d *NotSupportedDestinationDatabaseBuilderService) TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error {
return nil
}

func (d *NotSupportedDestinationDatabaseBuilderService) CloseConnections() {
}
177 changes: 177 additions & 0 deletions internal/destination-database-builder/postgres/postgres.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package ddbuilder_postgres

import (
"context"
"fmt"
"log/slog"
"strings"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/pkg/sqlmanager"
sqlmanager_postgres "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/postgres"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency"
connectionmanager "github.com/nucleuscloud/neosync/internal/connection-manager"
destdb_shared "github.com/nucleuscloud/neosync/internal/destination-database-builder/shared"
)

type PostgresDestinationDatabaseBuilderService struct {
logger *slog.Logger
sqlmanagerclient sqlmanager.SqlManagerClient
sourceConnection *mgmtv1alpha1.Connection
destinationConnection *mgmtv1alpha1.Connection
destOpts *mgmtv1alpha1.PostgresDestinationConnectionOptions
destdb *sqlmanager.SqlConnection
sourcedb *sqlmanager.SqlConnection
}

func NewPostgresDestinationDatabaseBuilderService(
ctx context.Context,
logger *slog.Logger,
session connectionmanager.SessionInterface,
sqlmanagerclient sqlmanager.SqlManagerClient,
sourceConnection *mgmtv1alpha1.Connection,
destinationConnection *mgmtv1alpha1.Connection,
destOpts *mgmtv1alpha1.PostgresDestinationConnectionOptions,
) (*PostgresDestinationDatabaseBuilderService, error) {
sourcedb, err := sqlmanagerclient.NewSqlConnection(ctx, session, sourceConnection, logger)
if err != nil {
return nil, fmt.Errorf("unable to create new sql db: %w", err)
}

destdb, err := sqlmanagerclient.NewSqlConnection(ctx, session, destinationConnection, logger)
if err != nil {
return nil, fmt.Errorf("unable to create new sql db: %w", err)
}

return &PostgresDestinationDatabaseBuilderService{
logger: logger,
sqlmanagerclient: sqlmanagerclient,
sourceConnection: sourceConnection,
destinationConnection: destinationConnection,
destOpts: destOpts,
destdb: destdb,
sourcedb: sourcedb,
}, nil
}

func (d *PostgresDestinationDatabaseBuilderService) InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*destdb_shared.InitSchemaError, error) {
initErrors := []*destdb_shared.InitSchemaError{}
if !d.destOpts.GetInitTableSchema() {
d.logger.Info("skipping schema init as it is not enabled")
return initErrors, nil
}
tables := []*sqlmanager_shared.SchemaTable{}
for tableKey := range uniqueTables {
schema, table := sqlmanager_shared.SplitTableKey(tableKey)
tables = append(tables, &sqlmanager_shared.SchemaTable{Schema: schema, Table: table})
}

initblocks, err := d.sourcedb.Db().GetSchemaInitStatements(ctx, tables)
if err != nil {
return nil, err
}

for _, block := range initblocks {
d.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements)))
if len(block.Statements) == 0 {
continue
}
err = d.destdb.Db().BatchExec(ctx, destdb_shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
d.logger.Error(fmt.Sprintf("unable to exec pg %s statements: %s", block.Label, err.Error()))
if block.Label != sqlmanager_postgres.SchemasLabel && block.Label != sqlmanager_postgres.ExtensionsLabel {
return nil, fmt.Errorf("unable to exec pg %s statements: %w", block.Label, err)
}
for _, stmt := range block.Statements {
err := d.destdb.Db().Exec(ctx, stmt)
if err != nil {
initErrors = append(initErrors, &destdb_shared.InitSchemaError{
Statement: stmt,
Error: err.Error(),
})
}
}
}
}
return initErrors, nil
}

func (d *PostgresDestinationDatabaseBuilderService) TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error {
if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() && !d.destOpts.GetTruncateTable().GetCascade() {
d.logger.Info("skipping truncate as it is not enabled")
return nil
}
if d.destOpts.GetTruncateTable().GetCascade() {
tableTruncateStmts := []string{}
for table := range uniqueTables {
schema, table := sqlmanager_shared.SplitTableKey(table)
stmt, err := sqlmanager_postgres.BuildPgTruncateCascadeStatement(schema, table)
if err != nil {
return err
}
tableTruncateStmts = append(tableTruncateStmts, stmt)
}
d.logger.Info(fmt.Sprintf("executing %d sql statements that will truncate cascade tables", len(tableTruncateStmts)))
err := d.destdb.Db().BatchExec(ctx, destdb_shared.BatchSizeConst, tableTruncateStmts, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
return fmt.Errorf("unable to exec truncate cascade statements: %w", err)
}
} else if d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() {
tableDependencies, err := d.sourcedb.Db().GetTableConstraintsBySchema(ctx, uniqueSchemas)
if err != nil {
return fmt.Errorf("unable to retrieve database foreign key constraints: %w", err)
}
d.logger.Info(fmt.Sprintf("found %d foreign key constraints for database", len(tableDependencies.ForeignKeyConstraints)))
tablePrimaryDependencyMap := destdb_shared.GetFilteredForeignToPrimaryTableMap(tableDependencies.ForeignKeyConstraints, uniqueTables)
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(tablePrimaryDependencyMap)
if err != nil {
return err
}

d.logger.Info(fmt.Sprintf("executing %d sql statements that will truncate tables", len(orderedTablesResp.OrderedTables)))
truncateStmt, err := sqlmanager_postgres.BuildPgTruncateStatement(orderedTablesResp.OrderedTables)
if err != nil {
return fmt.Errorf("unable to build postgres truncate statement: %w", err)
}
err = d.destdb.Db().Exec(ctx, truncateStmt)
if err != nil {
return fmt.Errorf("unable to exec ordered truncate statements: %w", err)
}
}
if d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() || d.destOpts.GetTruncateTable().GetCascade() {
// reset serial counts
// identity counts are automatically reset with truncate identity restart clause
schemaTableMap := map[string][]string{}
for schemaTable := range uniqueTables {
schema, table := sqlmanager_shared.SplitTableKey(schemaTable)
schemaTableMap[schema] = append(schemaTableMap[schema], table)
}

resetSeqStmts := []string{}
for schema, tables := range schemaTableMap {
sequences, err := d.sourcedb.Db().GetSequencesByTables(ctx, schema, tables)
if err != nil {
return err
}
for _, seq := range sequences {
resetSeqStmts = append(resetSeqStmts, sqlmanager_postgres.BuildPgResetSequenceSql(seq.Name))
}
}
if len(resetSeqStmts) > 0 {
err := d.destdb.Db().BatchExec(ctx, 10, resetSeqStmts, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
// handle not found errors
if !strings.Contains(err.Error(), `does not exist`) {
return fmt.Errorf("unable to exec postgres sequence reset statements: %w", err)
}
}
}
}
return nil
}

func (d *PostgresDestinationDatabaseBuilderService) CloseConnections() {
d.sourcedb.Db().Close()
d.destdb.Db().Close()
}
35 changes: 35 additions & 0 deletions internal/destination-database-builder/shared/shared.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package destinationdatabasebuilder_shared

import sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"

const (
BatchSizeConst = 20
)

type InitSchemaError struct {
Statement string
Error string
}

// filtered by tables found in job mappings
func GetFilteredForeignToPrimaryTableMap(td map[string][]*sqlmanager_shared.ForeignConstraint, uniqueTables map[string]struct{}) map[string][]string {
dpMap := map[string][]string{}
for table := range uniqueTables {
_, dpOk := dpMap[table]
if !dpOk {
dpMap[table] = []string{}
}
constraints, ok := td[table]
if !ok {
continue
}
for _, dep := range constraints {
_, ok := uniqueTables[dep.ForeignKey.Table]
// only add to map if dependency is an included table
if ok {
dpMap[table] = append(dpMap[table], dep.ForeignKey.Table)
}
}
}
return dpMap
}
92 changes: 92 additions & 0 deletions internal/destination-database-builder/shared/shared_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package destinationdatabasebuilder_shared

import (
"testing"

sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
"github.com/stretchr/testify/assert"
)

func Test_getFilteredForeignToPrimaryTableMap(t *testing.T) {
t.Parallel()
tables := map[string]struct{}{
"public.regions": {},
"public.jobs": {},
"public.countries": {},
"public.locations": {},
"public.dependents": {},
"public.departments": {},
"public.employees": {},
}
dependencies := map[string][]*sqlmanager_shared.ForeignConstraint{
"public.countries": {
{Columns: []string{"region_id"}, NotNullable: []bool{true}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.regions", Columns: []string{"region_id"}}},
},
"public.departments": {
{Columns: []string{"location_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.locations", Columns: []string{"location_id"}}},
},
"public.dependents": {
{Columns: []string{"dependent_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.employees", Columns: []string{"employees_id"}}},
},
"public.locations": {
{Columns: []string{"country_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.countries", Columns: []string{"country_id"}}},
},
"public.employees": {
{Columns: []string{"department_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.departments", Columns: []string{"department_id"}}},
{Columns: []string{"job_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.jobs", Columns: []string{"job_id"}}},
{Columns: []string{"manager_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.employees", Columns: []string{"employee_id"}}},
},
}

expected := map[string][]string{
"public.regions": {},
"public.jobs": {},
"public.countries": {"public.regions"},
"public.departments": {"public.locations"},
"public.dependents": {"public.employees"},
"public.employees": {"public.departments", "public.jobs", "public.employees"},
"public.locations": {"public.countries"},
}
actual := GetFilteredForeignToPrimaryTableMap(dependencies, tables)
assert.Len(t, actual, len(expected))
for table, deps := range actual {
assert.Len(t, deps, len(expected[table]))
assert.ElementsMatch(t, expected[table], deps)
}
}

func Test_getFilteredForeignToPrimaryTableMap_filtered(t *testing.T) {
t.Parallel()
tables := map[string]struct{}{
"public.countries": {},
}
dependencies := map[string][]*sqlmanager_shared.ForeignConstraint{
"public.countries": {
{Columns: []string{"region_id"}, NotNullable: []bool{true}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.regions", Columns: []string{"region_id"}}}},

"public.departments": {
{Columns: []string{"location_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.locations", Columns: []string{"location_id"}}},
},
"public.dependents": {
{Columns: []string{"dependent_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.employees", Columns: []string{"employees_id"}}},
},
"public.locations": {
{Columns: []string{"country_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.countries", Columns: []string{"country_id"}}},
},
"public.employees": {
{Columns: []string{"department_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.departments", Columns: []string{"department_id"}}},
{Columns: []string{"job_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.jobs", Columns: []string{"job_id"}}},
{Columns: []string{"manager_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.employees", Columns: []string{"employee_id"}}},
},
}

expected := map[string][]string{
"public.countries": {},
}
actual := GetFilteredForeignToPrimaryTableMap(dependencies, tables)
assert.Len(t, actual, len(expected))
for table, deps := range actual {
assert.Len(t, deps, len(expected[table]))
assert.ElementsMatch(t, expected[table], deps)
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -183,7 +183,6 @@ func Test_InitStatementBuilder_Pg_Generate_NoInitStatement(t *testing.T) {
mockSqlManager := sqlmanager.NewMockSqlManagerClient(t)
connectionId := "456"

mockJobClient.On("SetRunContext", mock.Anything, mock.Anything).Return(connect.NewResponse(&mgmtv1alpha1.SetRunContextResponse{}), nil)
mockJobClient.On("GetJob", mock.Anything, mock.Anything).
Return(connect.NewResponse(&mgmtv1alpha1.GetJobResponse{
Job: &mgmtv1alpha1.Job{
@@ -238,6 +237,13 @@ func Test_InitStatementBuilder_Pg_Generate_NoInitStatement(t *testing.T) {
Destinations: []*mgmtv1alpha1.JobDestination{
{
ConnectionId: "456",
Options: &mgmtv1alpha1.JobDestinationOptions{
Config: &mgmtv1alpha1.JobDestinationOptions_PostgresOptions{
PostgresOptions: &mgmtv1alpha1.PostgresDestinationConnectionOptions{
InitTableSchema: false,
},
},
},
},
},
},
@@ -701,7 +707,6 @@ func Test_InitStatementBuilder_Mysql_Generate(t *testing.T) {
mockSqlManager := sqlmanager.NewMockSqlManagerClient(t)
connectionId := "456"

mockJobClient.On("SetRunContext", mock.Anything, mock.Anything).Return(connect.NewResponse(&mgmtv1alpha1.SetRunContextResponse{}), nil)
mockJobClient.On("GetJob", mock.Anything, mock.Anything).
Return(connect.NewResponse(&mgmtv1alpha1.GetJobResponse{
Job: &mgmtv1alpha1.Job{
@@ -805,90 +810,6 @@ func Test_InitStatementBuilder_Mysql_Generate(t *testing.T) {
assert.Nil(t, err)
}

func Test_getFilteredForeignToPrimaryTableMap(t *testing.T) {
t.Parallel()
tables := map[string]struct{}{
"public.regions": {},
"public.jobs": {},
"public.countries": {},
"public.locations": {},
"public.dependents": {},
"public.departments": {},
"public.employees": {},
}
dependencies := map[string][]*sqlmanager_shared.ForeignConstraint{
"public.countries": {
{Columns: []string{"region_id"}, NotNullable: []bool{true}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.regions", Columns: []string{"region_id"}}},
},
"public.departments": {
{Columns: []string{"location_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.locations", Columns: []string{"location_id"}}},
},
"public.dependents": {
{Columns: []string{"dependent_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.employees", Columns: []string{"employees_id"}}},
},
"public.locations": {
{Columns: []string{"country_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.countries", Columns: []string{"country_id"}}},
},
"public.employees": {
{Columns: []string{"department_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.departments", Columns: []string{"department_id"}}},
{Columns: []string{"job_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.jobs", Columns: []string{"job_id"}}},
{Columns: []string{"manager_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.employees", Columns: []string{"employee_id"}}},
},
}

expected := map[string][]string{
"public.regions": {},
"public.jobs": {},
"public.countries": {"public.regions"},
"public.departments": {"public.locations"},
"public.dependents": {"public.employees"},
"public.employees": {"public.departments", "public.jobs", "public.employees"},
"public.locations": {"public.countries"},
}
actual := getFilteredForeignToPrimaryTableMap(dependencies, tables)
assert.Len(t, actual, len(expected))
for table, deps := range actual {
assert.Len(t, deps, len(expected[table]))
assert.ElementsMatch(t, expected[table], deps)
}
}

func Test_getFilteredForeignToPrimaryTableMap_filtered(t *testing.T) {
t.Parallel()
tables := map[string]struct{}{
"public.countries": {},
}
dependencies := map[string][]*sqlmanager_shared.ForeignConstraint{
"public.countries": {
{Columns: []string{"region_id"}, NotNullable: []bool{true}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.regions", Columns: []string{"region_id"}}}},

"public.departments": {
{Columns: []string{"location_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.locations", Columns: []string{"location_id"}}},
},
"public.dependents": {
{Columns: []string{"dependent_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.employees", Columns: []string{"employees_id"}}},
},
"public.locations": {
{Columns: []string{"country_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.countries", Columns: []string{"country_id"}}},
},
"public.employees": {
{Columns: []string{"department_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.departments", Columns: []string{"department_id"}}},
{Columns: []string{"job_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.jobs", Columns: []string{"job_id"}}},
{Columns: []string{"manager_id"}, NotNullable: []bool{false}, ForeignKey: &sqlmanager_shared.ForeignKey{Table: "public.employees", Columns: []string{"employee_id"}}},
},
}

expected := map[string][]string{
"public.countries": {},
}
actual := getFilteredForeignToPrimaryTableMap(dependencies, tables)
assert.Len(t, actual, len(expected))
for table, deps := range actual {
assert.Len(t, deps, len(expected[table]))
assert.ElementsMatch(t, expected[table], deps)
}
}

func compareSlices(slice1, slice2 []string) bool {
for _, ele := range slice1 {
if !slices.Contains(slice2, ele) {