diff --git a/README.md b/README.md index 6c9d2d8a..837656aa 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,7 @@ The goal of this project is to provide full support of the [GraphQL draft specification](https://facebook.github.io/graphql/draft) with a set of idiomatic, easy to use Go packages. -While still under heavy development (`internal` APIs are almost certainly subject to change), this library is -safe for production use. +While still under development (`internal` and `directives` APIs are almost certainly subject to change), this library is safe for production use. ## Features @@ -17,14 +16,15 @@ safe for production use. - handles panics in resolvers - parallel execution of resolvers - subscriptions - - [sample WS transport](https://github.com/graph-gophers/graphql-transport-ws) + - [sample WS transport](https://github.com/graph-gophers/graphql-transport-ws) +- directive visitors on fields (the API is subject to change in future versions) ## Roadmap We're trying out the GitHub Project feature to manage `graphql-go`'s [development roadmap](https://github.com/graph-gophers/graphql-go/projects/1). Feedback is welcome and appreciated. -## (Some) Documentation +## (Some) Documentation [![GoDoc](https://godoc.org/github.com/graph-gophers/graphql-go?status.svg)](https://godoc.org/github.com/graph-gophers/graphql-go) ### Getting started diff --git a/directives/doc.go b/directives/doc.go new file mode 100644 index 00000000..39a1c61e --- /dev/null +++ b/directives/doc.go @@ -0,0 +1,4 @@ +/* +package directives contains a Visitor Pattern implementation of Schema Directives for Fields. +*/ +package directives diff --git a/directives/visitor.go b/directives/visitor.go new file mode 100644 index 00000000..729b6724 --- /dev/null +++ b/directives/visitor.go @@ -0,0 +1,18 @@ +package directives + +import ( + "context" + + "github.com/graph-gophers/graphql-go/types" +) + +// Visitor defines the interface that clients should use to implement a Directive +// see the graphql.DirectiveVisitors() Schema Option. +type Visitor interface { + // Before() is always called when the operation includes a directive matching this implementation's name. + // When the first return value is true, the field resolver will not be called. + // Errors in Before() will prevent field resolution. + Before(ctx context.Context, directive *types.Directive, input interface{}) (skipResolver bool, err error) + // After is called if Before() *and* the field resolver do not error. + After(ctx context.Context, directive *types.Directive, output interface{}) (modified interface{}, err error) +} diff --git a/example/directives/authorization/README.md b/example/directives/authorization/README.md new file mode 100644 index 00000000..a99816c7 --- /dev/null +++ b/example/directives/authorization/README.md @@ -0,0 +1,128 @@ +# @hasRole directive + +## Overview +A simple example of naive authorization directive which returns an error if the user in the context doesn't have the required role. Make sure that in production applications you use thread-safe maps for roles as an instance of the user struct might be accessed from multiple goroutines. In this naive example we use a simeple map which is not thread-safe. The required role to access a resolver is passed as an argument to the directive, for example, `@hasRole(role: ADMIN)`. + +## Getting started +To run this server + +`go run ./example/directives/authorization/server/server.go` + +Navigate to https://localhost:8080 in your browser to interact with the GraphiQL UI. + +## Testing with curl +Access public resolver: +``` +$ curl 'http://localhost:8080/query' \ + -H 'Accept: application/json' \ + --data-raw '{"query":"# mutation {\nquery {\n publicGreet(name: \"John\")\n}","variables":null}' + +{"data":{"publicGreet":"Hello from the public resolver, John!"}} +``` +Try accessing protected resolver without required role: +``` +$ curl 'http://localhost:8080/query' \ + -H 'Accept: application/json' \ + --data-raw '{"query":"# mutation {\nquery {\n privateGreet(name: \"John\")\n}","variables":null}' +{"errors":[{"message":"access denied, \"admin\" role required","path":["privateGreet"]}],"data":null} +``` +Try accessing protected resolver again with appropriate role: +``` +$ curl 'http://localhost:8080/query' \ + -H 'Accept: application/json' \ + -H 'role: admin' \ + --data-raw '{"query":"# mutation {\nquery {\n privateGreet(name: \"John\")\n}","variables":null}' +{"data":{"privateGreet":"Hi from the protected resolver, John!"}} +``` + +## Implementation details + +1. Add directive definition to your shema: + ```graphql + directive @hasRole(role: Role!) on FIELD_DEFINITION + ``` + +2. Add directive to the protected fields in the schema: + ```graphql + type Query { + # other field resolvers here + privateGreet(name: String!): String! @hasRole(role: ADMIN) + } + ``` + +3. Define a user Go type which can have a slice of roles where each role is a string: + ```go + type User struct { + ID string + Roles map[string]struct{} + } + + func (u *User) AddRole(r string) { + if u.Roles == nil { + u.Roles = map[string]struct{}{} + } + u.Roles[r] = struct{}{} + } + + func (u *User) HasRole(r string) bool { + _, ok := u.Roles[r] + return ok + } + ``` + +4. Define a Go type which implements the DirevtiveVisitor interface: + ```go + type HasRoleDirective struct{} + + func (h *HasRoleDirective) Before(ctx context.Context, directive *types.Directive, input interface{}) (bool, error) { + u, ok := user.FromContext(ctx) + if !ok { + return true, fmt.Errorf("user not provided in cotext") + } + role := strings.ToLower((directive.Arguments.MustGet("role").String()) + if !u.HasRole(role) { + return true, fmt.Errorf("access denied, %q role required", role) + } + return false, nil + } + + // After is a no-op and returns the output unchanged. + func (h *HasRoleDirective) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) { + return output, nil + } + ``` + +5. Pay attention to the schmema options. Directive visitors are added as schema option: + ```go + opts := []graphql.SchemaOpt{ + graphql.DirectiveVisitors(map[string]directives.Visitor{ + "hasRole": &authorization.HasRoleDirective{}, + }), + // other options go here + } + schema := graphql.MustParseSchema(authorization.Schema, &authorization.Resolver{}, opts...) + ``` + +6. Add a middleware to the HTTP handler which would read the `role` HTTP header and add that role to the slice of user roles. This naive middleware assumes that there is authentication proxy (e.g. Nginx, Envoy, Contour etc.) in front of this server which would authenticate the user and add their role in a header. In production application it would be fine if the same application handles the authentication and adds the user to the context. This is the middleware in this example: + ```go + func auth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u := &user.User{} + role := r.Header.Get("role") + if role != "" { + u.AddRole(role) + } + ctx := user.AddToContext(context.Background(), u) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } + ``` + +7. Wrap the GraphQL handler with the auth middleware: + ```go + http.Handle("/query", auth(&relay.Handler{Schema: schema})) + ``` + +8. In order to access the private resolver add a role header like below: + +![accessing a private resolver using role header](graphiql-has-role-example.png) \ No newline at end of file diff --git a/example/directives/authorization/authorization.go b/example/directives/authorization/authorization.go new file mode 100644 index 00000000..fc9981a8 --- /dev/null +++ b/example/directives/authorization/authorization.go @@ -0,0 +1,57 @@ +// Package authorization contains a simple GraphQL schema using directives. +package authorization + +import ( + "context" + "fmt" + "strings" + + "github.com/graph-gophers/graphql-go/example/directives/authorization/user" + "github.com/graph-gophers/graphql-go/types" +) + +const Schema = ` + schema { + query: Query + } + + directive @hasRole(role: Role!) on FIELD_DEFINITION + + type Query { + publicGreet(name: String!): String! + privateGreet(name: String!): String! @hasRole(role: ADMIN) + } + + enum Role { + ADMIN + USER + } +` + +type HasRoleDirective struct{} + +func (h *HasRoleDirective) Before(ctx context.Context, directive *types.Directive, input interface{}) (bool, error) { + u, ok := user.FromContext(ctx) + if !ok { + return true, fmt.Errorf("user not provided in cotext") + } + role := strings.ToLower(directive.Arguments.MustGet("role").String()) + if !u.HasRole(role) { + return true, fmt.Errorf("access denied, %q role required", role) + } + return false, nil +} + +func (h *HasRoleDirective) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) { + return output, nil +} + +type Resolver struct{} + +func (r *Resolver) PublicGreet(ctx context.Context, args struct{ Name string }) string { + return fmt.Sprintf("Hello from the public resolver, %s!", args.Name) +} + +func (r *Resolver) PrivateGreet(ctx context.Context, args struct{ Name string }) string { + return fmt.Sprintf("Hi from the protected resolver, %s!", args.Name) +} diff --git a/example/directives/authorization/graphiql-has-role-example.png b/example/directives/authorization/graphiql-has-role-example.png new file mode 100644 index 00000000..6fb008b4 Binary files /dev/null and b/example/directives/authorization/graphiql-has-role-example.png differ diff --git a/example/directives/authorization/server/server.go b/example/directives/authorization/server/server.go new file mode 100644 index 00000000..31b2263e --- /dev/null +++ b/example/directives/authorization/server/server.go @@ -0,0 +1,79 @@ +package main + +import ( + "context" + "log" + "net/http" + + "github.com/graph-gophers/graphql-go" + "github.com/graph-gophers/graphql-go/directives" + "github.com/graph-gophers/graphql-go/example/directives/authorization" + "github.com/graph-gophers/graphql-go/example/directives/authorization/user" + "github.com/graph-gophers/graphql-go/relay" +) + +func main() { + opts := []graphql.SchemaOpt{ + graphql.DirectiveVisitors(map[string]directives.Visitor{ + "hasRole": &authorization.HasRoleDirective{}, + }), + // other options go here + } + schema := graphql.MustParseSchema(authorization.Schema, &authorization.Resolver{}, opts...) + + http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(page) + })) + + http.Handle("/query", auth(&relay.Handler{Schema: schema})) + + log.Fatal(http.ListenAndServe(":8080", nil)) +} + +func auth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + u := &user.User{} + role := r.Header.Get("role") + if role != "" { + u.AddRole(role) + } + ctx := user.AddToContext(context.Background(), u) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +var page = []byte(` + + + + GraphiQL + + + + + + +
Loading...
+ + + + +`) diff --git a/example/directives/authorization/user/user.go b/example/directives/authorization/user/user.go new file mode 100644 index 00000000..e8ba6111 --- /dev/null +++ b/example/directives/authorization/user/user.go @@ -0,0 +1,37 @@ +// package user contains a naive implementation of an user with roles. +// Each user can be assigned roles and added to/retrieved from context. +package user + +import ( + "context" +) + +type userKey string + +const contextKey userKey = "user" + +type User struct { + ID string + Roles map[string]struct{} +} + +func (u *User) AddRole(r string) { + if u.Roles == nil { + u.Roles = map[string]struct{}{} + } + u.Roles[r] = struct{}{} +} + +func (u *User) HasRole(r string) bool { + _, ok := u.Roles[r] + return ok +} + +func AddToContext(ctx context.Context, u *User) context.Context { + return context.WithValue(ctx, contextKey, u) +} + +func FromContext(ctx context.Context) (*User, bool) { + u, ok := ctx.Value(contextKey).(*User) + return u, ok +} diff --git a/graphql.go b/graphql.go index 891c0379..4d46adcc 100644 --- a/graphql.go +++ b/graphql.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "github.com/graph-gophers/graphql-go/directives" "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" "github.com/graph-gophers/graphql-go/internal/exec" @@ -83,6 +84,7 @@ type Schema struct { useStringDescriptions bool disableIntrospection bool subscribeResolverTimeout time.Duration + visitors map[string]directives.Visitor } func (s *Schema) ASTSchema() *types.Schema { @@ -169,6 +171,14 @@ func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt { } } +// DirectiveVisitors defines the implementation for each directive. +// Per the GraphQL specification, each Field Directive in the schema must have an implementation here. +func DirectiveVisitors(visitors map[string]directives.Visitor) SchemaOpt { + return func(s *Schema) { + s.visitors = visitors + } +} + // Response represents a typical response of a GraphQL server. It may be encoded to JSON directly or // it may be further processed to a custom response type, for example to include custom error data. // Errors are intentionally serialized first based on the advice in https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107 @@ -258,6 +268,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str Tracer: s.tracer, Logger: s.logger, PanicHandler: s.panicHandler, + Visitors: s.visitors, } varTypes := make(map[string]*introspection.Type) for _, v := range op.Vars { diff --git a/graphql_test.go b/graphql_test.go index c12334c8..5b18e1c9 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -9,11 +9,13 @@ import ( "time" "github.com/graph-gophers/graphql-go" + "github.com/graph-gophers/graphql-go/directives" gqlerrors "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/example/starwars" "github.com/graph-gophers/graphql-go/gqltesting" "github.com/graph-gophers/graphql-go/introspection" "github.com/graph-gophers/graphql-go/trace/tracer" + "github.com/graph-gophers/graphql-go/types" ) type helloWorldResolver1 struct{} @@ -48,6 +50,54 @@ func (r *helloSnakeResolver2) SayHello(ctx context.Context, args struct{ FullNam return "Hello " + args.FullName + "!", nil } +type structFieldResolver struct { + Hello string +} + +type customDirectiveVisitor struct { + beforeWasCalled bool +} + +func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) (bool, error) { + v.beforeWasCalled = true + return false, nil +} + +func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) { + if v.beforeWasCalled == false { + return nil, errors.New("Before directive visitor method wasn't called.") + } + + if value, ok := directive.Arguments.Get("customAttribute"); ok { + return fmt.Sprintf("Directive '%s' (with arg '%s') modified result: %s", directive.Name.Name, value.String(), output.(string)), nil + } + return fmt.Sprintf("Directive '%s' modified result: %s", directive.Name.Name, output.(string)), nil +} + +type cachedDirectiveVisitor struct { + cachedValue interface{} +} + +func (v *cachedDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) (bool, error) { + s := "valueFromCache" + v.cachedValue = s + return true, nil +} + +func (v *cachedDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) { + return v.cachedValue, nil +} + +type cachedDirectiveResolver struct { + t *testing.T +} + +func (r *cachedDirectiveResolver) Hello(ctx context.Context, args struct{ FullName string }) string { + r.t.Error("expected cached resolver to not be called, but it was") + + return "" +} + type theNumberResolver struct { number int32 } @@ -191,7 +241,6 @@ func TestHelloWorld(t *testing.T) { } `, }, - { Schema: graphql.MustParseSchema(` schema { @@ -216,6 +265,187 @@ func TestHelloWorld(t *testing.T) { }) } +func TestHelloWorldStructFieldResolver(t *testing.T) { + t.Parallel() + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(` + schema { + query: Query + } + + type Query { + hello: String! + } + `, + &structFieldResolver{Hello: "Hello world!"}, + graphql.UseFieldResolvers()), + Query: ` + { + hello + } + `, + ExpectedResult: ` + { + "hello": "Hello world!" + } + `, + }, + }) +} + +func TestCustomDirective(t *testing.T) { + t.Parallel() + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(` + directive @customDirective on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + hello_html: String! @customDirective + } + `, &helloSnakeResolver1{}, + graphql.DirectiveVisitors(map[string]directives.Visitor{ + "customDirective": &customDirectiveVisitor{}, + })), + Query: ` + { + hello_html + } + `, + ExpectedResult: ` + { + "hello_html": "Directive 'customDirective' modified result: Hello snake!" + } + `, + }, + { + Schema: graphql.MustParseSchema(` + directive @customDirective( + customAttribute: String! + ) on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + say_hello(full_name: String!): String! @customDirective(customAttribute: hi) + } + `, &helloSnakeResolver1{}, + graphql.DirectiveVisitors(map[string]directives.Visitor{ + "customDirective": &customDirectiveVisitor{}, + })), + Query: ` + { + say_hello(full_name: "Johnny") + } + `, + ExpectedResult: ` + { + "say_hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello Johnny!" + } + `, + }, + { + Schema: graphql.MustParseSchema(` + directive @cached( + key: String! + ) on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + hello(full_name: String!): String! @cached(key: "notcheckedintest") + } + `, &cachedDirectiveResolver{t: t}, + graphql.DirectiveVisitors(map[string]directives.Visitor{ + "cached": &cachedDirectiveVisitor{}, + })), + Query: ` + { + hello(full_name: "Full Name") + } + `, + ExpectedResult: ` + { + "hello": "valueFromCache" + } + `, + }, + }) +} + +func TestCustomDirectiveStructFieldResolver(t *testing.T) { + t.Parallel() + + schemaOpt := []graphql.SchemaOpt{ + graphql.DirectiveVisitors(map[string]directives.Visitor{ + "customDirective": &customDirectiveVisitor{}, + }), + graphql.UseFieldResolvers(), + } + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(` + directive @customDirective on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + hello: String! @customDirective + } + `, &structFieldResolver{Hello: "Hello world!"}, schemaOpt...), + Query: ` + { + hello + } + `, + ExpectedResult: ` + { + "hello": "Directive 'customDirective' modified result: Hello world!" + } + `, + }, + { + Schema: graphql.MustParseSchema(` + directive @customDirective( + customAttribute: String! + ) on FIELD_DEFINITION + + schema { + query: Query + } + + type Query { + hello: String! @customDirective(customAttribute: hi) + } + `, &structFieldResolver{Hello: "Hello world!"}, schemaOpt...), + Query: ` + { + hello + } + `, + ExpectedResult: ` + { + "hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello world!" + } + `, + }, + }) +} + func TestHelloSnake(t *testing.T) { t.Parallel() diff --git a/internal/exec/exec.go b/internal/exec/exec.go index e9056c53..55b0a7fe 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/graph-gophers/graphql-go/directives" "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/exec/resolvable" "github.com/graph-gophers/graphql-go/internal/exec/selected" @@ -25,6 +26,7 @@ type Request struct { Logger log.Logger PanicHandler errors.PanicHandler SubscribeResolverTimeout time.Duration + Visitors map[string]directives.Visitor } func (r *Request) handlePanic(ctx context.Context) { @@ -201,15 +203,68 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f res := f.resolver if f.field.UseMethodResolver() { - var in []reflect.Value + var ( + skipResolver bool + in []reflect.Value + callOut []reflect.Value + visitorErr error + ) if f.field.HasContext { in = append(in, reflect.ValueOf(traceCtx)) } if f.field.ArgsPacker != nil { in = append(in, f.field.PackedArgs) } - callOut := res.Method(f.field.MethodIndex).Call(in) - result = callOut[0] + + // Before hook directive visitor + if len(f.field.Directives) > 0 { + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + values := make([]interface{}, 0, len(in)) + for _, inValue := range in { + values = append(values, inValue.Interface()) + } + skipResolver, visitorErr = visitor.Before(traceCtx, directive, values) + if visitorErr != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } + } + } + } + + // Call resolver method unless a Before visitor tells us not to + if !skipResolver { + callOut = res.Method(f.field.MethodIndex).Call(in) + result = callOut[0] + } + + // After hook directive visitor (when no error is returned from resolver) + if !f.field.HasError && len(f.field.Directives) > 0 { + var modified interface{} + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + if !skipResolver { + modified, visitorErr = visitor.After(traceCtx, directive, result.Interface()) + } else { + modified, visitorErr = visitor.After(traceCtx, directive, nil) + } + + if visitorErr != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } + result = reflect.ValueOf(modified) + } + } + } + if skipResolver { + return nil + } if f.field.HasError && !callOut[1].IsNil() { resolverErr := callOut[1].Interface().(error) err := errors.Errorf("%s", resolverErr) @@ -221,11 +276,51 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f return err } } else { + var ( + skipResolver bool + visitorErr error + modified interface{} + ) // TODO extract out unwrapping ptr logic to a common place if res.Kind() == reflect.Ptr { res = res.Elem() } - result = res.FieldByIndex(f.field.FieldIndex) + // Before hook directive visitor struct field + if len(f.field.Directives) > 0 { + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + skipResolver, visitorErr = visitor.Before(traceCtx, directive, nil) + if visitorErr != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } + } + } + } + if !skipResolver { + result = res.FieldByIndex(f.field.FieldIndex) + } + // After hook directive visitor (when no error is returned from resolver) + if !f.field.HasError && len(f.field.Directives) > 0 { + for _, directive := range f.field.Directives { + if visitor, ok := r.Visitors[directive.Name.Name]; ok { + if !skipResolver { + modified, visitorErr = visitor.After(traceCtx, directive, result.Interface()) + } else { + modified, visitorErr = visitor.After(traceCtx, directive, nil) + } + if visitorErr != nil { + err := errors.Errorf("%s", visitorErr) + err.Path = path.toSlice() + err.ResolverError = visitorErr + return err + } + result = reflect.ValueOf(modified) + } + } + } } return nil }()