From d3a16a26bfb883b499be122bff9e30e19efef858 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Sat, 9 May 2015 11:01:40 +0900 Subject: [PATCH 01/10] Explicitly disallow repeated field on field paths https://groups.google.com/d/msg/grpc-io/Xqx80hG0D44/yyqfF2gTEFMJ --- .../descriptor/services.go | 3 + .../descriptor/services_test.go | 186 ++++++++++++++++++ 2 files changed, 189 insertions(+) diff --git a/protoc-gen-grpc-gateway/descriptor/services.go b/protoc-gen-grpc-gateway/descriptor/services.go index 32e38111946..2aaa602adcc 100644 --- a/protoc-gen-grpc-gateway/descriptor/services.go +++ b/protoc-gen-grpc-gateway/descriptor/services.go @@ -231,6 +231,9 @@ func (r *Registry) resolveFiledPath(msg *Message, path string) ([]FieldPathCompo if f == nil { return nil, fmt.Errorf("no field %q found in %s", path, root.GetName()) } + if f.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED { + return nil, fmt.Errorf("repeated field not allowed in field path: %s in %s", f.GetName(), path) + } result = append(result, FieldPathComponent{Name: c, Target: f}) } return result, nil diff --git a/protoc-gen-grpc-gateway/descriptor/services_test.go b/protoc-gen-grpc-gateway/descriptor/services_test.go index cfed7dab86b..ba12bb147f4 100644 --- a/protoc-gen-grpc-gateway/descriptor/services_test.go +++ b/protoc-gen-grpc-gateway/descriptor/services_test.go @@ -776,3 +776,189 @@ func TestExtractServicesWithError(t *testing.T) { t.Log(err) } } + +func TestResolveFieldPath(t *testing.T) { + for _, spec := range []struct { + src string + path string + wantErr bool + }{ + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'string' + type: TYPE_STRING + label: LABEL_OPTIONAL + number: 1 + > + > + `, + path: "string", + wantErr: false, + }, + // no such field + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'string' + type: TYPE_STRING + label: LABEL_OPTIONAL + number: 1 + > + > + `, + path: "something_else", + wantErr: true, + }, + // repeated field + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'string' + type: TYPE_STRING + label: LABEL_REPEATED + number: 1 + > + > + `, + path: "string", + wantErr: true, + }, + // nested field + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'nested' + type: TYPE_MESSAGE + type_name: 'AnotherMessage' + label: LABEL_OPTIONAL + number: 1 + > + field < + name: 'terminal' + type: TYPE_BOOL + label: LABEL_OPTIONAL + number: 2 + > + > + message_type < + name: 'AnotherMessage' + field < + name: 'nested2' + type: TYPE_MESSAGE + type_name: 'ExampleMessage' + label: LABEL_OPTIONAL + number: 1 + > + > + `, + path: "nested.nested2.nested.nested2.nested.nested2.terminal", + wantErr: false, + }, + // non aggregate field on the path + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'nested' + type: TYPE_MESSAGE + type_name: 'AnotherMessage' + label: LABEL_OPTIONAL + number: 1 + > + field < + name: 'terminal' + type: TYPE_BOOL + label: LABEL_OPTIONAL + number: 2 + > + > + message_type < + name: 'AnotherMessage' + field < + name: 'nested2' + type: TYPE_MESSAGE + type_name: 'ExampleMessage' + label: LABEL_OPTIONAL + number: 1 + > + > + `, + path: "nested.terminal.nested2", + wantErr: true, + }, + // repeated field + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'nested' + type: TYPE_MESSAGE + type_name: 'AnotherMessage' + label: LABEL_OPTIONAL + number: 1 + > + field < + name: 'terminal' + type: TYPE_BOOL + label: LABEL_OPTIONAL + number: 2 + > + > + message_type < + name: 'AnotherMessage' + field < + name: 'nested2' + type: TYPE_MESSAGE + type_name: 'ExampleMessage' + label: LABEL_REPEATED + number: 1 + > + > + `, + path: "nested.nested2.terminal", + wantErr: true, + }, + } { + var file descriptor.FileDescriptorProto + if err := proto.UnmarshalText(spec.src, &file); err != nil { + t.Fatalf("proto.Unmarshal(%s) failed with %v; want success", spec.src, err) + } + reg := NewRegistry() + reg.loadFile(&file) + f, err := reg.LookupFile(file.GetName()) + if err != nil { + t.Fatalf("reg.LookupFile(%q) failed with %v; want success; on file=%s", file.GetName(), err, spec.src) + } + _, err = reg.resolveFiledPath(f.Messages[0], spec.path) + if got, want := err != nil, spec.wantErr; got != want { + if want { + t.Errorf("reg.resolveFiledPath(%q, %q) succeeded; want an error", f.Messages[0].GetName(), spec.path) + continue + } + t.Errorf("reg.resolveFiledPath(%q, %q) failed with %v; want success", f.Messages[0].GetName(), spec.path, err) + } + } +} From 3b2082e4aff1ccb731d902f6b620518d38f1feb7 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Sat, 9 May 2015 12:08:02 +0900 Subject: [PATCH 02/10] Prepare for supporting addtional_bindings Update the data structure in protoc-gen-grpc-gateway/descriptor --- .../descriptor/services.go | 142 ++++++---- .../descriptor/services_test.go | 256 ++++++++++++++---- protoc-gen-grpc-gateway/descriptor/types.go | 17 +- 3 files changed, 298 insertions(+), 117 deletions(-) diff --git a/protoc-gen-grpc-gateway/descriptor/services.go b/protoc-gen-grpc-gateway/descriptor/services.go index 2aaa602adcc..b01f42acfdc 100644 --- a/protoc-gen-grpc-gateway/descriptor/services.go +++ b/protoc-gen-grpc-gateway/descriptor/services.go @@ -51,56 +51,6 @@ func (r *Registry) loadServices(targetFile string) error { } func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, opts *options.HttpRule) (*Method, error) { - var ( - httpMethod string - pathTemplate string - ) - switch { - case opts.Get != "": - httpMethod = "GET" - pathTemplate = opts.Get - if opts.Body != "" { - return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName()) - } - - case opts.Put != "": - httpMethod = "PUT" - pathTemplate = opts.Put - - case opts.Post != "": - httpMethod = "POST" - pathTemplate = opts.Post - - case opts.Delete != "": - httpMethod = "DELETE" - pathTemplate = opts.Delete - if opts.Body != "" { - return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName()) - } - - case opts.Patch != "": - httpMethod = "PATCH" - pathTemplate = opts.Patch - - case opts.Custom != nil: - httpMethod = opts.Custom.Kind - pathTemplate = opts.Custom.Path - - default: - glog.Errorf("No pattern specified in google.api.HttpRule: %s", md.GetName()) - return nil, fmt.Errorf("none of pattern specified") - } - - parsed, err := httprule.Parse(pathTemplate) - if err != nil { - return nil, err - } - tmpl := parsed.Compile() - - if md.GetClientStreaming() && len(tmpl.Fields) > 0 { - return nil, fmt.Errorf("cannot use path parameter in client streaming") - } - requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType()) if err != nil { return nil, err @@ -109,31 +59,105 @@ func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, if err != nil { return nil, err } - meth := &Method{ Service: svc, MethodDescriptorProto: md, - PathTmpl: tmpl, - HTTPMethod: httpMethod, RequestType: requestType, ResponseType: responseType, } - for _, f := range tmpl.Fields { - param, err := r.newParam(meth, f) + newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) { + var ( + httpMethod string + pathTemplate string + ) + switch { + case opts.Get != "": + httpMethod = "GET" + pathTemplate = opts.Get + if opts.Body != "" { + return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName()) + } + + case opts.Put != "": + httpMethod = "PUT" + pathTemplate = opts.Put + + case opts.Post != "": + httpMethod = "POST" + pathTemplate = opts.Post + + case opts.Delete != "": + httpMethod = "DELETE" + pathTemplate = opts.Delete + if opts.Body != "" { + return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName()) + } + + case opts.Patch != "": + httpMethod = "PATCH" + pathTemplate = opts.Patch + + case opts.Custom != nil: + httpMethod = opts.Custom.Kind + pathTemplate = opts.Custom.Path + + default: + glog.Errorf("No pattern specified in google.api.HttpRule: %s", md.GetName()) + return nil, fmt.Errorf("none of pattern specified") + } + + parsed, err := httprule.Parse(pathTemplate) if err != nil { return nil, err } - meth.PathParams = append(meth.PathParams, param) - } + tmpl := parsed.Compile() + + if md.GetClientStreaming() && len(tmpl.Fields) > 0 { + return nil, fmt.Errorf("cannot use path parameter in client streaming") + } + + b := &Binding{ + Method: meth, + Index: idx, + PathTmpl: tmpl, + HTTPMethod: httpMethod, + } + + for _, f := range tmpl.Fields { + param, err := r.newParam(meth, f) + if err != nil { + return nil, err + } + b.PathParams = append(b.PathParams, param) + } - // TODO(yugui) Handle query params + // TODO(yugui) Handle query params - meth.Body, err = r.newBody(meth, opts.Body) + b.Body, err = r.newBody(meth, opts.Body) + if err != nil { + return nil, err + } + + return b, nil + } + b, err := newBinding(opts, 0) if err != nil { return nil, err } + meth.Bindings = append(meth.Bindings, b) + for i, additional := range opts.GetAdditionalBindings() { + if len(additional.AdditionalBindings) > 0 { + return nil, fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName()) + } + b, err := newBinding(additional, i+1) + if err != nil { + return nil, err + } + meth.Bindings = append(meth.Bindings, b) + } + return meth, nil } diff --git a/protoc-gen-grpc-gateway/descriptor/services_test.go b/protoc-gen-grpc-gateway/descriptor/services_test.go index ba12bb147f4..0705af471c2 100644 --- a/protoc-gen-grpc-gateway/descriptor/services_test.go +++ b/protoc-gen-grpc-gateway/descriptor/services_test.go @@ -43,49 +43,68 @@ func testExtractServices(t *testing.T, input []*descriptor.FileDescriptorProto, t.Errorf("svcs[%d].Methods[%d].MethodDescriptorProto = %v; want %v; input = %v", i, j, got, want, input) continue } - if got, want := meth.PathTmpl, wantMeth.PathTmpl; !reflect.DeepEqual(got, want) { - t.Errorf("svcs[%d].Methods[%d].PathTmpl = %#v; want %#v; input = %v", i, j, got, want, input) - } - if got, want := meth.HTTPMethod, wantMeth.HTTPMethod; got != want { - t.Errorf("svcs[%d].Methods[%d].HTTPMethod = %q; want %q; input = %v", i, j, got, want, input) - } if got, want := meth.RequestType, wantMeth.RequestType; got.FQMN() != want.FQMN() { t.Errorf("svcs[%d].Methods[%d].RequestType = %s; want %s; input = %v", i, j, got.FQMN(), want.FQMN(), input) } if got, want := meth.ResponseType, wantMeth.ResponseType; got.FQMN() != want.FQMN() { t.Errorf("svcs[%d].Methods[%d].ResponseType = %s; want %s; input = %v", i, j, got.FQMN(), want.FQMN(), input) } - var k int - for k = 0; k < len(meth.PathParams) && k < len(wantMeth.PathParams); k++ { - param, wantParam := meth.PathParams[k], wantMeth.PathParams[k] - if got, want := param.FieldPath.String(), wantParam.FieldPath.String(); got != want { - t.Errorf("svcs[%d].Methods[%d].PathParams[%d].FieldPath.String() = %q; want %q; input = %v", i, j, k, got, want, input) - continue + for k = 0; k < len(meth.Bindings) && k < len(wantMeth.Bindings); k++ { + binding, wantBinding := meth.Bindings[k], wantMeth.Bindings[k] + if got, want := binding.Index, wantBinding.Index; got != want { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Index = %d; want %d; input = %v", i, j, k, got, want, input) + } + if got, want := binding.PathTmpl, wantBinding.PathTmpl; !reflect.DeepEqual(got, want) { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathTmpl = %#v; want %#v; input = %v", i, j, k, got, want, input) } - for l := 0; l < len(param.FieldPath) && l < len(wantParam.FieldPath); l++ { - field, wantField := param.FieldPath[l].Target, wantParam.FieldPath[l].Target - if got, want := field.FieldDescriptorProto, wantField.FieldDescriptorProto; !proto.Equal(got, want) { - t.Errorf("svcs[%d].Methods[%d].PathParams[%d].FieldPath[%d].Target.FieldDescriptorProto = %v; want %v; input = %v", i, j, k, l, got, want, input) + if got, want := binding.HTTPMethod, wantBinding.HTTPMethod; got != want { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].HTTPMethod = %q; want %q; input = %v", i, j, k, got, want, input) + } + + var l int + for l = 0; l < len(binding.PathParams) && l < len(wantBinding.PathParams); l++ { + param, wantParam := binding.PathParams[l], wantBinding.PathParams[l] + if got, want := param.FieldPath.String(), wantParam.FieldPath.String(); got != want { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d].FieldPath.String() = %q; want %q; input = %v", i, j, k, l, got, want, input) + continue + } + for l := 0; l < len(param.FieldPath) && l < len(wantParam.FieldPath); l++ { + field, wantField := param.FieldPath[l].Target, wantParam.FieldPath[l].Target + if got, want := field.FieldDescriptorProto, wantField.FieldDescriptorProto; !proto.Equal(got, want) { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d].FieldPath[%d].Target.FieldDescriptorProto = %v; want %v; input = %v", i, j, k, l, got, want, input) + } } } - } - for ; k < len(meth.PathParams); k++ { - got := meth.PathParams[k].FieldPath.String() - t.Errorf("svcs[%d].Methods[%d].PathParams[%d] = %q; want it to be missing; input = %v", i, j, k, got, input) - } - for ; k < len(wantMeth.PathParams); k++ { - want := wantMeth.PathParams[k].FieldPath.String() - t.Errorf("svcs[%d].Methods[%d].PathParams[%d] missing; want %q; input = %v", i, j, k, want, input) - } + for ; l < len(binding.PathParams); l++ { + got := binding.PathParams[l].FieldPath.String() + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d] = %q; want it to be missing; input = %v", i, j, k, l, got, input) + } + for ; l < len(wantBinding.PathParams); l++ { + want := wantBinding.PathParams[l].FieldPath.String() + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d] missing; want %q; input = %v", i, j, k, l, want, input) + } - if got, want := (meth.Body != nil), (wantMeth.Body != nil); got != want { - if got { - t.Errorf("svcs[%d].Methods[%d].Body = %q; want it to be missing; input = %v", i, j, meth.Body.FieldPath.String(), input) - } else { - t.Errorf("svcs[%d].Methods[%d].Body missing; want %q; input = %v", i, j, wantMeth.Body.FieldPath.String(), input) + if got, want := (binding.Body != nil), (wantBinding.Body != nil); got != want { + if got { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Body = %q; want it to be missing; input = %v", i, j, k, binding.Body.FieldPath.String(), input) + } else { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Body missing; want %q; input = %v", i, j, k, wantBinding.Body.FieldPath.String(), input) + } + } else if binding.Body != nil { + if got, want := binding.Body.FieldPath.String(), wantBinding.Body.FieldPath.String(); got != want { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Body = %q; want it to be missing; input = %v", i, j, k, got, want, input) + } } } + for ; k < len(meth.Bindings); k++ { + got := meth.Bindings[k] + t.Errorf("svcs[%d].Methods[%d].Bindings[%d] = %q; want it to be missing; input = %v", i, j, k, got, input) + } + for ; k < len(wantMeth.Bindings); k++ { + want := wantMeth.Bindings[k] + t.Errorf("svcs[%d].Methods[%d].Bindings[%d] missing; want %q; input = %v", i, j, k, want, input) + } } for ; j < len(svc.Methods); j++ { got := svc.Methods[j].MethodDescriptorProto @@ -117,11 +136,14 @@ func crossLinkFixture(f *File) *File { svc.File = f for _, m := range svc.Methods { m.Service = svc - for _, param := range m.PathParams { - param.Method = m - } - for _, param := range m.QueryParams { - param.Method = m + for _, b := range m.Bindings { + b.Method = m + for _, param := range b.PathParams { + param.Method = m + } + for _, param := range b.QueryParams { + param.Method = m + } } } } @@ -181,11 +203,15 @@ func TestExtractServicesSimple(t *testing.T) { Methods: []*Method{ { MethodDescriptorProto: fd.Service[0].Method[0], - PathTmpl: compilePath(t, "/v1/example/echo"), - HTTPMethod: "POST", RequestType: msg, ResponseType: msg, - Body: &Body{FieldPath: nil}, + Bindings: []*Binding{ + { + PathTmpl: compilePath(t, "/v1/example/echo"), + HTTPMethod: "POST", + Body: &Body{FieldPath: nil}, + }, + }, }, }, }, @@ -276,11 +302,15 @@ func TestExtractServicesCrossPackage(t *testing.T) { Methods: []*Method{ { MethodDescriptorProto: fds[0].Service[0].Method[0], - PathTmpl: compilePath(t, "/v1/example/to_s"), - HTTPMethod: "POST", RequestType: boolMsg, ResponseType: stringMsg, - Body: &Body{FieldPath: nil}, + Bindings: []*Binding{ + { + PathTmpl: compilePath(t, "/v1/example/to_s"), + HTTPMethod: "POST", + Body: &Body{FieldPath: nil}, + }, + }, }, }, }, @@ -365,15 +395,19 @@ func TestExtractServicesWithBodyPath(t *testing.T) { Methods: []*Method{ { MethodDescriptorProto: fd.Service[0].Method[0], - PathTmpl: compilePath(t, "/v1/example/echo"), - HTTPMethod: "POST", RequestType: msg, ResponseType: msg, - Body: &Body{ - FieldPath: FieldPath{ - { - Name: "nested", - Target: msg.Fields[0], + Bindings: []*Binding{ + { + PathTmpl: compilePath(t, "/v1/example/echo"), + HTTPMethod: "POST", + Body: &Body{ + FieldPath: FieldPath{ + { + Name: "nested", + Target: msg.Fields[0], + }, + }, }, }, }, @@ -439,19 +473,133 @@ func TestExtractServicesWithPathParam(t *testing.T) { Methods: []*Method{ { MethodDescriptorProto: fd.Service[0].Method[0], - PathTmpl: compilePath(t, "/v1/example/echo/{string=*}"), - HTTPMethod: "GET", RequestType: msg, ResponseType: msg, - PathParams: []Parameter{ + Bindings: []*Binding{ { - FieldPath: FieldPath{ + PathTmpl: compilePath(t, "/v1/example/echo/{string=*}"), + HTTPMethod: "GET", + PathParams: []Parameter{ { - Name: "string", + FieldPath: FieldPath{ + { + Name: "string", + Target: msg.Fields[0], + }, + }, Target: msg.Fields[0], }, }, - Target: msg.Fields[0], + }, + }, + }, + }, + }, + }, + } + + crossLinkFixture(file) + testExtractServices(t, []*descriptor.FileDescriptorProto{&fd}, "path/to/example.proto", file.Services) +} + +func TestExtractServicesWithAdditionalBinding(t *testing.T) { + src := ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + post: "/v1/example/echo" + body: "*" + additional_bindings < + get: "/v1/example/echo/{string}" + > + additional_bindings < + post: "/v2/example/echo" + body: "string" + > + > + > + > + > + ` + var fd descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &fd); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err) + } + msg := &Message{ + DescriptorProto: fd.MessageType[0], + Fields: []*Field{ + { + FieldDescriptorProto: fd.MessageType[0].Field[0], + }, + }, + } + file := &File{ + FileDescriptorProto: &fd, + GoPkg: GoPackage{ + Path: "path/to/example.pb", + Name: "example_pb", + }, + Messages: []*Message{msg}, + Services: []*Service{ + { + ServiceDescriptorProto: fd.Service[0], + Methods: []*Method{ + { + MethodDescriptorProto: fd.Service[0].Method[0], + RequestType: msg, + ResponseType: msg, + Bindings: []*Binding{ + { + Index: 0, + PathTmpl: compilePath(t, "/v1/example/echo"), + HTTPMethod: "POST", + Body: &Body{FieldPath: nil}, + }, + { + Index: 1, + PathTmpl: compilePath(t, "/v1/example/echo/{string}"), + HTTPMethod: "GET", + PathParams: []Parameter{ + { + FieldPath: FieldPath{ + { + Name: "string", + Target: msg.Fields[0], + }, + }, + Target: msg.Fields[0], + }, + }, + Body: nil, + }, + { + Index: 2, + PathTmpl: compilePath(t, "/v2/example/echo"), + HTTPMethod: "POST", + Body: &Body{ + FieldPath: FieldPath{ + FieldPathComponent{ + Name: "string", + Target: msg.Fields[0], + }, + }, + }, }, }, }, diff --git a/protoc-gen-grpc-gateway/descriptor/types.go b/protoc-gen-grpc-gateway/descriptor/types.go index 6d8eeade127..6cc35dd00d2 100644 --- a/protoc-gen-grpc-gateway/descriptor/types.go +++ b/protoc-gen-grpc-gateway/descriptor/types.go @@ -102,14 +102,23 @@ type Method struct { Service *Service *descriptor.MethodDescriptorProto - // PathTmpl is path template where this method is mapped to. - PathTmpl httprule.Template - // HTTPMethod is the HTTP method which this method is mapped to. - HTTPMethod string // RequestType is the message type of requests to this method. RequestType *Message // ResponseType is the message type of responses from this method. ResponseType *Message + Bindings []*Binding +} + +// Binding describes how an HTTP endpoint is bound to a gRPC method. +type Binding struct { + // Method is the method which the endpoint is bound to. + Method *Method + // Index is a zero-origin index of the binding in the target method + Index int + // PathTmpl is path template where this method is mapped to. + PathTmpl httprule.Template + // HTTPMethod is the HTTP method which this method is mapped to. + HTTPMethod string // PathParams is the list of parameters provided in HTTP request paths. PathParams []Parameter // QueryParam is the list of parameters provided in HTTP query strings. From 620cd7cd698e8372e31ddaeaac00701dc5c7d025 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Sat, 9 May 2015 12:28:08 +0900 Subject: [PATCH 03/10] Update templates to follow the change in descriptor/ --- .../descriptor/services_test.go | 8 +- .../gengateway/template.go | 34 ++-- .../gengateway/template_test.go | 149 ++++++++++-------- 3 files changed, 106 insertions(+), 85 deletions(-) diff --git a/protoc-gen-grpc-gateway/descriptor/services_test.go b/protoc-gen-grpc-gateway/descriptor/services_test.go index 0705af471c2..3c467ed7bf9 100644 --- a/protoc-gen-grpc-gateway/descriptor/services_test.go +++ b/protoc-gen-grpc-gateway/descriptor/services_test.go @@ -69,10 +69,10 @@ func testExtractServices(t *testing.T, input []*descriptor.FileDescriptorProto, t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d].FieldPath.String() = %q; want %q; input = %v", i, j, k, l, got, want, input) continue } - for l := 0; l < len(param.FieldPath) && l < len(wantParam.FieldPath); l++ { - field, wantField := param.FieldPath[l].Target, wantParam.FieldPath[l].Target + for m := 0; m < len(param.FieldPath) && m < len(wantParam.FieldPath); m++ { + field, wantField := param.FieldPath[m].Target, wantParam.FieldPath[m].Target if got, want := field.FieldDescriptorProto, wantField.FieldDescriptorProto; !proto.Equal(got, want) { - t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d].FieldPath[%d].Target.FieldDescriptorProto = %v; want %v; input = %v", i, j, k, l, got, want, input) + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d].FieldPath[%d].Target.FieldDescriptorProto = %v; want %v; input = %v", i, j, k, l, m, got, want, input) } } } @@ -93,7 +93,7 @@ func testExtractServices(t *testing.T, input []*descriptor.FileDescriptorProto, } } else if binding.Body != nil { if got, want := binding.Body.FieldPath.String(), wantBinding.Body.FieldPath.String(); got != want { - t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Body = %q; want it to be missing; input = %v", i, j, k, got, want, input) + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Body = %q; want %q; input = %v", i, j, k, got, want, input) } } } diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index 387039a57e3..7640045e866 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -22,8 +22,10 @@ func applyTemplate(p param) (string, error) { for _, svc := range p.Services { for _, meth := range svc.Methods { methodSeen = true - if err := handlerTemplate.Execute(w, meth); err != nil { - return "", err + for _, b := range meth.Bindings { + if err := handlerTemplate.Execute(w, b); err != nil { + return "", err + } } } } @@ -61,7 +63,7 @@ var _ = json.Marshal `)) handlerTemplate = template.Must(template.New("handler").Parse(` -{{if .GetClientStreaming}} +{{if .Method.GetClientStreaming}} {{template "client-streaming-request-func" .}} {{else}} {{template "client-rpc-request-func" .}} @@ -69,22 +71,22 @@ var _ = json.Marshal `)) _ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.Replace(` -{{if .GetServerStreaming}} -func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Service.GetName}}Client, req *http.Request, pathParams map[string]string) ({{.Service.GetName}}_{{.GetName}}Client, error) +{{if .Method.GetServerStreaming}} +func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.GetName}}_{{.Method.GetName}}Client, error) {{else}} -func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Service.GetName}}Client, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) +func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {{end}}`, "\n", "", -1))) _ = template.Must(handlerTemplate.New("client-streaming-request-func").Parse(` {{template "request-func-signature" .}} { - stream, err := client.{{.GetName}}(ctx) + stream, err := client.{{.Method.GetName}}(ctx) if err != nil { glog.Errorf("Failed to start streaming: %v", err) return nil, err } dec := json.NewDecoder(req.Body) for { - var protoReq {{.RequestType.GoType .Service.File.GoPkg.Path}} + var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}} err = dec.Decode(&protoReq) if err == io.EOF { break @@ -98,7 +100,7 @@ func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Se return nil, err } } -{{if .GetServerStreaming}} +{{if .Method.GetServerStreaming}} if err = stream.CloseSend(); err != nil { glog.Errorf("Failed to terminate client stream: %v", err) return nil, err @@ -112,7 +114,7 @@ func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Se _ = template.Must(handlerTemplate.New("client-rpc-request-func").Parse(` {{template "request-func-signature" .}} { - var protoReq {{.RequestType.GoType .Service.File.GoPkg.Path}} + var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}} {{range $param := .QueryParams}} protoReq.{{$param.RHS "protoReq"}}, err = {{$param.ConvertFuncExpr}}(req.FormValue({{$param | printf "%q"}})) if err != nil { @@ -139,7 +141,7 @@ func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Se {{end}} {{end}} - return client.{{.GetName}}(ctx, &protoReq) + return client.{{.Method.GetName}}(ctx, &protoReq) }`)) trailerTemplate = template.Must(template.New("trailer").Parse(` @@ -174,8 +176,9 @@ func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runti func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { client := New{{$svc.GetName}}Client(conn) {{range $m := $svc.Methods}} - mux.Handle({{$m.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_{{$svc.GetName}}_{{$m.GetName}}(ctx, client, req, pathParams) + {{range $b := $m.Bindings}} + mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -187,12 +190,15 @@ func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, {{end}} }) {{end}} + {{end}} return nil } var ( {{range $m := $svc.Methods}} - pattern_{{$svc.GetName}}_{{$m.GetName}} = runtime.MustPattern(runtime.NewPattern({{$m.PathTmpl.Version}}, {{$m.PathTmpl.OpCodes | printf "%#v"}}, {{$m.PathTmpl.Pool | printf "%#v"}}, {{$m.PathTmpl.Verb | printf "%q"}})) + {{range $b := $m.Bindings}} + pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}})) + {{end}} {{end}} ) {{end}}`)) diff --git a/protoc-gen-grpc-gateway/gengateway/template_test.go b/protoc-gen-grpc-gateway/gengateway/template_test.go index 3f0ecfe9313..e66be979d73 100644 --- a/protoc-gen-grpc-gateway/gengateway/template_test.go +++ b/protoc-gen-grpc-gateway/gengateway/template_test.go @@ -18,11 +18,14 @@ func crossLinkFixture(f *descriptor.File) *descriptor.File { svc.File = f for _, m := range svc.Methods { m.Service = svc - for _, param := range m.PathParams { - param.Method = m - } - for _, param := range m.QueryParams { - param.Method = m + for _, b := range m.Bindings { + b.Method = m + for _, param := range b.PathParams { + param.Method = m + } + for _, param := range b.QueryParams { + param.Method = m + } } } } @@ -64,10 +67,14 @@ func TestApplyTemplateHeader(t *testing.T) { Methods: []*descriptor.Method{ { MethodDescriptorProto: meth, - HTTPMethod: "GET", RequestType: msg, ResponseType: msg, - Body: &descriptor.Body{FieldPath: nil}, + Bindings: []*descriptor.Binding{ + { + HTTPMethod: "GET", + Body: &descriptor.Body{FieldPath: nil}, + }, + }, }, }, }, @@ -129,11 +136,11 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) { }{ { serverStreaming: false, - sigWant: `func request_ExampleService_Echo(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {`, }, { serverStreaming: true, - sigWant: `func request_ExampleService_Echo(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, }, } { meth.ServerStreaming = proto.Bool(spec.serverStreaming) @@ -175,39 +182,43 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) { Methods: []*descriptor.Method{ { MethodDescriptorProto: meth, - HTTPMethod: "POST", - PathTmpl: httprule.Template{ - Version: 1, - OpCodes: []int{0, 0}, - }, - RequestType: msg, - ResponseType: msg, - PathParams: []descriptor.Parameter{ + RequestType: msg, + ResponseType: msg, + Bindings: []*descriptor.Binding{ { - FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ - { - Name: "nested", - Target: nestedField, - }, + HTTPMethod: "POST", + PathTmpl: httprule.Template{ + Version: 1, + OpCodes: []int{0, 0}, + }, + PathParams: []descriptor.Parameter{ { - Name: "int32", + FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ + { + Name: "nested", + Target: nestedField, + }, + { + Name: "int32", + Target: intField, + }, + }), Target: intField, }, - }), - Target: intField, - }, - }, - Body: &descriptor.Body{ - FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ - { - Name: "nested", - Target: nestedField, }, - { - Name: "bool", - Target: boolField, + Body: &descriptor.Body{ + FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ + { + Name: "nested", + Target: nestedField, + }, + { + Name: "bool", + Target: boolField, + }, + }), }, - }), + }, }, }, }, @@ -234,7 +245,7 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) { if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } - if want := `pattern_ExampleService_Echo = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { + if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } } @@ -286,11 +297,11 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { }{ { serverStreaming: false, - sigWant: `func request_ExampleService_Echo(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {`, }, { serverStreaming: true, - sigWant: `func request_ExampleService_Echo(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, }, } { meth.ServerStreaming = proto.Bool(spec.serverStreaming) @@ -332,39 +343,43 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { Methods: []*descriptor.Method{ { MethodDescriptorProto: meth, - HTTPMethod: "POST", - PathTmpl: httprule.Template{ - Version: 1, - OpCodes: []int{0, 0}, - }, - RequestType: msg, - ResponseType: msg, - PathParams: []descriptor.Parameter{ + RequestType: msg, + ResponseType: msg, + Bindings: []*descriptor.Binding{ { - FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ - { - Name: "nested", - Target: nestedField, - }, + HTTPMethod: "POST", + PathTmpl: httprule.Template{ + Version: 1, + OpCodes: []int{0, 0}, + }, + PathParams: []descriptor.Parameter{ { - Name: "int32", + FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ + { + Name: "nested", + Target: nestedField, + }, + { + Name: "int32", + Target: intField, + }, + }), Target: intField, }, - }), - Target: intField, - }, - }, - Body: &descriptor.Body{ - FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ - { - Name: "nested", - Target: nestedField, }, - { - Name: "bool", - Target: boolField, + Body: &descriptor.Body{ + FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ + { + Name: "nested", + Target: nestedField, + }, + { + Name: "bool", + Target: boolField, + }, + }), }, - }), + }, }, }, }, @@ -385,7 +400,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } - if want := `pattern_ExampleService_Echo = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { + if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } } From c78916ae8ff788885731e3edea909893ce935c2c Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Sat, 9 May 2015 12:41:52 +0900 Subject: [PATCH 04/10] Update examples --- examples/a_bit_of_everything.pb.gw.go | 95 +++++++++++++++++---------- examples/a_bit_of_everything.proto | 4 ++ examples/echo_service.pb.gw.go | 16 ++--- examples/integration_test.go | 50 ++++++++++++++ 4 files changed, 121 insertions(+), 44 deletions(-) diff --git a/examples/a_bit_of_everything.pb.gw.go b/examples/a_bit_of_everything.pb.gw.go index 27873818077..4f901c280ce 100644 --- a/examples/a_bit_of_everything.pb.gw.go +++ b/examples/a_bit_of_everything.pb.gw.go @@ -28,7 +28,7 @@ var _ io.Reader var _ = runtime.String var _ = json.Marshal -func request_ABitOfEverythingService_Create(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_Create_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq ABitOfEverything var val string @@ -163,7 +163,7 @@ func request_ABitOfEverythingService_Create(ctx context.Context, client ABitOfEv return client.Create(ctx, &protoReq) } -func request_ABitOfEverythingService_CreateBody(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_CreateBody_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq ABitOfEverything if err = json.NewDecoder(req.Body).Decode(&protoReq); err != nil { @@ -173,7 +173,7 @@ func request_ABitOfEverythingService_CreateBody(ctx context.Context, client ABit return client.CreateBody(ctx, &protoReq) } -func request_ABitOfEverythingService_BulkCreate(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_BulkCreate_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { stream, err := client.BulkCreate(ctx) if err != nil { glog.Errorf("Failed to start streaming: %v", err) @@ -200,7 +200,7 @@ func request_ABitOfEverythingService_BulkCreate(ctx context.Context, client ABit } -func request_ABitOfEverythingService_Lookup(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_Lookup_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq IdMessage var val string @@ -218,13 +218,13 @@ func request_ABitOfEverythingService_Lookup(ctx context.Context, client ABitOfEv return client.Lookup(ctx, &protoReq) } -func request_ABitOfEverythingService_List(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (ABitOfEverythingService_ListClient, error) { +func request_ABitOfEverythingService_List_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (ABitOfEverythingService_ListClient, error) { var protoReq EmptyMessage return client.List(ctx, &protoReq) } -func request_ABitOfEverythingService_Update(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_Update_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq ABitOfEverything if err = json.NewDecoder(req.Body).Decode(&protoReq); err != nil { @@ -246,7 +246,7 @@ func request_ABitOfEverythingService_Update(ctx context.Context, client ABitOfEv return client.Update(ctx, &protoReq) } -func request_ABitOfEverythingService_Delete(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_Delete_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq IdMessage var val string @@ -264,7 +264,7 @@ func request_ABitOfEverythingService_Delete(ctx context.Context, client ABitOfEv return client.Delete(ctx, &protoReq) } -func request_ABitOfEverythingService_Echo(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_Echo_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq sub.StringMessage var val string @@ -282,7 +282,17 @@ func request_ABitOfEverythingService_Echo(ctx context.Context, client ABitOfEver return client.Echo(ctx, &protoReq) } -func request_ABitOfEverythingService_BulkEcho(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (ABitOfEverythingService_BulkEchoClient, error) { +func request_ABitOfEverythingService_Echo_1(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { + var protoReq sub.StringMessage + + if err = json.NewDecoder(req.Body).Decode(&protoReq.Value); err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + } + + return client.Echo(ctx, &protoReq) +} + +func request_ABitOfEverythingService_BulkEcho_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (ABitOfEverythingService_BulkEchoClient, error) { stream, err := client.BulkEcho(ctx) if err != nil { glog.Errorf("Failed to start streaming: %v", err) @@ -343,8 +353,8 @@ func RegisterABitOfEverythingServiceHandlerFromEndpoint(ctx context.Context, mux func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { client := NewABitOfEverythingServiceClient(conn) - mux.Handle("POST", pattern_ABitOfEverythingService_Create, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Create(ctx, client, req, pathParams) + mux.Handle("POST", pattern_ABitOfEverythingService_Create_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Create_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -354,8 +364,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("POST", pattern_ABitOfEverythingService_CreateBody, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_CreateBody(ctx, client, req, pathParams) + mux.Handle("POST", pattern_ABitOfEverythingService_CreateBody_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_CreateBody_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -365,8 +375,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("POST", pattern_ABitOfEverythingService_BulkCreate, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_BulkCreate(ctx, client, req, pathParams) + mux.Handle("POST", pattern_ABitOfEverythingService_BulkCreate_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_BulkCreate_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -376,8 +386,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("GET", pattern_ABitOfEverythingService_Lookup, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Lookup(ctx, client, req, pathParams) + mux.Handle("GET", pattern_ABitOfEverythingService_Lookup_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Lookup_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -387,8 +397,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("GET", pattern_ABitOfEverythingService_List, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_List(ctx, client, req, pathParams) + mux.Handle("GET", pattern_ABitOfEverythingService_List_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_List_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -398,8 +408,19 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("PUT", pattern_ABitOfEverythingService_Update, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Update(ctx, client, req, pathParams) + mux.Handle("PUT", pattern_ABitOfEverythingService_Update_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Update_0(ctx, client, req, pathParams) + if err != nil { + runtime.HTTPError(w, err) + return + } + + runtime.ForwardResponseMessage(w, resp) + + }) + + mux.Handle("DELETE", pattern_ABitOfEverythingService_Delete_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Delete_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -409,8 +430,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("DELETE", pattern_ABitOfEverythingService_Delete, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Delete(ctx, client, req, pathParams) + mux.Handle("GET", pattern_ABitOfEverythingService_Echo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Echo_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -420,8 +441,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("GET", pattern_ABitOfEverythingService_Echo, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Echo(ctx, client, req, pathParams) + mux.Handle("POST", pattern_ABitOfEverythingService_Echo_1, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Echo_1(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -431,8 +452,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("POST", pattern_ABitOfEverythingService_BulkEcho, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_BulkEcho(ctx, client, req, pathParams) + mux.Handle("POST", pattern_ABitOfEverythingService_BulkEcho_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_BulkEcho_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -446,21 +467,23 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se } var ( - pattern_ABitOfEverythingService_Create = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 1, 0, 4, 1, 5, 4, 1, 0, 4, 1, 5, 5, 2, 6, 1, 0, 4, 1, 5, 7, 1, 0, 4, 1, 5, 8, 1, 0, 4, 1, 5, 9, 1, 0, 4, 1, 5, 10, 1, 0, 4, 1, 5, 11, 2, 12, 1, 0, 4, 2, 5, 13, 1, 0, 4, 1, 5, 14, 1, 0, 4, 1, 5, 15, 1, 0, 4, 1, 5, 16, 1, 0, 4, 1, 5, 17, 1, 0, 4, 1, 5, 18}, []string{"v1", "example", "a_bit_of_everything", "float_value", "double_value", "int64_value", "separator", "uint64_value", "int32_value", "fixed64_value", "fixed32_value", "bool_value", "strprefix", "string_value", "uint32_value", "sfixed32_value", "sfixed64_value", "sint32_value", "sint64_value"}, "")) + pattern_ABitOfEverythingService_Create_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 1, 0, 4, 1, 5, 4, 1, 0, 4, 1, 5, 5, 2, 6, 1, 0, 4, 1, 5, 7, 1, 0, 4, 1, 5, 8, 1, 0, 4, 1, 5, 9, 1, 0, 4, 1, 5, 10, 1, 0, 4, 1, 5, 11, 2, 12, 1, 0, 4, 2, 5, 13, 1, 0, 4, 1, 5, 14, 1, 0, 4, 1, 5, 15, 1, 0, 4, 1, 5, 16, 1, 0, 4, 1, 5, 17, 1, 0, 4, 1, 5, 18}, []string{"v1", "example", "a_bit_of_everything", "float_value", "double_value", "int64_value", "separator", "uint64_value", "int32_value", "fixed64_value", "fixed32_value", "bool_value", "strprefix", "string_value", "uint32_value", "sfixed32_value", "sfixed64_value", "sint32_value", "sint64_value"}, "")) + + pattern_ABitOfEverythingService_CreateBody_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "a_bit_of_everything"}, "")) - pattern_ABitOfEverythingService_CreateBody = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "a_bit_of_everything"}, "")) + pattern_ABitOfEverythingService_BulkCreate_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "bulk"}, "")) - pattern_ABitOfEverythingService_BulkCreate = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "bulk"}, "")) + pattern_ABitOfEverythingService_Lookup_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) - pattern_ABitOfEverythingService_Lookup = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) + pattern_ABitOfEverythingService_List_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "a_bit_of_everything"}, "")) - pattern_ABitOfEverythingService_List = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "a_bit_of_everything"}, "")) + pattern_ABitOfEverythingService_Update_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) - pattern_ABitOfEverythingService_Update = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) + pattern_ABitOfEverythingService_Delete_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) - pattern_ABitOfEverythingService_Delete = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) + pattern_ABitOfEverythingService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 1, 0, 4, 1, 5, 4}, []string{"v1", "example", "a_bit_of_everything", "echo", "value"}, "")) - pattern_ABitOfEverythingService_Echo = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 1, 0, 4, 1, 5, 4}, []string{"v1", "example", "a_bit_of_everything", "echo", "value"}, "")) + pattern_ABitOfEverythingService_Echo_1 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v2", "example", "echo"}, "")) - pattern_ABitOfEverythingService_BulkEcho = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "echo"}, "")) + pattern_ABitOfEverythingService_BulkEcho_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "echo"}, "")) ) diff --git a/examples/a_bit_of_everything.proto b/examples/a_bit_of_everything.proto index 434ff2b4e25..980509f558f 100644 --- a/examples/a_bit_of_everything.proto +++ b/examples/a_bit_of_everything.proto @@ -80,6 +80,10 @@ service ABitOfEverythingService { rpc Echo(gengo.grpc.gateway.examples.sub.StringMessage) returns (gengo.grpc.gateway.examples.sub.StringMessage) { option (google.api.http) = { get: "/v1/example/a_bit_of_everything/echo/{value}" + additional_bindings { + post: "/v2/example/echo" + body: "value" + } }; } rpc BulkEcho(stream gengo.grpc.gateway.examples.sub.StringMessage) returns (stream gengo.grpc.gateway.examples.sub.StringMessage) { diff --git a/examples/echo_service.pb.gw.go b/examples/echo_service.pb.gw.go index 6e47decf18b..2046e5717ad 100644 --- a/examples/echo_service.pb.gw.go +++ b/examples/echo_service.pb.gw.go @@ -27,7 +27,7 @@ var _ io.Reader var _ = runtime.String var _ = json.Marshal -func request_EchoService_Echo(ctx context.Context, client EchoServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_EchoService_Echo_0(ctx context.Context, client EchoServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq SimpleMessage var val string @@ -45,7 +45,7 @@ func request_EchoService_Echo(ctx context.Context, client EchoServiceClient, req return client.Echo(ctx, &protoReq) } -func request_EchoService_EchoBody(ctx context.Context, client EchoServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_EchoService_EchoBody_0(ctx context.Context, client EchoServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq SimpleMessage if err = json.NewDecoder(req.Body).Decode(&protoReq); err != nil { @@ -85,8 +85,8 @@ func RegisterEchoServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.Se func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { client := NewEchoServiceClient(conn) - mux.Handle("POST", pattern_EchoService_Echo, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_EchoService_Echo(ctx, client, req, pathParams) + mux.Handle("POST", pattern_EchoService_Echo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_EchoService_Echo_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -96,8 +96,8 @@ func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn }) - mux.Handle("POST", pattern_EchoService_EchoBody, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_EchoService_EchoBody(ctx, client, req, pathParams) + mux.Handle("POST", pattern_EchoService_EchoBody_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_EchoService_EchoBody_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -111,7 +111,7 @@ func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn } var ( - pattern_EchoService_Echo = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "echo", "id"}, "")) + pattern_EchoService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "echo", "id"}, "")) - pattern_EchoService_EchoBody = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "echo_body"}, "")) + pattern_EchoService_EchoBody_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "echo_body"}, "")) ) diff --git a/examples/integration_test.go b/examples/integration_test.go index 7dfac5315de..ae2aa347f83 100644 --- a/examples/integration_test.go +++ b/examples/integration_test.go @@ -14,6 +14,7 @@ import ( gw "github.com/gengo/grpc-gateway/examples" server "github.com/gengo/grpc-gateway/examples/server" + sub "github.com/gengo/grpc-gateway/examples/sub" ) func TestIntegration(t *testing.T) { @@ -43,6 +44,7 @@ func TestIntegration(t *testing.T) { testABEBulkCreate(t) testABELookup(t) testABEList(t) + testAdditionalBindings(t) } func testEcho(t *testing.T) { @@ -377,3 +379,51 @@ func testABEList(t *testing.T) { t.Errorf("i == %d; want > 0", i) } } + +func testAdditionalBindings(t *testing.T) { + for i, f := range []func() *http.Response{ + func() *http.Response { + url := "http://localhost:8080/v1/example/a_bit_of_everything/echo/hello" + resp, err := http.Get(url) + if err != nil { + t.Errorf("http.Get(%q) failed with %v; want success", url, err) + return nil + } + return resp + }, + func() *http.Response { + url := "http://localhost:8080/v2/example/echo" + resp, err := http.Post(url, "application/json", strings.NewReader(`"hello"`)) + if err != nil { + t.Errorf("http.Post(%q, %q, %q) failed with %v; want success", url, "application/json", `"hello"`, err) + return nil + } + return resp + }, + } { + resp := f() + if resp == nil { + continue + } + + defer resp.Body.Close() + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("iotuil.ReadAll(resp.Body) failed with %v; want success; i=%d", err, i) + return + } + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("resp.StatusCode = %d; want %d; i=%d", got, want, i) + t.Logf("%s", buf) + } + + var msg sub.StringMessage + if err := json.Unmarshal(buf, &msg); err != nil { + t.Errorf("json.Unmarshal(%s, &msg) failed with %v; want success; %i", buf, err, i) + return + } + if got, want := msg.GetValue(), "hello"; got != want { + t.Errorf("msg.GetValue() = %q; want %q", got, want) + } + } +} From e08af582bc268dc858b96aaa89ca0f6d9be1561d Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Tue, 12 May 2015 20:32:12 +0900 Subject: [PATCH 05/10] Prototype query string support --- Makefile | 4 +- examples/a_bit_of_everything.pb.gw.go | 27 +++++++++++ examples/a_bit_of_everything.proto | 3 ++ examples/integration_test.go | 9 ++++ .../descriptor => internal}/name.go | 10 ++-- internal/name_test.go | 22 +++++++++ .../descriptor/services_test.go | 3 -- protoc-gen-grpc-gateway/descriptor/types.go | 40 ++++++++++++++-- .../gengateway/template.go | 11 ++--- .../gengateway/template_test.go | 3 -- runtime/query.go | 46 +++++++++++++++++++ 11 files changed, 153 insertions(+), 25 deletions(-) rename {protoc-gen-grpc-gateway/descriptor => internal}/name.go (60%) create mode 100644 internal/name_test.go create mode 100644 runtime/query.go diff --git a/Makefile b/Makefile index 2fad9892b8c..fbde89d7cd5 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,9 @@ GO_PLUGIN=bin/protoc-gen-go GO_PLUGIN_PKG=github.com/golang/protobuf/protoc-gen-go GATEWAY_PLUGIN=bin/protoc-gen-grpc-gateway GATEWAY_PLUGIN_PKG=$(PKG)/protoc-gen-grpc-gateway -GATEWAY_PLUGIN_SRC= protoc-gen-grpc-gateway/descriptor/name.go \ +GATEWAY_PLUGIN_SRC= internal/doc.go \ + internal/name.go \ + internal/pattern.go \ protoc-gen-grpc-gateway/descriptor/registry.go \ protoc-gen-grpc-gateway/descriptor/services.go \ protoc-gen-grpc-gateway/descriptor/types.go \ diff --git a/examples/a_bit_of_everything.pb.gw.go b/examples/a_bit_of_everything.pb.gw.go index 4f901c280ce..7d09f5d4a8d 100644 --- a/examples/a_bit_of_everything.pb.gw.go +++ b/examples/a_bit_of_everything.pb.gw.go @@ -160,6 +160,10 @@ func request_ABitOfEverythingService_Create_0(ctx context.Context, client ABitOf return nil, err } + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), []string{"float_value", "double_value", "int64_value", "uint64_value", "int32_value", "fixed64_value", "fixed32_value", "bool_value", "string_value", "uint32_value", "sfixed32_value", "sfixed64_value", "sint32_value", "sint64_value"}); err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + } + return client.Create(ctx, &protoReq) } @@ -292,6 +296,16 @@ func request_ABitOfEverythingService_Echo_1(ctx context.Context, client ABitOfEv return client.Echo(ctx, &protoReq) } +func request_ABitOfEverythingService_Echo_2(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { + var protoReq sub.StringMessage + + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), []string(nil)); err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + } + + return client.Echo(ctx, &protoReq) +} + func request_ABitOfEverythingService_BulkEcho_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (ABitOfEverythingService_BulkEchoClient, error) { stream, err := client.BulkEcho(ctx) if err != nil { @@ -452,6 +466,17 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) + mux.Handle("GET", pattern_ABitOfEverythingService_Echo_2, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Echo_2(ctx, client, req, pathParams) + if err != nil { + runtime.HTTPError(w, err) + return + } + + runtime.ForwardResponseMessage(w, resp) + + }) + mux.Handle("POST", pattern_ABitOfEverythingService_BulkEcho_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { resp, err := request_ABitOfEverythingService_BulkEcho_0(ctx, client, req, pathParams) if err != nil { @@ -485,5 +510,7 @@ var ( pattern_ABitOfEverythingService_Echo_1 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v2", "example", "echo"}, "")) + pattern_ABitOfEverythingService_Echo_2 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v2", "example", "echo"}, "")) + pattern_ABitOfEverythingService_BulkEcho_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "echo"}, "")) ) diff --git a/examples/a_bit_of_everything.proto b/examples/a_bit_of_everything.proto index 980509f558f..c4cf1ee8b32 100644 --- a/examples/a_bit_of_everything.proto +++ b/examples/a_bit_of_everything.proto @@ -84,6 +84,9 @@ service ABitOfEverythingService { post: "/v2/example/echo" body: "value" } + additional_bindings { + get: "/v2/example/echo" + } }; } rpc BulkEcho(stream gengo.grpc.gateway.examples.sub.StringMessage) returns (stream gengo.grpc.gateway.examples.sub.StringMessage) { diff --git a/examples/integration_test.go b/examples/integration_test.go index ae2aa347f83..65fddc9907b 100644 --- a/examples/integration_test.go +++ b/examples/integration_test.go @@ -400,6 +400,15 @@ func testAdditionalBindings(t *testing.T) { } return resp }, + func() *http.Response { + url := "http://localhost:8080/v2/example/echo?value=hello" + resp, err := http.Get(url) + if err != nil { + t.Errorf("http.Get(%q) failed with %v; want success", url, err) + return nil + } + return resp + }, } { resp := f() if resp == nil { diff --git a/protoc-gen-grpc-gateway/descriptor/name.go b/internal/name.go similarity index 60% rename from protoc-gen-grpc-gateway/descriptor/name.go rename to internal/name.go index 78234461a73..a6d771936c9 100644 --- a/protoc-gen-grpc-gateway/descriptor/name.go +++ b/internal/name.go @@ -1,15 +1,11 @@ -package descriptor +package internal import ( - "regexp" "strings" ) -var ( - upperPattern = regexp.MustCompile("[A-Z]") -) - -func toCamel(str string) string { +// PascalFromSnake converts an identifier in snake_case into PascalCase. +func PascalFromSnake(str string) string { var components []string for _, c := range strings.Split(str, "_") { components = append(components, strings.Title(strings.ToLower(c))) diff --git a/internal/name_test.go b/internal/name_test.go new file mode 100644 index 00000000000..7b1b6b62bbe --- /dev/null +++ b/internal/name_test.go @@ -0,0 +1,22 @@ +package internal_test + +import ( + "testing" + + "github.com/gengo/grpc-gateway/internal" +) + +func TestPascalToSnake(t *testing.T) { + for _, spec := range []struct { + input, want string + }{ + {input: "value", want: "Value"}, + {input: "prefixed_value", want: "PrefixedValue"}, + {input: "foo_id", want: "FooId"}, + } { + got := internal.PascalFromSnake(spec.input) + if got != spec.want { + t.Errorf("internal.PascalFromSnake(%q) = %q; want %q", spec.input, got, spec.want) + } + } +} diff --git a/protoc-gen-grpc-gateway/descriptor/services_test.go b/protoc-gen-grpc-gateway/descriptor/services_test.go index 3c467ed7bf9..323147904f0 100644 --- a/protoc-gen-grpc-gateway/descriptor/services_test.go +++ b/protoc-gen-grpc-gateway/descriptor/services_test.go @@ -141,9 +141,6 @@ func crossLinkFixture(f *File) *File { for _, param := range b.PathParams { param.Method = m } - for _, param := range b.QueryParams { - param.Method = m - } } } } diff --git a/protoc-gen-grpc-gateway/descriptor/types.go b/protoc-gen-grpc-gateway/descriptor/types.go index 6cc35dd00d2..31ef5c358ee 100644 --- a/protoc-gen-grpc-gateway/descriptor/types.go +++ b/protoc-gen-grpc-gateway/descriptor/types.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/gengo/grpc-gateway/internal" "github.com/gengo/grpc-gateway/protoc-gen-grpc-gateway/httprule" descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" ) @@ -121,12 +122,41 @@ type Binding struct { HTTPMethod string // PathParams is the list of parameters provided in HTTP request paths. PathParams []Parameter - // QueryParam is the list of parameters provided in HTTP query strings. - QueryParams []Parameter // Body describes parameters provided in HTTP request body. Body *Body } +// HasQueryParams returns if "b" has query_string params. +func (b *Binding) HasQueryParams() bool { + if b.Body != nil && len(b.Body.FieldPath) == 0 { + return false + } + fields := make(map[string]bool) + for _, f := range b.Method.RequestType.GetField() { + fields[f.GetName()] = true + } + if b.Body != nil { + delete(fields, b.Body.FieldPath.String()) + } + for _, p := range b.PathParams { + delete(fields, p.FieldPath.String()) + } + return len(fields) > 0 +} + +// ExplicitParams returns a list of explicitly bound parameters of "b", +// i.e. a union of field path for body and field paths for path parameters. +func (b *Binding) ExplicitParams() []string { + var result []string + if b.Body != nil { + result = append(result, b.Body.FieldPath.String()) + } + for _, p := range b.PathParams { + result = append(result, p.FieldPath.String()) + } + return result +} + // Field wraps descriptor.FieldDescriptorProto for richer features. type Field struct { // Message is the message type which this field belongs to. @@ -213,15 +243,15 @@ type FieldPathComponent struct { // RHS returns a right-hand-side expression in go for this field. func (c FieldPathComponent) RHS() string { - return toCamel(c.Name) + return internal.PascalFromSnake(c.Name) } // LHS returns a left-hand-side expression in go for this field. func (c FieldPathComponent) LHS() string { if c.Target.Message.File.proto2() { - return fmt.Sprintf("Get%s()", toCamel(c.Name)) + return fmt.Sprintf("Get%s()", internal.PascalFromSnake(c.Name)) } - return toCamel(c.Name) + return internal.PascalFromSnake(c.Name) } var ( diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index 7640045e866..e6a78fa2248 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -115,12 +115,6 @@ func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx cont _ = template.Must(handlerTemplate.New("client-rpc-request-func").Parse(` {{template "request-func-signature" .}} { var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}} - {{range $param := .QueryParams}} - protoReq.{{$param.RHS "protoReq"}}, err = {{$param.ConvertFuncExpr}}(req.FormValue({{$param | printf "%q"}})) - if err != nil { - return nil, err - } - {{end}} {{if .Body}} if err = json.NewDecoder(req.Body).Decode(&{{.Body.RHS "protoReq"}}); err != nil { return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) @@ -140,6 +134,11 @@ func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx cont } {{end}} {{end}} +{{if .HasQueryParams}} + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), {{.ExplicitParams | printf "%#v"}}); err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + } +{{end}} return client.{{.Method.GetName}}(ctx, &protoReq) }`)) diff --git a/protoc-gen-grpc-gateway/gengateway/template_test.go b/protoc-gen-grpc-gateway/gengateway/template_test.go index e66be979d73..f52287f16cc 100644 --- a/protoc-gen-grpc-gateway/gengateway/template_test.go +++ b/protoc-gen-grpc-gateway/gengateway/template_test.go @@ -23,9 +23,6 @@ func crossLinkFixture(f *descriptor.File) *descriptor.File { for _, param := range b.PathParams { param.Method = m } - for _, param := range b.QueryParams { - param.Method = m - } } } } diff --git a/runtime/query.go b/runtime/query.go new file mode 100644 index 00000000000..05635ba5902 --- /dev/null +++ b/runtime/query.go @@ -0,0 +1,46 @@ +package runtime + +import ( + "net/url" + "strings" + + "github.com/ajg/form" + "github.com/gengo/grpc-gateway/internal" + "github.com/golang/protobuf/proto" +) + +func isQueryParam(key string, filters []string) bool { + for _, f := range filters { + if strings.HasPrefix(key, f) { + switch l, m := len(key), len(f); { + case l == m: + return false + case key[m] == '.': + return false + } + } + } + return true +} + +func convertPath(path string) string { + var components []string + for _, c := range strings.Split(path, ".") { + components = append(components, internal.PascalFromSnake(c)) + } + return strings.Join(components, ".") +} + +// PopulateQueryParameters populates "values" into "msg". +// A value is ignored if its key starts with one of the elements in "filters". +// +// TODO(yugui) Use trie for filters? +func PopulateQueryParameters(msg proto.Message, values url.Values, filters []string) error { + filtered := make(url.Values) + for key, values := range values { + if isQueryParam(key, filters) { + filtered[convertPath(key)] = values + } + } + return form.DecodeValues(msg, filtered) +} From eeb1332429737b75e5000a1306c0360279abe5a1 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Sat, 16 May 2015 22:03:40 +0900 Subject: [PATCH 06/10] Implement Double Arrary of sequences of strings --- internal/trie.go | 177 ++++++++++++++++++++ internal/trie_test.go | 372 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 549 insertions(+) create mode 100644 internal/trie.go create mode 100644 internal/trie_test.go diff --git a/internal/trie.go b/internal/trie.go new file mode 100644 index 00000000000..03f80d72f8a --- /dev/null +++ b/internal/trie.go @@ -0,0 +1,177 @@ +package internal + +import ( + "sort" +) + +// DoubleArray is a Double Array implementation of trie on sequences of strings. +type DoubleArray struct { + // Encoding keeps an encoding from string to int + Encoding map[string]int + // Base is the base array of Double Array + Base []int + // Check is the check array of Double Array + Check []int +} + +// NewDoubleArray builds a DoubleArray from a set of sequences of strings. +func NewDoubleArray(seqs [][]string) *DoubleArray { + da := &DoubleArray{Encoding: make(map[string]int)} + if len(seqs) == 0 { + return da + } + + encoded := registerTokens(da, seqs) + sort.Sort(byLex(encoded)) + + root := node{row: -1, col: -1, left: 0, right: len(encoded)} + addSeqs(da, encoded, 0, root) + + for i := len(da.Base); i > 0; i-- { + if da.Check[i-1] != 0 { + da.Base = da.Base[:i] + da.Check = da.Check[:i] + break + } + } + return da +} + +func registerTokens(da *DoubleArray, seqs [][]string) [][]int { + var result [][]int + for _, seq := range seqs { + var encoded []int + for _, token := range seq { + if _, ok := da.Encoding[token]; !ok { + da.Encoding[token] = len(da.Encoding) + } + encoded = append(encoded, da.Encoding[token]) + } + result = append(result, encoded) + } + for i := range result { + result[i] = append(result[i], len(da.Encoding)) + } + return result +} + +type node struct { + row, col int + left, right int +} + +func (n node) value(seqs [][]int) int { + return seqs[n.row][n.col] +} + +func (n node) children(seqs [][]int) []*node { + var result []*node + lastVal := int(-1) + last := new(node) + for i := n.left; i < n.right; i++ { + if lastVal == seqs[i][n.col+1] { + continue + } + last.right = i + last = &node{ + row: i, + col: n.col + 1, + left: i, + } + result = append(result, last) + } + last.right = n.right + return result +} + +func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) { + ensureSize(da, pos) + + children := n.children(seqs) + var i int + for i = 1; ; i++ { + ok := func() bool { + for _, child := range children { + code := child.value(seqs) + j := i + code + ensureSize(da, j) + if da.Check[j] != 0 { + return false + } + } + return true + }() + if ok { + break + } + } + da.Base[pos] = i + for _, child := range children { + code := child.value(seqs) + j := i + code + da.Check[j] = pos + 1 + } + terminator := len(da.Encoding) + for _, child := range children { + code := child.value(seqs) + if code == terminator { + continue + } + j := i + code + addSeqs(da, seqs, j, *child) + } +} + +func ensureSize(da *DoubleArray, i int) { + for i >= len(da.Base) { + da.Base = append(da.Base, make([]int, len(da.Base)+1)...) + da.Check = append(da.Check, make([]int, len(da.Check)+1)...) + } +} + +type byLex [][]int + +func (l byLex) Len() int { return len(l) } +func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] } +func (l byLex) Less(i, j int) bool { + si := l[i] + sj := l[j] + var k int + for k = 0; k < len(si) && k < len(sj); k++ { + if si[k] < sj[k] { + return true + } + if si[k] > sj[k] { + return false + } + } + if k < len(sj) { + return true + } + return false +} + +// HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence. +func (da *DoubleArray) HasCommonPrefix(seq []string) bool { + if len(da.Base) == 0 { + return false + } + + var i int + for _, t := range seq { + code, ok := da.Encoding[t] + if !ok { + break + } + j := da.Base[i] + code + if len(da.Check) <= j || da.Check[j] != i+1 { + break + } + i = j + } + j := da.Base[i] + len(da.Encoding) + if len(da.Check) <= j || da.Check[j] != i+1 { + return false + } + return true +} diff --git a/internal/trie_test.go b/internal/trie_test.go new file mode 100644 index 00000000000..0422b693598 --- /dev/null +++ b/internal/trie_test.go @@ -0,0 +1,372 @@ +package internal_test + +import ( + "reflect" + "testing" + + "github.com/gengo/grpc-gateway/internal" +) + +func TestMaxCommonPrefix(t *testing.T) { + for _, spec := range []struct { + da internal.DoubleArray + tokens []string + want bool + }{ + { + da: internal.DoubleArray{}, + tokens: nil, + want: false, + }, + { + da: internal.DoubleArray{}, + tokens: []string{"foo"}, + want: false, + }, + { + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + }, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + }, + tokens: nil, + want: false, + }, + { + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + }, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + }, + tokens: []string{"foo"}, + want: true, + }, + { + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + }, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + }, + tokens: []string{"bar"}, + want: false, + }, + { + // foo|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"foo"}, + want: true, + }, + { + // foo|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"bar"}, + want: true, + }, + { + // foo|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"something-else"}, + want: false, + }, + { + // foo|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"foo", "bar"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"foo"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"foo", "bar"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"bar"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"something-else"}, + want: false, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"foo", "bar", "baz"}, + want: true, + }, + } { + got := spec.da.HasCommonPrefix(spec.tokens) + if got != spec.want { + t.Errorf("%#v.HasCommonPrefix(%v) = %v; want %v", spec.da, spec.tokens, got, spec.want) + } + } +} + +func TestAdd(t *testing.T) { + for _, spec := range []struct { + tokens [][]string + want internal.DoubleArray + }{ + { + want: internal.DoubleArray{ + Encoding: make(map[string]int), + }, + }, + { + tokens: [][]string{{"foo"}}, + want: internal.DoubleArray{ + Encoding: map[string]int{"foo": 0}, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + // 0: ^ + // 1: ^foo + // 2: ^foo$ + }, + }, + { + tokens: [][]string{{"foo"}, {"bar"}}, + want: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + }, + { + tokens: [][]string{{"foo", "bar"}, {"foo", "baz"}}, + want: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + "baz": 2, + }, + Base: []int{1, 1, 1, 2, 0, 0}, + Check: []int{0, 1, 2, 2, 3, 4}, + // 0: ^ + // 1: ^foo + // 2: ^foo.bar + // 3: ^foo.baz + // 4: ^foo.bar$ + // 5: ^foo.baz$ + }, + }, + { + tokens: [][]string{{"foo", "bar"}, {"foo", "baz"}, {"qux"}}, + want: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + "baz": 2, + "qux": 3, + }, + Base: []int{1, 1, 1, 2, 3, 0, 0, 0}, + Check: []int{0, 1, 2, 2, 1, 3, 4, 5}, + // 0: ^ + // 1: ^foo + // 2: ^foo.bar + // 3: ^foo.baz + // 4: ^qux + // 5: ^foo.bar$ + // 6: ^foo.baz$ + // 7: ^qux$ + }, + }, + { + tokens: [][]string{ + {"foo", "bar"}, + {"foo", "baz", "bar"}, + {"qux", "foo"}, + }, + want: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + "baz": 2, + "qux": 3, + }, + Base: []int{1, 1, 1, 5, 8, 0, 3, 0, 5, 0}, + Check: []int{0, 1, 2, 2, 1, 3, 4, 7, 5, 9}, + // 0: ^ + // 1: ^foo + // 2: ^foo.bar + // 3: ^foo.baz + // 4: ^qux + // 5: ^foo.bar$ + // 6: ^foo.baz.bar + // 7: ^foo.baz.bar$ + // 8: ^qux.foo + // 9: ^qux.foo$ + }, + }, + } { + da := internal.NewDoubleArray(spec.tokens) + if got, want := da.Encoding, spec.want.Encoding; !reflect.DeepEqual(got, want) { + t.Errorf("da.Encoding = %v; want %v; tokens = %#v", got, want, spec.tokens) + } + if got, want := da.Base, spec.want.Base; !compareArray(got, want) { + t.Errorf("da.Base = %v; want %v; tokens = %#v", got, want, spec.tokens) + } + if got, want := da.Check, spec.want.Check; !compareArray(got, want) { + t.Errorf("da.Check = %v; want %v; tokens = %#v", got, want, spec.tokens) + } + } +} + +func compareArray(got, want []int) bool { + var i int + for i = 0; i < len(got) && i < len(want); i++ { + if got[i] != want[i] { + return false + } + } + if i < len(want) { + return false + } + for ; i < len(got); i++ { + if got[i] != 0 { + return false + } + } + return true +} From 9438224af4233840962faeb25beab059a46312de Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Mon, 18 May 2015 21:24:59 +0900 Subject: [PATCH 07/10] Reimplement url.Value unmarshaller * Use internal.DoubleArray to reduce time complexity. * Replace github.com/ajg/form with a custom impl because: 1. proto messages do not have "form" struct tag, thus they does not fit to the library very well. 2. field paths in google.api.http does not support so complex path that we need the library. --- runtime/query.go | 124 +++++++++++++---- runtime/query_test.go | 306 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 404 insertions(+), 26 deletions(-) create mode 100644 runtime/query_test.go diff --git a/runtime/query.go b/runtime/query.go index 05635ba5902..2ede4735754 100644 --- a/runtime/query.go +++ b/runtime/query.go @@ -1,46 +1,118 @@ package runtime import ( + "fmt" "net/url" + "reflect" "strings" - "github.com/ajg/form" "github.com/gengo/grpc-gateway/internal" + "github.com/golang/glog" "github.com/golang/protobuf/proto" ) -func isQueryParam(key string, filters []string) bool { - for _, f := range filters { - if strings.HasPrefix(key, f) { - switch l, m := len(key), len(f); { - case l == m: - return false - case key[m] == '.': - return false - } +// PopulateQueryParameters populates "values" into "msg". +// A value is ignored if its key starts with one of the elements in "filters". +func PopulateQueryParameters(msg proto.Message, values url.Values, filter *internal.DoubleArray) error { + for key, values := range values { + fieldPath := strings.Split(key, ".") + if filter.HasCommonPrefix(fieldPath) { + continue + } + if err := populateQueryParameter(msg, fieldPath, values); err != nil { + return err } } - return true + return nil } -func convertPath(path string) string { - var components []string - for _, c := range strings.Split(path, ".") { - components = append(components, internal.PascalFromSnake(c)) +func populateQueryParameter(msg proto.Message, fieldPath []string, values []string) error { + m := reflect.ValueOf(msg) + if m.Kind() != reflect.Ptr { + return fmt.Errorf("unexpected type %T: %v", msg, msg) + } + m = m.Elem() + for i, fieldName := range fieldPath { + isLast := i == len(fieldPath)-1 + if !isLast && m.Kind() != reflect.Struct { + return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, ".")) + } + f := m.FieldByName(internal.PascalFromSnake(fieldName)) + if !f.IsValid() { + glog.Warningf("field not found in %T: %s", msg, strings.Join(fieldPath, ".")) + return nil + } + + switch f.Kind() { + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64: + m = f + case reflect.Slice: + // TODO(yugui) Support []byte + if !isLast { + return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, ".")) + } + return populateRepeatedField(f, values) + case reflect.Ptr: + if f.IsNil() { + m = reflect.New(f.Type().Elem()) + f.Set(m) + } + m = f.Elem() + continue + default: + return fmt.Errorf("unexpected type %s in %T", f.Type(), msg) + } } - return strings.Join(components, ".") + switch len(values) { + case 0: + return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, ".")) + case 1: + default: + glog.Warningf("too many field values: %s", strings.Join(fieldPath, ".")) + } + return populateField(m, values[0]) } -// PopulateQueryParameters populates "values" into "msg". -// A value is ignored if its key starts with one of the elements in "filters". -// -// TODO(yugui) Use trie for filters? -func PopulateQueryParameters(msg proto.Message, values url.Values, filters []string) error { - filtered := make(url.Values) - for key, values := range values { - if isQueryParam(key, filters) { - filtered[convertPath(key)] = values +func populateRepeatedField(f reflect.Value, values []string) error { + elemType := f.Type().Elem() + conv, ok := convFromType[elemType.Kind()] + if !ok { + return fmt.Errorf("unsupported field type %s", elemType) + } + f.Set(reflect.MakeSlice(f.Type(), len(values), len(values))) + for i, v := range values { + result := conv.Call([]reflect.Value{reflect.ValueOf(v)}) + if err := result[1].Interface(); err != nil { + return err.(error) } + f.Index(i).Set(result[0]) + } + return nil +} + +func populateField(f reflect.Value, value string) error { + conv, ok := convFromType[f.Kind()] + if !ok { + return fmt.Errorf("unsupported field type %T", f) + } + result := conv.Call([]reflect.Value{reflect.ValueOf(value)}) + if err := result[1].Interface(); err != nil { + return err.(error) } - return form.DecodeValues(msg, filtered) + f.Set(result[0]) + return nil } + +var ( + convFromType = map[reflect.Kind]reflect.Value{ + reflect.String: reflect.ValueOf(String), + reflect.Bool: reflect.ValueOf(Bool), + reflect.Float64: reflect.ValueOf(Float64), + reflect.Float32: reflect.ValueOf(Float32), + reflect.Int64: reflect.ValueOf(Int64), + reflect.Int32: reflect.ValueOf(Int32), + reflect.Uint64: reflect.ValueOf(Uint64), + reflect.Uint32: reflect.ValueOf(Uint32), + // TODO(yugui) Support []byte + } +) diff --git a/runtime/query_test.go b/runtime/query_test.go new file mode 100644 index 00000000000..5c238d53540 --- /dev/null +++ b/runtime/query_test.go @@ -0,0 +1,306 @@ +package runtime_test + +import ( + "net/url" + "testing" + + "github.com/gengo/grpc-gateway/internal" + "github.com/gengo/grpc-gateway/runtime" + "github.com/golang/protobuf/proto" +) + +func TestPopulateParameters(t *testing.T) { + for _, spec := range []struct { + values url.Values + filter *internal.DoubleArray + want proto.Message + }{ + { + values: url.Values{ + "float_value": {"1.5"}, + "double_value": {"2.5"}, + "int64_value": {"-1"}, + "int32_value": {"-2"}, + "uint64_value": {"3"}, + "uint32_value": {"4"}, + "bool_value": {"true"}, + "string_value": {"str"}, + "repeated_value": {"a", "b", "c"}, + }, + filter: internal.NewDoubleArray(nil), + want: &proto3Message{ + FloatValue: 1.5, + DoubleValue: 2.5, + Int64Value: -1, + Int32Value: -2, + Uint64Value: 3, + Uint32Value: 4, + BoolValue: true, + StringValue: "str", + RepeatedValue: []string{"a", "b", "c"}, + }, + }, + { + values: url.Values{ + "float_value": {"1.5"}, + "double_value": {"2.5"}, + "int64_value": {"-1"}, + "int32_value": {"-2"}, + "uint64_value": {"3"}, + "uint32_value": {"4"}, + "bool_value": {"true"}, + "string_value": {"str"}, + "repeated_value": {"a", "b", "c"}, + }, + filter: internal.NewDoubleArray(nil), + want: &proto2Message{ + FloatValue: proto.Float32(1.5), + DoubleValue: proto.Float64(2.5), + Int64Value: proto.Int64(-1), + Int32Value: proto.Int32(-2), + Uint64Value: proto.Uint64(3), + Uint32Value: proto.Uint32(4), + BoolValue: proto.Bool(true), + StringValue: proto.String("str"), + RepeatedValue: []string{"a", "b", "c"}, + }, + }, + { + values: url.Values{ + "nested.nested.nested.repeated_value": {"a", "b", "c"}, + "nested.nested.nested.string_value": {"s"}, + "nested.nested.string_value": {"t"}, + "nested.string_value": {"u"}, + }, + filter: internal.NewDoubleArray(nil), + want: &proto3Message{ + Nested: &proto2Message{ + Nested: &proto3Message{ + Nested: &proto2Message{ + RepeatedValue: []string{"a", "b", "c"}, + StringValue: proto.String("s"), + }, + StringValue: "t", + }, + StringValue: proto.String("u"), + }, + }, + }, + { + values: url.Values{ + "uint64_value": {"1", "2", "3", "4", "5"}, + }, + filter: internal.NewDoubleArray(nil), + want: &proto3Message{ + Uint64Value: 1, + }, + }, + } { + msg := proto.Clone(spec.want) + msg.Reset() + err := runtime.PopulateQueryParameters(msg, spec.values, spec.filter) + if err != nil { + t.Errorf("runtime.PoplateQueryParameters(msg, %v, %v) failed with %v; want success", spec.values, spec.filter, err) + continue + } + if got, want := msg, spec.want; !proto.Equal(got, want) { + t.Errorf("runtime.PopulateQueryParameters(msg, %v, %v = %v; want %v", spec.values, spec.filter, got, want) + } + } +} + +func TestPopulateParametersWithFilters(t *testing.T) { + for _, spec := range []struct { + values url.Values + filter *internal.DoubleArray + want proto.Message + }{ + { + values: url.Values{ + "bool_value": {"true"}, + "string_value": {"str"}, + "repeated_value": {"a", "b", "c"}, + }, + filter: internal.NewDoubleArray([][]string{ + {"bool_value"}, {"repeated_value"}, + }), + want: &proto3Message{ + StringValue: "str", + }, + }, + { + values: url.Values{ + "nested.nested.bool_value": {"true"}, + "nested.nested.string_value": {"str"}, + "nested.string_value": {"str"}, + "string_value": {"str"}, + }, + filter: internal.NewDoubleArray([][]string{ + {"nested"}, + }), + want: &proto3Message{ + StringValue: "str", + }, + }, + { + values: url.Values{ + "nested.nested.bool_value": {"true"}, + "nested.nested.string_value": {"str"}, + "nested.string_value": {"str"}, + "string_value": {"str"}, + }, + filter: internal.NewDoubleArray([][]string{ + {"nested", "nested"}, + }), + want: &proto3Message{ + Nested: &proto2Message{ + StringValue: proto.String("str"), + }, + StringValue: "str", + }, + }, + { + values: url.Values{ + "nested.nested.bool_value": {"true"}, + "nested.nested.string_value": {"str"}, + "nested.string_value": {"str"}, + "string_value": {"str"}, + }, + filter: internal.NewDoubleArray([][]string{ + {"nested", "nested", "string_value"}, + }), + want: &proto3Message{ + Nested: &proto2Message{ + StringValue: proto.String("str"), + Nested: &proto3Message{ + BoolValue: true, + }, + }, + StringValue: "str", + }, + }, + } { + msg := proto.Clone(spec.want) + msg.Reset() + err := runtime.PopulateQueryParameters(msg, spec.values, spec.filter) + if err != nil { + t.Errorf("runtime.PoplateQueryParameters(msg, %v, %v) failed with %v; want success", spec.values, spec.filter, err) + continue + } + if got, want := msg, spec.want; !proto.Equal(got, want) { + t.Errorf("runtime.PopulateQueryParameters(msg, %v, %v = %v; want %v", spec.values, spec.filter, got, want) + } + } +} + +type proto3Message struct { + Nested *proto2Message `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"` + FloatValue float32 `protobuf:"fixed32,2,opt,name=float_value" json:"float_value,omitempty"` + DoubleValue float64 `protobuf:"fixed64,3,opt,name=double_value" json:"double_value,omitempty"` + Int64Value int64 `protobuf:"varint,4,opt,name=int64_value" json:"int64_value,omitempty"` + Int32Value int32 `protobuf:"varint,5,opt,name=int32_value" json:"int32_value,omitempty"` + Uint64Value uint64 `protobuf:"varint,6,opt,name=uint64_value" json:"uint64_value,omitempty"` + Uint32Value uint32 `protobuf:"varint,7,opt,name=uint32_value" json:"uint32_value,omitempty"` + BoolValue bool `protobuf:"varint,8,opt,name=bool_value" json:"bool_value,omitempty"` + StringValue string `protobuf:"bytes,9,opt,name=string_value" json:"string_value,omitempty"` + RepeatedValue []string `protobuf:"bytes,10,rep,name=repeated_value" json:"repeated_value,omitempty"` +} + +func (m *proto3Message) Reset() { *m = proto3Message{} } +func (m *proto3Message) String() string { return proto.CompactTextString(m) } +func (*proto3Message) ProtoMessage() {} + +func (m *proto3Message) GetNested() *proto2Message { + if m != nil { + return m.Nested + } + return nil +} + +type proto2Message struct { + Nested *proto3Message `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"` + FloatValue *float32 `protobuf:"fixed32,2,opt,name=float_value" json:"float_value,omitempty"` + DoubleValue *float64 `protobuf:"fixed64,3,opt,name=double_value" json:"double_value,omitempty"` + Int64Value *int64 `protobuf:"varint,4,opt,name=int64_value" json:"int64_value,omitempty"` + Int32Value *int32 `protobuf:"varint,5,opt,name=int32_value" json:"int32_value,omitempty"` + Uint64Value *uint64 `protobuf:"varint,6,opt,name=uint64_value" json:"uint64_value,omitempty"` + Uint32Value *uint32 `protobuf:"varint,7,opt,name=uint32_value" json:"uint32_value,omitempty"` + BoolValue *bool `protobuf:"varint,8,opt,name=bool_value" json:"bool_value,omitempty"` + StringValue *string `protobuf:"bytes,9,opt,name=string_value" json:"string_value,omitempty"` + RepeatedValue []string `protobuf:"bytes,10,rep,name=repeated_value" json:"repeated_value,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *proto2Message) Reset() { *m = proto2Message{} } +func (m *proto2Message) String() string { return proto.CompactTextString(m) } +func (*proto2Message) ProtoMessage() {} + +func (m *proto2Message) GetNested() *proto3Message { + if m != nil { + return m.Nested + } + return nil +} + +func (m *proto2Message) GetFloatValue() float32 { + if m != nil && m.FloatValue != nil { + return *m.FloatValue + } + return 0 +} + +func (m *proto2Message) GetDoubleValue() float64 { + if m != nil && m.DoubleValue != nil { + return *m.DoubleValue + } + return 0 +} + +func (m *proto2Message) GetInt64Value() int64 { + if m != nil && m.Int64Value != nil { + return *m.Int64Value + } + return 0 +} + +func (m *proto2Message) GetInt32Value() int32 { + if m != nil && m.Int32Value != nil { + return *m.Int32Value + } + return 0 +} + +func (m *proto2Message) GetUint64Value() uint64 { + if m != nil && m.Uint64Value != nil { + return *m.Uint64Value + } + return 0 +} + +func (m *proto2Message) GetUint32Value() uint32 { + if m != nil && m.Uint32Value != nil { + return *m.Uint32Value + } + return 0 +} + +func (m *proto2Message) GetBoolValue() bool { + if m != nil && m.BoolValue != nil { + return *m.BoolValue + } + return false +} + +func (m *proto2Message) GetStringValue() string { + if m != nil && m.StringValue != nil { + return *m.StringValue + } + return "" +} + +func (m *proto2Message) GetRepeatedValue() []string { + if m != nil { + return m.RepeatedValue + } + return nil +} From f4f141617c2a638ba27131e306d3a5779cef5a33 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Tue, 19 May 2015 21:07:44 +0900 Subject: [PATCH 08/10] Update code template so that it uses Double Array. --- protoc-gen-grpc-gateway/descriptor/types.go | 18 ------- .../gengateway/generator.go | 1 + .../gengateway/template.go | 49 +++++++++++++++++-- 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/protoc-gen-grpc-gateway/descriptor/types.go b/protoc-gen-grpc-gateway/descriptor/types.go index 31ef5c358ee..089aca06c42 100644 --- a/protoc-gen-grpc-gateway/descriptor/types.go +++ b/protoc-gen-grpc-gateway/descriptor/types.go @@ -126,24 +126,6 @@ type Binding struct { Body *Body } -// HasQueryParams returns if "b" has query_string params. -func (b *Binding) HasQueryParams() bool { - if b.Body != nil && len(b.Body.FieldPath) == 0 { - return false - } - fields := make(map[string]bool) - for _, f := range b.Method.RequestType.GetField() { - fields[f.GetName()] = true - } - if b.Body != nil { - delete(fields, b.Body.FieldPath.String()) - } - for _, p := range b.PathParams { - delete(fields, p.FieldPath.String()) - } - return len(fields) > 0 -} - // ExplicitParams returns a list of explicitly bound parameters of "b", // i.e. a union of field path for body and field paths for path parameters. func (b *Binding) ExplicitParams() []string { diff --git a/protoc-gen-grpc-gateway/gengateway/generator.go b/protoc-gen-grpc-gateway/gengateway/generator.go index 78b8ba5fd05..e04dd7e9b9f 100644 --- a/protoc-gen-grpc-gateway/gengateway/generator.go +++ b/protoc-gen-grpc-gateway/gengateway/generator.go @@ -31,6 +31,7 @@ func New(reg *descriptor.Registry) *generator { "io", "net/http", "github.com/gengo/grpc-gateway/runtime", + "github.com/gengo/grpc-gateway/internal", "github.com/golang/glog", "github.com/golang/protobuf/proto", "golang.org/x/net/context", diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index e6a78fa2248..8d1d3e66a85 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -5,6 +5,7 @@ import ( "strings" "text/template" + "github.com/gengo/grpc-gateway/internal" "github.com/gengo/grpc-gateway/protoc-gen-grpc-gateway/descriptor" ) @@ -13,6 +14,42 @@ type param struct { Imports []descriptor.GoPackage } +type binding struct { + *descriptor.Binding +} + +// HasQueryParam determines if the binding needs parameters in query string. +// +// It sometimes returns true even though actually the binding does not need. +// But it is not serious because it just results in a small amount of extra codes generated. +func (b binding) HasQueryParam() bool { + if b.Body != nil && len(b.Body.FieldPath) == 0 { + return false + } + fields := make(map[string]bool) + for _, f := range b.Method.RequestType.Fields { + fields[f.GetName()] = true + } + if b.Body != nil { + delete(fields, b.Body.FieldPath.String()) + } + for _, p := range b.PathParams { + delete(fields, p.FieldPath.String()) + } + return len(fields) > 0 +} + +func (b binding) QueryParamFilter() *internal.DoubleArray { + var seqs [][]string + if b.Body != nil { + seqs = append(seqs, strings.Split(b.Body.FieldPath.String(), ".")) + } + for _, p := range b.PathParams { + seqs = append(seqs, strings.Split(p.FieldPath.String(), ".")) + } + return internal.NewDoubleArray(seqs) +} + func applyTemplate(p param) (string, error) { w := bytes.NewBuffer(nil) if err := headerTemplate.Execute(w, p); err != nil { @@ -23,7 +60,7 @@ func applyTemplate(p param) (string, error) { for _, meth := range svc.Methods { methodSeen = true for _, b := range meth.Bindings { - if err := handlerTemplate.Execute(w, b); err != nil { + if err := handlerTemplate.Execute(w, binding{Binding: b}); err != nil { return "", err } } @@ -60,6 +97,7 @@ var _ codes.Code var _ io.Reader var _ = runtime.String var _ = json.Marshal +var _ = internal.PascalFromSnake `)) handlerTemplate = template.Must(template.New("handler").Parse(` @@ -113,6 +151,11 @@ func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx cont `)) _ = template.Must(handlerTemplate.New("client-rpc-request-func").Parse(` +{{if .HasQueryParam}} +var ( + filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}} = {{.QueryParamFilter | printf "%#v"}} +) +{{end}} {{template "request-func-signature" .}} { var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}} {{if .Body}} @@ -134,8 +177,8 @@ func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx cont } {{end}} {{end}} -{{if .HasQueryParams}} - if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), {{.ExplicitParams | printf "%#v"}}); err != nil { +{{if .HasQueryParam}} + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}); err != nil { return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) } {{end}} From 34d3be982e64845a595160de3f6185a1d8be6542 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Tue, 19 May 2015 21:10:28 +0900 Subject: [PATCH 09/10] Update generated codes in examples --- examples/a_bit_of_everything.pb.gw.go | 14 ++++++++++++-- examples/echo_service.pb.gw.go | 2 ++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/a_bit_of_everything.pb.gw.go b/examples/a_bit_of_everything.pb.gw.go index 7d09f5d4a8d..28871407ab4 100644 --- a/examples/a_bit_of_everything.pb.gw.go +++ b/examples/a_bit_of_everything.pb.gw.go @@ -15,6 +15,7 @@ import ( "net/http" "github.com/gengo/grpc-gateway/examples/sub" + "github.com/gengo/grpc-gateway/internal" "github.com/gengo/grpc-gateway/runtime" "github.com/golang/glog" "github.com/golang/protobuf/proto" @@ -27,6 +28,11 @@ var _ codes.Code var _ io.Reader var _ = runtime.String var _ = json.Marshal +var _ = internal.PascalFromSnake + +var ( + filter_ABitOfEverythingService_Create_0 = &internal.DoubleArray{Encoding: map[string]int{"float_value": 0, "double_value": 1, "bool_value": 7, "sfixed64_value": 11, "sint32_value": 12, "sint64_value": 13, "int64_value": 2, "int32_value": 4, "string_value": 8, "uint32_value": 9, "sfixed32_value": 10, "fixed64_value": 5, "fixed32_value": 6, "uint64_value": 3}, Base: []int{1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Check: []int{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}} +) func request_ABitOfEverythingService_Create_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq ABitOfEverything @@ -160,7 +166,7 @@ func request_ABitOfEverythingService_Create_0(ctx context.Context, client ABitOf return nil, err } - if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), []string{"float_value", "double_value", "int64_value", "uint64_value", "int32_value", "fixed64_value", "fixed32_value", "bool_value", "string_value", "uint32_value", "sfixed32_value", "sfixed64_value", "sint32_value", "sint64_value"}); err != nil { + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_ABitOfEverythingService_Create_0); err != nil { return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) } @@ -296,10 +302,14 @@ func request_ABitOfEverythingService_Echo_1(ctx context.Context, client ABitOfEv return client.Echo(ctx, &protoReq) } +var ( + filter_ABitOfEverythingService_Echo_2 = &internal.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +) + func request_ABitOfEverythingService_Echo_2(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq sub.StringMessage - if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), []string(nil)); err != nil { + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_ABitOfEverythingService_Echo_2); err != nil { return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) } diff --git a/examples/echo_service.pb.gw.go b/examples/echo_service.pb.gw.go index 2046e5717ad..97b6fba4feb 100644 --- a/examples/echo_service.pb.gw.go +++ b/examples/echo_service.pb.gw.go @@ -14,6 +14,7 @@ import ( "io" "net/http" + "github.com/gengo/grpc-gateway/internal" "github.com/gengo/grpc-gateway/runtime" "github.com/golang/glog" "github.com/golang/protobuf/proto" @@ -26,6 +27,7 @@ var _ codes.Code var _ io.Reader var _ = runtime.String var _ = json.Marshal +var _ = internal.PascalFromSnake func request_EchoService_Echo_0(ctx context.Context, client EchoServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq SimpleMessage From 976f675102fdd268d3ce67bd08f4f72ebdf98f01 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Wed, 20 May 2015 20:33:13 +0900 Subject: [PATCH 10/10] Support X-HTTP-Method-Override --- runtime/mux.go | 29 ++++++- runtime/mux_test.go | 195 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 runtime/mux_test.go diff --git a/runtime/mux.go b/runtime/mux.go index 0fbd365c640..3538fc3820f 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -48,6 +48,13 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { verb, components[l-1] = c[:idx], c[idx+1:] } + if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) { + r.Method = strings.ToUpper(override) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } for _, h := range s.handlers[r.Method] { pathParams, err := h.pat.Match(components, verb) if err != nil { @@ -58,21 +65,37 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // lookup other methods to determine if it is MethodNotAllowed + // lookup other methods to handle fallback from GET to POST and + // to determine if it is MethodNotAllowed or NotFound. for m, handlers := range s.handlers { if m == r.Method { continue } for _, h := range handlers { - if _, err := h.pat.Match(components, verb); err == nil { - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + pathParams, err := h.pat.Match(components, verb) + if err != nil { + continue + } + // X-HTTP-Method-Override is optional. Always allow fallback to POST. + if isPathLengthFallback(r) { + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + h.h(w, r, pathParams) return } + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return } } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) } +func isPathLengthFallback(r *http.Request) bool { + return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" +} + type handler struct { pat Pattern h HandlerFunc diff --git a/runtime/mux_test.go b/runtime/mux_test.go new file mode 100644 index 00000000000..7394bb13a94 --- /dev/null +++ b/runtime/mux_test.go @@ -0,0 +1,195 @@ +package runtime_test + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gengo/grpc-gateway/internal" + "github.com/gengo/grpc-gateway/runtime" +) + +func TestMuxServeHTTP(t *testing.T) { + type stubPattern struct { + method string + ops []int + pool []string + } + for _, spec := range []struct { + patterns []stubPattern + + reqMethod string + reqPath string + headers map[string]string + + respStatus int + respContent string + }{ + { + patterns: nil, + reqMethod: "GET", + reqPath: "/", + respStatus: http.StatusNotFound, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "GET", + reqPath: "/foo", + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "GET", + reqPath: "/bar", + respStatus: http.StatusNotFound, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "GET", + ops: []int{int(internal.OpPush), 0}, + }, + }, + reqMethod: "GET", + reqPath: "/foo", + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "POST", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + respStatus: http.StatusOK, + respContent: "POST /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "DELETE", + reqPath: "/foo", + respStatus: http.StatusMethodNotAllowed, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "POST", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + "X-HTTP-Method-Override": "GET", + }, + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/json", + }, + respStatus: http.StatusMethodNotAllowed, + }, + } { + mux := runtime.NewServeMux() + for _, p := range spec.patterns { + func(p stubPattern) { + pat, err := runtime.NewPattern(1, p.ops, p.pool, "") + if err != nil { + t.Fatalf("runtime.NewPattern(1, %#v, %#v, %q) failed with %v; want success", p.ops, p.pool, "", err) + } + mux.Handle(p.method, pat, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + fmt.Fprintf(w, "%s %s", p.method, pat.String()) + }) + }(p) + } + + url := fmt.Sprintf("http://host.example%s", spec.reqPath) + r, err := http.NewRequest(spec.reqMethod, url, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", spec.reqMethod, url, err) + } + for name, value := range spec.headers { + r.Header.Set(name, value) + } + w := httptest.NewRecorder() + mux.ServeHTTP(w, r) + + if got, want := w.Code, spec.respStatus; got != want { + t.Errorf("w.Code = %d; want %d; patterns=%v; req=%v", got, want, spec.patterns, r) + } + if spec.respContent != "" { + if got, want := w.Body.String(), spec.respContent; got != want { + t.Errorf("w.Body = %q; want %q; patterns=%v; req=%v", got, want, spec.patterns, r) + } + } + } +}