From 0877ae6164fab14bf54e6ec831fbf5751c238102 Mon Sep 17 00:00:00 2001 From: Nathan Bosscher Date: Mon, 23 Jan 2023 15:13:00 -0500 Subject: [PATCH] Update reflectx to allow for optional nested structs Nested structs are now only instantiated when one of the database columns in that nested struct is not nil. This allows objects scanned in left/outer joins to keep their natural types (instead of setting everything to NullableX). Example: ```sql select house.id, owner.*, from house left join owner on owner.id = house.owner ``` ``` golang type House struct { ID int Owner *Person // if left join gives nulls, Owner will be nil } type Owner struct { ID int // no need to set this to sql.NullInt } ``` --- reflectx/reflect.go | 189 +++++++++++++++++++++++++++++++++++++++++++ sqlx.go | 16 ++-- sqlx_context_test.go | 104 ++++++++++++++++++++++++ 3 files changed, 304 insertions(+), 5 deletions(-) diff --git a/reflectx/reflect.go b/reflectx/reflect.go index 0b109942..6b710f22 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -7,8 +7,11 @@ package reflectx import ( + "database/sql" + "fmt" "reflect" "runtime" + "strconv" "strings" "sync" ) @@ -201,6 +204,192 @@ func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(in return nil } +// ObjectContext provides a single layer to abstract away +// nested struct scanning functionality +type ObjectContext struct { + value reflect.Value +} + +func NewObjectContext() *ObjectContext { + return &ObjectContext{} +} + +// NewRow updates the object reference. +// This ensures all columns point to the same object +func (o *ObjectContext) NewRow(value reflect.Value) { + o.value = value +} + +// FieldForIndexes returns the value for address. If the address is a nested struct, +// a nestedFieldScanner is returned instead of the standard value reference +func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { + if len(indexes) == 1 { + val := FieldByIndexes(o.value, indexes) + return val + } + + obj := &nestedFieldScanner{ + parent: o, + indexes: indexes, + } + + v := reflect.ValueOf(obj).Elem() + return v +} + +// getFieldByIndex returns a value for the field given by the struct traversal +// for the given value. +func (o *ObjectContext) getFieldByIndex(indexes []int) reflect.Value { + return FieldByIndexes(o.value, indexes) +} + +// nestedFieldScanner will only forward the Scan to the nested value if +// the database value is not nil. +type nestedFieldScanner struct { + parent *ObjectContext + indexes []int +} + +// Scan implements sql.Scanner. +// This method largely mirrors the sql.convertAssign() method with some minor changes +func (o *nestedFieldScanner) Scan(src interface{}) error { + if src == nil { + return nil + } + + dv := FieldByIndexes(o.parent.value, o.indexes) + iface := dv.Addr().Interface() + + if scan, ok := iface.(sql.Scanner); ok { + return scan.Scan(src) + } + + sv := reflect.ValueOf(src) + + // below is taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go + // with a few minor edits + + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(bytesClone(b))) + default: + dv.Set(sv) + } + + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("don't know how to parse type %T -> %T", src, iface) +} + +// returns internal conversion error if available +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +// converts value to it's string value +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +// bytesClone returns a copy of b[:len(b)]. +// The result may have additional unused capacity. +// Clone(nil) returns nil. +// +// bytesClone is a mirror of bytes.Clone while our go.mod is on an older version +func bytesClone(b []byte) []byte { + if b == nil { + return nil + } + return append([]byte{}, b...) +} + // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { diff --git a/sqlx.go b/sqlx.go index f7b28768..9b3ad8b0 100644 --- a/sqlx.go +++ b/sqlx.go @@ -621,7 +621,8 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - err := fieldsByTraversal(v, r.fields, r.values, true) + octx := reflectx.NewObjectContext() + err := fieldsByTraversal(octx, v, r.fields, r.values, true) if err != nil { return err } @@ -781,7 +782,9 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - err = fieldsByTraversal(v, fields, values, true) + octx := reflectx.NewObjectContext() + + err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { return err } @@ -948,13 +951,14 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) + octx := reflectx.NewObjectContext() for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(v, fields, values, true) + err = fieldsByTraversal(octx, v, fields, values, true) if err != nil { return err } @@ -1020,18 +1024,20 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(octx *reflectx.ObjectContext, v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") } + octx.NewRow(v) + for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) continue } - f := reflectx.FieldByIndexes(v, traversal) + f := octx.FieldForIndexes(traversal) if ptrs { values[i] = f.Addr().Interface() } else { diff --git a/sqlx_context_test.go b/sqlx_context_test.go index e49ab8b7..107eb970 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -642,6 +642,110 @@ func TestNamedQueryContext(t *testing.T) { t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) } } + + rows.Close() + + type Owner struct { + Email string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + } + + // Test optional nested structs with left join + type PlaceOwner struct { + Place Place `db:"place"` + Owner *Owner `db:"owner"` + } + + pl = Place{ + Name: sql.NullString{String: "the-house", Valid: true}, + } + + q4 := `INSERT INTO place (id, name) VALUES (2, :name)` + _, err = db.NamedExecContext(ctx, q4, pl) + if err != nil { + log.Fatal(err) + } + + id = 2 + pp.Place.ID = id + + q5 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` + _, err = db.NamedExecContext(ctx, q5, pp) + if err != nil { + log.Fatal(err) + } + + pp3 := &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + place.id AS "place.id", + place.name AS "place.name" + FROM place + LEFT JOIN placeperson ON false -- null left join + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp3) + if err != nil { + t.Error(err) + } + if pp3.Owner != nil { + t.Error("Expected `Owner`, to be nil") + } + if pp3.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + } + if pp3.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + } + } + + rows.Close() + + pp3 = &PlaceOwner{} + rows, err = db.NamedQueryContext(ctx, ` + SELECT + placeperson.first_name "owner.first_name", + placeperson.last_name "owner.last_name", + placeperson.email "owner.email", + place.id AS "place.id", + place.name AS "place.name" + FROM place + left JOIN placeperson ON placeperson.place_id = place.id + WHERE + place.id=:place.id`, pp) + if err != nil { + log.Fatal(err) + } + for rows.Next() { + err = rows.StructScan(pp3) + if err != nil { + t.Error(err) + } + if pp3.Owner == nil { + t.Error("Expected `Owner`, to not be nil") + } + + if pp3.Owner.FirstName != "ben" { + t.Error("Expected first name of `ben`, got " + pp3.Owner.FirstName) + } + if pp3.Owner.LastName != "doe" { + t.Error("Expected first name of `doe`, got " + pp3.Owner.LastName) + } + if pp3.Place.Name.String != "the-house" { + t.Error("Expected place name of `the-house`, got " + pp3.Place.Name.String) + } + if pp3.Place.ID != pp.Place.ID { + t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp3.Place.ID) + } + } }) }