Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 89 additions & 46 deletions drf_openapi/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,56 +17,105 @@
SwaggerUIRenderer as _SwaggerUIRenderer


def _get_field_required(field):
return getattr(field, 'required', True)


def _get_list_type(list_obj):
types = {type(x) for x in list_obj}

if len(types) != 1:
return "string"
else:
return {
int: "integer",
str: "string",
bool: "boolean",
float: "number"
}.get(next(iter(types)), "string")


def _parse_field_prop(field, source, dest=None, include_nulls=False):
result = {}
field_schema = getattr(field, 'schema', None)
if dest is None:
dest = source

attr = getattr(field, source, None)
if attr is not None or (attr is None and include_nulls):
result[dest] = attr

if attr is None and field_schema is not None:
attr = getattr(field_schema, source, None)

if attr is not None or (attr is None and include_nulls):
result[dest] = attr

return result


def _parse_field(field, add_name=True):
field_type = _get_field_type(field)

result = {
'description': _get_field_description(field),
'type': field_type,
'required': _get_field_required(field)
}

# name for all types
if add_name:
result.update(_parse_field_prop(field, 'name'))

# enum
result.update(_parse_field_prop(field, 'enum'))
if 'enum' in result:
result['type'] = _get_list_type(result['enum'])

# format
result.update(_parse_field_prop(field, 'format'))

# string
if field_type == "string":
result.update(_parse_field_prop(field, 'min_length', 'minLength'))
result.update(_parse_field_prop(field, 'max_length', 'maxLength'))
elif field_type == 'array':
if hasattr(field, 'schema'):
items = field.schema.items
else:
items = field.items

result['items'] = {'type': _get_field_type(items)}
if hasattr(items, 'properties'):
result['items']['properties'] = {name: _parse_field(prop) for name, prop in items.properties.items()}
result['items']['required'] = _get_field_required(items)
elif field_type == 'object':
if hasattr(field, 'schema'):
result['properties'] = {
name: _parse_field(prop) for name, prop in field.schema.properties.items()
}
result['required'] = _get_field_required(field.schema)
elif hasattr(field, 'properties'):
result['properties'] = {
name: _parse_field(prop) for name, prop in field.properties.items()
}
result['required'] = _get_field_required(field)

return result


class OpenApiFieldParser:

def __init__(self, link, field):
self.field = field
self.field_description = _get_field_description(field)
self.field_type = _get_field_type(field)
self.location = get_location(link, field)

@property
def location_string(self):
return 'formData' if self.location == 'form' else self.location

def parse_array_field(self):
parameter = {
'name': self.field.name,
'required': self.field.required,
'description': self.field_description,
'type': self.field_type,
}

items_type = _get_field_type(self.field.schema.items)
if items_type == 'object':
parameter['items'] = {
'type': items_type,
'properties': {
name: {
'description': _get_field_description(prop),
'type': _get_field_type(prop)
} for name, prop in self.field.schema.items.properties.items()
}
}
else:
parameter['items'] = {
'type': items_type,
'description': _get_field_description(self.field.schema.items)
}

return parameter

def as_parameter(self):
if self.field_type == 'array':
param = self.parse_array_field()
else:
param = {
'name': self.field.name,
'required': self.field.required,
'description': self.field_description,
'type': self.field_type
}

param = _parse_field(self.field, add_name=self.location_string == 'query')
param['in'] = self.location_string
return param

Expand All @@ -82,13 +131,7 @@ def as_body_parameter(self, encoding):
return param

def as_schema_property(self):
if self.field_type == 'array':
return self.parse_array_field()

return {
'description': self.field_description,
'type': self.field_type,
}
return _parse_field(self.field, add_name=False)


class OpenAPICodec(_OpenAPICodec):
Expand Down
6 changes: 5 additions & 1 deletion drf_openapi/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from rest_framework.pagination import PageNumberPagination, LimitOffsetPagination, CursorPagination
from rest_framework.schemas import SchemaGenerator
from rest_framework.schemas.generators import insert_into, distribute_links, LinkNode
from rest_framework.schemas.inspectors import get_pk_description, field_to_schema
from rest_framework.schemas.inspectors import get_pk_description
from .inspectors import field_to_schema

from drf_openapi.codec import _get_parameters

Expand Down Expand Up @@ -397,6 +398,9 @@ def get_response_object(self, response_serializer_class, description):

schema = res[0]['schema']
schema['properties'].update(nested_obj)
if 'required' in schema:
schema['required'] += [nested_field_name for nested_field_name in nested_obj if
getattr(serializer.fields[nested_field_name], 'required', True) is True]
response_schema = {
'description': description,
'schema': schema
Expand Down
82 changes: 82 additions & 0 deletions drf_openapi/inspectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from rest_framework.compat import coreschema
from django.utils.encoding import force_text
from rest_framework import serializers
from collections import OrderedDict


def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''

if isinstance(field, serializers.CharField):
return coreschema.String(
title=title,
description=description,
max_length=getattr(field, 'max_length', None),
min_length=getattr(field, 'min_length', None)
)
elif isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
required=[field_name for field_name, field_data in field.fields.items() if
getattr(field_data, 'required', True) is True],
title=title,
description=description
)
elif isinstance(field, serializers.ManyRelatedField):
return coreschema.Array(
items=coreschema.String(),
title=title,
description=description
)
elif isinstance(field, serializers.RelatedField):
return coreschema.String(title=title, description=description)
elif isinstance(field, serializers.MultipleChoiceField):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices.keys())),
title=title,
description=description
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices.keys()),
title=title,
description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
return coreschema.Number(title=title, description=description)
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField):
return coreschema.String(
title=title,
description=description,
format='date'
)
elif isinstance(field, serializers.DateTimeField):
return coreschema.String(
title=title,
description=description,
format='date-time'
)

if field.style.get('base_template') == 'textarea.html':
return coreschema.String(
title=title,
description=description,
format='textarea'
)
return coreschema.String(title=title, description=description)