Skip to content

Commit 91d02c3

Browse files
committed
Add explicit AutoBinding capability, with tests
1 parent d07cde0 commit 91d02c3

File tree

13 files changed

+374
-46
lines changed

13 files changed

+374
-46
lines changed

internal/wire/analyze.go

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,18 @@ func verifyArgsUsed(set *ProviderSet, used []*providerSetSrc) []error {
287287
errs = append(errs, fmt.Errorf("unused provider %q", p.Pkg.Name()+"."+p.Name))
288288
}
289289
}
290+
for _, ab := range set.AutoBindings {
291+
found := false
292+
for _, u := range used {
293+
if u.AutoBinding == ab {
294+
found = true
295+
break
296+
}
297+
}
298+
if !found {
299+
errs = append(errs, fmt.Errorf("unused auto binding %q", types.TypeString(ab.Concrete, nil)))
300+
}
301+
}
290302
for _, v := range set.Values {
291303
found := false
292304
for _, u := range used {
@@ -458,49 +470,50 @@ func verifyAcyclic(providerMap *typeutil.Map, hasher typeutil.Hasher) []error {
458470
continue
459471
}
460472
pt := x.(*ProvidedType)
473+
var args []types.Type
461474
switch {
462475
case pt.IsValue():
463476
// Leaf: values do not have dependencies.
464477
case pt.IsArg():
465478
// Injector arguments do not have dependencies.
466-
case pt.IsProvider() || pt.IsField():
467-
var args []types.Type
468-
if pt.IsProvider() {
469-
for _, arg := range pt.Provider().Args {
470-
args = append(args, arg.Type)
471-
}
472-
} else {
473-
args = append(args, pt.Field().Parent)
479+
case pt.IsProvider():
480+
for _, arg := range pt.Provider().Args {
481+
args = append(args, arg.Type)
474482
}
475-
for _, a := range args {
476-
hasCycle := false
477-
for i, b := range curr {
478-
if types.Identical(a, b) {
479-
sb := new(strings.Builder)
480-
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
481-
for j := i; j < len(curr); j++ {
482-
t := providerMap.At(curr[j]).(*ProvidedType)
483-
if t.IsProvider() {
484-
p := t.Provider()
485-
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
486-
} else {
487-
p := t.Field()
488-
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name)
489-
}
483+
case pt.IsAutoBinding():
484+
args = append(args, pt.AutoBinding().Concrete)
485+
case pt.IsField():
486+
args = append(args, pt.Field().Parent)
487+
default:
488+
panic("invalid provider map value")
489+
}
490+
491+
for _, a := range args {
492+
hasCycle := false
493+
for i, b := range curr {
494+
if types.Identical(a, b) {
495+
sb := new(strings.Builder)
496+
fmt.Fprintf(sb, "cycle for %s:\n", types.TypeString(a, nil))
497+
for j := i; j < len(curr); j++ {
498+
t := providerMap.At(curr[j]).(*ProvidedType)
499+
if t.IsProvider() {
500+
p := t.Provider()
501+
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Pkg.Path(), p.Name)
502+
} else {
503+
p := t.Field()
504+
fmt.Fprintf(sb, "%s (%s.%s) ->\n", types.TypeString(curr[j], nil), p.Parent, p.Name)
490505
}
491-
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
492-
ec.add(errors.New(sb.String()))
493-
hasCycle = true
494-
break
495506
}
496-
}
497-
if !hasCycle {
498-
next := append(append([]types.Type(nil), curr...), a)
499-
stk = append(stk, next)
507+
fmt.Fprintf(sb, "%s", types.TypeString(a, nil))
508+
ec.add(errors.New(sb.String()))
509+
hasCycle = true
510+
break
500511
}
501512
}
502-
default:
503-
panic("invalid provider map value")
513+
if !hasCycle {
514+
next := append(append([]types.Type(nil), curr...), a)
515+
stk = append(stk, next)
516+
}
504517
}
505518
}
506519
}

internal/wire/parse.go

Lines changed: 101 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
// Exactly one of the fields will be set.
3636
type providerSetSrc struct {
3737
Provider *Provider
38+
AutoBinding *AutoBinding
3839
Binding *IfaceBinding
3940
Value *Value
4041
Import *ProviderSet
@@ -57,6 +58,8 @@ func (p *providerSetSrc) description(fset *token.FileSet, typ types.Type) string
5758
kind = "struct provider"
5859
}
5960
return fmt.Sprintf("%s %s(%s)", kind, quoted(p.Provider.Name), fset.Position(p.Provider.Pos))
61+
case p.AutoBinding != nil:
62+
return fmt.Sprintf("wire.AutoBind (%s)", fset.Position(p.AutoBinding.Pos))
6063
case p.Binding != nil:
6164
return fmt.Sprintf("wire.Bind (%s)", fset.Position(p.Binding.Pos))
6265
case p.Value != nil:
@@ -98,11 +101,12 @@ type ProviderSet struct {
98101
// variable.
99102
VarName string
100103

101-
Providers []*Provider
102-
Bindings []*IfaceBinding
103-
Values []*Value
104-
Fields []*Field
105-
Imports []*ProviderSet
104+
Providers []*Provider
105+
Bindings []*IfaceBinding
106+
AutoBindings []*AutoBinding
107+
Values []*Value
108+
Fields []*Field
109+
Imports []*ProviderSet
106110
// InjectorArgs is only filled in for wire.Build.
107111
InjectorArgs *InjectorArgs
108112

@@ -125,6 +129,22 @@ func (set *ProviderSet) Outputs() []types.Type {
125129
func (set *ProviderSet) For(t types.Type) ProvidedType {
126130
pt := set.providerMap.At(t)
127131
if pt == nil {
132+
// if t is an interface, we may have an AutoBinding that implements it.
133+
iface, ok := t.Underlying().(*types.Interface)
134+
if !ok {
135+
return ProvidedType{}
136+
}
137+
138+
for _, ab := range set.AutoBindings {
139+
if types.Implements(ab.Concrete, iface) {
140+
// cache for later
141+
pt := &ProvidedType{t: ab.Concrete, ab: ab}
142+
set.providerMap.Set(t, pt)
143+
set.srcMap.Set(t, &providerSetSrc{AutoBinding: ab})
144+
return *pt
145+
}
146+
}
147+
128148
return ProvidedType{}
129149
}
130150
return *pt.(*ProvidedType)
@@ -179,6 +199,17 @@ type Provider struct {
179199
HasErr bool
180200
}
181201

202+
// AutoBinding records the signature of a provider eligible for auto-binding
203+
// to interfaces it implements. A provider is a single Go object, either a
204+
// function or a named type.
205+
type AutoBinding struct {
206+
// Concrete is always a type that implements N number of interfaces.
207+
Concrete types.Type
208+
209+
// Pos is the position where the binding was declared.
210+
Pos token.Pos
211+
}
212+
182213
// ProviderInput describes an incoming edge in the provider graph.
183214
type ProviderInput struct {
184215
Type types.Type
@@ -520,7 +551,7 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec {
520551
}
521552

522553
// processExpr converts an expression into a Wire structure. It may return a
523-
// *Provider, an *IfaceBinding, a *ProviderSet, a *Value or a []*Field.
554+
// *Provider, an *AutoBinding, an *IfaceBinding, a *ProviderSet, a *Value or a []*Field.
524555
func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Expr, varName string) (interface{}, []error) {
525556
exprPos := oc.fset.Position(expr.Pos())
526557
expr = astutil.Unparen(expr)
@@ -546,6 +577,12 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
546577
case "NewSet":
547578
pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName)
548579
return pset, notePositionAll(exprPos, errs)
580+
case "AutoBind":
581+
abs, err := processAutoBind(oc.fset, info, call)
582+
if err != nil {
583+
return nil, []error{notePosition(exprPos, err)}
584+
}
585+
return abs, nil
549586
case "Bind":
550587
b, err := processBind(oc.fset, info, call)
551588
if err != nil {
@@ -607,6 +644,8 @@ func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast
607644
continue
608645
}
609646
switch item := item.(type) {
647+
case *AutoBinding:
648+
pset.AutoBindings = append(pset.AutoBindings, item)
610649
case *Provider:
611650
pset.Providers = append(pset.Providers, item)
612651
case *ProviderSet:
@@ -880,6 +919,41 @@ func isPrevented(tag string) bool {
880919
return reflect.StructTag(tag).Get("wire") == "-"
881920
}
882921

922+
func processAutoBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*AutoBinding, error) {
923+
// Assumes that call.Fun is wire.AutoBind.
924+
925+
if len(call.Args) != 1 {
926+
return nil, notePosition(fset.Position(call.Pos()),
927+
errors.New("call to AutoBind takes exactly one argument"))
928+
}
929+
const firstArgReqFormat = "first argument to AutoBind must be a pointer to a type; found %s"
930+
typ := info.TypeOf(call.Args[0])
931+
ptr, ok := typ.(*types.Pointer)
932+
if !ok {
933+
return nil, notePosition(fset.Position(call.Pos()),
934+
fmt.Errorf(firstArgReqFormat, types.TypeString(typ, nil)))
935+
}
936+
937+
switch ptr.Elem().Underlying().(type) {
938+
case *types.Named,
939+
*types.Struct,
940+
*types.Basic:
941+
// good!
942+
943+
default:
944+
return nil, notePosition(fset.Position(call.Pos()),
945+
fmt.Errorf(firstArgReqFormat, types.TypeString(ptr, nil)))
946+
}
947+
948+
typeExpr := call.Args[0].(*ast.CallExpr)
949+
typeName := qualifiedIdentObject(info, typeExpr.Args[0]) // should be either an identifier or selector
950+
autoBinding := &AutoBinding{
951+
Concrete: ptr,
952+
Pos: typeName.Pos(),
953+
}
954+
return autoBinding, nil
955+
}
956+
883957
// processBind creates an interface binding from a wire.Bind call.
884958
func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) {
885959
// Assumes that call.Fun is wire.Bind.
@@ -1122,7 +1196,6 @@ func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error
11221196
default:
11231197
invalid = true
11241198
}
1125-
11261199
}
11271200
if wireBuildCall == nil {
11281201
return nil, nil
@@ -1157,16 +1230,17 @@ func isProviderSetType(t types.Type) bool {
11571230
// none of the above, and returns true for IsNil.
11581231
type ProvidedType struct {
11591232
// t is the provided concrete type.
1160-
t types.Type
1161-
p *Provider
1162-
v *Value
1163-
a *InjectorArg
1164-
f *Field
1233+
t types.Type
1234+
p *Provider
1235+
ab *AutoBinding
1236+
v *Value
1237+
a *InjectorArg
1238+
f *Field
11651239
}
11661240

11671241
// IsNil reports whether pt is the zero value.
11681242
func (pt ProvidedType) IsNil() bool {
1169-
return pt.p == nil && pt.v == nil && pt.a == nil && pt.f == nil
1243+
return pt.p == nil && pt.ab == nil && pt.v == nil && pt.a == nil && pt.f == nil
11701244
}
11711245

11721246
// Type returns the output type.
@@ -1185,6 +1259,11 @@ func (pt ProvidedType) IsProvider() bool {
11851259
return pt.p != nil
11861260
}
11871261

1262+
// IsAutoBinding reports whether pt points to an AutoBinding.
1263+
func (pt ProvidedType) IsAutoBinding() bool {
1264+
return pt.ab != nil
1265+
}
1266+
11881267
// IsValue reports whether pt points to a Value.
11891268
func (pt ProvidedType) IsValue() bool {
11901269
return pt.v != nil
@@ -1209,6 +1288,15 @@ func (pt ProvidedType) Provider() *Provider {
12091288
return pt.p
12101289
}
12111290

1291+
// AutoBinding returns pt as a AutoBinding pointer. It panics if pt does not point
1292+
// to a AutoBinding.
1293+
func (pt ProvidedType) AutoBinding() *AutoBinding {
1294+
if pt.ab == nil {
1295+
panic("ProvidedType does not hold an AutoBinding")
1296+
}
1297+
return pt.ab
1298+
}
1299+
12121300
// Value returns pt as a Value pointer. It panics if pt does not point
12131301
// to a Value.
12141302
func (pt ProvidedType) Value() *Value {
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/google/wire"
21+
)
22+
23+
func main() {
24+
fmt.Println(injectFooer().Foo())
25+
}
26+
27+
type Fooer interface {
28+
Foo() string
29+
}
30+
31+
type Bar string
32+
33+
func (b *Bar) Foo() string {
34+
return string(*b)
35+
}
36+
37+
func provideBar() *Bar {
38+
b := new(Bar)
39+
*b = "Hello, World!"
40+
return b
41+
}
42+
43+
var Set = wire.NewSet(provideBar)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//+build wireinject
16+
17+
package main
18+
19+
import (
20+
"github.com/google/wire"
21+
)
22+
23+
func injectFooer() Fooer {
24+
wire.Build(Set, wire.AutoBind(new(Bar)))
25+
return nil
26+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
example.com/foo
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Hello, World!

0 commit comments

Comments
 (0)