Skip to content

Commit 1f5d8bf

Browse files
authored
Avoid escaping times (#256)
1 parent 41dc46a commit 1f5d8bf

File tree

8 files changed

+244
-113
lines changed

8 files changed

+244
-113
lines changed

driver/driver.go

Lines changed: 108 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,45 @@
2020
// - a [serializable] transaction is always "immediate";
2121
// - a [read-only] transaction is always "deferred".
2222
//
23+
// # Datatypes In SQLite
24+
//
25+
// SQLite is dynamically typed.
26+
// Columns can mostly hold any value regardless of their declared type.
27+
// SQLite supports most [driver.Value] types out of the box,
28+
// but bool and [time.Time] require special care.
29+
//
30+
// Booleans can be stored on any column type and scanned back to a *bool.
31+
// However, if scanned to a *any, booleans may either become an
32+
// int64, string or bool, depending on the declared type of the column.
33+
// If you use BOOLEAN for your column type,
34+
// 1 and 0 will always scan as true and false.
35+
//
2336
// # Working with time
2437
//
38+
// Time values can similarly be stored on any column type.
2539
// The time encoding/decoding format can be specified using "_timefmt":
2640
//
2741
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
2842
//
29-
// Possible values are: "auto" (the default), "sqlite", "rfc3339";
43+
// Special values are: "auto" (the default), "sqlite", "rfc3339";
3044
// - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
3145
// - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
3246
// - "rfc3339" encodes and decodes RFC 3339 only.
3347
//
48+
// You can also set "_timefmt" to an arbitrary [sqlite3.TimeFormat] or [time.Layout].
49+
//
50+
// If you encode as RFC 3339 (the default),
51+
// consider using the TIME [collating sequence] to produce time-ordered sequences.
52+
//
3453
// If you encode as RFC 3339 (the default),
35-
// consider using the TIME [collating sequence] to produce a time-ordered sequence.
54+
// time values will scan back to a *time.Time unless your column type is TEXT.
55+
// Otherwise, if scanned to a *any, time values may either become an
56+
// int64, float64 or string, depending on the time format and declared type of the column.
57+
// If you use DATE, TIME, DATETIME, or TIMESTAMP for your column type,
58+
// "_timefmt" will be used to decode values.
3659
//
37-
// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful.
38-
// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding.
60+
// To scan values in custom formats, [sqlite3.TimeFormat.Scanner] may be helpful.
61+
// To bind values in custom formats, [sqlite3.TimeFormat.Encode] them before binding.
3962
//
4063
// When using a custom time struct, you'll have to implement
4164
// [database/sql/driver.Valuer] and [database/sql.Scanner].
@@ -48,7 +71,7 @@
4871
// The Scan method needs to take into account that the value it receives can be of differing types.
4972
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
5073
// Or it can be a: string, int64, float64, []byte, or nil,
51-
// depending on the column type and what whoever wrote the value.
74+
// depending on the column type and whoever wrote the value.
5275
// [sqlite3.TimeFormat.Decode] may help.
5376
//
5477
// # Setting PRAGMAs
@@ -595,6 +618,28 @@ const (
595618
_TIME
596619
)
597620

621+
func scanFromDecl(decl string) scantype {
622+
// These types are only used before we have rows,
623+
// and otherwise as type hints.
624+
// The first few ensure STRICT tables are strictly typed.
625+
// The other two are type hints for booleans and time.
626+
switch decl {
627+
case "INT", "INTEGER":
628+
return _INT
629+
case "REAL":
630+
return _REAL
631+
case "TEXT":
632+
return _TEXT
633+
case "BLOB":
634+
return _BLOB
635+
case "BOOLEAN":
636+
return _BOOL
637+
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
638+
return _TIME
639+
}
640+
return _ANY
641+
}
642+
598643
var (
599644
// Ensure these interfaces are implemented:
600645
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
@@ -619,6 +664,18 @@ func (r *rows) Columns() []string {
619664
return r.names
620665
}
621666

667+
func (r *rows) scanType(index int) scantype {
668+
if r.scans == nil {
669+
count := r.Stmt.ColumnCount()
670+
scans := make([]scantype, count)
671+
for i := range scans {
672+
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
673+
}
674+
r.scans = scans
675+
}
676+
return r.scans[index]
677+
}
678+
622679
func (r *rows) loadColumnMetadata() {
623680
if r.nulls == nil {
624681
count := r.Stmt.ColumnCount()
@@ -632,24 +689,7 @@ func (r *rows) loadColumnMetadata() {
632689
r.Stmt.ColumnTableName(i),
633690
col)
634691
types[i] = strings.ToUpper(types[i])
635-
// These types are only used before we have rows,
636-
// and otherwise as type hints.
637-
// The first few ensure STRICT tables are strictly typed.
638-
// The other two are type hints for booleans and time.
639-
switch types[i] {
640-
case "INT", "INTEGER":
641-
scans[i] = _INT
642-
case "REAL":
643-
scans[i] = _REAL
644-
case "TEXT":
645-
scans[i] = _TEXT
646-
case "BLOB":
647-
scans[i] = _BLOB
648-
case "BOOLEAN":
649-
scans[i] = _BOOL
650-
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
651-
scans[i] = _TIME
652-
}
692+
scans[i] = scanFromDecl(types[i])
653693
}
654694
}
655695
r.nulls = nulls
@@ -658,27 +698,15 @@ func (r *rows) loadColumnMetadata() {
658698
}
659699
}
660700

661-
func (r *rows) declType(index int) string {
662-
if r.types == nil {
663-
count := r.Stmt.ColumnCount()
664-
types := make([]string, count)
665-
for i := range types {
666-
types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
667-
}
668-
r.types = types
669-
}
670-
return r.types[index]
671-
}
672-
673701
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
674702
r.loadColumnMetadata()
675-
decltype := r.types[index]
676-
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
677-
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
678-
decltype = decltype[:i]
703+
decl := r.types[index]
704+
if len := len(decl); len > 0 && decl[len-1] == ')' {
705+
if i := strings.LastIndexByte(decl, '('); i >= 0 {
706+
decl = decl[:i]
679707
}
680708
}
681-
return strings.TrimSpace(decltype)
709+
return strings.TrimSpace(decl)
682710
}
683711

