|
3 | 3 | // allows for Go-compatible named attribute access, including accessing embedded
|
4 | 4 | // struct attributes and the ability to use functions and struct tags to
|
5 | 5 | // customize field names.
|
6 |
| -// |
7 | 6 | package reflectx
|
8 | 7 |
|
9 | 8 | import (
|
| 9 | + "database/sql" |
| 10 | + "fmt" |
10 | 11 | "reflect"
|
11 | 12 | "runtime"
|
| 13 | + "strconv" |
12 | 14 | "strings"
|
13 | 15 | "sync"
|
14 | 16 | )
|
@@ -201,6 +203,192 @@ func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(in
|
201 | 203 | return nil
|
202 | 204 | }
|
203 | 205 |
|
| 206 | +// ObjectContext provides a single layer to abstract away |
| 207 | +// nested struct scanning functionality |
| 208 | +type ObjectContext struct { |
| 209 | + value reflect.Value |
| 210 | +} |
| 211 | + |
| 212 | +func NewObjectContext() *ObjectContext { |
| 213 | + return &ObjectContext{} |
| 214 | +} |
| 215 | + |
| 216 | +// NewRow updates the object reference. |
| 217 | +// This ensures all columns point to the same object |
| 218 | +func (o *ObjectContext) NewRow(value reflect.Value) { |
| 219 | + o.value = value |
| 220 | +} |
| 221 | + |
| 222 | +// FieldForIndexes returns the value for address. If the address is a nested struct, |
| 223 | +// a nestedFieldScanner is returned instead of the standard value reference |
| 224 | +func (o *ObjectContext) FieldForIndexes(indexes []int) reflect.Value { |
| 225 | + if len(indexes) == 1 { |
| 226 | + val := FieldByIndexes(o.value, indexes) |
| 227 | + return val |
| 228 | + } |
| 229 | + |
| 230 | + obj := &nestedFieldScanner{ |
| 231 | + parent: o, |
| 232 | + indexes: indexes, |
| 233 | + } |
| 234 | + |
| 235 | + v := reflect.ValueOf(obj).Elem() |
| 236 | + return v |
| 237 | +} |
| 238 | + |
| 239 | +// getFieldByIndex returns a value for the field given by the struct traversal |
| 240 | +// for the given value. |
| 241 | +func (o *ObjectContext) getFieldByIndex(indexes []int) reflect.Value { |
| 242 | + return FieldByIndexes(o.value, indexes) |
| 243 | +} |
| 244 | + |
| 245 | +// nestedFieldScanner will only forward the Scan to the nested value if |
| 246 | +// the database value is not nil. |
| 247 | +type nestedFieldScanner struct { |
| 248 | + parent *ObjectContext |
| 249 | + indexes []int |
| 250 | +} |
| 251 | + |
| 252 | +// Scan implements sql.Scanner. |
| 253 | +// This method largely mirrors the sql.convertAssign() method with some minor changes |
| 254 | +func (o *nestedFieldScanner) Scan(src interface{}) error { |
| 255 | + if src == nil { |
| 256 | + return nil |
| 257 | + } |
| 258 | + |
| 259 | + dv := FieldByIndexes(o.parent.value, o.indexes) |
| 260 | + iface := dv.Addr().Interface() |
| 261 | + |
| 262 | + if scan, ok := iface.(sql.Scanner); ok { |
| 263 | + return scan.Scan(src) |
| 264 | + } |
| 265 | + |
| 266 | + sv := reflect.ValueOf(src) |
| 267 | + |
| 268 | + // below is taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go |
| 269 | + // with a few minor edits |
| 270 | + |
| 271 | + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { |
| 272 | + switch b := src.(type) { |
| 273 | + case []byte: |
| 274 | + dv.Set(reflect.ValueOf(bytesClone(b))) |
| 275 | + default: |
| 276 | + dv.Set(sv) |
| 277 | + } |
| 278 | + |
| 279 | + return nil |
| 280 | + } |
| 281 | + |
| 282 | + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { |
| 283 | + dv.Set(sv.Convert(dv.Type())) |
| 284 | + return nil |
| 285 | + } |
| 286 | + |
| 287 | + // The following conversions use a string value as an intermediate representation |
| 288 | + // to convert between various numeric types. |
| 289 | + // |
| 290 | + // This also allows scanning into user defined types such as "type Int int64". |
| 291 | + // For symmetry, also check for string destination types. |
| 292 | + switch dv.Kind() { |
| 293 | + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| 294 | + if src == nil { |
| 295 | + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) |
| 296 | + } |
| 297 | + s := asString(src) |
| 298 | + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) |
| 299 | + if err != nil { |
| 300 | + err = strconvErr(err) |
| 301 | + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) |
| 302 | + } |
| 303 | + dv.SetInt(i64) |
| 304 | + return nil |
| 305 | + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
| 306 | + if src == nil { |
| 307 | + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) |
| 308 | + } |
| 309 | + s := asString(src) |
| 310 | + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) |
| 311 | + if err != nil { |
| 312 | + err = strconvErr(err) |
| 313 | + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) |
| 314 | + } |
| 315 | + dv.SetUint(u64) |
| 316 | + return nil |
| 317 | + case reflect.Float32, reflect.Float64: |
| 318 | + if src == nil { |
| 319 | + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) |
| 320 | + } |
| 321 | + s := asString(src) |
| 322 | + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) |
| 323 | + if err != nil { |
| 324 | + err = strconvErr(err) |
| 325 | + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) |
| 326 | + } |
| 327 | + dv.SetFloat(f64) |
| 328 | + return nil |
| 329 | + case reflect.String: |
| 330 | + if src == nil { |
| 331 | + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) |
| 332 | + } |
| 333 | + switch v := src.(type) { |
| 334 | + case string: |
| 335 | + dv.SetString(v) |
| 336 | + return nil |
| 337 | + case []byte: |
| 338 | + dv.SetString(string(v)) |
| 339 | + return nil |
| 340 | + } |
| 341 | + } |
| 342 | + |
| 343 | + return fmt.Errorf("don't know how to parse type %T -> %T", src, iface) |
| 344 | +} |
| 345 | + |
| 346 | +// returns internal conversion error if available |
| 347 | +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go |
| 348 | +func strconvErr(err error) error { |
| 349 | + if ne, ok := err.(*strconv.NumError); ok { |
| 350 | + return ne.Err |
| 351 | + } |
| 352 | + return err |
| 353 | +} |
| 354 | + |
| 355 | +// converts value to it's string value |
| 356 | +// taken from https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/database/sql/convert.go |
| 357 | +func asString(src interface{}) string { |
| 358 | + switch v := src.(type) { |
| 359 | + case string: |
| 360 | + return v |
| 361 | + case []byte: |
| 362 | + return string(v) |
| 363 | + } |
| 364 | + rv := reflect.ValueOf(src) |
| 365 | + switch rv.Kind() { |
| 366 | + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: |
| 367 | + return strconv.FormatInt(rv.Int(), 10) |
| 368 | + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: |
| 369 | + return strconv.FormatUint(rv.Uint(), 10) |
| 370 | + case reflect.Float64: |
| 371 | + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) |
| 372 | + case reflect.Float32: |
| 373 | + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) |
| 374 | + case reflect.Bool: |
| 375 | + return strconv.FormatBool(rv.Bool()) |
| 376 | + } |
| 377 | + return fmt.Sprintf("%v", src) |
| 378 | +} |
| 379 | + |
| 380 | +// bytesClone returns a copy of b[:len(b)]. |
| 381 | +// The result may have additional unused capacity. |
| 382 | +// Clone(nil) returns nil. |
| 383 | +// |
| 384 | +// bytesClone is a mirror of bytes.Clone while our go.mod is on an older version |
| 385 | +func bytesClone(b []byte) []byte { |
| 386 | + if b == nil { |
| 387 | + return nil |
| 388 | + } |
| 389 | + return append([]byte{}, b...) |
| 390 | +} |
| 391 | + |
204 | 392 | // FieldByIndexes returns a value for the field given by the struct traversal
|
205 | 393 | // for the given value.
|
206 | 394 | func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
|
|
0 commit comments