Skip to content

Commit 5e217c2

Browse files
ekoSean Sorrell
authored and
Sean Sorrell
committed
Added support of custom directives
based on #446 and work by @eko
1 parent 4423f25 commit 5e217c2

File tree

4 files changed

+241
-2
lines changed

4 files changed

+241
-2
lines changed

graphql.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ type Schema struct {
8383
useStringDescriptions bool
8484
disableIntrospection bool
8585
subscribeResolverTimeout time.Duration
86+
visitors map[string]types.DirectiveVisitor
8687
}
8788

8889
func (s *Schema) ASTSchema() *types.Schema {
@@ -169,6 +170,14 @@ func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt {
169170
}
170171
}
171172

173+
// DirectiveVisitors allows to pass custom directive visitors that will be able to handle
174+
// your GraphQL schema directives.
175+
func DirectiveVisitors(visitors map[string]types.DirectiveVisitor) SchemaOpt {
176+
return func(s *Schema) {
177+
s.visitors = visitors
178+
}
179+
}
180+
172181
// Response represents a typical response of a GraphQL server. It may be encoded to JSON directly or
173182
// it may be further processed to a custom response type, for example to include custom error data.
174183
// Errors are intentionally serialized first based on the advice in https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107
@@ -258,6 +267,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
258267
Tracer: s.tracer,
259268
Logger: s.logger,
260269
PanicHandler: s.panicHandler,
270+
Visitors: s.visitors,
261271
}
262272
varTypes := make(map[string]*introspection.Type)
263273
for _, v := range op.Vars {

graphql_test.go

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/graph-gophers/graphql-go/gqltesting"
1515
"github.com/graph-gophers/graphql-go/introspection"
1616
"github.com/graph-gophers/graphql-go/trace/tracer"
17+
"github.com/graph-gophers/graphql-go/types"
1718
)
1819

1920
type helloWorldResolver1 struct{}
@@ -48,6 +49,27 @@ func (r *helloSnakeResolver2) SayHello(ctx context.Context, args struct{ FullNam
4849
return "Hello " + args.FullName + "!", nil
4950
}
5051

52+
type customDirectiveVisitor struct {
53+
beforeWasCalled bool
54+
}
55+
56+
func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) error {
57+
v.beforeWasCalled = true
58+
return nil
59+
}
60+
61+
func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) {
62+
if v.beforeWasCalled == false {
63+
return nil, errors.New("Before directive visitor method wasn't called.")
64+
}
65+
66+
if value, ok := directive.Arguments.Get("customAttribute"); ok {
67+
return fmt.Sprintf("Directive '%s' (with arg '%s') modified result: %s", directive.Name.Name, value.String(), output.(string)), nil
68+
} else {
69+
return fmt.Sprintf("Directive '%s' modified result: %s", directive.Name.Name, output.(string)), nil
70+
}
71+
}
72+
5173
type theNumberResolver struct {
5274
number int32
5375
}
@@ -191,7 +213,6 @@ func TestHelloWorld(t *testing.T) {
191213
}
192214
`,
193215
},
194-
195216
{
196217
Schema: graphql.MustParseSchema(`
197218
schema {
@@ -216,6 +237,67 @@ func TestHelloWorld(t *testing.T) {
216237
})
217238
}
218239

240+
func TestCustomDirective(t *testing.T) {
241+
t.Parallel()
242+
243+
gqltesting.RunTests(t, []*gqltesting.Test{
244+
{
245+
Schema: graphql.MustParseSchema(`
246+
directive @customDirective on FIELD_DEFINITION
247+
248+
schema {
249+
query: Query
250+
}
251+
252+
type Query {
253+
hello_html: String! @customDirective
254+
}
255+
`, &helloSnakeResolver1{},
256+
graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{
257+
"customDirective": &customDirectiveVisitor{},
258+
})),
259+
Query: `
260+
{
261+
hello_html
262+
}
263+
`,
264+
ExpectedResult: `
265+
{
266+
"hello_html": "Directive 'customDirective' modified result: Hello snake!"
267+
}
268+
`,
269+
},
270+
{
271+
Schema: graphql.MustParseSchema(`
272+
directive @customDirective(
273+
customAttribute: String!
274+
) on FIELD_DEFINITION
275+
276+
schema {
277+
query: Query
278+
}
279+
280+
type Query {
281+
say_hello(full_name: String!): String! @customDirective(customAttribute: hi)
282+
}
283+
`, &helloSnakeResolver1{},
284+
graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{
285+
"customDirective": &customDirectiveVisitor{},
286+
})),
287+
Query: `
288+
{
289+
say_hello(full_name: "Johnny")
290+
}
291+
`,
292+
ExpectedResult: `
293+
{
294+
"say_hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello Johnny!"
295+
}
296+
`,
297+
},
298+
})
299+
}
300+
219301
func TestHelloSnake(t *testing.T) {
220302
t.Parallel()
221303

@@ -4550,3 +4632,69 @@ func TestQueryService(t *testing.T) {
45504632
},
45514633
})
45524634
}
4635+
4636+
type StructFieldResolver struct {
4637+
Hello string
4638+
}
4639+
4640+
func TestStructFieldResolver(t *testing.T) {
4641+
gqltesting.RunTests(t, []*gqltesting.Test{
4642+
{
4643+
Schema: graphql.MustParseSchema(`
4644+
schema {
4645+
query: Query
4646+
}
4647+
4648+
type Query {
4649+
hello: String!
4650+
}
4651+
`, &StructFieldResolver{Hello: "Hello world!"}, graphql.UseFieldResolvers()),
4652+
Query: `
4653+
{
4654+
hello
4655+
}
4656+
`,
4657+
ExpectedResult: `
4658+
{
4659+
"hello": "Hello world!"
4660+
}
4661+
`,
4662+
},
4663+
})
4664+
}
4665+
4666+
func TestDirectiveStructFieldResolver(t *testing.T) {
4667+
schemaOpt := []graphql.SchemaOpt{
4668+
graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{
4669+
"customDirective": &customDirectiveVisitor{},
4670+
}),
4671+
graphql.UseFieldResolvers(),
4672+
}
4673+
4674+
gqltesting.RunTests(t, []*gqltesting.Test{
4675+
4676+
{
4677+
Schema: graphql.MustParseSchema(`
4678+
directive @customDirective on FIELD_DEFINITION
4679+
4680+
schema {
4681+
query: Query
4682+
}
4683+
4684+
type Query {
4685+
hello: String! @customDirective
4686+
}
4687+
`, &StructFieldResolver{Hello: "Hello world!"}, schemaOpt...),
4688+
Query: `
4689+
{
4690+
hello
4691+
}
4692+
`,
4693+
ExpectedResult: `
4694+
{
4695+
"hello": "Directive 'customDirective' modified result: Hello world!"
4696+
}
4697+
`,
4698+
}})
4699+
4700+
}