684712
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
@@ -745,36 +773,49 @@ func (r *rows) Next(dest []driver.Value) error {
745773
}
746774

747775
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
748-
err := r.Stmt.Columns(data...)
749-
for i := range dest {
750-
if t, ok := r.decodeTime(i, dest[i]); ok {
751-
dest[i] = t
752-
}
776+
if err := r.Stmt.ColumnsRaw(data...); err != nil {
777+
return err
753778
}
754-
return err
755-
}
756-
757-
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) {
758-
switch v := v.(type) {
759-
case int64, float64:
760-
// could be a time value
761-
case string:
762-
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano {
779+
for i := range dest {
780+
scan := r.scanType(i)
781+
switch v := dest[i].(type) {
782+
case int64:
783+
if scan == _BOOL {
784+
switch v {
785+
case 1:
786+
dest[i] = true
787+
case 0:
788+
dest[i] = false
789+
}
790+
continue
791+
}
792+
case []byte:
793+
if len(v) == cap(v) { // a BLOB
794+
continue
795+
}
796+
if scan != _TEXT {
797+
switch r.tmWrite {
798+
case "", time.RFC3339, time.RFC3339Nano:
799+
t, ok := maybeTime(v)
800+
if ok {
801+
dest[i] = t
802+
continue
803+
}
804+
}
805+
}
806+
dest[i] = string(v)
807+
case float64:
763808
break
809+
default:
810+
continue
764811
}
765-
t, ok := maybeTime(v)
766-
if ok {
767-
return t, true
812+
if scan == _TIME {
813+
t, err := r.tmRead.Decode(dest[i])
814+
if err == nil {
815+
dest[i] = t
816+
continue
817+
}
768818
}
769-
default:
770-
return
771819
}
772-
switch r.declType(i) {
773-
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
774-
// could be a time value
775-
default:
776-
return
777-
}
778-
t, err := r.tmRead.Decode(v)
779-
return t, err == nil
820+
return nil
780821
}

driver/example2_test.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
//go:build linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos || sqlite3_flock || sqlite3_dotlk
2-
31
package driver_test
42

5-
// Adapted from: https://go.dev/doc/tutorial/database-access
6-
73
import (
84
"database/sql"
95
"database/sql/driver"
@@ -27,7 +23,7 @@ func Example_customTime() {
2723
_, err = db.Exec(`
2824
CREATE TABLE data (
2925
id INTEGER PRIMARY KEY,
30-
date_time TEXT
26+
date_time ANY
3127
) STRICT;
3228
`)
3329
if err != nil {

driver/time.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package driver
22

3-
import "time"
3+
import (
4+
"bytes"
5+
"time"
6+
)
47

58
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
69
// if it roundtrips back to the same string.
710
// This way times can be persisted to, and recovered from, the database,
811
// but if a string is needed, [database/sql] will recover the same string.
9-
func maybeTime(text string) (_ time.Time, _ bool) {
12+
func maybeTime(text []byte) (_ time.Time, _ bool) {
1013
// Weed out (some) values that can't possibly be
1114
// [time.RFC3339Nano] timestamps.
1215
if len(text) < len("2006-01-02T15:04:05Z") {
@@ -21,8 +24,8 @@ func maybeTime(text string) (_ time.Time, _ bool) {
2124

2225
// Slow path.
2326
var buf [len(time.RFC3339Nano)]byte
24-
date, err := time.Parse(time.RFC3339Nano, text)
25-
if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) {
27+
date, err := time.Parse(time.RFC3339Nano, string(text))
28+
if err == nil && bytes.Equal(text, date.AppendFormat(buf[:0], time.RFC3339Nano)) {
2629
return date, true
2730
}
2831
return

driver/time_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func Fuzz_stringOrTime_1(f *testing.F) {
2222
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
2323

2424
f.Fuzz(func(t *testing.T, str string) {
25-
v, ok := maybeTime(str)
25+
v, ok := maybeTime([]byte(str))
2626
if ok {
2727
// Make sure times round-trip to the same string:
2828
// https://pkg.go.dev/database/sql#Rows.Scan
@@ -51,7 +51,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
5151
f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC
5252

5353
checkTime := func(t testing.TB, date time.Time) {
54-
v, ok := maybeTime(date.Format(time.RFC3339Nano))
54+
v, ok := maybeTime(date.AppendFormat(nil, time.RFC3339Nano))
5555
if ok {
5656
// Make sure times round-trip to the same time:
5757
if !v.Equal(date) {

0 commit comments

Comments
 (0)