diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 89469e26..e5b8d21a 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -4,6 +4,8 @@ import operator import typing import uuid +import pkg_resources +from packaging import version from collections import OrderedDict from decimal import Decimal from inspect import signature as inspect_signature @@ -20,6 +22,9 @@ ) from .base import FieldInspector, NotHandled, SerializerInspector, call_view_method + +drf_version = pkg_resources.get_distribution("djangorestframework").version + logger = logging.getLogger(__name__) @@ -37,7 +42,7 @@ def add_manual_parameters(self, serializer, parameters): is called only when the serializer is converted into a list of parameters for use in a form data request. :param serializer: serializer instance - :param list[openapi.Parameter] parameters: generated parameters + :param list[openapi.Parameter] parameters: genereated parameters :return: modified parameters :rtype: list[openapi.Parameter] """ @@ -176,7 +181,7 @@ def get_queryset_from_view(view, serializer=None): """Try to get the queryset of the given view :param view: the view instance or class - :param serializer: if given, will check that the view's get_serializer_class return matches this serializer + :param serializer: if given, will check that the view's get_serializer_class return matches this serialzier :return: queryset or ``None`` """ try: @@ -376,7 +381,6 @@ def decimal_field_type(field): (models.AutoField, (openapi.TYPE_INTEGER, None)), (models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)), (models.BooleanField, (openapi.TYPE_BOOLEAN, None)), - (models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)), (models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)), (models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)), (models.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)), @@ -390,7 +394,7 @@ def decimal_field_type(field): (models.TimeField, (openapi.TYPE_STRING, None)), (models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)), (models.CharField, (openapi.TYPE_STRING, None)), -] +] ip_format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6} @@ -402,7 +406,6 @@ def decimal_field_type(field): (serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)), (serializers.RegexField, (openapi.TYPE_STRING, None)), (serializers.CharField, (openapi.TYPE_STRING, None)), - (serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)), (serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)), (serializers.IntegerField, (openapi.TYPE_INTEGER, None)), (serializers.FloatField, (openapi.TYPE_NUMBER, None)), @@ -413,6 +416,15 @@ def decimal_field_type(field): (serializers.ModelField, (openapi.TYPE_STRING, None)), ] +if version.parse(drf_version) < version.parse("3.14.0"): + model_field_to_basic_type.append( + (models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)) + ) + + serializer_field_to_basic_type.append( + (serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)), + ) + basic_type_info = serializer_field_to_basic_type + model_field_to_basic_type @@ -840,3 +852,4 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** return ref return NotHandled +