diff --git a/di/di.go b/di/di.go index 3075dba..fd3958e 100644 --- a/di/di.go +++ b/di/di.go @@ -1,25 +1,45 @@ package di import ( - "github.com/pubgo/dix/dix_internal" + "reflect" + + "github.com/pubgo/dix/dixinternal" ) -var _dix = dix_internal.New(dix_internal.WithValuesNull()) +var _dix = dixinternal.New(dixinternal.WithValuesNull()) + +// Example: +// +// c := di.New() +// c.Provide(func() *Config { return &Config{Endpoint: "localhost:..."} }) // Configuration +// c.Provide(NewDB) // Database connection +// c.Provide(NewHTTPServer) // Server +// +// c.Invoke(func(server *http.Server) { // Application startup +// server.ListenAndServe() +// }) +// +// For more usage details, see the documentation for the Container type. -// Provide 注册对象构造器 +// Provide registers an object constructor func Provide(data any) { _dix.Provide(data) } -// Inject 注入对象 +// Inject injects objects // -// data: <*struct>或 -func Inject[T any](data T, opts ...dix_internal.Option) T { - _ = _dix.Inject(data, opts...) +// data: <*struct> or +func Inject[T any](data T, opts ...dixinternal.Option) T { + vp := reflect.ValueOf(data) + if vp.Kind() == reflect.Struct { + _ = _dix.Inject(&data, opts...) + } else { + _ = _dix.Inject(data, opts...) + } return data } // Graph Dix graph -func Graph() *dix_internal.Graph { +func Graph() *dixinternal.Graph { return _dix.Graph() } diff --git a/dix.go b/dix.go index cfbabf5..e1458e4 100644 --- a/dix.go +++ b/dix.go @@ -2,27 +2,27 @@ package dix import ( "reflect" - - "github.com/pubgo/dix/dix_internal" + + "github.com/pubgo/dix/dixinternal" ) const ( - InjectMethodPrefix = dix_internal.InjectMethodPrefix + InjectMethodPrefix = dixinternal.InjectMethodPrefix ) type ( - Option = dix_internal.Option - Options = dix_internal.Options - Dix = dix_internal.Dix - Graph = dix_internal.Graph + Option = dixinternal.Option + Options = dixinternal.Options + Dix = dixinternal.Dix + Graph = dixinternal.Graph ) func WithValuesNull() Option { - return dix_internal.WithValuesNull() + return dixinternal.WithValuesNull() } func New(opts ...Option) *Dix { - return dix_internal.New(opts...) + return dixinternal.New(opts...) } func Inject[T any](di *Dix, data T, opts ...Option) T { diff --git a/dix_internal/api.go b/dix_internal/api.go deleted file mode 100644 index 01141e6..0000000 --- a/dix_internal/api.go +++ /dev/null @@ -1,21 +0,0 @@ -package dix_internal - -// New Dix new -func New(opts ...Option) *Dix { - return newDix(opts...) -} - -func (x *Dix) Provide(param any) { - x.provide(param) -} - -func (x *Dix) Inject(param any, opts ...Option) any { - return x.inject(param, opts...) -} - -func (x *Dix) Graph() *Graph { - return &Graph{ - Objects: x.objectGraph(), - Providers: x.providerGraph(), - } -} diff --git a/dix_internal/graph.go b/dix_internal/graph.go deleted file mode 100644 index b263bb4..0000000 --- a/dix_internal/graph.go +++ /dev/null @@ -1,51 +0,0 @@ -package dix_internal - -import ( - "bytes" - "fmt" - "io" - - "github.com/pubgo/funk/stack" -) - -func fPrintln(writer io.Writer, msg string) { - _, _ = fmt.Fprintln(writer, msg) -} - -func (x *Dix) providerGraph() string { - b := &bytes.Buffer{} - fPrintln(b, "digraph G {") - - fPrintln(b, "\tsubgraph providers {") - fPrintln(b, "\t\tlabel=providers") - for providerOutputType, nodes := range x.providers { - for _, n := range nodes { - fn := stack.CallerWithFunc(n.fn).String() - fPrintln(b, fmt.Sprintf("\t\t"+`"%s" -> "%s"`, fn, providerOutputType)) - for _, in := range n.input { - fPrintln(b, fmt.Sprintf("\t\t"+`"%s" -> "%s"`, in.typ, fn)) - } - } - } - fPrintln(b, "\t}") - - fPrintln(b, "}") - return b.String() -} - -func (x *Dix) objectGraph() string { - b := &bytes.Buffer{} - fPrintln(b, "digraph G {") - fPrintln(b, "\tsubgraph objects {") - fPrintln(b, "\t\tlabel=objects") - for k, objects := range x.objects { - for g, values := range objects { - for _, v := range values { - fPrintln(b, fmt.Sprintf("\t\t"+`"%s" -> "%s -> %s"`, k, g, v.Type().String())) - } - } - } - fPrintln(b, "\t}") - fPrintln(b, "}") - return b.String() -} diff --git a/dix_internal/aaa.go b/dixinternal/aaa.go similarity index 57% rename from dix_internal/aaa.go rename to dixinternal/aaa.go index 031c272..7c23edf 100644 --- a/dix_internal/aaa.go +++ b/dixinternal/aaa.go @@ -1,4 +1,4 @@ -package dix_internal +package dixinternal import ( "reflect" @@ -7,10 +7,10 @@ import ( ) const ( - // defaultKey 默认的 namespace + // defaultKey default namespace defaultKey = "default" - // InjectMethodPrefix 可以对对象进行 Inject, 只要这个对象的方法中包含了以`InjectMethodPrefix`为前缀的方法 + // InjectMethodPrefix can inject objects, as long as the method of this object contains a prefix of `InjectMethodPrefix` InjectMethodPrefix = "DixInject" ) @@ -27,6 +27,6 @@ type Graph struct { var logger = log.GetLogger("dix") -func SetLogLevel(lvl log.Level) { - logger = logger.WithLevel(lvl) +func SetLog(setter func(logger log.Logger) log.Logger) { + logger = setter(logger) } diff --git a/dixinternal/api.go b/dixinternal/api.go new file mode 100644 index 0000000..2766e57 --- /dev/null +++ b/dixinternal/api.go @@ -0,0 +1,37 @@ +package dixinternal + +import ( + "reflect" + + "github.com/pubgo/funk/assert" + "github.com/pubgo/funk/errors" +) + +// New Dix new +func New(opts ...Option) *Dix { + return newDix(opts...) +} + +func (x *Dix) Provide(param any) { + x.provide(param) +} + +func (x *Dix) Inject(param any, opts ...Option) any { + if dep, ok := x.isCycle(); ok { + logger.Error(). + Str("cycle_path", dep). + Str("component", reflect.TypeOf(param).String()). + Msg("dependency cycle detected") + assert.Must(errors.New("circular dependency: " + dep)) + } + + assert.Must(x.inject(param, opts...)) + return param +} + +func (x *Dix) Graph() *Graph { + return &Graph{ + Objects: x.objectGraph(), + Providers: x.providerGraph(), + } +} diff --git a/dix_internal/cycle-check.go b/dixinternal/cycle-check.go similarity index 59% rename from dix_internal/cycle-check.go rename to dixinternal/cycle-check.go index 5ed081b..419ddcc 100644 --- a/dix_internal/cycle-check.go +++ b/dixinternal/cycle-check.go @@ -1,4 +1,4 @@ -package dix_internal +package dixinternal import ( "reflect" @@ -7,14 +7,17 @@ import ( func (x *Dix) buildDependencyGraph() map[reflect.Type]map[reflect.Type]bool { graph := make(map[reflect.Type]map[reflect.Type]bool) - for typ, nodes := range x.providers { - for _, n := range nodes { - if graph[typ] == nil { - graph[typ] = make(map[reflect.Type]bool) - } - for _, input := range n.input { - for _, provider := range x.getAllProvideInput(input.typ) { - graph[typ][provider.typ] = true + // Pre-allocate map capacity to reduce rehash + for outTyp := range x.providers { + graph[outTyp] = make(map[reflect.Type]bool) + } + + // Build dependency graph + for outTyp, nodes := range x.providers { + for _, providerNode := range nodes { + for _, input := range providerNode.inputList { + for _, provider := range x.getProvideAllInputs(input.typ) { + graph[outTyp][provider.typ] = true } } } diff --git a/dix_internal/dix.go b/dixinternal/dix.go similarity index 83% rename from dix_internal/dix.go rename to dixinternal/dix.go index 20428ac..58b9c42 100644 --- a/dix_internal/dix.go +++ b/dixinternal/dix.go @@ -1,10 +1,11 @@ -package dix_internal +package dixinternal import ( "fmt" - "os" + "path/filepath" "reflect" "strings" + "time" "github.com/pubgo/funk/assert" "github.com/pubgo/funk/errors" @@ -48,64 +49,65 @@ func (x *Dix) Option() Options { return x.option } -func (x *Dix) evalProvider(typ outputType, opt Options) map[group][]value { - switch typ.Kind() { +func (x *Dix) getOutputTypeValues(outTyp outputType, opt Options) map[group][]value { + switch outTyp.Kind() { case reflect.Ptr, reflect.Interface, reflect.Func: default: assert.Must(errors.Err{ Msg: "provider type kind error, the supported type kinds are ", - Detail: fmt.Sprintf("type=%s kind=%s", typ, typ.Kind()), + Detail: fmt.Sprintf("type=%s kind=%s", outTyp, outTyp.Kind()), }) } - if len(x.providers[typ]) == 0 { + if len(x.providers[outTyp]) == 0 { logger.Warn(). - Str("type", typ.String()). - Str("kind", typ.Kind().String()). + Str("type", outTyp.String()). + Str("kind", outTyp.Kind().String()). Msg("provider not found, please check whether the provider imports or type error") - // return make(map[group][]value) } - if x.objects[typ] == nil { - x.objects[typ] = make(map[group][]value) + if x.objects[outTyp] == nil { + x.objects[outTyp] = make(map[group][]value) } - logger.Debug(). - Str("type", typ.String()). - Str("kind", typ.Kind().String()). - Int("providers", len(x.providers[typ])). - Msg("eval type value") - for _, n := range x.providers[typ] { + for _, n := range x.providers[outTyp] { if x.initializer[n.fn] { continue } var input []reflect.Value - for _, in := range n.input { - val := x.getValue(in.typ, opt, in.isMap, in.isList, typ) + for _, in := range n.inputList { + val := x.getValue(in.typ, opt, in.isMap, in.isList, outTyp) input = append(input, val) } + var now = time.Now() + var fnStack = stack.CallerWithFunc(n.fn) fnCall := n.call(input) + logger.Debug(). + Str("cost", time.Since(now).String()). + Str("provider", fnStack.String()). + Msgf("eval provider func %s.%s", filepath.Base(fnStack.Pkg), fnStack.Name) + x.initializer[n.fn] = true objects := make(map[outputType]map[group][]value) - for k, oo := range handleOutput(typ, fnCall[0]) { + for outT, groupValue := range handleOutput(outTyp, fnCall[0]) { if n.output.isMap { - if _, ok := objects[k]; ok { + if _, ok := objects[outT]; ok { logger.Info(). - Str("type", typ.String()). - Str("key", k.String()). + Str("type", outTyp.String()). + Str("key", outT.String()). Msg("type value exists") } } - if objects[k] == nil { - objects[k] = make(map[group][]value) + if objects[outT] == nil { + objects[outT] = make(map[group][]value) } - for g, o := range oo { - objects[k][g] = append(objects[k][g], o...) + for g, o := range groupValue { + objects[outT][g] = append(objects[outT][g], o...) } } @@ -120,7 +122,7 @@ func (x *Dix) evalProvider(typ outputType, opt Options) map[group][]value { } } - return x.objects[typ] + return x.objects[outTyp] } func (x *Dix) getProviderStack(typ reflect.Type) []string { @@ -132,9 +134,15 @@ func (x *Dix) getProviderStack(typ reflect.Type) []string { } func (x *Dix) getValue(typ reflect.Type, opt Options, isMap, isList bool, parents ...reflect.Type) reflect.Value { + if typ.Kind() == reflect.Struct { + v := reflect.New(typ) + x.injectStruct(v.Elem(), opt) + return v.Elem() + } + + valMap := x.getOutputTypeValues(typ, opt) switch { case isMap: - valMap := x.evalProvider(typ, opt) if !opt.AllowValuesNull && len(valMap) == 0 { logger.Panic(). Any("options", opt). @@ -147,7 +155,6 @@ func (x *Dix) getValue(typ reflect.Type, opt Options, isMap, isList bool, parent return makeMap(typ, valMap, isList) case isList: - valMap := x.evalProvider(typ, opt) if !opt.AllowValuesNull && len(valMap[defaultKey]) == 0 { err := &errors.Err{ Msg: "provider value not found", @@ -165,12 +172,7 @@ func (x *Dix) getValue(typ reflect.Type, opt Options, isMap, isList bool, parent } return makeList(typ, valMap[defaultKey]) - case typ.Kind() == reflect.Struct: - v := reflect.New(typ) - x.injectStruct(v.Elem(), opt) - return v.Elem() default: - valMap := x.evalProvider(typ, opt) if valList, ok := valMap[defaultKey]; !ok || len(valList) == 0 { logger.Panic(). Any("options", opt). @@ -266,12 +268,14 @@ func (x *Dix) injectStruct(vp reflect.Value, opt Options) { } } -func (x *Dix) inject(param interface{}, opts ...Option) interface{} { - defer recovery.Raise(func(err error) error { +func (x *Dix) inject(param interface{}, opts ...Option) (gErr error) { + defer recovery.Err(&gErr, func(err error) error { return errors.WrapKV(err, "param", fmt.Sprintf("%#v", param)) }) - assert.If(param == nil, "param is null") + if param == nil { + return errors.New("nil injection parameter") + } var opt Options for i := range opts { @@ -313,11 +317,11 @@ func (x *Dix) inject(param interface{}, opts ...Option) interface{} { }) x.injectStruct(vp, opt) - return param + return nil } func (x *Dix) handleProvide(fnVal reflect.Value, out reflect.Type, in []*inType) { - n := &node{fn: fnVal, input: in} + n := &node{fn: fnVal, inputList: in} switch outTyp := out; outTyp.Kind() { case reflect.Slice: n.output = &outType{isList: true, typ: outTyp.Elem()} @@ -333,7 +337,6 @@ func (x *Dix) handleProvide(fnVal reflect.Value, out reflect.Type, in []*inType) n.output = &outType{typ: outTyp} x.providers[n.output.typ] = append(x.providers[n.output.typ], n) case reflect.Struct: - logger.Debug().Str("name", outTyp.Name()).Msg("struct info") for i := 0; i < outTyp.NumField(); i++ { x.handleProvide(fnVal, outTyp.Field(i).Type, in) } @@ -342,14 +345,14 @@ func (x *Dix) handleProvide(fnVal reflect.Value, out reflect.Type, in []*inType) } } -func (x *Dix) getAllProvideInput(typ reflect.Type) []*inType { +func (x *Dix) getProvideAllInputs(typ reflect.Type) []*inType { var input []*inType switch inTye := typ; inTye.Kind() { case reflect.Interface, reflect.Ptr, reflect.Func: input = append(input, &inType{typ: inTye}) case reflect.Struct: for j := 0; j < inTye.NumField(); j++ { - input = append(input, x.getAllProvideInput(inTye.Field(j).Type)...) + input = append(input, x.getProvideAllInputs(inTye.Field(j).Type)...) } case reflect.Map: tt := &inType{typ: inTye.Elem(), isMap: true, isList: inTye.Elem().Kind() == reflect.Slice} @@ -384,6 +387,11 @@ func (x *Dix) getProvideInput(typ reflect.Type) []*inType { return input } +// Provide registers the constructor with the container. +// The constructor must be a function that returns at least one value (or an error). +// Arguments of the constructor are treated as dependencies, +// and return values are treated as results that can be injected elsewhere. +// Provide panics if the constructor is not a function or does not have the required signature. func (x *Dix) provide(param interface{}) { defer recovery.Raise(func(err error) error { return errors.WrapKV(err, "param", fmt.Sprintf("%#v", param)) @@ -412,10 +420,4 @@ func (x *Dix) provide(param interface{}) { // The return value can only have one // TODO Add the second parameter, support for error x.handleProvide(fnVal, typ.Out(0), input) - - dep, ok := x.isCycle() - if ok { - logger.Fatal().Str("cycle", dep).Msg("provider circular dependency") - os.Exit(1) - } } diff --git a/dixinternal/graph.go b/dixinternal/graph.go new file mode 100644 index 0000000..b463b9f --- /dev/null +++ b/dixinternal/graph.go @@ -0,0 +1,107 @@ +package dixinternal + +import ( + "bytes" + "fmt" + + "github.com/pubgo/funk/stack" +) + +// DotRenderer implements DOT format graph rendering +type DotRenderer struct { + buf *bytes.Buffer + indent string + cache map[string]string +} + +func NewDotRenderer() *DotRenderer { + return &DotRenderer{ + buf: &bytes.Buffer{}, + indent: "", + cache: make(map[string]string), + } +} + +func (d *DotRenderer) writef(format string, args ...interface{}) { + _, _ = fmt.Fprintf(d.buf, d.indent+format+"\n", args...) +} + +func (d *DotRenderer) RenderNode(name string, attrs map[string]string) { + d.writef("%s [label=\"%s\"%s]", name, name, d.formatAttrs(attrs)) +} + +func (d *DotRenderer) RenderEdge(from, to string, attrs map[string]string) { + d.writef("%s -> %s%s", from, to, d.formatAttrs(attrs)) +} + +func (d *DotRenderer) BeginSubgraph(name, label string) { + d.writef("subgraph %s {", name) + d.indent += "\t" + d.writef("label=\"%s\"", label) +} + +func (d *DotRenderer) EndSubgraph() { + d.indent = d.indent[:len(d.indent)-1] + d.writef("}") +} + +func (d *DotRenderer) String() string { + return d.buf.String() +} + +func (d *DotRenderer) formatAttrs(attrs map[string]string) string { + if len(attrs) == 0 { + return "" + } + + var result bytes.Buffer + result.WriteString(" [") + first := true + for k, v := range attrs { + if !first { + result.WriteString(",") + } + first = false + fmt.Fprintf(&result, "%s=\"%s\"", k, v) + } + result.WriteString("]") + return result.String() +} + +func (x *Dix) providerGraph() string { + d := NewDotRenderer() + d.writef("digraph G {") + d.BeginSubgraph("cluster_providers", "providers") + + for providerOutputType, nodes := range x.providers { + for _, n := range nodes { + fn := stack.CallerWithFunc(n.fn).String() + d.RenderEdge(fn, providerOutputType.String(), nil) + for _, in := range n.inputList { + d.RenderEdge(in.typ.String(), fn, nil) + } + } + } + + d.EndSubgraph() + d.writef("}") + return d.String() +} + +func (x *Dix) objectGraph() string { + d := NewDotRenderer() + d.writef("digraph G {") + d.BeginSubgraph("cluster_objects", "objects") + + for k, objects := range x.objects { + for g, values := range objects { + for _, v := range values { + d.RenderEdge(k.String(), fmt.Sprintf("%s -> %s", g, v.Type().String()), nil) + } + } + } + + d.EndSubgraph() + d.writef("}") + return d.String() +} diff --git a/dix_internal/node.go b/dixinternal/node.go similarity index 63% rename from dix_internal/node.go rename to dixinternal/node.go index 3ca247f..37dec56 100644 --- a/dix_internal/node.go +++ b/dixinternal/node.go @@ -1,8 +1,9 @@ -package dix_internal +package dixinternal import ( "fmt" "reflect" + "strings" "github.com/pubgo/funk/errors" "github.com/pubgo/funk/recovery" @@ -38,9 +39,9 @@ type outType struct { } type node struct { - fn reflect.Value - input []*inType - output *outType + fn reflect.Value + inputList []*inType + output *outType } func (n node) call(in []reflect.Value) []reflect.Value { @@ -48,10 +49,31 @@ func (n node) call(in []reflect.Value) []reflect.Value { return errors.WrapTag(err, errors.T("msg", "failed to handle provider invoke"), errors.T("fn_stack", stack.CallerWithFunc(n.fn).String()), + errors.T("fn_type", n.fn.Type().String()), errors.T("input", fmt.Sprintf("%v", in)), errors.T("input_data", reflectValueToString(in)), + errors.T("input_types", reflectTypesToString(n.inputList)), + errors.T("output_type", n.output.typ.String()), ) }) return n.fn.Call(in) } + +// reflectTypesToString converts input type list to readable string +func reflectTypesToString(types []*inType) string { + var result strings.Builder + for i, t := range types { + if i > 0 { + result.WriteString(", ") + } + result.WriteString(t.typ.String()) + if t.isMap { + result.WriteString("(map)") + } + if t.isList { + result.WriteString("(list)") + } + } + return result.String() +} diff --git a/dix_internal/option.go b/dixinternal/option.go similarity index 83% rename from dix_internal/option.go rename to dixinternal/option.go index b6c3045..509e9e3 100644 --- a/dix_internal/option.go +++ b/dixinternal/option.go @@ -1,9 +1,9 @@ -package dix_internal +package dixinternal type ( Option func(opts *Options) Options struct { - // 允许结果为nil + // AllowValuesNull allows result to be nil AllowValuesNull bool } ) diff --git a/dix_internal/util.go b/dixinternal/util.go similarity index 66% rename from dix_internal/util.go rename to dixinternal/util.go index c7cdf36..4b1f26e 100644 --- a/dix_internal/util.go +++ b/dixinternal/util.go @@ -1,8 +1,9 @@ -package dix_internal +package dixinternal import ( "fmt" "reflect" + "slices" "strings" ) @@ -27,7 +28,7 @@ func makeMap(typ reflect.Type, data map[string][]reflect.Value, valueList bool) mapVal := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(""), typ)) for index, values := range data { - // 最后一个值作为默认值 + // The last value as the default value val := values[len(values)-1] if valueList { val = reflect.MakeSlice(typ, 0, len(values)) @@ -46,15 +47,15 @@ func reflectValueToString(values []reflect.Value) []string { return data } -func handleOutput(outType outputType, out reflect.Value) map[outputType]map[group][]value { +func handleOutput(outType outputType, providerOutTyp reflect.Value) map[outputType]map[group][]value { rr := make(map[outputType]map[group][]value) - if !out.IsValid() || out.IsZero() { + if !providerOutTyp.IsValid() || providerOutTyp.IsZero() { return rr } - switch out.Kind() { + switch providerOutTyp.Kind() { case reflect.Map: - outType = out.Type().Elem() + outType = providerOutTyp.Type().Elem() isList := outType.Kind() == reflect.Slice if isList { outType = outType.Elem() @@ -64,13 +65,13 @@ func handleOutput(outType outputType, out reflect.Value) map[outputType]map[grou rr[outType] = make(map[group][]value) } - for _, k := range out.MapKeys() { + for _, k := range providerOutTyp.MapKeys() { mapK := strings.TrimSpace(k.String()) if mapK == "" { mapK = defaultKey } - val := out.MapIndex(k) + val := providerOutTyp.MapIndex(k) if !val.IsValid() || val.IsNil() { continue } @@ -89,13 +90,13 @@ func handleOutput(outType outputType, out reflect.Value) map[outputType]map[grou } } case reflect.Slice: - outType = out.Type().Elem() + outType = providerOutTyp.Type().Elem() if rr[outType] == nil { rr[outType] = make(map[group][]value) } - for i := 0; i < out.Len(); i++ { - val := out.Index(i) + for i := 0; i < providerOutTyp.Len(); i++ { + val := providerOutTyp.Index(i) if !val.IsValid() || val.IsNil() { continue } @@ -103,8 +104,8 @@ func handleOutput(outType outputType, out reflect.Value) map[outputType]map[grou rr[outType][defaultKey] = append(rr[outType][defaultKey], val) } case reflect.Struct: - for i := 0; i < out.NumField(); i++ { - for typ, vv := range handleOutput(out.Field(i).Type(), out.Field(i)) { + for i := 0; i < providerOutTyp.NumField(); i++ { + for typ, vv := range handleOutput(providerOutTyp.Field(i).Type(), providerOutTyp.Field(i)) { if rr[typ] == nil { rr[typ] = vv } else { @@ -119,8 +120,8 @@ func handleOutput(outType outputType, out reflect.Value) map[outputType]map[grou rr[outType] = make(map[group][]value) } - if out.IsValid() && !out.IsNil() { - rr[outType][defaultKey] = []value{out} + if providerOutTyp.IsValid() && !providerOutTyp.IsNil() { + rr[outType][defaultKey] = []value{providerOutTyp} } } return rr @@ -128,18 +129,15 @@ func handleOutput(outType outputType, out reflect.Value) map[outputType]map[grou func detectCycle(graph map[reflect.Type]map[reflect.Type]bool) []reflect.Type { visited := make(map[reflect.Type]bool) - recursionStack := make(map[reflect.Type]bool) - var cycle []reflect.Type - - var dfs func(reflect.Type, []reflect.Type) - dfs = func(t reflect.Type, path []reflect.Type) { + var dfs func(reflect.Type, map[reflect.Type]bool, []reflect.Type) []reflect.Type + dfs = func(t reflect.Type, recursionStack map[reflect.Type]bool, path []reflect.Type) []reflect.Type { if recursionStack[t] { - cycle = append([]reflect.Type(nil), path...) - return + return slices.Clone(path) } + if visited[t] { - return + return nil } visited[t] = true @@ -147,17 +145,23 @@ func detectCycle(graph map[reflect.Type]map[reflect.Type]bool) []reflect.Type { defer delete(recursionStack, t) for dep := range graph[t] { - dfs(dep, append(path, dep)) + cycle := dfs(dep, recursionStack, append(slices.Clone(path), dep)) if len(cycle) > 0 { - return + return cycle } } + return nil } for t := range graph { - if !visited[t] { - dfs(t, []reflect.Type{t}) + if visited[t] { + continue + } + + cycle := dfs(t, make(map[reflect.Type]bool), []reflect.Type{t}) + if len(cycle) > 0 { + return cycle } } - return cycle + return nil } diff --git a/example/cycle/main.go b/example/cycle/main.go index 5b25feb..7c39cfc 100644 --- a/example/cycle/main.go +++ b/example/cycle/main.go @@ -2,12 +2,9 @@ package main import ( "fmt" - "strings" - + "github.com/pubgo/dix/di" - "github.com/pubgo/funk/generic" "github.com/pubgo/funk/recovery" - "github.com/pubgo/funk/try" ) func main() { @@ -32,18 +29,8 @@ func main() { return new(B) }) - err := try.Try(func() error { - di.Provide(func(*A) *C { - return new(C) - }) - return nil + di.Provide(func(*A) *C { + return new(C) }) - - if !generic.IsNil(err) { - if strings.Contains(err.Error(), "provider circular dependency") { - return - } - - panic(err) - } + di.Inject(func(*C) {}) }