internal/exec/exec.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Request struct {
2525
Logger log.Logger
2626
PanicHandler errors.PanicHandler
2727
SubscribeResolverTimeout time.Duration
28+
Visitors map[string]types.DirectiveVisitor
2829
}
2930

3031
func (r *Request) handlePanic(ctx context.Context) {
@@ -208,8 +209,48 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
208209
if f.field.ArgsPacker != nil {
209210
in = append(in, f.field.PackedArgs)
210211
}
212+
213+
// Before hook directive visitor
214+
if len(f.field.Directives) > 0 {
215+
for _, directive := range f.field.Directives {
216+
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
217+
values := make([]interface{}, 0, len(in))
218+
for _, inValue := range in {
219+
values = append(values, inValue.Interface())
220+
}
221+
222+
visitorErr := visitor.Before(ctx, directive, values)
223+
if visitorErr != nil {
224+
err := errors.Errorf("%s", visitorErr)
225+
err.Path = path.toSlice()
226+
err.ResolverError = visitorErr
227+
return err
228+
}
229+
}
230+
}
231+
}
232+
233+
// Call method
211234
callOut := res.Method(f.field.MethodIndex).Call(in)
212235
result = callOut[0]
236+
237+
// After hook directive visitor (when no error is returned from resolver)
238+
if !f.field.HasError && len(f.field.Directives) > 0 {
239+
for _, directive := range f.field.Directives {
240+
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
241+
returned, visitorErr := visitor.After(ctx, directive, result.Interface())
242+
if visitorErr != nil {
243+
err := errors.Errorf("%s", visitorErr)
244+
err.Path = path.toSlice()
245+
err.ResolverError = visitorErr
246+
return err
247+
} else {
248+
result = reflect.ValueOf(returned)
249+
}
250+
}
251+
}
252+
}
253+
213254
if f.field.HasError && !callOut[1].IsNil() {
214255
resolverErr := callOut[1].Interface().(error)
215256
err := errors.Errorf("%s", resolverErr)
@@ -225,7 +266,38 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
225266
if res.Kind() == reflect.Ptr {
226267
res = res.Elem()
227268
}
269+
// Before hook directive visitor struct field
270+
if len(f.field.Directives) > 0 {
271+
for _, directive := range f.field.Directives {
272+
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
273+
// TODO check that directive arity == 0-that should be an error at schema init time
274+
visitorErr := visitor.Before(ctx, directive, nil)
275+
if visitorErr != nil {
276+
err := errors.Errorf("%s", visitorErr)
277+
err.Path = path.toSlice()
278+
err.ResolverError = visitorErr
279+
return err
280+
}
281+
}
282+
}
283+
}
228284
result = res.FieldByIndex(f.field.FieldIndex)
285+
// After hook directive visitor (when no error is returned from resolver)
286+
if !f.field.HasError && len(f.field.Directives) > 0 {
287+
for _, directive := range f.field.Directives {
288+
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
289+
returned, visitorErr := visitor.After(ctx, directive, result.Interface())
290+
if visitorErr != nil {
291+
err := errors.Errorf("%s", visitorErr)
292+
err.Path = path.toSlice()
293+
err.ResolverError = visitorErr
294+
return err
295+
} else {
296+
result = reflect.ValueOf(returned)
297+
}
298+
}
299+
}
300+
}
229301
}
230302
return nil
231303
}()

types/directive.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package types
22

3-
import "github.com/graph-gophers/graphql-go/errors"
3+
import (
4+
"context"
5+
6+
"github.com/graph-gophers/graphql-go/errors"
7+
)
48

59
// Directive is a representation of the GraphQL Directive.
610
//
@@ -24,6 +28,11 @@ type DirectiveDefinition struct {
2428

2529
type DirectiveList []*Directive
2630

31+
type DirectiveVisitor interface {
32+
Before(ctx context.Context, directive *Directive, input interface{}) error
33+
After(ctx context.Context, directive *Directive, output interface{}) (interface{}, error)
34+
}
35+
2736
// Returns the Directive in the DirectiveList by name or nil if not found.
2837
func (l DirectiveList) Get(name string) *Directive {
2938
for _, d := range l {

0 commit comments

Comments
 (0)