Skip to content

feat/sql-pool-config: Added command support for poolConfig for sql #4681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/stores/postgres/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ import (
const postgresDriverName = "pgx"

// New returns a postgres connection.
func New(datasource string, opts ...sqlx.SqlOption) sqlx.SqlConn {
return sqlx.NewSqlConn(postgresDriverName, datasource, opts...)
func New(datasource string, poolConfig sqlx.PoolConfig, opts ...sqlx.SqlOption) sqlx.SqlConn {
return sqlx.NewSqlConn(postgresDriverName, datasource, poolConfig, opts...)
}
8 changes: 7 additions & 1 deletion core/stores/postgres/postgresql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ package postgres

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/sqlx"
)

func TestPostgreSql(t *testing.T) {
assert.NotNil(t, New("postgre"))
assert.NotNil(t, New("postgre", sqlx.PoolConfig{
MaxIdleConns: 10,
MaxOpenConns: 10,
MaxLifetime: time.Minute,
}))
}
4 changes: 2 additions & 2 deletions core/stores/sqlx/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ const (
)

// NewMysql returns a mysql connection.
func NewMysql(datasource string, opts ...SqlOption) SqlConn {
func NewMysql(datasource string, poolConfig PoolConfig, opts ...SqlOption) SqlConn {
opts = append([]SqlOption{withMysqlAcceptable()}, opts...)
return NewSqlConn(mysqlDriverName, datasource, opts...)
return NewSqlConn(mysqlDriverName, datasource, poolConfig, opts...)
}

func mysqlAcceptable(err error) bool {
Expand Down
8 changes: 6 additions & 2 deletions core/stores/sqlx/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package sqlx
import (
"errors"
"testing"

"time"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
Expand Down Expand Up @@ -35,7 +35,11 @@ func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
}

func TestMysqlAcceptable(t *testing.T) {
conn := NewMysql("nomysql").(*commonSqlConn)
conn := NewMysql("nomysql", PoolConfig{
MaxIdleConns: 10,
MaxOpenConns: 10,
MaxLifetime: time.Minute,
}).(*commonSqlConn)
withMysqlAcceptable()(conn)
assert.True(t, mysqlAcceptable(nil))
assert.False(t, mysqlAcceptable(errors.New("any")))
Expand Down
11 changes: 9 additions & 2 deletions core/stores/sqlx/sqlconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"time"

"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/errorx"
Expand Down Expand Up @@ -54,6 +55,12 @@ type (
accept breaker.Acceptable
}

PoolConfig struct {
MaxIdleConns int
MaxOpenConns int
MaxLifetime time.Duration
}

connProvider func() (*sql.DB, error)

sessionConn interface {
Expand All @@ -65,10 +72,10 @@ type (
)

// NewSqlConn returns a SqlConn with given driver name and datasource.
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
func NewSqlConn(driverName, datasource string, poolConfig PoolConfig, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return getSqlConn(driverName, datasource)
return getSqlConn(driverName, datasource, poolConfig)
},
onError: func(ctx context.Context, err error) {
logInstanceError(ctx, datasource, err)
Expand Down
14 changes: 11 additions & 3 deletions core/stores/sqlx/sqlconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"errors"
"io"
"testing"

"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
Expand All @@ -26,11 +26,19 @@ func TestSqlConn(t *testing.T) {
assert.Nil(t, err)
mock.ExpectExec("any")
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
conn := NewMysql(mockedDatasource)
conn := NewMysql(mockedDatasource, PoolConfig{
MaxIdleConns: 10,
MaxOpenConns: 10,
MaxLifetime: time.Minute,
})
db, err := conn.RawDB()
assert.Nil(t, err)
rawConn := NewSqlConnFromDB(db, withMysqlAcceptable())
badConn := NewMysql("badsql")
badConn := NewMysql("badsql", PoolConfig{
MaxIdleConns: 10,
MaxOpenConns: 10,
MaxLifetime: time.Minute,
})
_, err = conn.Exec("any", "value")
assert.NotNil(t, err)
_, err = badConn.Exec("any", "value")
Expand Down
23 changes: 8 additions & 15 deletions core/stores/sqlx/sqlmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,17 @@ import (
"database/sql"
"encoding/hex"
"io"
"time"

"github.com/go-sql-driver/mysql"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/syncx"
)

const (
maxIdleConns = 64
maxOpenConns = 64
maxLifetime = time.Minute
)

var connManager = syncx.NewResourceManager()

func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
func getCachedSqlConn(driverName, server string, poolConfig PoolConfig) (*sql.DB, error) {
val, err := connManager.GetResource(server, func() (io.Closer, error) {
conn, err := newDBConnection(driverName, server)
conn, err := newDBConnection(driverName, server, poolConfig)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -53,16 +46,16 @@ func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
return val.(*sql.DB), nil
}

func getSqlConn(driverName, server string) (*sql.DB, error) {
conn, err := getCachedSqlConn(driverName, server)
func getSqlConn(driverName, server string, poolConfig PoolConfig) (*sql.DB, error) {
conn, err := getCachedSqlConn(driverName, server, poolConfig)
if err != nil {
return nil, err
}

return conn, nil
}

func newDBConnection(driverName, datasource string) (*sql.DB, error) {
func newDBConnection(driverName, datasource string, poolConfig PoolConfig) (*sql.DB, error) {
conn, err := sql.Open(driverName, datasource)
if err != nil {
return nil, err
Expand All @@ -72,9 +65,9 @@ func newDBConnection(driverName, datasource string) (*sql.DB, error) {
// discussed here https://github.com/go-sql-driver/mysql/issues/257
// if the discussed SetMaxIdleTimeout methods added, we'll change this behavior
// 8 means we can't have more than 8 goroutines to concurrently access the same database.
conn.SetMaxIdleConns(maxIdleConns)
conn.SetMaxOpenConns(maxOpenConns)
conn.SetConnMaxLifetime(maxLifetime)
conn.SetMaxIdleConns(poolConfig.MaxIdleConns)
conn.SetMaxOpenConns(poolConfig.MaxOpenConns)
conn.SetConnMaxLifetime(poolConfig.MaxLifetime)

if err := conn.Ping(); err != nil {
_ = conn.Close()
Expand Down
6 changes: 6 additions & 0 deletions tools/goctl/model/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ func init() {
datasourceCmdFlags.StringVar(&command.VarStringHome, "home")
datasourceCmdFlags.StringVar(&command.VarStringRemote, "remote")
datasourceCmdFlags.StringVar(&command.VarStringBranch, "branch")
datasourceCmdFlags.IntVarP(&command.VarIntMaxIdleConns, "max-idle-conns", "i", 64)
datasourceCmdFlags.IntVarP(&command.VarIntMaxOpenConns, "max-open-conns", "o", 64)
datasourceCmdFlags.IntVarP(&command.VarIntMaxLifetime, "max-lifetime", "l", 60)

pgDatasourceCmdFlags.StringVar(&command.VarStringURL, "url")
pgDatasourceCmdFlags.StringSliceVarP(&command.VarStringSliceTable, "table", "t")
Expand All @@ -56,6 +59,9 @@ func init() {
pgDatasourceCmdFlags.StringVar(&command.VarStringHome, "home")
pgDatasourceCmdFlags.StringVar(&command.VarStringRemote, "remote")
pgDatasourceCmdFlags.StringVar(&command.VarStringBranch, "branch")
pgDatasourceCmdFlags.IntVarP(&command.VarIntMaxIdleConns, "max-idle-conns", "i", 64)
pgDatasourceCmdFlags.IntVarP(&command.VarIntMaxOpenConns, "max-open-conns", "o", 64)
pgDatasourceCmdFlags.IntVarP(&command.VarIntMaxLifetime, "max-lifetime", "l", 60)
pgCmd.PersistentFlags().StringSliceVarPWithDefaultValue(&command.VarStringSliceIgnoreColumns,
"ignore-columns", "i", []string{"create_at", "created_at", "create_time", "update_at", "updated_at", "update_time"})

Expand Down
34 changes: 30 additions & 4 deletions tools/goctl/model/sql/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ var (
VarStringSliceIgnoreColumns []string
// VarStringCachePrefix describes the prefix of cache.
VarStringCachePrefix string
// VarIntMaxIdleConns describes the max idle conns.
VarIntMaxIdleConns int
// VarIntMaxOpenConns describes the max open conns.
VarIntMaxOpenConns int
// VarIntMaxLifetime describes the max lifetime.
VarIntMaxLifetime int
)

var errNotMatched = errors.New("sql not matched")
Expand Down Expand Up @@ -113,6 +119,9 @@ func MySqlDataSource(_ *cobra.Command, _ []string) error {
home := VarStringHome
remote := VarStringRemote
branch := VarStringBranch
maxIdleConns := VarIntMaxIdleConns
maxOpenConns := VarIntMaxOpenConns
maxLifetime := VarIntMaxLifetime
if len(remote) > 0 {
repo, _ := file.CloneIntoGitHome(remote, branch)
if len(repo) > 0 {
Expand Down Expand Up @@ -140,6 +149,9 @@ func MySqlDataSource(_ *cobra.Command, _ []string) error {
strict: VarBoolStrict,
ignoreColumns: mergeColumns(VarStringSliceIgnoreColumns),
prefix: VarStringCachePrefix,
maxIdleConns: maxIdleConns,
maxOpenConns: maxOpenConns,
maxLifetime: maxLifetime,
}
return fromMysqlDataSource(arg)
}
Expand Down Expand Up @@ -204,6 +216,9 @@ func PostgreSqlDataSource(_ *cobra.Command, _ []string) error {
home := VarStringHome
remote := VarStringRemote
branch := VarStringBranch
maxIdleConns := VarIntMaxIdleConns
maxOpenConns := VarIntMaxOpenConns
maxLifetime := VarIntMaxLifetime
if len(remote) > 0 {
repo, _ := file.CloneIntoGitHome(remote, branch)
if len(repo) > 0 {
Expand All @@ -225,7 +240,7 @@ func PostgreSqlDataSource(_ *cobra.Command, _ []string) error {
}
ignoreColumns := mergeColumns(VarStringSliceIgnoreColumns)

return fromPostgreSqlDataSource(url, patterns, dir, schema, cfg, cache, idea, VarBoolStrict, ignoreColumns)
return fromPostgreSqlDataSource(url, patterns, dir, schema, cfg, cache, idea, VarBoolStrict, ignoreColumns, maxIdleConns, maxOpenConns, maxLifetime)
}

type ddlArg struct {
Expand Down Expand Up @@ -278,6 +293,9 @@ type dataSourceArg struct {
strict bool
ignoreColumns []string
prefix string
maxIdleConns int
maxOpenConns int
maxLifetime int
}

func fromMysqlDataSource(arg dataSourceArg) error {
Expand All @@ -299,7 +317,11 @@ func fromMysqlDataSource(arg dataSourceArg) error {

logx.Disable()
databaseSource := strings.TrimSuffix(arg.url, "/"+dsn.DBName) + "/information_schema"
db := sqlx.NewMysql(databaseSource)
db := sqlx.NewMysql(databaseSource, sqlx.PoolConfig{
MaxIdleConns: arg.maxIdleConns,
MaxOpenConns: arg.maxOpenConns,
MaxLifetime: arg.maxLifetime,
})
im := model.NewInformationSchemaModel(db)

tables, err := im.GetAllTables(dsn.DBName)
Expand Down Expand Up @@ -339,7 +361,7 @@ func fromMysqlDataSource(arg dataSourceArg) error {
return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict)
}

func fromPostgreSqlDataSource(url string, pattern pattern, dir, schema string, cfg *config.Config, cache, idea, strict bool, ignoreColumns []string) error {
func fromPostgreSqlDataSource(url string, pattern pattern, dir, schema string, cfg *config.Config, cache, idea, strict bool, ignoreColumns []string, maxIdleConns, maxOpenConns, maxLifetime int) error {
log := console.NewConsole(idea)
if len(url) == 0 {
log.Error("%v", "expected data source of postgresql, but nothing found")
Expand All @@ -350,7 +372,11 @@ func fromPostgreSqlDataSource(url string, pattern pattern, dir, schema string, c
log.Error("%v", "expected table or table globbing patterns, but nothing found")
return nil
}
db := postgres.New(url)
db := postgres.New(url, sqlx.PoolConfig{
MaxIdleConns: maxIdleConns,
MaxOpenConns: maxOpenConns,
MaxLifetime: time.Duration(maxLifetime),
})
im := model.NewPostgreSqlModel(db)

tables, err := im.GetAllTables(schema)
Expand Down