diff --git a/flask_apispec/annotations.py b/flask_apispec/annotations.py index 6e1992c..f3b16e8 100644 --- a/flask_apispec/annotations.py +++ b/flask_apispec/annotations.py @@ -37,8 +37,8 @@ def wrapper(func): return activate(func) return wrapper - -def marshal_with(schema, code='default', description='', inherit=None, apply=None): +def marshal_with(schema, code='default', description='', content_type=None, + inherit=None, apply=None): """Marshal the return value of the decorated view function using the specified schema. @@ -57,6 +57,7 @@ def get_pet(pet_id): :param schema: :class:`Schema ` class or instance, or `None` :param code: Optional HTTP response code :param description: Optional response description + :param content_type: Optional response content type header (only used in OpenAPI 3.x) :param inherit: Inherit schemas from parent classes :param apply: Marshal response with specified schema """ @@ -65,6 +66,7 @@ def wrapper(func): code: { 'schema': schema or {}, 'description': description, + 'content_type': content_type, }, } annotate(func, 'schemas', [options], inherit=inherit, apply=apply) diff --git a/flask_apispec/apidoc.py b/flask_apispec/apidoc.py index de539cd..08f644f 100644 --- a/flask_apispec/apidoc.py +++ b/flask_apispec/apidoc.py @@ -93,7 +93,26 @@ def get_parameters(self, rule, view, docs, parent=None): def get_responses(self, view, parent=None): annotation = resolve_annotations(view, 'schemas', parent) - return merge_recursive(annotation.options) + options = [] + for option in annotation.options: + exploded = {} + for status_code, meta in option.items(): + if self.spec.openapi_version.major < 3: + meta.pop('content_type', None) + exploded[status_code] = meta + else: + content_type = meta['content_type'] or 'application/json' + exploded[status_code] = { + 'content': { + content_type: { + 'schema': meta['schema'] + } + } + } + if meta['description']: + exploded[status_code]['description'] = meta['description'] + options.append(exploded) + return merge_recursive(options) class ViewConverter(Converter): diff --git a/tests/test_openapi.py b/tests/test_openapi.py index afa00fa..8d941df 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -24,6 +24,15 @@ def spec(marshmallow_plugin): plugins=[marshmallow_plugin], ) +@pytest.fixture +def spec_oapi3(marshmallow_plugin): + return APISpec( + title='title', + version='v1', + openapi_version='3.0', + plugins=[marshmallow_plugin], + ) + @pytest.fixture() def openapi(marshmallow_plugin): return marshmallow_plugin.openapi @@ -88,6 +97,51 @@ def test_responses(self, schemas, path, openapi): def test_tags(self, path): assert path['get']['tags'] == ['band'] +class TestFunctionView_OpenAPI3: + + @pytest.fixture + def function_view(self, app, models, schemas): + @app.route('/bands//') + @doc(tags=['band']) + @use_kwargs({'name': fields.Str(missing='queen')}, locations=('query',)) + @marshal_with(schemas.BandSchema, description='a band', content_type='text/json') + def get_band(band_id): + return models.Band(name='slowdive', genre='spacerock') + + return get_band + + @pytest.fixture + def path(self, app, spec_oapi3, function_view): + converter = ViewConverter(app=app, spec=spec_oapi3) + paths = converter.convert(function_view) + for path in paths: + spec_oapi3.path(**path) + return spec_oapi3._paths['/bands/{band_id}/'] + + def test_params(self, app, path): + params = path['get']['parameters'] + rule = app.url_map._rules_by_endpoint['get_band'][0] + expected = ( + [{ + 'in': 'query', + 'name': 'name', + 'required': False, + 'schema': { + 'type': 'string', + 'default': 'queen', + } + }] + rule_to_params(rule) + ) + assert params == expected + + def test_responses(self, schemas, path, openapi): + response = path['get']['responses']['default'] + assert response['description'] == 'a band' + assert response['content'] == {'text/json': {'schema': {'$ref': ref_path(openapi.spec) + 'Band'}}} + + def test_tags(self, path): + assert path['get']['tags'] == ['band'] + class TestArgSchema: @pytest.fixture