diff --git a/drf_openapi/codec.py b/drf_openapi/codec.py index 8cee482..9873988 100644 --- a/drf_openapi/codec.py +++ b/drf_openapi/codec.py @@ -17,56 +17,105 @@ SwaggerUIRenderer as _SwaggerUIRenderer +def _get_field_required(field): + return getattr(field, 'required', True) + + +def _get_list_type(list_obj): + types = {type(x) for x in list_obj} + + if len(types) != 1: + return "string" + else: + return { + int: "integer", + str: "string", + bool: "boolean", + float: "number" + }.get(next(iter(types)), "string") + + +def _parse_field_prop(field, source, dest=None, include_nulls=False): + result = {} + field_schema = getattr(field, 'schema', None) + if dest is None: + dest = source + + attr = getattr(field, source, None) + if attr is not None or (attr is None and include_nulls): + result[dest] = attr + + if attr is None and field_schema is not None: + attr = getattr(field_schema, source, None) + + if attr is not None or (attr is None and include_nulls): + result[dest] = attr + + return result + + +def _parse_field(field, add_name=True): + field_type = _get_field_type(field) + + result = { + 'description': _get_field_description(field), + 'type': field_type, + 'required': _get_field_required(field) + } + + # name for all types + if add_name: + result.update(_parse_field_prop(field, 'name')) + + # enum + result.update(_parse_field_prop(field, 'enum')) + if 'enum' in result: + result['type'] = _get_list_type(result['enum']) + + # format + result.update(_parse_field_prop(field, 'format')) + + # string + if field_type == "string": + result.update(_parse_field_prop(field, 'min_length', 'minLength')) + result.update(_parse_field_prop(field, 'max_length', 'maxLength')) + elif field_type == 'array': + if hasattr(field, 'schema'): + items = field.schema.items + else: + items = field.items + + result['items'] = {'type': _get_field_type(items)} + if hasattr(items, 'properties'): + result['items']['properties'] = {name: _parse_field(prop) for name, prop in items.properties.items()} + result['items']['required'] = _get_field_required(items) + elif field_type == 'object': + if hasattr(field, 'schema'): + result['properties'] = { + name: _parse_field(prop) for name, prop in field.schema.properties.items() + } + result['required'] = _get_field_required(field.schema) + elif hasattr(field, 'properties'): + result['properties'] = { + name: _parse_field(prop) for name, prop in field.properties.items() + } + result['required'] = _get_field_required(field) + + return result + + class OpenApiFieldParser: def __init__(self, link, field): self.field = field - self.field_description = _get_field_description(field) - self.field_type = _get_field_type(field) self.location = get_location(link, field) @property def location_string(self): return 'formData' if self.location == 'form' else self.location - def parse_array_field(self): - parameter = { - 'name': self.field.name, - 'required': self.field.required, - 'description': self.field_description, - 'type': self.field_type, - } - - items_type = _get_field_type(self.field.schema.items) - if items_type == 'object': - parameter['items'] = { - 'type': items_type, - 'properties': { - name: { - 'description': _get_field_description(prop), - 'type': _get_field_type(prop) - } for name, prop in self.field.schema.items.properties.items() - } - } - else: - parameter['items'] = { - 'type': items_type, - 'description': _get_field_description(self.field.schema.items) - } - - return parameter - def as_parameter(self): - if self.field_type == 'array': - param = self.parse_array_field() - else: - param = { - 'name': self.field.name, - 'required': self.field.required, - 'description': self.field_description, - 'type': self.field_type - } - + param = _parse_field(self.field, add_name=self.location_string == 'query') param['in'] = self.location_string return param @@ -82,13 +131,7 @@ def as_body_parameter(self, encoding): return param def as_schema_property(self): - if self.field_type == 'array': - return self.parse_array_field() - - return { - 'description': self.field_description, - 'type': self.field_type, - } + return _parse_field(self.field, add_name=False) class OpenAPICodec(_OpenAPICodec): diff --git a/drf_openapi/entities.py b/drf_openapi/entities.py index 70bba2d..46e7bb3 100644 --- a/drf_openapi/entities.py +++ b/drf_openapi/entities.py @@ -14,7 +14,8 @@ from rest_framework.pagination import PageNumberPagination, LimitOffsetPagination, CursorPagination from rest_framework.schemas import SchemaGenerator from rest_framework.schemas.generators import insert_into, distribute_links, LinkNode -from rest_framework.schemas.inspectors import get_pk_description, field_to_schema +from rest_framework.schemas.inspectors import get_pk_description +from .inspectors import field_to_schema from drf_openapi.codec import _get_parameters @@ -397,6 +398,9 @@ def get_response_object(self, response_serializer_class, description): schema = res[0]['schema'] schema['properties'].update(nested_obj) + if 'required' in schema: + schema['required'] += [nested_field_name for nested_field_name in nested_obj if + getattr(serializer.fields[nested_field_name], 'required', True) is True] response_schema = { 'description': description, 'schema': schema diff --git a/drf_openapi/inspectors.py b/drf_openapi/inspectors.py new file mode 100644 index 0000000..9a751f7 --- /dev/null +++ b/drf_openapi/inspectors.py @@ -0,0 +1,82 @@ +from rest_framework.compat import coreschema +from django.utils.encoding import force_text +from rest_framework import serializers +from collections import OrderedDict + + +def field_to_schema(field): + title = force_text(field.label) if field.label else '' + description = force_text(field.help_text) if field.help_text else '' + + if isinstance(field, serializers.CharField): + return coreschema.String( + title=title, + description=description, + max_length=getattr(field, 'max_length', None), + min_length=getattr(field, 'min_length', None) + ) + elif isinstance(field, (serializers.ListSerializer, serializers.ListField)): + child_schema = field_to_schema(field.child) + return coreschema.Array( + items=child_schema, + title=title, + description=description + ) + elif isinstance(field, serializers.Serializer): + return coreschema.Object( + properties=OrderedDict([ + (key, field_to_schema(value)) + for key, value + in field.fields.items() + ]), + required=[field_name for field_name, field_data in field.fields.items() if + getattr(field_data, 'required', True) is True], + title=title, + description=description + ) + elif isinstance(field, serializers.ManyRelatedField): + return coreschema.Array( + items=coreschema.String(), + title=title, + description=description + ) + elif isinstance(field, serializers.RelatedField): + return coreschema.String(title=title, description=description) + elif isinstance(field, serializers.MultipleChoiceField): + return coreschema.Array( + items=coreschema.Enum(enum=list(field.choices.keys())), + title=title, + description=description + ) + elif isinstance(field, serializers.ChoiceField): + return coreschema.Enum( + enum=list(field.choices.keys()), + title=title, + description=description + ) + elif isinstance(field, serializers.BooleanField): + return coreschema.Boolean(title=title, description=description) + elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): + return coreschema.Number(title=title, description=description) + elif isinstance(field, serializers.IntegerField): + return coreschema.Integer(title=title, description=description) + elif isinstance(field, serializers.DateField): + return coreschema.String( + title=title, + description=description, + format='date' + ) + elif isinstance(field, serializers.DateTimeField): + return coreschema.String( + title=title, + description=description, + format='date-time' + ) + + if field.style.get('base_template') == 'textarea.html': + return coreschema.String( + title=title, + description=description, + format='textarea' + ) + return coreschema.String(title=title, description=description)