-
Notifications
You must be signed in to change notification settings - Fork 153
/
Copy pathpostgres.go
177 lines (164 loc) · 7.03 KB
/
postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
package ddbuilder_postgres
import (
"context"
"fmt"
"log/slog"
"strings"
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/pkg/sqlmanager"
sqlmanager_postgres "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/postgres"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency"
connectionmanager "github.com/nucleuscloud/neosync/internal/connection-manager"
destdb_shared "github.com/nucleuscloud/neosync/internal/destination-database-builder/shared"
)
type PostgresDestinationDatabaseBuilderService struct {
logger *slog.Logger
sqlmanagerclient sqlmanager.SqlManagerClient
sourceConnection *mgmtv1alpha1.Connection
destinationConnection *mgmtv1alpha1.Connection
destOpts *mgmtv1alpha1.PostgresDestinationConnectionOptions
destdb *sqlmanager.SqlConnection
sourcedb *sqlmanager.SqlConnection
}
func NewPostgresDestinationDatabaseBuilderService(
ctx context.Context,
logger *slog.Logger,
session connectionmanager.SessionInterface,
sqlmanagerclient sqlmanager.SqlManagerClient,
sourceConnection *mgmtv1alpha1.Connection,
destinationConnection *mgmtv1alpha1.Connection,
destOpts *mgmtv1alpha1.PostgresDestinationConnectionOptions,
) (*PostgresDestinationDatabaseBuilderService, error) {
sourcedb, err := sqlmanagerclient.NewSqlConnection(ctx, session, sourceConnection, logger)
if err != nil {
return nil, fmt.Errorf("unable to create new sql db: %w", err)
}
destdb, err := sqlmanagerclient.NewSqlConnection(ctx, session, destinationConnection, logger)
if err != nil {
return nil, fmt.Errorf("unable to create new sql db: %w", err)
}
return &PostgresDestinationDatabaseBuilderService{
logger: logger,
sqlmanagerclient: sqlmanagerclient,
sourceConnection: sourceConnection,
destinationConnection: destinationConnection,
destOpts: destOpts,
destdb: destdb,
sourcedb: sourcedb,
}, nil
}
func (d *PostgresDestinationDatabaseBuilderService) InitializeSchema(ctx context.Context, uniqueTables map[string]struct{}) ([]*destdb_shared.InitSchemaError, error) {
initErrors := []*destdb_shared.InitSchemaError{}
if !d.destOpts.GetInitTableSchema() {
d.logger.Info("skipping schema init as it is not enabled")
return initErrors, nil
}
tables := []*sqlmanager_shared.SchemaTable{}
for tableKey := range uniqueTables {
schema, table := sqlmanager_shared.SplitTableKey(tableKey)
tables = append(tables, &sqlmanager_shared.SchemaTable{Schema: schema, Table: table})
}
initblocks, err := d.sourcedb.Db().GetSchemaInitStatements(ctx, tables)
if err != nil {
return nil, err
}
for _, block := range initblocks {
d.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements)))
if len(block.Statements) == 0 {
continue
}
err = d.destdb.Db().BatchExec(ctx, destdb_shared.BatchSizeConst, block.Statements, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
d.logger.Error(fmt.Sprintf("unable to exec pg %s statements: %s", block.Label, err.Error()))
if block.Label != sqlmanager_postgres.SchemasLabel && block.Label != sqlmanager_postgres.ExtensionsLabel {
return nil, fmt.Errorf("unable to exec pg %s statements: %w", block.Label, err)
}
for _, stmt := range block.Statements {
err := d.destdb.Db().Exec(ctx, stmt)
if err != nil {
initErrors = append(initErrors, &destdb_shared.InitSchemaError{
Statement: stmt,
Error: err.Error(),
})
}
}
}
}
return initErrors, nil
}
func (d *PostgresDestinationDatabaseBuilderService) TruncateData(ctx context.Context, uniqueTables map[string]struct{}, uniqueSchemas []string) error {
if !d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() && !d.destOpts.GetTruncateTable().GetCascade() {
d.logger.Info("skipping truncate as it is not enabled")
return nil
}
if d.destOpts.GetTruncateTable().GetCascade() {
tableTruncateStmts := []string{}
for table := range uniqueTables {
schema, table := sqlmanager_shared.SplitTableKey(table)
stmt, err := sqlmanager_postgres.BuildPgTruncateCascadeStatement(schema, table)
if err != nil {
return err
}
tableTruncateStmts = append(tableTruncateStmts, stmt)
}
d.logger.Info(fmt.Sprintf("executing %d sql statements that will truncate cascade tables", len(tableTruncateStmts)))
err := d.destdb.Db().BatchExec(ctx, destdb_shared.BatchSizeConst, tableTruncateStmts, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
return fmt.Errorf("unable to exec truncate cascade statements: %w", err)
}
} else if d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() {
tableDependencies, err := d.sourcedb.Db().GetTableConstraintsBySchema(ctx, uniqueSchemas)
if err != nil {
return fmt.Errorf("unable to retrieve database foreign key constraints: %w", err)
}
d.logger.Info(fmt.Sprintf("found %d foreign key constraints for database", len(tableDependencies.ForeignKeyConstraints)))
tablePrimaryDependencyMap := destdb_shared.GetFilteredForeignToPrimaryTableMap(tableDependencies.ForeignKeyConstraints, uniqueTables)
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(tablePrimaryDependencyMap)
if err != nil {
return err
}
d.logger.Info(fmt.Sprintf("executing %d sql statements that will truncate tables", len(orderedTablesResp.OrderedTables)))
truncateStmt, err := sqlmanager_postgres.BuildPgTruncateStatement(orderedTablesResp.OrderedTables)
if err != nil {
return fmt.Errorf("unable to build postgres truncate statement: %w", err)
}
err = d.destdb.Db().Exec(ctx, truncateStmt)
if err != nil {
return fmt.Errorf("unable to exec ordered truncate statements: %w", err)
}
}
if d.destOpts.GetTruncateTable().GetTruncateBeforeInsert() || d.destOpts.GetTruncateTable().GetCascade() {
// reset serial counts
// identity counts are automatically reset with truncate identity restart clause
schemaTableMap := map[string][]string{}
for schemaTable := range uniqueTables {
schema, table := sqlmanager_shared.SplitTableKey(schemaTable)
schemaTableMap[schema] = append(schemaTableMap[schema], table)
}
resetSeqStmts := []string{}
for schema, tables := range schemaTableMap {
sequences, err := d.sourcedb.Db().GetSequencesByTables(ctx, schema, tables)
if err != nil {
return err
}
for _, seq := range sequences {
resetSeqStmts = append(resetSeqStmts, sqlmanager_postgres.BuildPgResetSequenceSql(seq.Name))
}
}
if len(resetSeqStmts) > 0 {
err := d.destdb.Db().BatchExec(ctx, 10, resetSeqStmts, &sqlmanager_shared.BatchExecOpts{})
if err != nil {
// handle not found errors
if !strings.Contains(err.Error(), `does not exist`) {
return fmt.Errorf("unable to exec postgres sequence reset statements: %w", err)
}
}
}
}
return nil
}
func (d *PostgresDestinationDatabaseBuilderService) CloseConnections() {
d.sourcedb.Db().Close()
d.destdb.Db().Close()
}