@@ -24,7 +24,7 @@ type Comparison func() Result
24
24
// The comparison can be customized using comparison Options.
25
25
// Package http://pkg.go.dev/gotest.tools/v3/assert/opt provides some additional
26
26
// commonly used Options.
27
- func DeepEqual (x , y interface {} , opts ... cmp.Option ) Comparison {
27
+ func DeepEqual [ ANY any ] (x , y ANY , opts ... cmp.Option ) Comparison {
28
28
return func () (result Result ) {
29
29
defer func () {
30
30
if panicmsg , handled := handleCmpPanic (recover ()); handled {
@@ -63,7 +63,9 @@ func toResult(success bool, msg string) Result {
63
63
64
64
// RegexOrPattern may be either a *regexp.Regexp or a string that is a valid
65
65
// regexp pattern.
66
- type RegexOrPattern interface {}
66
+ type RegexOrPattern interface {
67
+ ~ string | * regexp.Regexp
68
+ }
67
69
68
70
// Regexp succeeds if value v matches regular expression re.
69
71
//
@@ -72,15 +74,15 @@ type RegexOrPattern interface{}
72
74
// assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
73
75
// r := regexp.MustCompile("^[0-9a-f]{32}$")
74
76
// assert.Assert(t, cmp.Regexp(r, str))
75
- func Regexp (re RegexOrPattern , v string ) Comparison {
77
+ func Regexp [ R RegexOrPattern ] (re R , v string ) Comparison {
76
78
match := func (re * regexp.Regexp ) Result {
77
79
return toResult (
78
80
re .MatchString (v ),
79
81
fmt .Sprintf ("value %q does not match regexp %q" , v , re .String ()))
80
82
}
81
83
82
84
return func () Result {
83
- switch regex := re .(type ) {
85
+ switch regex := any ( re ) .(type ) {
84
86
case * regexp.Regexp :
85
87
return match (regex )
86
88
case string :
@@ -96,13 +98,13 @@ func Regexp(re RegexOrPattern, v string) Comparison {
96
98
}
97
99
98
100
// Equal succeeds if x == y. See assert.Equal for full documentation.
99
- func Equal (x , y interface {} ) Comparison {
101
+ func Equal [ ANY any ] (x , y ANY ) Comparison {
100
102
return func () Result {
101
103
switch {
102
- case x == y :
104
+ case any ( x ) == any ( y ) :
103
105
return ResultSuccess
104
106
case isMultiLineStringCompare (x , y ):
105
- diff := format .UnifiedDiff (format.DiffConfig {A : x .(string ), B : y .(string )})
107
+ diff := format .UnifiedDiff (format.DiffConfig {A : any ( x ) .(string ), B : any ( y ) .(string )})
106
108
return multiLineDiffResult (diff , x , y )
107
109
}
108
110
return ResultFailureTemplate (`
@@ -117,7 +119,7 @@ func Equal(x, y interface{}) Comparison {
117
119
}
118
120
}
119
121
120
- func isMultiLineStringCompare (x , y interface {} ) bool {
122
+ func isMultiLineStringCompare (x , y any ) bool {
121
123
strX , ok := x .(string )
122
124
if ! ok {
123
125
return false
@@ -129,7 +131,7 @@ func isMultiLineStringCompare(x, y interface{}) bool {
129
131
return strings .Contains (strX , "\n " ) || strings .Contains (strY , "\n " )
130
132
}
131
133
132
- func multiLineDiffResult (diff string , x , y interface {} ) Result {
134
+ func multiLineDiffResult (diff string , x , y any ) Result {
133
135
return ResultFailureTemplate (`
134
136
--- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
135
137
+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
@@ -138,7 +140,7 @@ func multiLineDiffResult(diff string, x, y interface{}) Result {
138
140
}
139
141
140
142
// Len succeeds if the sequence has the expected length.
141
- func Len (seq interface {} , expected int ) Comparison {
143
+ func Len (seq any , expected int ) Comparison {
142
144
return func () (result Result ) {
143
145
defer func () {
144
146
if e := recover (); e != nil {
@@ -163,7 +165,7 @@ func Len(seq interface{}, expected int) Comparison {
163
165
// If collection is a Map, contains will succeed if item is a key in the map.
164
166
// If collection is a slice or array, item is compared to each item in the
165
167
// sequence using reflect.DeepEqual().
166
- func Contains (collection interface {} , item interface {} ) Comparison {
168
+ func Contains (collection any , item any ) Comparison {
167
169
return func () Result {
168
170
colValue := reflect .ValueOf (collection )
169
171
if ! colValue .IsValid () {
@@ -261,14 +263,14 @@ func formatErrorMessage(err error) string {
261
263
//
262
264
// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices,
263
265
// maps, and channels.
264
- func Nil (obj interface {} ) Comparison {
266
+ func Nil (obj any ) Comparison {
265
267
msgFunc := func (value reflect.Value ) string {
266
268
return fmt .Sprintf ("%v (type %s) is not nil" , reflect .Indirect (value ), value .Type ())
267
269
}
268
270
return isNil (obj , msgFunc )
269
271
}
270
272
271
- func isNil (obj interface {} , msgFunc func (reflect.Value ) string ) Comparison {
273
+ func isNil (obj any , msgFunc func (reflect.Value ) string ) Comparison {
272
274
return func () Result {
273
275
if obj == nil {
274
276
return ResultSuccess
@@ -309,7 +311,7 @@ func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
309
311
// Fails if err does not implement the reflect.Type.
310
312
//
311
313
// Deprecated: Use ErrorIs
312
- func ErrorType (err error , expected interface {} ) Comparison {
314
+ func ErrorType (err error , expected any ) Comparison {
313
315
return func () Result {
314
316
switch expectedType := expected .(type ) {
315
317
case func (error ) bool :
0 commit comments