Skip to content

Commit d39110e

Browse files
committed
Use type constraints for assertions
1 parent 72a1942 commit d39110e

File tree

5 files changed

+30
-39
lines changed

5 files changed

+30
-39
lines changed

Diff for: assert/assert.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ func NilError(t TestingT, err error, msgAndArgs ...any) {
199199
// called from the goroutine running the test function, not from other
200200
// goroutines created during the test. Use Check with cmp.Equal from other
201201
// goroutines.
202-
func Equal(t TestingT, x, y interface{}, msgAndArgs ...any) {
202+
func Equal[ANY any](t TestingT, x, y ANY, msgAndArgs ...any) {
203203
if ht, ok := t.(helperT); ok {
204204
ht.Helper()
205205
}
@@ -218,7 +218,7 @@ func Equal(t TestingT, x, y interface{}, msgAndArgs ...any) {
218218
// called from the goroutine running the test function, not from other
219219
// goroutines created during the test. Use Check with cmp.DeepEqual from other
220220
// goroutines.
221-
func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) {
221+
func DeepEqual[ANY any](t TestingT, x, y ANY, opts ...gocmp.Option) {
222222
if ht, ok := t.(helperT); ok {
223223
ht.Helper()
224224
}

Diff for: assert/assert_test.go

-7
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,6 @@ func TestEqualFailure(t *testing.T) {
283283
expectFailNowed(t, fakeT, "assertion failed: 1 (actual int) != 3 (expected int)")
284284
}
285285

286-
func TestEqualFailureTypes(t *testing.T) {
287-
fakeT := &fakeTestingT{}
288-
289-
Equal(fakeT, 3, uint(3))
290-
expectFailNowed(t, fakeT, `assertion failed: 3 (int) != 3 (uint)`)
291-
}
292-
293286
func TestEqualFailureWithSelectorArgument(t *testing.T) {
294287
fakeT := &fakeTestingT{}
295288

Diff for: assert/cmp/compare.go

+16-14
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ type Comparison func() Result
2424
// The comparison can be customized using comparison Options.
2525
// Package http://pkg.go.dev/gotest.tools/v3/assert/opt provides some additional
2626
// commonly used Options.
27-
func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
27+
func DeepEqual[ANY any](x, y ANY, opts ...cmp.Option) Comparison {
2828
return func() (result Result) {
2929
defer func() {
3030
if panicmsg, handled := handleCmpPanic(recover()); handled {
@@ -63,7 +63,9 @@ func toResult(success bool, msg string) Result {
6363

6464
// RegexOrPattern may be either a *regexp.Regexp or a string that is a valid
6565
// regexp pattern.
66-
type RegexOrPattern interface{}
66+
type RegexOrPattern interface {
67+
~string | *regexp.Regexp
68+
}
6769

6870
// Regexp succeeds if value v matches regular expression re.
6971
//
@@ -72,15 +74,15 @@ type RegexOrPattern interface{}
7274
// assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
7375
// r := regexp.MustCompile("^[0-9a-f]{32}$")
7476
// assert.Assert(t, cmp.Regexp(r, str))
75-
func Regexp(re RegexOrPattern, v string) Comparison {
77+
func Regexp[R RegexOrPattern](re R, v string) Comparison {
7678
match := func(re *regexp.Regexp) Result {
7779
return toResult(
7880
re.MatchString(v),
7981
fmt.Sprintf("value %q does not match regexp %q", v, re.String()))
8082
}
8183

8284
return func() Result {
83-
switch regex := re.(type) {
85+
switch regex := any(re).(type) {
8486
case *regexp.Regexp:
8587
return match(regex)
8688
case string:
@@ -96,13 +98,13 @@ func Regexp(re RegexOrPattern, v string) Comparison {
9698
}
9799

98100
// Equal succeeds if x == y. See assert.Equal for full documentation.
99-
func Equal(x, y interface{}) Comparison {
101+
func Equal[ANY any](x, y ANY) Comparison {
100102
return func() Result {
101103
switch {
102-
case x == y:
104+
case any(x) == any(y):
103105
return ResultSuccess
104106
case isMultiLineStringCompare(x, y):
105-
diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
107+
diff := format.UnifiedDiff(format.DiffConfig{A: any(x).(string), B: any(y).(string)})
106108
return multiLineDiffResult(diff, x, y)
107109
}
108110
return ResultFailureTemplate(`
@@ -117,7 +119,7 @@ func Equal(x, y interface{}) Comparison {
117119
}
118120
}
119121

120-
func isMultiLineStringCompare(x, y interface{}) bool {
122+
func isMultiLineStringCompare(x, y any) bool {
121123
strX, ok := x.(string)
122124
if !ok {
123125
return false
@@ -129,7 +131,7 @@ func isMultiLineStringCompare(x, y interface{}) bool {
129131
return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
130132
}
131133

132-
func multiLineDiffResult(diff string, x, y interface{}) Result {
134+
func multiLineDiffResult(diff string, x, y any) Result {
133135
return ResultFailureTemplate(`
134136
--- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
135137
+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
@@ -138,7 +140,7 @@ func multiLineDiffResult(diff string, x, y interface{}) Result {
138140
}
139141

140142
// Len succeeds if the sequence has the expected length.
141-
func Len(seq interface{}, expected int) Comparison {
143+
func Len(seq any, expected int) Comparison {
142144
return func() (result Result) {
143145
defer func() {
144146
if e := recover(); e != nil {
@@ -163,7 +165,7 @@ func Len(seq interface{}, expected int) Comparison {
163165
// If collection is a Map, contains will succeed if item is a key in the map.
164166
// If collection is a slice or array, item is compared to each item in the
165167
// sequence using reflect.DeepEqual().
166-
func Contains(collection interface{}, item interface{}) Comparison {
168+
func Contains(collection any, item any) Comparison {
167169
return func() Result {
168170
colValue := reflect.ValueOf(collection)
169171
if !colValue.IsValid() {
@@ -261,14 +263,14 @@ func formatErrorMessage(err error) string {
261263
//
262264
// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices,
263265
// maps, and channels.
264-
func Nil(obj interface{}) Comparison {
266+
func Nil(obj any) Comparison {
265267
msgFunc := func(value reflect.Value) string {
266268
return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
267269
}
268270
return isNil(obj, msgFunc)
269271
}
270272

271-
func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
273+
func isNil(obj any, msgFunc func(reflect.Value) string) Comparison {
272274
return func() Result {
273275
if obj == nil {
274276
return ResultSuccess
@@ -309,7 +311,7 @@ func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
309311
// Fails if err does not implement the reflect.Type.
310312
//
311313
// Deprecated: Use ErrorIs
312-
func ErrorType(err error, expected interface{}) Comparison {
314+
func ErrorType(err error, expected any) Comparison {
313315
return func() Result {
314316
switch expectedType := expected.(type) {
315317
case func(error) bool:

Diff for: assert/cmp/compare_test.go

+11-15
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@ func TestDeepEqualWithUnexported(t *testing.T) {
4545
}
4646

4747
func TestRegexp(t *testing.T) {
48-
var testcases = []struct {
48+
type testCase struct {
4949
name string
50-
regex interface{}
50+
regex string
5151
value string
5252
match bool
5353
expErr string
54-
}{
54+
}
55+
56+
var testcases = []testCase{
5557
{
5658
name: "pattern string match",
5759
regex: "^[0-9]+$",
@@ -70,24 +72,12 @@ func TestRegexp(t *testing.T) {
7072
value: "2123423456",
7173
expErr: `value "2123423456" does not match regexp "^1"`,
7274
},
73-
{
74-
name: "regexp match",
75-
regex: regexp.MustCompile("^d[0-9a-f]{8}$"),
76-
value: "d1632beef",
77-
match: true,
78-
},
7975
{
8076
name: "invalid regexp",
8177
regex: "^1(",
8278
value: "2",
8379
expErr: "error parsing regexp: missing closing ): `^1(`",
8480
},
85-
{
86-
name: "invalid type",
87-
regex: struct{}{},
88-
value: "some string",
89-
expErr: "invalid type struct {} for regex pattern",
90-
},
9181
}
9282

9383
for _, tc := range testcases {
@@ -100,6 +90,12 @@ func TestRegexp(t *testing.T) {
10090
}
10191
})
10292
}
93+
94+
t.Run("regexp match", func(t *testing.T) {
95+
regex := regexp.MustCompile("^d[0-9a-f]{8}$")
96+
res := Regexp(regex, "d1632beef")()
97+
assertSuccess(t, res)
98+
})
10399
}
104100

105101
func TestLen(t *testing.T) {

Diff for: fs/example_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func ExampleNewFile() {
2929

3030
content, err := os.ReadFile(file.Path())
3131
assert.NilError(t, err)
32-
assert.Equal(t, "content\n", content)
32+
assert.Equal(t, "content\n", string(content))
3333
}
3434

3535
// Create a directory and subdirectory with files

0 commit comments

Comments
 (0)