diff --git a/Makefile b/Makefile index 2fad9892b8c..fbde89d7cd5 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,9 @@ GO_PLUGIN=bin/protoc-gen-go GO_PLUGIN_PKG=github.com/golang/protobuf/protoc-gen-go GATEWAY_PLUGIN=bin/protoc-gen-grpc-gateway GATEWAY_PLUGIN_PKG=$(PKG)/protoc-gen-grpc-gateway -GATEWAY_PLUGIN_SRC= protoc-gen-grpc-gateway/descriptor/name.go \ +GATEWAY_PLUGIN_SRC= internal/doc.go \ + internal/name.go \ + internal/pattern.go \ protoc-gen-grpc-gateway/descriptor/registry.go \ protoc-gen-grpc-gateway/descriptor/services.go \ protoc-gen-grpc-gateway/descriptor/types.go \ diff --git a/examples/a_bit_of_everything.pb.gw.go b/examples/a_bit_of_everything.pb.gw.go index 27873818077..28871407ab4 100644 --- a/examples/a_bit_of_everything.pb.gw.go +++ b/examples/a_bit_of_everything.pb.gw.go @@ -15,6 +15,7 @@ import ( "net/http" "github.com/gengo/grpc-gateway/examples/sub" + "github.com/gengo/grpc-gateway/internal" "github.com/gengo/grpc-gateway/runtime" "github.com/golang/glog" "github.com/golang/protobuf/proto" @@ -27,8 +28,13 @@ var _ codes.Code var _ io.Reader var _ = runtime.String var _ = json.Marshal +var _ = internal.PascalFromSnake -func request_ABitOfEverythingService_Create(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +var ( + filter_ABitOfEverythingService_Create_0 = &internal.DoubleArray{Encoding: map[string]int{"float_value": 0, "double_value": 1, "bool_value": 7, "sfixed64_value": 11, "sint32_value": 12, "sint64_value": 13, "int64_value": 2, "int32_value": 4, "string_value": 8, "uint32_value": 9, "sfixed32_value": 10, "fixed64_value": 5, "fixed32_value": 6, "uint64_value": 3}, Base: []int{1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Check: []int{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}} +) + +func request_ABitOfEverythingService_Create_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq ABitOfEverything var val string @@ -160,10 +166,14 @@ func request_ABitOfEverythingService_Create(ctx context.Context, client ABitOfEv return nil, err } + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_ABitOfEverythingService_Create_0); err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + } + return client.Create(ctx, &protoReq) } -func request_ABitOfEverythingService_CreateBody(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_CreateBody_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq ABitOfEverything if err = json.NewDecoder(req.Body).Decode(&protoReq); err != nil { @@ -173,7 +183,7 @@ func request_ABitOfEverythingService_CreateBody(ctx context.Context, client ABit return client.CreateBody(ctx, &protoReq) } -func request_ABitOfEverythingService_BulkCreate(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_BulkCreate_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { stream, err := client.BulkCreate(ctx) if err != nil { glog.Errorf("Failed to start streaming: %v", err) @@ -200,7 +210,7 @@ func request_ABitOfEverythingService_BulkCreate(ctx context.Context, client ABit } -func request_ABitOfEverythingService_Lookup(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_Lookup_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq IdMessage var val string @@ -218,13 +228,13 @@ func request_ABitOfEverythingService_Lookup(ctx context.Context, client ABitOfEv return client.Lookup(ctx, &protoReq) } -func request_ABitOfEverythingService_List(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (ABitOfEverythingService_ListClient, error) { +func request_ABitOfEverythingService_List_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (ABitOfEverythingService_ListClient, error) { var protoReq EmptyMessage return client.List(ctx, &protoReq) } -func request_ABitOfEverythingService_Update(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_Update_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq ABitOfEverything if err = json.NewDecoder(req.Body).Decode(&protoReq); err != nil { @@ -246,7 +256,7 @@ func request_ABitOfEverythingService_Update(ctx context.Context, client ABitOfEv return client.Update(ctx, &protoReq) } -func request_ABitOfEverythingService_Delete(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_Delete_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq IdMessage var val string @@ -264,7 +274,7 @@ func request_ABitOfEverythingService_Delete(ctx context.Context, client ABitOfEv return client.Delete(ctx, &protoReq) } -func request_ABitOfEverythingService_Echo(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_ABitOfEverythingService_Echo_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq sub.StringMessage var val string @@ -282,7 +292,31 @@ func request_ABitOfEverythingService_Echo(ctx context.Context, client ABitOfEver return client.Echo(ctx, &protoReq) } -func request_ABitOfEverythingService_BulkEcho(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (ABitOfEverythingService_BulkEchoClient, error) { +func request_ABitOfEverythingService_Echo_1(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { + var protoReq sub.StringMessage + + if err = json.NewDecoder(req.Body).Decode(&protoReq.Value); err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + } + + return client.Echo(ctx, &protoReq) +} + +var ( + filter_ABitOfEverythingService_Echo_2 = &internal.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +) + +func request_ABitOfEverythingService_Echo_2(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { + var protoReq sub.StringMessage + + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_ABitOfEverythingService_Echo_2); err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + } + + return client.Echo(ctx, &protoReq) +} + +func request_ABitOfEverythingService_BulkEcho_0(ctx context.Context, client ABitOfEverythingServiceClient, req *http.Request, pathParams map[string]string) (ABitOfEverythingService_BulkEchoClient, error) { stream, err := client.BulkEcho(ctx) if err != nil { glog.Errorf("Failed to start streaming: %v", err) @@ -343,8 +377,8 @@ func RegisterABitOfEverythingServiceHandlerFromEndpoint(ctx context.Context, mux func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { client := NewABitOfEverythingServiceClient(conn) - mux.Handle("POST", pattern_ABitOfEverythingService_Create, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Create(ctx, client, req, pathParams) + mux.Handle("POST", pattern_ABitOfEverythingService_Create_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Create_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -354,8 +388,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("POST", pattern_ABitOfEverythingService_CreateBody, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_CreateBody(ctx, client, req, pathParams) + mux.Handle("POST", pattern_ABitOfEverythingService_CreateBody_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_CreateBody_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -365,8 +399,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("POST", pattern_ABitOfEverythingService_BulkCreate, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_BulkCreate(ctx, client, req, pathParams) + mux.Handle("POST", pattern_ABitOfEverythingService_BulkCreate_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_BulkCreate_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -376,8 +410,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("GET", pattern_ABitOfEverythingService_Lookup, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Lookup(ctx, client, req, pathParams) + mux.Handle("GET", pattern_ABitOfEverythingService_Lookup_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Lookup_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -387,8 +421,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("GET", pattern_ABitOfEverythingService_List, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_List(ctx, client, req, pathParams) + mux.Handle("GET", pattern_ABitOfEverythingService_List_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_List_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -398,8 +432,19 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("PUT", pattern_ABitOfEverythingService_Update, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Update(ctx, client, req, pathParams) + mux.Handle("PUT", pattern_ABitOfEverythingService_Update_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Update_0(ctx, client, req, pathParams) + if err != nil { + runtime.HTTPError(w, err) + return + } + + runtime.ForwardResponseMessage(w, resp) + + }) + + mux.Handle("DELETE", pattern_ABitOfEverythingService_Delete_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Delete_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -409,8 +454,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("DELETE", pattern_ABitOfEverythingService_Delete, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Delete(ctx, client, req, pathParams) + mux.Handle("GET", pattern_ABitOfEverythingService_Echo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Echo_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -420,8 +465,8 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("GET", pattern_ABitOfEverythingService_Echo, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_Echo(ctx, client, req, pathParams) + mux.Handle("POST", pattern_ABitOfEverythingService_Echo_1, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Echo_1(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -431,8 +476,19 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se }) - mux.Handle("POST", pattern_ABitOfEverythingService_BulkEcho, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_ABitOfEverythingService_BulkEcho(ctx, client, req, pathParams) + mux.Handle("GET", pattern_ABitOfEverythingService_Echo_2, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_Echo_2(ctx, client, req, pathParams) + if err != nil { + runtime.HTTPError(w, err) + return + } + + runtime.ForwardResponseMessage(w, resp) + + }) + + mux.Handle("POST", pattern_ABitOfEverythingService_BulkEcho_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_ABitOfEverythingService_BulkEcho_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -446,21 +502,25 @@ func RegisterABitOfEverythingServiceHandler(ctx context.Context, mux *runtime.Se } var ( - pattern_ABitOfEverythingService_Create = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 1, 0, 4, 1, 5, 4, 1, 0, 4, 1, 5, 5, 2, 6, 1, 0, 4, 1, 5, 7, 1, 0, 4, 1, 5, 8, 1, 0, 4, 1, 5, 9, 1, 0, 4, 1, 5, 10, 1, 0, 4, 1, 5, 11, 2, 12, 1, 0, 4, 2, 5, 13, 1, 0, 4, 1, 5, 14, 1, 0, 4, 1, 5, 15, 1, 0, 4, 1, 5, 16, 1, 0, 4, 1, 5, 17, 1, 0, 4, 1, 5, 18}, []string{"v1", "example", "a_bit_of_everything", "float_value", "double_value", "int64_value", "separator", "uint64_value", "int32_value", "fixed64_value", "fixed32_value", "bool_value", "strprefix", "string_value", "uint32_value", "sfixed32_value", "sfixed64_value", "sint32_value", "sint64_value"}, "")) + pattern_ABitOfEverythingService_Create_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 1, 0, 4, 1, 5, 4, 1, 0, 4, 1, 5, 5, 2, 6, 1, 0, 4, 1, 5, 7, 1, 0, 4, 1, 5, 8, 1, 0, 4, 1, 5, 9, 1, 0, 4, 1, 5, 10, 1, 0, 4, 1, 5, 11, 2, 12, 1, 0, 4, 2, 5, 13, 1, 0, 4, 1, 5, 14, 1, 0, 4, 1, 5, 15, 1, 0, 4, 1, 5, 16, 1, 0, 4, 1, 5, 17, 1, 0, 4, 1, 5, 18}, []string{"v1", "example", "a_bit_of_everything", "float_value", "double_value", "int64_value", "separator", "uint64_value", "int32_value", "fixed64_value", "fixed32_value", "bool_value", "strprefix", "string_value", "uint32_value", "sfixed32_value", "sfixed64_value", "sint32_value", "sint64_value"}, "")) + + pattern_ABitOfEverythingService_CreateBody_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "a_bit_of_everything"}, "")) + + pattern_ABitOfEverythingService_BulkCreate_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "bulk"}, "")) - pattern_ABitOfEverythingService_CreateBody = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "a_bit_of_everything"}, "")) + pattern_ABitOfEverythingService_Lookup_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) - pattern_ABitOfEverythingService_BulkCreate = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "bulk"}, "")) + pattern_ABitOfEverythingService_List_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "a_bit_of_everything"}, "")) - pattern_ABitOfEverythingService_Lookup = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) + pattern_ABitOfEverythingService_Update_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) - pattern_ABitOfEverythingService_List = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "a_bit_of_everything"}, "")) + pattern_ABitOfEverythingService_Delete_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) - pattern_ABitOfEverythingService_Update = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) + pattern_ABitOfEverythingService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 1, 0, 4, 1, 5, 4}, []string{"v1", "example", "a_bit_of_everything", "echo", "value"}, "")) - pattern_ABitOfEverythingService_Delete = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "a_bit_of_everything", "uuid"}, "")) + pattern_ABitOfEverythingService_Echo_1 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v2", "example", "echo"}, "")) - pattern_ABitOfEverythingService_Echo = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 1, 0, 4, 1, 5, 4}, []string{"v1", "example", "a_bit_of_everything", "echo", "value"}, "")) + pattern_ABitOfEverythingService_Echo_2 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v2", "example", "echo"}, "")) - pattern_ABitOfEverythingService_BulkEcho = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "echo"}, "")) + pattern_ABitOfEverythingService_BulkEcho_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"v1", "example", "a_bit_of_everything", "echo"}, "")) ) diff --git a/examples/a_bit_of_everything.proto b/examples/a_bit_of_everything.proto index 434ff2b4e25..c4cf1ee8b32 100644 --- a/examples/a_bit_of_everything.proto +++ b/examples/a_bit_of_everything.proto @@ -80,6 +80,13 @@ service ABitOfEverythingService { rpc Echo(gengo.grpc.gateway.examples.sub.StringMessage) returns (gengo.grpc.gateway.examples.sub.StringMessage) { option (google.api.http) = { get: "/v1/example/a_bit_of_everything/echo/{value}" + additional_bindings { + post: "/v2/example/echo" + body: "value" + } + additional_bindings { + get: "/v2/example/echo" + } }; } rpc BulkEcho(stream gengo.grpc.gateway.examples.sub.StringMessage) returns (stream gengo.grpc.gateway.examples.sub.StringMessage) { diff --git a/examples/echo_service.pb.gw.go b/examples/echo_service.pb.gw.go index 6e47decf18b..97b6fba4feb 100644 --- a/examples/echo_service.pb.gw.go +++ b/examples/echo_service.pb.gw.go @@ -14,6 +14,7 @@ import ( "io" "net/http" + "github.com/gengo/grpc-gateway/internal" "github.com/gengo/grpc-gateway/runtime" "github.com/golang/glog" "github.com/golang/protobuf/proto" @@ -26,8 +27,9 @@ var _ codes.Code var _ io.Reader var _ = runtime.String var _ = json.Marshal +var _ = internal.PascalFromSnake -func request_EchoService_Echo(ctx context.Context, client EchoServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_EchoService_Echo_0(ctx context.Context, client EchoServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq SimpleMessage var val string @@ -45,7 +47,7 @@ func request_EchoService_Echo(ctx context.Context, client EchoServiceClient, req return client.Echo(ctx, &protoReq) } -func request_EchoService_EchoBody(ctx context.Context, client EchoServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { +func request_EchoService_EchoBody_0(ctx context.Context, client EchoServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) { var protoReq SimpleMessage if err = json.NewDecoder(req.Body).Decode(&protoReq); err != nil { @@ -85,8 +87,8 @@ func RegisterEchoServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.Se func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { client := NewEchoServiceClient(conn) - mux.Handle("POST", pattern_EchoService_Echo, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_EchoService_Echo(ctx, client, req, pathParams) + mux.Handle("POST", pattern_EchoService_Echo_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_EchoService_Echo_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -96,8 +98,8 @@ func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn }) - mux.Handle("POST", pattern_EchoService_EchoBody, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_EchoService_EchoBody(ctx, client, req, pathParams) + mux.Handle("POST", pattern_EchoService_EchoBody_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_EchoService_EchoBody_0(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -111,7 +113,7 @@ func RegisterEchoServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn } var ( - pattern_EchoService_Echo = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "echo", "id"}, "")) + pattern_EchoService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"v1", "example", "echo", "id"}, "")) - pattern_EchoService_EchoBody = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "echo_body"}, "")) + pattern_EchoService_EchoBody_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"v1", "example", "echo_body"}, "")) ) diff --git a/examples/integration_test.go b/examples/integration_test.go index 7dfac5315de..65fddc9907b 100644 --- a/examples/integration_test.go +++ b/examples/integration_test.go @@ -14,6 +14,7 @@ import ( gw "github.com/gengo/grpc-gateway/examples" server "github.com/gengo/grpc-gateway/examples/server" + sub "github.com/gengo/grpc-gateway/examples/sub" ) func TestIntegration(t *testing.T) { @@ -43,6 +44,7 @@ func TestIntegration(t *testing.T) { testABEBulkCreate(t) testABELookup(t) testABEList(t) + testAdditionalBindings(t) } func testEcho(t *testing.T) { @@ -377,3 +379,60 @@ func testABEList(t *testing.T) { t.Errorf("i == %d; want > 0", i) } } + +func testAdditionalBindings(t *testing.T) { + for i, f := range []func() *http.Response{ + func() *http.Response { + url := "http://localhost:8080/v1/example/a_bit_of_everything/echo/hello" + resp, err := http.Get(url) + if err != nil { + t.Errorf("http.Get(%q) failed with %v; want success", url, err) + return nil + } + return resp + }, + func() *http.Response { + url := "http://localhost:8080/v2/example/echo" + resp, err := http.Post(url, "application/json", strings.NewReader(`"hello"`)) + if err != nil { + t.Errorf("http.Post(%q, %q, %q) failed with %v; want success", url, "application/json", `"hello"`, err) + return nil + } + return resp + }, + func() *http.Response { + url := "http://localhost:8080/v2/example/echo?value=hello" + resp, err := http.Get(url) + if err != nil { + t.Errorf("http.Get(%q) failed with %v; want success", url, err) + return nil + } + return resp + }, + } { + resp := f() + if resp == nil { + continue + } + + defer resp.Body.Close() + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("iotuil.ReadAll(resp.Body) failed with %v; want success; i=%d", err, i) + return + } + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("resp.StatusCode = %d; want %d; i=%d", got, want, i) + t.Logf("%s", buf) + } + + var msg sub.StringMessage + if err := json.Unmarshal(buf, &msg); err != nil { + t.Errorf("json.Unmarshal(%s, &msg) failed with %v; want success; %i", buf, err, i) + return + } + if got, want := msg.GetValue(), "hello"; got != want { + t.Errorf("msg.GetValue() = %q; want %q", got, want) + } + } +} diff --git a/protoc-gen-grpc-gateway/descriptor/name.go b/internal/name.go similarity index 60% rename from protoc-gen-grpc-gateway/descriptor/name.go rename to internal/name.go index 78234461a73..a6d771936c9 100644 --- a/protoc-gen-grpc-gateway/descriptor/name.go +++ b/internal/name.go @@ -1,15 +1,11 @@ -package descriptor +package internal import ( - "regexp" "strings" ) -var ( - upperPattern = regexp.MustCompile("[A-Z]") -) - -func toCamel(str string) string { +// PascalFromSnake converts an identifier in snake_case into PascalCase. +func PascalFromSnake(str string) string { var components []string for _, c := range strings.Split(str, "_") { components = append(components, strings.Title(strings.ToLower(c))) diff --git a/internal/name_test.go b/internal/name_test.go new file mode 100644 index 00000000000..7b1b6b62bbe --- /dev/null +++ b/internal/name_test.go @@ -0,0 +1,22 @@ +package internal_test + +import ( + "testing" + + "github.com/gengo/grpc-gateway/internal" +) + +func TestPascalToSnake(t *testing.T) { + for _, spec := range []struct { + input, want string + }{ + {input: "value", want: "Value"}, + {input: "prefixed_value", want: "PrefixedValue"}, + {input: "foo_id", want: "FooId"}, + } { + got := internal.PascalFromSnake(spec.input) + if got != spec.want { + t.Errorf("internal.PascalFromSnake(%q) = %q; want %q", spec.input, got, spec.want) + } + } +} diff --git a/internal/trie.go b/internal/trie.go new file mode 100644 index 00000000000..03f80d72f8a --- /dev/null +++ b/internal/trie.go @@ -0,0 +1,177 @@ +package internal + +import ( + "sort" +) + +// DoubleArray is a Double Array implementation of trie on sequences of strings. +type DoubleArray struct { + // Encoding keeps an encoding from string to int + Encoding map[string]int + // Base is the base array of Double Array + Base []int + // Check is the check array of Double Array + Check []int +} + +// NewDoubleArray builds a DoubleArray from a set of sequences of strings. +func NewDoubleArray(seqs [][]string) *DoubleArray { + da := &DoubleArray{Encoding: make(map[string]int)} + if len(seqs) == 0 { + return da + } + + encoded := registerTokens(da, seqs) + sort.Sort(byLex(encoded)) + + root := node{row: -1, col: -1, left: 0, right: len(encoded)} + addSeqs(da, encoded, 0, root) + + for i := len(da.Base); i > 0; i-- { + if da.Check[i-1] != 0 { + da.Base = da.Base[:i] + da.Check = da.Check[:i] + break + } + } + return da +} + +func registerTokens(da *DoubleArray, seqs [][]string) [][]int { + var result [][]int + for _, seq := range seqs { + var encoded []int + for _, token := range seq { + if _, ok := da.Encoding[token]; !ok { + da.Encoding[token] = len(da.Encoding) + } + encoded = append(encoded, da.Encoding[token]) + } + result = append(result, encoded) + } + for i := range result { + result[i] = append(result[i], len(da.Encoding)) + } + return result +} + +type node struct { + row, col int + left, right int +} + +func (n node) value(seqs [][]int) int { + return seqs[n.row][n.col] +} + +func (n node) children(seqs [][]int) []*node { + var result []*node + lastVal := int(-1) + last := new(node) + for i := n.left; i < n.right; i++ { + if lastVal == seqs[i][n.col+1] { + continue + } + last.right = i + last = &node{ + row: i, + col: n.col + 1, + left: i, + } + result = append(result, last) + } + last.right = n.right + return result +} + +func addSeqs(da *DoubleArray, seqs [][]int, pos int, n node) { + ensureSize(da, pos) + + children := n.children(seqs) + var i int + for i = 1; ; i++ { + ok := func() bool { + for _, child := range children { + code := child.value(seqs) + j := i + code + ensureSize(da, j) + if da.Check[j] != 0 { + return false + } + } + return true + }() + if ok { + break + } + } + da.Base[pos] = i + for _, child := range children { + code := child.value(seqs) + j := i + code + da.Check[j] = pos + 1 + } + terminator := len(da.Encoding) + for _, child := range children { + code := child.value(seqs) + if code == terminator { + continue + } + j := i + code + addSeqs(da, seqs, j, *child) + } +} + +func ensureSize(da *DoubleArray, i int) { + for i >= len(da.Base) { + da.Base = append(da.Base, make([]int, len(da.Base)+1)...) + da.Check = append(da.Check, make([]int, len(da.Check)+1)...) + } +} + +type byLex [][]int + +func (l byLex) Len() int { return len(l) } +func (l byLex) Swap(i, j int) { l[i], l[j] = l[j], l[i] } +func (l byLex) Less(i, j int) bool { + si := l[i] + sj := l[j] + var k int + for k = 0; k < len(si) && k < len(sj); k++ { + if si[k] < sj[k] { + return true + } + if si[k] > sj[k] { + return false + } + } + if k < len(sj) { + return true + } + return false +} + +// HasCommonPrefix determines if any sequence in the DoubleArray is a prefix of the given sequence. +func (da *DoubleArray) HasCommonPrefix(seq []string) bool { + if len(da.Base) == 0 { + return false + } + + var i int + for _, t := range seq { + code, ok := da.Encoding[t] + if !ok { + break + } + j := da.Base[i] + code + if len(da.Check) <= j || da.Check[j] != i+1 { + break + } + i = j + } + j := da.Base[i] + len(da.Encoding) + if len(da.Check) <= j || da.Check[j] != i+1 { + return false + } + return true +} diff --git a/internal/trie_test.go b/internal/trie_test.go new file mode 100644 index 00000000000..0422b693598 --- /dev/null +++ b/internal/trie_test.go @@ -0,0 +1,372 @@ +package internal_test + +import ( + "reflect" + "testing" + + "github.com/gengo/grpc-gateway/internal" +) + +func TestMaxCommonPrefix(t *testing.T) { + for _, spec := range []struct { + da internal.DoubleArray + tokens []string + want bool + }{ + { + da: internal.DoubleArray{}, + tokens: nil, + want: false, + }, + { + da: internal.DoubleArray{}, + tokens: []string{"foo"}, + want: false, + }, + { + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + }, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + }, + tokens: nil, + want: false, + }, + { + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + }, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + }, + tokens: []string{"foo"}, + want: true, + }, + { + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + }, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + }, + tokens: []string{"bar"}, + want: false, + }, + { + // foo|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"foo"}, + want: true, + }, + { + // foo|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"bar"}, + want: true, + }, + { + // foo|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"something-else"}, + want: false, + }, + { + // foo|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + tokens: []string{"foo", "bar"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"foo"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"foo", "bar"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"bar"}, + want: true, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"something-else"}, + want: false, + }, + { + // foo|foo\.bar|bar + da: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 3, 1, 0, 4, 0, 0}, + Check: []int{0, 1, 1, 3, 2, 2, 5}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^bar$ + // 4: ^foo.bar + // 5: ^foo$ + // 6: ^foo.bar$ + }, + tokens: []string{"foo", "bar", "baz"}, + want: true, + }, + } { + got := spec.da.HasCommonPrefix(spec.tokens) + if got != spec.want { + t.Errorf("%#v.HasCommonPrefix(%v) = %v; want %v", spec.da, spec.tokens, got, spec.want) + } + } +} + +func TestAdd(t *testing.T) { + for _, spec := range []struct { + tokens [][]string + want internal.DoubleArray + }{ + { + want: internal.DoubleArray{ + Encoding: make(map[string]int), + }, + }, + { + tokens: [][]string{{"foo"}}, + want: internal.DoubleArray{ + Encoding: map[string]int{"foo": 0}, + Base: []int{1, 1, 0}, + Check: []int{0, 1, 2}, + // 0: ^ + // 1: ^foo + // 2: ^foo$ + }, + }, + { + tokens: [][]string{{"foo"}, {"bar"}}, + want: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + }, + Base: []int{1, 1, 2, 0, 0}, + Check: []int{0, 1, 1, 2, 3}, + // 0: ^ + // 1: ^foo + // 2: ^bar + // 3: ^foo$ + // 4: ^bar$ + }, + }, + { + tokens: [][]string{{"foo", "bar"}, {"foo", "baz"}}, + want: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + "baz": 2, + }, + Base: []int{1, 1, 1, 2, 0, 0}, + Check: []int{0, 1, 2, 2, 3, 4}, + // 0: ^ + // 1: ^foo + // 2: ^foo.bar + // 3: ^foo.baz + // 4: ^foo.bar$ + // 5: ^foo.baz$ + }, + }, + { + tokens: [][]string{{"foo", "bar"}, {"foo", "baz"}, {"qux"}}, + want: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + "baz": 2, + "qux": 3, + }, + Base: []int{1, 1, 1, 2, 3, 0, 0, 0}, + Check: []int{0, 1, 2, 2, 1, 3, 4, 5}, + // 0: ^ + // 1: ^foo + // 2: ^foo.bar + // 3: ^foo.baz + // 4: ^qux + // 5: ^foo.bar$ + // 6: ^foo.baz$ + // 7: ^qux$ + }, + }, + { + tokens: [][]string{ + {"foo", "bar"}, + {"foo", "baz", "bar"}, + {"qux", "foo"}, + }, + want: internal.DoubleArray{ + Encoding: map[string]int{ + "foo": 0, + "bar": 1, + "baz": 2, + "qux": 3, + }, + Base: []int{1, 1, 1, 5, 8, 0, 3, 0, 5, 0}, + Check: []int{0, 1, 2, 2, 1, 3, 4, 7, 5, 9}, + // 0: ^ + // 1: ^foo + // 2: ^foo.bar + // 3: ^foo.baz + // 4: ^qux + // 5: ^foo.bar$ + // 6: ^foo.baz.bar + // 7: ^foo.baz.bar$ + // 8: ^qux.foo + // 9: ^qux.foo$ + }, + }, + } { + da := internal.NewDoubleArray(spec.tokens) + if got, want := da.Encoding, spec.want.Encoding; !reflect.DeepEqual(got, want) { + t.Errorf("da.Encoding = %v; want %v; tokens = %#v", got, want, spec.tokens) + } + if got, want := da.Base, spec.want.Base; !compareArray(got, want) { + t.Errorf("da.Base = %v; want %v; tokens = %#v", got, want, spec.tokens) + } + if got, want := da.Check, spec.want.Check; !compareArray(got, want) { + t.Errorf("da.Check = %v; want %v; tokens = %#v", got, want, spec.tokens) + } + } +} + +func compareArray(got, want []int) bool { + var i int + for i = 0; i < len(got) && i < len(want); i++ { + if got[i] != want[i] { + return false + } + } + if i < len(want) { + return false + } + for ; i < len(got); i++ { + if got[i] != 0 { + return false + } + } + return true +} diff --git a/protoc-gen-grpc-gateway/descriptor/services.go b/protoc-gen-grpc-gateway/descriptor/services.go index 32e38111946..b01f42acfdc 100644 --- a/protoc-gen-grpc-gateway/descriptor/services.go +++ b/protoc-gen-grpc-gateway/descriptor/services.go @@ -51,56 +51,6 @@ func (r *Registry) loadServices(targetFile string) error { } func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, opts *options.HttpRule) (*Method, error) { - var ( - httpMethod string - pathTemplate string - ) - switch { - case opts.Get != "": - httpMethod = "GET" - pathTemplate = opts.Get - if opts.Body != "" { - return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName()) - } - - case opts.Put != "": - httpMethod = "PUT" - pathTemplate = opts.Put - - case opts.Post != "": - httpMethod = "POST" - pathTemplate = opts.Post - - case opts.Delete != "": - httpMethod = "DELETE" - pathTemplate = opts.Delete - if opts.Body != "" { - return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName()) - } - - case opts.Patch != "": - httpMethod = "PATCH" - pathTemplate = opts.Patch - - case opts.Custom != nil: - httpMethod = opts.Custom.Kind - pathTemplate = opts.Custom.Path - - default: - glog.Errorf("No pattern specified in google.api.HttpRule: %s", md.GetName()) - return nil, fmt.Errorf("none of pattern specified") - } - - parsed, err := httprule.Parse(pathTemplate) - if err != nil { - return nil, err - } - tmpl := parsed.Compile() - - if md.GetClientStreaming() && len(tmpl.Fields) > 0 { - return nil, fmt.Errorf("cannot use path parameter in client streaming") - } - requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType()) if err != nil { return nil, err @@ -109,31 +59,105 @@ func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, if err != nil { return nil, err } - meth := &Method{ Service: svc, MethodDescriptorProto: md, - PathTmpl: tmpl, - HTTPMethod: httpMethod, RequestType: requestType, ResponseType: responseType, } - for _, f := range tmpl.Fields { - param, err := r.newParam(meth, f) + newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) { + var ( + httpMethod string + pathTemplate string + ) + switch { + case opts.Get != "": + httpMethod = "GET" + pathTemplate = opts.Get + if opts.Body != "" { + return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName()) + } + + case opts.Put != "": + httpMethod = "PUT" + pathTemplate = opts.Put + + case opts.Post != "": + httpMethod = "POST" + pathTemplate = opts.Post + + case opts.Delete != "": + httpMethod = "DELETE" + pathTemplate = opts.Delete + if opts.Body != "" { + return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName()) + } + + case opts.Patch != "": + httpMethod = "PATCH" + pathTemplate = opts.Patch + + case opts.Custom != nil: + httpMethod = opts.Custom.Kind + pathTemplate = opts.Custom.Path + + default: + glog.Errorf("No pattern specified in google.api.HttpRule: %s", md.GetName()) + return nil, fmt.Errorf("none of pattern specified") + } + + parsed, err := httprule.Parse(pathTemplate) if err != nil { return nil, err } - meth.PathParams = append(meth.PathParams, param) - } + tmpl := parsed.Compile() + + if md.GetClientStreaming() && len(tmpl.Fields) > 0 { + return nil, fmt.Errorf("cannot use path parameter in client streaming") + } + + b := &Binding{ + Method: meth, + Index: idx, + PathTmpl: tmpl, + HTTPMethod: httpMethod, + } + + for _, f := range tmpl.Fields { + param, err := r.newParam(meth, f) + if err != nil { + return nil, err + } + b.PathParams = append(b.PathParams, param) + } - // TODO(yugui) Handle query params + // TODO(yugui) Handle query params - meth.Body, err = r.newBody(meth, opts.Body) + b.Body, err = r.newBody(meth, opts.Body) + if err != nil { + return nil, err + } + + return b, nil + } + b, err := newBinding(opts, 0) if err != nil { return nil, err } + meth.Bindings = append(meth.Bindings, b) + for i, additional := range opts.GetAdditionalBindings() { + if len(additional.AdditionalBindings) > 0 { + return nil, fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName()) + } + b, err := newBinding(additional, i+1) + if err != nil { + return nil, err + } + meth.Bindings = append(meth.Bindings, b) + } + return meth, nil } @@ -231,6 +255,9 @@ func (r *Registry) resolveFiledPath(msg *Message, path string) ([]FieldPathCompo if f == nil { return nil, fmt.Errorf("no field %q found in %s", path, root.GetName()) } + if f.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED { + return nil, fmt.Errorf("repeated field not allowed in field path: %s in %s", f.GetName(), path) + } result = append(result, FieldPathComponent{Name: c, Target: f}) } return result, nil diff --git a/protoc-gen-grpc-gateway/descriptor/services_test.go b/protoc-gen-grpc-gateway/descriptor/services_test.go index cfed7dab86b..323147904f0 100644 --- a/protoc-gen-grpc-gateway/descriptor/services_test.go +++ b/protoc-gen-grpc-gateway/descriptor/services_test.go @@ -43,49 +43,68 @@ func testExtractServices(t *testing.T, input []*descriptor.FileDescriptorProto, t.Errorf("svcs[%d].Methods[%d].MethodDescriptorProto = %v; want %v; input = %v", i, j, got, want, input) continue } - if got, want := meth.PathTmpl, wantMeth.PathTmpl; !reflect.DeepEqual(got, want) { - t.Errorf("svcs[%d].Methods[%d].PathTmpl = %#v; want %#v; input = %v", i, j, got, want, input) - } - if got, want := meth.HTTPMethod, wantMeth.HTTPMethod; got != want { - t.Errorf("svcs[%d].Methods[%d].HTTPMethod = %q; want %q; input = %v", i, j, got, want, input) - } if got, want := meth.RequestType, wantMeth.RequestType; got.FQMN() != want.FQMN() { t.Errorf("svcs[%d].Methods[%d].RequestType = %s; want %s; input = %v", i, j, got.FQMN(), want.FQMN(), input) } if got, want := meth.ResponseType, wantMeth.ResponseType; got.FQMN() != want.FQMN() { t.Errorf("svcs[%d].Methods[%d].ResponseType = %s; want %s; input = %v", i, j, got.FQMN(), want.FQMN(), input) } - var k int - for k = 0; k < len(meth.PathParams) && k < len(wantMeth.PathParams); k++ { - param, wantParam := meth.PathParams[k], wantMeth.PathParams[k] - if got, want := param.FieldPath.String(), wantParam.FieldPath.String(); got != want { - t.Errorf("svcs[%d].Methods[%d].PathParams[%d].FieldPath.String() = %q; want %q; input = %v", i, j, k, got, want, input) - continue + for k = 0; k < len(meth.Bindings) && k < len(wantMeth.Bindings); k++ { + binding, wantBinding := meth.Bindings[k], wantMeth.Bindings[k] + if got, want := binding.Index, wantBinding.Index; got != want { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Index = %d; want %d; input = %v", i, j, k, got, want, input) + } + if got, want := binding.PathTmpl, wantBinding.PathTmpl; !reflect.DeepEqual(got, want) { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathTmpl = %#v; want %#v; input = %v", i, j, k, got, want, input) + } + if got, want := binding.HTTPMethod, wantBinding.HTTPMethod; got != want { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].HTTPMethod = %q; want %q; input = %v", i, j, k, got, want, input) } - for l := 0; l < len(param.FieldPath) && l < len(wantParam.FieldPath); l++ { - field, wantField := param.FieldPath[l].Target, wantParam.FieldPath[l].Target - if got, want := field.FieldDescriptorProto, wantField.FieldDescriptorProto; !proto.Equal(got, want) { - t.Errorf("svcs[%d].Methods[%d].PathParams[%d].FieldPath[%d].Target.FieldDescriptorProto = %v; want %v; input = %v", i, j, k, l, got, want, input) + + var l int + for l = 0; l < len(binding.PathParams) && l < len(wantBinding.PathParams); l++ { + param, wantParam := binding.PathParams[l], wantBinding.PathParams[l] + if got, want := param.FieldPath.String(), wantParam.FieldPath.String(); got != want { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d].FieldPath.String() = %q; want %q; input = %v", i, j, k, l, got, want, input) + continue + } + for m := 0; m < len(param.FieldPath) && m < len(wantParam.FieldPath); m++ { + field, wantField := param.FieldPath[m].Target, wantParam.FieldPath[m].Target + if got, want := field.FieldDescriptorProto, wantField.FieldDescriptorProto; !proto.Equal(got, want) { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d].FieldPath[%d].Target.FieldDescriptorProto = %v; want %v; input = %v", i, j, k, l, m, got, want, input) + } } } - } - for ; k < len(meth.PathParams); k++ { - got := meth.PathParams[k].FieldPath.String() - t.Errorf("svcs[%d].Methods[%d].PathParams[%d] = %q; want it to be missing; input = %v", i, j, k, got, input) - } - for ; k < len(wantMeth.PathParams); k++ { - want := wantMeth.PathParams[k].FieldPath.String() - t.Errorf("svcs[%d].Methods[%d].PathParams[%d] missing; want %q; input = %v", i, j, k, want, input) - } + for ; l < len(binding.PathParams); l++ { + got := binding.PathParams[l].FieldPath.String() + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d] = %q; want it to be missing; input = %v", i, j, k, l, got, input) + } + for ; l < len(wantBinding.PathParams); l++ { + want := wantBinding.PathParams[l].FieldPath.String() + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].PathParams[%d] missing; want %q; input = %v", i, j, k, l, want, input) + } - if got, want := (meth.Body != nil), (wantMeth.Body != nil); got != want { - if got { - t.Errorf("svcs[%d].Methods[%d].Body = %q; want it to be missing; input = %v", i, j, meth.Body.FieldPath.String(), input) - } else { - t.Errorf("svcs[%d].Methods[%d].Body missing; want %q; input = %v", i, j, wantMeth.Body.FieldPath.String(), input) + if got, want := (binding.Body != nil), (wantBinding.Body != nil); got != want { + if got { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Body = %q; want it to be missing; input = %v", i, j, k, binding.Body.FieldPath.String(), input) + } else { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Body missing; want %q; input = %v", i, j, k, wantBinding.Body.FieldPath.String(), input) + } + } else if binding.Body != nil { + if got, want := binding.Body.FieldPath.String(), wantBinding.Body.FieldPath.String(); got != want { + t.Errorf("svcs[%d].Methods[%d].Bindings[%d].Body = %q; want %q; input = %v", i, j, k, got, want, input) + } } } + for ; k < len(meth.Bindings); k++ { + got := meth.Bindings[k] + t.Errorf("svcs[%d].Methods[%d].Bindings[%d] = %q; want it to be missing; input = %v", i, j, k, got, input) + } + for ; k < len(wantMeth.Bindings); k++ { + want := wantMeth.Bindings[k] + t.Errorf("svcs[%d].Methods[%d].Bindings[%d] missing; want %q; input = %v", i, j, k, want, input) + } } for ; j < len(svc.Methods); j++ { got := svc.Methods[j].MethodDescriptorProto @@ -117,11 +136,11 @@ func crossLinkFixture(f *File) *File { svc.File = f for _, m := range svc.Methods { m.Service = svc - for _, param := range m.PathParams { - param.Method = m - } - for _, param := range m.QueryParams { - param.Method = m + for _, b := range m.Bindings { + b.Method = m + for _, param := range b.PathParams { + param.Method = m + } } } } @@ -181,11 +200,15 @@ func TestExtractServicesSimple(t *testing.T) { Methods: []*Method{ { MethodDescriptorProto: fd.Service[0].Method[0], - PathTmpl: compilePath(t, "/v1/example/echo"), - HTTPMethod: "POST", RequestType: msg, ResponseType: msg, - Body: &Body{FieldPath: nil}, + Bindings: []*Binding{ + { + PathTmpl: compilePath(t, "/v1/example/echo"), + HTTPMethod: "POST", + Body: &Body{FieldPath: nil}, + }, + }, }, }, }, @@ -276,11 +299,15 @@ func TestExtractServicesCrossPackage(t *testing.T) { Methods: []*Method{ { MethodDescriptorProto: fds[0].Service[0].Method[0], - PathTmpl: compilePath(t, "/v1/example/to_s"), - HTTPMethod: "POST", RequestType: boolMsg, ResponseType: stringMsg, - Body: &Body{FieldPath: nil}, + Bindings: []*Binding{ + { + PathTmpl: compilePath(t, "/v1/example/to_s"), + HTTPMethod: "POST", + Body: &Body{FieldPath: nil}, + }, + }, }, }, }, @@ -365,15 +392,19 @@ func TestExtractServicesWithBodyPath(t *testing.T) { Methods: []*Method{ { MethodDescriptorProto: fd.Service[0].Method[0], - PathTmpl: compilePath(t, "/v1/example/echo"), - HTTPMethod: "POST", RequestType: msg, ResponseType: msg, - Body: &Body{ - FieldPath: FieldPath{ - { - Name: "nested", - Target: msg.Fields[0], + Bindings: []*Binding{ + { + PathTmpl: compilePath(t, "/v1/example/echo"), + HTTPMethod: "POST", + Body: &Body{ + FieldPath: FieldPath{ + { + Name: "nested", + Target: msg.Fields[0], + }, + }, }, }, }, @@ -439,19 +470,133 @@ func TestExtractServicesWithPathParam(t *testing.T) { Methods: []*Method{ { MethodDescriptorProto: fd.Service[0].Method[0], - PathTmpl: compilePath(t, "/v1/example/echo/{string=*}"), - HTTPMethod: "GET", RequestType: msg, ResponseType: msg, - PathParams: []Parameter{ + Bindings: []*Binding{ { - FieldPath: FieldPath{ + PathTmpl: compilePath(t, "/v1/example/echo/{string=*}"), + HTTPMethod: "GET", + PathParams: []Parameter{ { - Name: "string", + FieldPath: FieldPath{ + { + Name: "string", + Target: msg.Fields[0], + }, + }, Target: msg.Fields[0], }, }, - Target: msg.Fields[0], + }, + }, + }, + }, + }, + }, + } + + crossLinkFixture(file) + testExtractServices(t, []*descriptor.FileDescriptorProto{&fd}, "path/to/example.proto", file.Services) +} + +func TestExtractServicesWithAdditionalBinding(t *testing.T) { + src := ` + name: "path/to/example.proto", + package: "example" + message_type < + name: "StringMessage" + field < + name: "string" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + > + > + service < + name: "ExampleService" + method < + name: "Echo" + input_type: "StringMessage" + output_type: "StringMessage" + options < + [google.api.http] < + post: "/v1/example/echo" + body: "*" + additional_bindings < + get: "/v1/example/echo/{string}" + > + additional_bindings < + post: "/v2/example/echo" + body: "string" + > + > + > + > + > + ` + var fd descriptor.FileDescriptorProto + if err := proto.UnmarshalText(src, &fd); err != nil { + t.Fatalf("proto.UnmarshalText(%s, &fd) failed with %v; want success", src, err) + } + msg := &Message{ + DescriptorProto: fd.MessageType[0], + Fields: []*Field{ + { + FieldDescriptorProto: fd.MessageType[0].Field[0], + }, + }, + } + file := &File{ + FileDescriptorProto: &fd, + GoPkg: GoPackage{ + Path: "path/to/example.pb", + Name: "example_pb", + }, + Messages: []*Message{msg}, + Services: []*Service{ + { + ServiceDescriptorProto: fd.Service[0], + Methods: []*Method{ + { + MethodDescriptorProto: fd.Service[0].Method[0], + RequestType: msg, + ResponseType: msg, + Bindings: []*Binding{ + { + Index: 0, + PathTmpl: compilePath(t, "/v1/example/echo"), + HTTPMethod: "POST", + Body: &Body{FieldPath: nil}, + }, + { + Index: 1, + PathTmpl: compilePath(t, "/v1/example/echo/{string}"), + HTTPMethod: "GET", + PathParams: []Parameter{ + { + FieldPath: FieldPath{ + { + Name: "string", + Target: msg.Fields[0], + }, + }, + Target: msg.Fields[0], + }, + }, + Body: nil, + }, + { + Index: 2, + PathTmpl: compilePath(t, "/v2/example/echo"), + HTTPMethod: "POST", + Body: &Body{ + FieldPath: FieldPath{ + FieldPathComponent{ + Name: "string", + Target: msg.Fields[0], + }, + }, + }, }, }, }, @@ -776,3 +921,189 @@ func TestExtractServicesWithError(t *testing.T) { t.Log(err) } } + +func TestResolveFieldPath(t *testing.T) { + for _, spec := range []struct { + src string + path string + wantErr bool + }{ + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'string' + type: TYPE_STRING + label: LABEL_OPTIONAL + number: 1 + > + > + `, + path: "string", + wantErr: false, + }, + // no such field + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'string' + type: TYPE_STRING + label: LABEL_OPTIONAL + number: 1 + > + > + `, + path: "something_else", + wantErr: true, + }, + // repeated field + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'string' + type: TYPE_STRING + label: LABEL_REPEATED + number: 1 + > + > + `, + path: "string", + wantErr: true, + }, + // nested field + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'nested' + type: TYPE_MESSAGE + type_name: 'AnotherMessage' + label: LABEL_OPTIONAL + number: 1 + > + field < + name: 'terminal' + type: TYPE_BOOL + label: LABEL_OPTIONAL + number: 2 + > + > + message_type < + name: 'AnotherMessage' + field < + name: 'nested2' + type: TYPE_MESSAGE + type_name: 'ExampleMessage' + label: LABEL_OPTIONAL + number: 1 + > + > + `, + path: "nested.nested2.nested.nested2.nested.nested2.terminal", + wantErr: false, + }, + // non aggregate field on the path + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'nested' + type: TYPE_MESSAGE + type_name: 'AnotherMessage' + label: LABEL_OPTIONAL + number: 1 + > + field < + name: 'terminal' + type: TYPE_BOOL + label: LABEL_OPTIONAL + number: 2 + > + > + message_type < + name: 'AnotherMessage' + field < + name: 'nested2' + type: TYPE_MESSAGE + type_name: 'ExampleMessage' + label: LABEL_OPTIONAL + number: 1 + > + > + `, + path: "nested.terminal.nested2", + wantErr: true, + }, + // repeated field + { + src: ` + name: 'example.proto' + package: 'example' + message_type < + name: 'ExampleMessage' + field < + name: 'nested' + type: TYPE_MESSAGE + type_name: 'AnotherMessage' + label: LABEL_OPTIONAL + number: 1 + > + field < + name: 'terminal' + type: TYPE_BOOL + label: LABEL_OPTIONAL + number: 2 + > + > + message_type < + name: 'AnotherMessage' + field < + name: 'nested2' + type: TYPE_MESSAGE + type_name: 'ExampleMessage' + label: LABEL_REPEATED + number: 1 + > + > + `, + path: "nested.nested2.terminal", + wantErr: true, + }, + } { + var file descriptor.FileDescriptorProto + if err := proto.UnmarshalText(spec.src, &file); err != nil { + t.Fatalf("proto.Unmarshal(%s) failed with %v; want success", spec.src, err) + } + reg := NewRegistry() + reg.loadFile(&file) + f, err := reg.LookupFile(file.GetName()) + if err != nil { + t.Fatalf("reg.LookupFile(%q) failed with %v; want success; on file=%s", file.GetName(), err, spec.src) + } + _, err = reg.resolveFiledPath(f.Messages[0], spec.path) + if got, want := err != nil, spec.wantErr; got != want { + if want { + t.Errorf("reg.resolveFiledPath(%q, %q) succeeded; want an error", f.Messages[0].GetName(), spec.path) + continue + } + t.Errorf("reg.resolveFiledPath(%q, %q) failed with %v; want success", f.Messages[0].GetName(), spec.path, err) + } + } +} diff --git a/protoc-gen-grpc-gateway/descriptor/types.go b/protoc-gen-grpc-gateway/descriptor/types.go index 6d8eeade127..089aca06c42 100644 --- a/protoc-gen-grpc-gateway/descriptor/types.go +++ b/protoc-gen-grpc-gateway/descriptor/types.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/gengo/grpc-gateway/internal" "github.com/gengo/grpc-gateway/protoc-gen-grpc-gateway/httprule" descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor" ) @@ -102,22 +103,42 @@ type Method struct { Service *Service *descriptor.MethodDescriptorProto - // PathTmpl is path template where this method is mapped to. - PathTmpl httprule.Template - // HTTPMethod is the HTTP method which this method is mapped to. - HTTPMethod string // RequestType is the message type of requests to this method. RequestType *Message // ResponseType is the message type of responses from this method. ResponseType *Message + Bindings []*Binding +} + +// Binding describes how an HTTP endpoint is bound to a gRPC method. +type Binding struct { + // Method is the method which the endpoint is bound to. + Method *Method + // Index is a zero-origin index of the binding in the target method + Index int + // PathTmpl is path template where this method is mapped to. + PathTmpl httprule.Template + // HTTPMethod is the HTTP method which this method is mapped to. + HTTPMethod string // PathParams is the list of parameters provided in HTTP request paths. PathParams []Parameter - // QueryParam is the list of parameters provided in HTTP query strings. - QueryParams []Parameter // Body describes parameters provided in HTTP request body. Body *Body } +// ExplicitParams returns a list of explicitly bound parameters of "b", +// i.e. a union of field path for body and field paths for path parameters. +func (b *Binding) ExplicitParams() []string { + var result []string + if b.Body != nil { + result = append(result, b.Body.FieldPath.String()) + } + for _, p := range b.PathParams { + result = append(result, p.FieldPath.String()) + } + return result +} + // Field wraps descriptor.FieldDescriptorProto for richer features. type Field struct { // Message is the message type which this field belongs to. @@ -204,15 +225,15 @@ type FieldPathComponent struct { // RHS returns a right-hand-side expression in go for this field. func (c FieldPathComponent) RHS() string { - return toCamel(c.Name) + return internal.PascalFromSnake(c.Name) } // LHS returns a left-hand-side expression in go for this field. func (c FieldPathComponent) LHS() string { if c.Target.Message.File.proto2() { - return fmt.Sprintf("Get%s()", toCamel(c.Name)) + return fmt.Sprintf("Get%s()", internal.PascalFromSnake(c.Name)) } - return toCamel(c.Name) + return internal.PascalFromSnake(c.Name) } var ( diff --git a/protoc-gen-grpc-gateway/gengateway/generator.go b/protoc-gen-grpc-gateway/gengateway/generator.go index 78b8ba5fd05..e04dd7e9b9f 100644 --- a/protoc-gen-grpc-gateway/gengateway/generator.go +++ b/protoc-gen-grpc-gateway/gengateway/generator.go @@ -31,6 +31,7 @@ func New(reg *descriptor.Registry) *generator { "io", "net/http", "github.com/gengo/grpc-gateway/runtime", + "github.com/gengo/grpc-gateway/internal", "github.com/golang/glog", "github.com/golang/protobuf/proto", "golang.org/x/net/context", diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index 387039a57e3..8d1d3e66a85 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -5,6 +5,7 @@ import ( "strings" "text/template" + "github.com/gengo/grpc-gateway/internal" "github.com/gengo/grpc-gateway/protoc-gen-grpc-gateway/descriptor" ) @@ -13,6 +14,42 @@ type param struct { Imports []descriptor.GoPackage } +type binding struct { + *descriptor.Binding +} + +// HasQueryParam determines if the binding needs parameters in query string. +// +// It sometimes returns true even though actually the binding does not need. +// But it is not serious because it just results in a small amount of extra codes generated. +func (b binding) HasQueryParam() bool { + if b.Body != nil && len(b.Body.FieldPath) == 0 { + return false + } + fields := make(map[string]bool) + for _, f := range b.Method.RequestType.Fields { + fields[f.GetName()] = true + } + if b.Body != nil { + delete(fields, b.Body.FieldPath.String()) + } + for _, p := range b.PathParams { + delete(fields, p.FieldPath.String()) + } + return len(fields) > 0 +} + +func (b binding) QueryParamFilter() *internal.DoubleArray { + var seqs [][]string + if b.Body != nil { + seqs = append(seqs, strings.Split(b.Body.FieldPath.String(), ".")) + } + for _, p := range b.PathParams { + seqs = append(seqs, strings.Split(p.FieldPath.String(), ".")) + } + return internal.NewDoubleArray(seqs) +} + func applyTemplate(p param) (string, error) { w := bytes.NewBuffer(nil) if err := headerTemplate.Execute(w, p); err != nil { @@ -22,8 +59,10 @@ func applyTemplate(p param) (string, error) { for _, svc := range p.Services { for _, meth := range svc.Methods { methodSeen = true - if err := handlerTemplate.Execute(w, meth); err != nil { - return "", err + for _, b := range meth.Bindings { + if err := handlerTemplate.Execute(w, binding{Binding: b}); err != nil { + return "", err + } } } } @@ -58,10 +97,11 @@ var _ codes.Code var _ io.Reader var _ = runtime.String var _ = json.Marshal +var _ = internal.PascalFromSnake `)) handlerTemplate = template.Must(template.New("handler").Parse(` -{{if .GetClientStreaming}} +{{if .Method.GetClientStreaming}} {{template "client-streaming-request-func" .}} {{else}} {{template "client-rpc-request-func" .}} @@ -69,22 +109,22 @@ var _ = json.Marshal `)) _ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.Replace(` -{{if .GetServerStreaming}} -func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Service.GetName}}Client, req *http.Request, pathParams map[string]string) ({{.Service.GetName}}_{{.GetName}}Client, error) +{{if .Method.GetServerStreaming}} +func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.GetName}}_{{.Method.GetName}}Client, error) {{else}} -func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Service.GetName}}Client, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) +func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {{end}}`, "\n", "", -1))) _ = template.Must(handlerTemplate.New("client-streaming-request-func").Parse(` {{template "request-func-signature" .}} { - stream, err := client.{{.GetName}}(ctx) + stream, err := client.{{.Method.GetName}}(ctx) if err != nil { glog.Errorf("Failed to start streaming: %v", err) return nil, err } dec := json.NewDecoder(req.Body) for { - var protoReq {{.RequestType.GoType .Service.File.GoPkg.Path}} + var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}} err = dec.Decode(&protoReq) if err == io.EOF { break @@ -98,7 +138,7 @@ func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Se return nil, err } } -{{if .GetServerStreaming}} +{{if .Method.GetServerStreaming}} if err = stream.CloseSend(); err != nil { glog.Errorf("Failed to terminate client stream: %v", err) return nil, err @@ -111,14 +151,13 @@ func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Se `)) _ = template.Must(handlerTemplate.New("client-rpc-request-func").Parse(` +{{if .HasQueryParam}} +var ( + filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}} = {{.QueryParamFilter | printf "%#v"}} +) +{{end}} {{template "request-func-signature" .}} { - var protoReq {{.RequestType.GoType .Service.File.GoPkg.Path}} - {{range $param := .QueryParams}} - protoReq.{{$param.RHS "protoReq"}}, err = {{$param.ConvertFuncExpr}}(req.FormValue({{$param | printf "%q"}})) - if err != nil { - return nil, err - } - {{end}} + var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}} {{if .Body}} if err = json.NewDecoder(req.Body).Decode(&{{.Body.RHS "protoReq"}}); err != nil { return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) @@ -138,8 +177,13 @@ func request_{{.Service.GetName}}_{{.GetName}}(ctx context.Context, client {{.Se } {{end}} {{end}} +{{if .HasQueryParam}} + if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}); err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + } +{{end}} - return client.{{.GetName}}(ctx, &protoReq) + return client.{{.Method.GetName}}(ctx, &protoReq) }`)) trailerTemplate = template.Must(template.New("trailer").Parse(` @@ -174,8 +218,9 @@ func Register{{$svc.GetName}}HandlerFromEndpoint(ctx context.Context, mux *runti func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { client := New{{$svc.GetName}}Client(conn) {{range $m := $svc.Methods}} - mux.Handle({{$m.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_{{$svc.GetName}}_{{$m.GetName}}(ctx, client, req, pathParams) + {{range $b := $m.Bindings}} + mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + resp, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(ctx, client, req, pathParams) if err != nil { runtime.HTTPError(w, err) return @@ -187,12 +232,15 @@ func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, {{end}} }) {{end}} + {{end}} return nil } var ( {{range $m := $svc.Methods}} - pattern_{{$svc.GetName}}_{{$m.GetName}} = runtime.MustPattern(runtime.NewPattern({{$m.PathTmpl.Version}}, {{$m.PathTmpl.OpCodes | printf "%#v"}}, {{$m.PathTmpl.Pool | printf "%#v"}}, {{$m.PathTmpl.Verb | printf "%q"}})) + {{range $b := $m.Bindings}} + pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = runtime.MustPattern(runtime.NewPattern({{$b.PathTmpl.Version}}, {{$b.PathTmpl.OpCodes | printf "%#v"}}, {{$b.PathTmpl.Pool | printf "%#v"}}, {{$b.PathTmpl.Verb | printf "%q"}})) + {{end}} {{end}} ) {{end}}`)) diff --git a/protoc-gen-grpc-gateway/gengateway/template_test.go b/protoc-gen-grpc-gateway/gengateway/template_test.go index 3f0ecfe9313..f52287f16cc 100644 --- a/protoc-gen-grpc-gateway/gengateway/template_test.go +++ b/protoc-gen-grpc-gateway/gengateway/template_test.go @@ -18,11 +18,11 @@ func crossLinkFixture(f *descriptor.File) *descriptor.File { svc.File = f for _, m := range svc.Methods { m.Service = svc - for _, param := range m.PathParams { - param.Method = m - } - for _, param := range m.QueryParams { - param.Method = m + for _, b := range m.Bindings { + b.Method = m + for _, param := range b.PathParams { + param.Method = m + } } } } @@ -64,10 +64,14 @@ func TestApplyTemplateHeader(t *testing.T) { Methods: []*descriptor.Method{ { MethodDescriptorProto: meth, - HTTPMethod: "GET", RequestType: msg, ResponseType: msg, - Body: &descriptor.Body{FieldPath: nil}, + Bindings: []*descriptor.Binding{ + { + HTTPMethod: "GET", + Body: &descriptor.Body{FieldPath: nil}, + }, + }, }, }, }, @@ -129,11 +133,11 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) { }{ { serverStreaming: false, - sigWant: `func request_ExampleService_Echo(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {`, }, { serverStreaming: true, - sigWant: `func request_ExampleService_Echo(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, }, } { meth.ServerStreaming = proto.Bool(spec.serverStreaming) @@ -175,39 +179,43 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) { Methods: []*descriptor.Method{ { MethodDescriptorProto: meth, - HTTPMethod: "POST", - PathTmpl: httprule.Template{ - Version: 1, - OpCodes: []int{0, 0}, - }, - RequestType: msg, - ResponseType: msg, - PathParams: []descriptor.Parameter{ + RequestType: msg, + ResponseType: msg, + Bindings: []*descriptor.Binding{ { - FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ - { - Name: "nested", - Target: nestedField, - }, + HTTPMethod: "POST", + PathTmpl: httprule.Template{ + Version: 1, + OpCodes: []int{0, 0}, + }, + PathParams: []descriptor.Parameter{ { - Name: "int32", + FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ + { + Name: "nested", + Target: nestedField, + }, + { + Name: "int32", + Target: intField, + }, + }), Target: intField, }, - }), - Target: intField, - }, - }, - Body: &descriptor.Body{ - FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ - { - Name: "nested", - Target: nestedField, }, - { - Name: "bool", - Target: boolField, + Body: &descriptor.Body{ + FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ + { + Name: "nested", + Target: nestedField, + }, + { + Name: "bool", + Target: boolField, + }, + }), }, - }), + }, }, }, }, @@ -234,7 +242,7 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) { if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } - if want := `pattern_ExampleService_Echo = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { + if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } } @@ -286,11 +294,11 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { }{ { serverStreaming: false, - sigWant: `func request_ExampleService_Echo(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (msg proto.Message, err error) {`, }, { serverStreaming: true, - sigWant: `func request_ExampleService_Echo(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, }, } { meth.ServerStreaming = proto.Bool(spec.serverStreaming) @@ -332,39 +340,43 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { Methods: []*descriptor.Method{ { MethodDescriptorProto: meth, - HTTPMethod: "POST", - PathTmpl: httprule.Template{ - Version: 1, - OpCodes: []int{0, 0}, - }, - RequestType: msg, - ResponseType: msg, - PathParams: []descriptor.Parameter{ + RequestType: msg, + ResponseType: msg, + Bindings: []*descriptor.Binding{ { - FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ - { - Name: "nested", - Target: nestedField, - }, + HTTPMethod: "POST", + PathTmpl: httprule.Template{ + Version: 1, + OpCodes: []int{0, 0}, + }, + PathParams: []descriptor.Parameter{ { - Name: "int32", + FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ + { + Name: "nested", + Target: nestedField, + }, + { + Name: "int32", + Target: intField, + }, + }), Target: intField, }, - }), - Target: intField, - }, - }, - Body: &descriptor.Body{ - FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ - { - Name: "nested", - Target: nestedField, }, - { - Name: "bool", - Target: boolField, + Body: &descriptor.Body{ + FieldPath: descriptor.FieldPath([]descriptor.FieldPathComponent{ + { + Name: "nested", + Target: nestedField, + }, + { + Name: "bool", + Target: boolField, + }, + }), }, - }), + }, }, }, }, @@ -385,7 +397,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { if want := `func RegisterExampleServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } - if want := `pattern_ExampleService_Echo = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { + if want := `pattern_ExampleService_Echo_0 = runtime.MustPattern(runtime.NewPattern(1, []int{0, 0}, []string(nil), ""))`; !strings.Contains(got, want) { t.Errorf("applyTemplate(%#v) = %s; want to contain %s", file, got, want) } } diff --git a/runtime/mux.go b/runtime/mux.go index 0fbd365c640..3538fc3820f 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -48,6 +48,13 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { verb, components[l-1] = c[:idx], c[idx+1:] } + if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && isPathLengthFallback(r) { + r.Method = strings.ToUpper(override) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } for _, h := range s.handlers[r.Method] { pathParams, err := h.pat.Match(components, verb) if err != nil { @@ -58,21 +65,37 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // lookup other methods to determine if it is MethodNotAllowed + // lookup other methods to handle fallback from GET to POST and + // to determine if it is MethodNotAllowed or NotFound. for m, handlers := range s.handlers { if m == r.Method { continue } for _, h := range handlers { - if _, err := h.pat.Match(components, verb); err == nil { - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + pathParams, err := h.pat.Match(components, verb) + if err != nil { + continue + } + // X-HTTP-Method-Override is optional. Always allow fallback to POST. + if isPathLengthFallback(r) { + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + h.h(w, r, pathParams) return } + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return } } http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) } +func isPathLengthFallback(r *http.Request) bool { + return r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" +} + type handler struct { pat Pattern h HandlerFunc diff --git a/runtime/mux_test.go b/runtime/mux_test.go new file mode 100644 index 00000000000..7394bb13a94 --- /dev/null +++ b/runtime/mux_test.go @@ -0,0 +1,195 @@ +package runtime_test + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gengo/grpc-gateway/internal" + "github.com/gengo/grpc-gateway/runtime" +) + +func TestMuxServeHTTP(t *testing.T) { + type stubPattern struct { + method string + ops []int + pool []string + } + for _, spec := range []struct { + patterns []stubPattern + + reqMethod string + reqPath string + headers map[string]string + + respStatus int + respContent string + }{ + { + patterns: nil, + reqMethod: "GET", + reqPath: "/", + respStatus: http.StatusNotFound, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "GET", + reqPath: "/foo", + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "GET", + reqPath: "/bar", + respStatus: http.StatusNotFound, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "GET", + ops: []int{int(internal.OpPush), 0}, + }, + }, + reqMethod: "GET", + reqPath: "/foo", + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "POST", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + respStatus: http.StatusOK, + respContent: "POST /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "DELETE", + reqPath: "/foo", + respStatus: http.StatusMethodNotAllowed, + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + }, + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + { + method: "POST", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/x-www-form-urlencoded", + "X-HTTP-Method-Override": "GET", + }, + respStatus: http.StatusOK, + respContent: "GET /foo", + }, + { + patterns: []stubPattern{ + { + method: "GET", + ops: []int{int(internal.OpLitPush), 0}, + pool: []string{"foo"}, + }, + }, + reqMethod: "POST", + reqPath: "/foo", + headers: map[string]string{ + "Content-Type": "application/json", + }, + respStatus: http.StatusMethodNotAllowed, + }, + } { + mux := runtime.NewServeMux() + for _, p := range spec.patterns { + func(p stubPattern) { + pat, err := runtime.NewPattern(1, p.ops, p.pool, "") + if err != nil { + t.Fatalf("runtime.NewPattern(1, %#v, %#v, %q) failed with %v; want success", p.ops, p.pool, "", err) + } + mux.Handle(p.method, pat, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + fmt.Fprintf(w, "%s %s", p.method, pat.String()) + }) + }(p) + } + + url := fmt.Sprintf("http://host.example%s", spec.reqPath) + r, err := http.NewRequest(spec.reqMethod, url, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", spec.reqMethod, url, err) + } + for name, value := range spec.headers { + r.Header.Set(name, value) + } + w := httptest.NewRecorder() + mux.ServeHTTP(w, r) + + if got, want := w.Code, spec.respStatus; got != want { + t.Errorf("w.Code = %d; want %d; patterns=%v; req=%v", got, want, spec.patterns, r) + } + if spec.respContent != "" { + if got, want := w.Body.String(), spec.respContent; got != want { + t.Errorf("w.Body = %q; want %q; patterns=%v; req=%v", got, want, spec.patterns, r) + } + } + } +} diff --git a/runtime/query.go b/runtime/query.go new file mode 100644 index 00000000000..2ede4735754 --- /dev/null +++ b/runtime/query.go @@ -0,0 +1,118 @@ +package runtime + +import ( + "fmt" + "net/url" + "reflect" + "strings" + + "github.com/gengo/grpc-gateway/internal" + "github.com/golang/glog" + "github.com/golang/protobuf/proto" +) + +// PopulateQueryParameters populates "values" into "msg". +// A value is ignored if its key starts with one of the elements in "filters". +func PopulateQueryParameters(msg proto.Message, values url.Values, filter *internal.DoubleArray) error { + for key, values := range values { + fieldPath := strings.Split(key, ".") + if filter.HasCommonPrefix(fieldPath) { + continue + } + if err := populateQueryParameter(msg, fieldPath, values); err != nil { + return err + } + } + return nil +} + +func populateQueryParameter(msg proto.Message, fieldPath []string, values []string) error { + m := reflect.ValueOf(msg) + if m.Kind() != reflect.Ptr { + return fmt.Errorf("unexpected type %T: %v", msg, msg) + } + m = m.Elem() + for i, fieldName := range fieldPath { + isLast := i == len(fieldPath)-1 + if !isLast && m.Kind() != reflect.Struct { + return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, ".")) + } + f := m.FieldByName(internal.PascalFromSnake(fieldName)) + if !f.IsValid() { + glog.Warningf("field not found in %T: %s", msg, strings.Join(fieldPath, ".")) + return nil + } + + switch f.Kind() { + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64: + m = f + case reflect.Slice: + // TODO(yugui) Support []byte + if !isLast { + return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, ".")) + } + return populateRepeatedField(f, values) + case reflect.Ptr: + if f.IsNil() { + m = reflect.New(f.Type().Elem()) + f.Set(m) + } + m = f.Elem() + continue + default: + return fmt.Errorf("unexpected type %s in %T", f.Type(), msg) + } + } + switch len(values) { + case 0: + return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, ".")) + case 1: + default: + glog.Warningf("too many field values: %s", strings.Join(fieldPath, ".")) + } + return populateField(m, values[0]) +} + +func populateRepeatedField(f reflect.Value, values []string) error { + elemType := f.Type().Elem() + conv, ok := convFromType[elemType.Kind()] + if !ok { + return fmt.Errorf("unsupported field type %s", elemType) + } + f.Set(reflect.MakeSlice(f.Type(), len(values), len(values))) + for i, v := range values { + result := conv.Call([]reflect.Value{reflect.ValueOf(v)}) + if err := result[1].Interface(); err != nil { + return err.(error) + } + f.Index(i).Set(result[0]) + } + return nil +} + +func populateField(f reflect.Value, value string) error { + conv, ok := convFromType[f.Kind()] + if !ok { + return fmt.Errorf("unsupported field type %T", f) + } + result := conv.Call([]reflect.Value{reflect.ValueOf(value)}) + if err := result[1].Interface(); err != nil { + return err.(error) + } + f.Set(result[0]) + return nil +} + +var ( + convFromType = map[reflect.Kind]reflect.Value{ + reflect.String: reflect.ValueOf(String), + reflect.Bool: reflect.ValueOf(Bool), + reflect.Float64: reflect.ValueOf(Float64), + reflect.Float32: reflect.ValueOf(Float32), + reflect.Int64: reflect.ValueOf(Int64), + reflect.Int32: reflect.ValueOf(Int32), + reflect.Uint64: reflect.ValueOf(Uint64), + reflect.Uint32: reflect.ValueOf(Uint32), + // TODO(yugui) Support []byte + } +) diff --git a/runtime/query_test.go b/runtime/query_test.go new file mode 100644 index 00000000000..5c238d53540 --- /dev/null +++ b/runtime/query_test.go @@ -0,0 +1,306 @@ +package runtime_test + +import ( + "net/url" + "testing" + + "github.com/gengo/grpc-gateway/internal" + "github.com/gengo/grpc-gateway/runtime" + "github.com/golang/protobuf/proto" +) + +func TestPopulateParameters(t *testing.T) { + for _, spec := range []struct { + values url.Values + filter *internal.DoubleArray + want proto.Message + }{ + { + values: url.Values{ + "float_value": {"1.5"}, + "double_value": {"2.5"}, + "int64_value": {"-1"}, + "int32_value": {"-2"}, + "uint64_value": {"3"}, + "uint32_value": {"4"}, + "bool_value": {"true"}, + "string_value": {"str"}, + "repeated_value": {"a", "b", "c"}, + }, + filter: internal.NewDoubleArray(nil), + want: &proto3Message{ + FloatValue: 1.5, + DoubleValue: 2.5, + Int64Value: -1, + Int32Value: -2, + Uint64Value: 3, + Uint32Value: 4, + BoolValue: true, + StringValue: "str", + RepeatedValue: []string{"a", "b", "c"}, + }, + }, + { + values: url.Values{ + "float_value": {"1.5"}, + "double_value": {"2.5"}, + "int64_value": {"-1"}, + "int32_value": {"-2"}, + "uint64_value": {"3"}, + "uint32_value": {"4"}, + "bool_value": {"true"}, + "string_value": {"str"}, + "repeated_value": {"a", "b", "c"}, + }, + filter: internal.NewDoubleArray(nil), + want: &proto2Message{ + FloatValue: proto.Float32(1.5), + DoubleValue: proto.Float64(2.5), + Int64Value: proto.Int64(-1), + Int32Value: proto.Int32(-2), + Uint64Value: proto.Uint64(3), + Uint32Value: proto.Uint32(4), + BoolValue: proto.Bool(true), + StringValue: proto.String("str"), + RepeatedValue: []string{"a", "b", "c"}, + }, + }, + { + values: url.Values{ + "nested.nested.nested.repeated_value": {"a", "b", "c"}, + "nested.nested.nested.string_value": {"s"}, + "nested.nested.string_value": {"t"}, + "nested.string_value": {"u"}, + }, + filter: internal.NewDoubleArray(nil), + want: &proto3Message{ + Nested: &proto2Message{ + Nested: &proto3Message{ + Nested: &proto2Message{ + RepeatedValue: []string{"a", "b", "c"}, + StringValue: proto.String("s"), + }, + StringValue: "t", + }, + StringValue: proto.String("u"), + }, + }, + }, + { + values: url.Values{ + "uint64_value": {"1", "2", "3", "4", "5"}, + }, + filter: internal.NewDoubleArray(nil), + want: &proto3Message{ + Uint64Value: 1, + }, + }, + } { + msg := proto.Clone(spec.want) + msg.Reset() + err := runtime.PopulateQueryParameters(msg, spec.values, spec.filter) + if err != nil { + t.Errorf("runtime.PoplateQueryParameters(msg, %v, %v) failed with %v; want success", spec.values, spec.filter, err) + continue + } + if got, want := msg, spec.want; !proto.Equal(got, want) { + t.Errorf("runtime.PopulateQueryParameters(msg, %v, %v = %v; want %v", spec.values, spec.filter, got, want) + } + } +} + +func TestPopulateParametersWithFilters(t *testing.T) { + for _, spec := range []struct { + values url.Values + filter *internal.DoubleArray + want proto.Message + }{ + { + values: url.Values{ + "bool_value": {"true"}, + "string_value": {"str"}, + "repeated_value": {"a", "b", "c"}, + }, + filter: internal.NewDoubleArray([][]string{ + {"bool_value"}, {"repeated_value"}, + }), + want: &proto3Message{ + StringValue: "str", + }, + }, + { + values: url.Values{ + "nested.nested.bool_value": {"true"}, + "nested.nested.string_value": {"str"}, + "nested.string_value": {"str"}, + "string_value": {"str"}, + }, + filter: internal.NewDoubleArray([][]string{ + {"nested"}, + }), + want: &proto3Message{ + StringValue: "str", + }, + }, + { + values: url.Values{ + "nested.nested.bool_value": {"true"}, + "nested.nested.string_value": {"str"}, + "nested.string_value": {"str"}, + "string_value": {"str"}, + }, + filter: internal.NewDoubleArray([][]string{ + {"nested", "nested"}, + }), + want: &proto3Message{ + Nested: &proto2Message{ + StringValue: proto.String("str"), + }, + StringValue: "str", + }, + }, + { + values: url.Values{ + "nested.nested.bool_value": {"true"}, + "nested.nested.string_value": {"str"}, + "nested.string_value": {"str"}, + "string_value": {"str"}, + }, + filter: internal.NewDoubleArray([][]string{ + {"nested", "nested", "string_value"}, + }), + want: &proto3Message{ + Nested: &proto2Message{ + StringValue: proto.String("str"), + Nested: &proto3Message{ + BoolValue: true, + }, + }, + StringValue: "str", + }, + }, + } { + msg := proto.Clone(spec.want) + msg.Reset() + err := runtime.PopulateQueryParameters(msg, spec.values, spec.filter) + if err != nil { + t.Errorf("runtime.PoplateQueryParameters(msg, %v, %v) failed with %v; want success", spec.values, spec.filter, err) + continue + } + if got, want := msg, spec.want; !proto.Equal(got, want) { + t.Errorf("runtime.PopulateQueryParameters(msg, %v, %v = %v; want %v", spec.values, spec.filter, got, want) + } + } +} + +type proto3Message struct { + Nested *proto2Message `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"` + FloatValue float32 `protobuf:"fixed32,2,opt,name=float_value" json:"float_value,omitempty"` + DoubleValue float64 `protobuf:"fixed64,3,opt,name=double_value" json:"double_value,omitempty"` + Int64Value int64 `protobuf:"varint,4,opt,name=int64_value" json:"int64_value,omitempty"` + Int32Value int32 `protobuf:"varint,5,opt,name=int32_value" json:"int32_value,omitempty"` + Uint64Value uint64 `protobuf:"varint,6,opt,name=uint64_value" json:"uint64_value,omitempty"` + Uint32Value uint32 `protobuf:"varint,7,opt,name=uint32_value" json:"uint32_value,omitempty"` + BoolValue bool `protobuf:"varint,8,opt,name=bool_value" json:"bool_value,omitempty"` + StringValue string `protobuf:"bytes,9,opt,name=string_value" json:"string_value,omitempty"` + RepeatedValue []string `protobuf:"bytes,10,rep,name=repeated_value" json:"repeated_value,omitempty"` +} + +func (m *proto3Message) Reset() { *m = proto3Message{} } +func (m *proto3Message) String() string { return proto.CompactTextString(m) } +func (*proto3Message) ProtoMessage() {} + +func (m *proto3Message) GetNested() *proto2Message { + if m != nil { + return m.Nested + } + return nil +} + +type proto2Message struct { + Nested *proto3Message `protobuf:"bytes,1,opt,name=nested" json:"nested,omitempty"` + FloatValue *float32 `protobuf:"fixed32,2,opt,name=float_value" json:"float_value,omitempty"` + DoubleValue *float64 `protobuf:"fixed64,3,opt,name=double_value" json:"double_value,omitempty"` + Int64Value *int64 `protobuf:"varint,4,opt,name=int64_value" json:"int64_value,omitempty"` + Int32Value *int32 `protobuf:"varint,5,opt,name=int32_value" json:"int32_value,omitempty"` + Uint64Value *uint64 `protobuf:"varint,6,opt,name=uint64_value" json:"uint64_value,omitempty"` + Uint32Value *uint32 `protobuf:"varint,7,opt,name=uint32_value" json:"uint32_value,omitempty"` + BoolValue *bool `protobuf:"varint,8,opt,name=bool_value" json:"bool_value,omitempty"` + StringValue *string `protobuf:"bytes,9,opt,name=string_value" json:"string_value,omitempty"` + RepeatedValue []string `protobuf:"bytes,10,rep,name=repeated_value" json:"repeated_value,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *proto2Message) Reset() { *m = proto2Message{} } +func (m *proto2Message) String() string { return proto.CompactTextString(m) } +func (*proto2Message) ProtoMessage() {} + +func (m *proto2Message) GetNested() *proto3Message { + if m != nil { + return m.Nested + } + return nil +} + +func (m *proto2Message) GetFloatValue() float32 { + if m != nil && m.FloatValue != nil { + return *m.FloatValue + } + return 0 +} + +func (m *proto2Message) GetDoubleValue() float64 { + if m != nil && m.DoubleValue != nil { + return *m.DoubleValue + } + return 0 +} + +func (m *proto2Message) GetInt64Value() int64 { + if m != nil && m.Int64Value != nil { + return *m.Int64Value + } + return 0 +} + +func (m *proto2Message) GetInt32Value() int32 { + if m != nil && m.Int32Value != nil { + return *m.Int32Value + } + return 0 +} + +func (m *proto2Message) GetUint64Value() uint64 { + if m != nil && m.Uint64Value != nil { + return *m.Uint64Value + } + return 0 +} + +func (m *proto2Message) GetUint32Value() uint32 { + if m != nil && m.Uint32Value != nil { + return *m.Uint32Value + } + return 0 +} + +func (m *proto2Message) GetBoolValue() bool { + if m != nil && m.BoolValue != nil { + return *m.BoolValue + } + return false +} + +func (m *proto2Message) GetStringValue() string { + if m != nil && m.StringValue != nil { + return *m.StringValue + } + return "" +} + +func (m *proto2Message) GetRepeatedValue() []string { + if m != nil { + return m.RepeatedValue + } + return nil +}