From 80eea77e5253d453ba1f6754beb3d5691eaad50a Mon Sep 17 00:00:00 2001 From: Jonathan Hiles Date: Wed, 14 Jun 2023 04:30:26 +1000 Subject: [PATCH] Allowed to overwrite resource id in serializer (#1127) Co-authored-by: Oliver Sauder --- AUTHORS | 1 + CHANGELOG.md | 17 +++++++ docs/usage.md | 38 ++++++++++++++++ rest_framework_json_api/relations.py | 14 +++--- rest_framework_json_api/renderers.py | 3 +- rest_framework_json_api/utils.py | 13 ++++++ tests/test_relations.py | 25 ++++++++++- tests/test_utils.py | 17 +++++++ tests/test_views.py | 67 ++++++++++++++++++++++++++++ 9 files changed, 187 insertions(+), 8 deletions(-) diff --git a/AUTHORS b/AUTHORS index 797ab52f..fbd3fe6e 100644 --- a/AUTHORS +++ b/AUTHORS @@ -20,6 +20,7 @@ Jeppe Fihl-Pearson Jerel Unruh Jonas Kiefer Jonas Metzener +Jonathan Hiles Jonathan Senecal Joseba Mendivil Kal diff --git a/CHANGELOG.md b/CHANGELOG.md index 50e8406c..972a3daf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,23 @@ any parts of the framework not mentioned in the documentation should generally b * Replaced `OrderedDict` with `dict` which is also ordered since Python 3.7. * Compound document "include" parameter is only included in the OpenAPI schema if serializer implements `included_serializers`. +* Allowed overwriting of resource id by defining an `id` field on the serializer. + + Example: + ```python + class CustomIdSerializer(serializers.Serializer): + id = serializers.CharField(source='name') + body = serializers.CharField() + ``` + +* Allowed overwriting resource id on resource related fields by creating custom `ResourceRelatedField`. + + Example: + ```python + class CustomResourceRelatedField(relations.ResourceRelatedField): + def get_resource_id(self, value): + return value.name + ``` ### Fixed diff --git a/docs/usage.md b/docs/usage.md index 5da8846a..a7dadcce 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -278,6 +278,44 @@ class MyModelSerializer(serializers.ModelSerializer): # ... ``` +### Overwriting the resource object's id + +Per default the primary key property `pk` on the instance is used as the resource identifier. + +It is possible to overwrite the resource id by defining an `id` field on the serializer like: + +```python +class UserSerializer(serializers.ModelSerializer): + id = serializers.CharField(source='email') + name = serializers.CharField() + + class Meta: + model = User +``` + +This also works on generic serializers. + +In case you also use a model as a resource related field make sure to overwrite `get_resource_id` by creating a custom `ResourceRelatedField` class: + +```python +class UserResourceRelatedField(ResourceRelatedField): + def get_resource_id(self, value): + return value.email + +class GroupSerializer(serializers.ModelSerializer): + user = UserResourceRelatedField(queryset=User.objects) + name = serializers.CharField() + + class Meta: + model = Group +``` + +
+ Note: + When using different id than primary key, make sure that your view + manages it properly by overwriting `get_object`. +
+ ### Setting resource identifier object type You may manually set resource identifier object type by using `resource_name` property on views, serializers, or diff --git a/rest_framework_json_api/relations.py b/rest_framework_json_api/relations.py index bb360ebb..32547253 100644 --- a/rest_framework_json_api/relations.py +++ b/rest_framework_json_api/relations.py @@ -247,17 +247,21 @@ def to_internal_value(self, data): return super().to_internal_value(data["id"]) def to_representation(self, value): - if getattr(self, "pk_field", None) is not None: - pk = self.pk_field.to_representation(value.pk) - else: - pk = value.pk - + pk = self.get_resource_id(value) resource_type = self.get_resource_type_from_included_serializer() if resource_type is None or not self._skip_polymorphic_optimization: resource_type = get_resource_type_from_instance(value) return {"type": resource_type, "id": str(pk)} + def get_resource_id(self, value): + """ + Get resource id of related field. + + Per default pk of value is returned. + """ + return super().to_representation(value) + def get_resource_type_from_included_serializer(self): """ Check to see it this resource has a different resource_name when diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 7263b96b..f7660208 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -443,10 +443,9 @@ def build_json_resource_obj( # Determine type from the instance if the underlying model is polymorphic if force_type_resolution: resource_name = utils.get_resource_type_from_instance(resource_instance) - resource_id = force_str(resource_instance.pk) if resource_instance else None resource_data = { "type": resource_name, - "id": resource_id, + "id": utils.get_resource_id(resource_instance, resource), "attributes": cls.extract_attributes(fields, resource), } relationships = cls.extract_relationships(fields, resource, resource_instance) diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index dab8a3bb..2e57fbbd 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -304,6 +304,19 @@ def get_resource_type_from_serializer(serializer): ) +def get_resource_id(resource_instance, resource): + """Returns the resource identifier for a given instance (`id` takes priority over `pk`).""" + if resource and "id" in resource: + return resource["id"] and encoding.force_str(resource["id"]) or None + if resource_instance: + return ( + hasattr(resource_instance, "pk") + and encoding.force_str(resource_instance.pk) + or None + ) + return None + + def get_included_resources(request, serializer=None): """Build a list of included resources.""" include_resources_param = request.query_params.get("include") if request else None diff --git a/tests/test_relations.py b/tests/test_relations.py index 74721cfa..ad4bebcb 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -10,9 +10,10 @@ HyperlinkedRelatedField, SerializerMethodHyperlinkedRelatedField, ) +from rest_framework_json_api.serializers import ModelSerializer, ResourceRelatedField from rest_framework_json_api.utils import format_link_segment from rest_framework_json_api.views import RelationshipView -from tests.models import BasicModel +from tests.models import BasicModel, ForeignKeySource, ForeignKeyTarget from tests.serializers import ( ForeignKeySourceSerializer, ManyToManySourceReadOnlySerializer, @@ -46,6 +47,28 @@ def test_serialize( assert serializer.data["target"] == expected + def test_get_resource_id(self, foreign_key_target): + class CustomResourceRelatedField(ResourceRelatedField): + def get_resource_id(self, value): + return value.name + + class CustomPkFieldSerializer(ModelSerializer): + target = CustomResourceRelatedField( + queryset=ForeignKeyTarget.objects, pk_field="name" + ) + + class Meta: + model = ForeignKeySource + fields = ("target",) + + serializer = CustomPkFieldSerializer(instance={"target": foreign_key_target}) + expected = { + "type": "ForeignKeyTarget", + "id": "Target", + } + + assert serializer.data["target"] == expected + @pytest.mark.parametrize( "format_type,pluralize_type,resource_type", [ diff --git a/tests/test_utils.py b/tests/test_utils.py index 038e8ce9..f2a3d176 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,6 +14,7 @@ format_resource_type, format_value, get_related_resource_type, + get_resource_id, get_resource_name, get_resource_type_from_serializer, undo_format_field_name, @@ -392,6 +393,22 @@ class SerializerWithoutResourceName(serializers.Serializer): ) +@pytest.mark.parametrize( + "resource_instance, resource, expected", + [ + (None, None, None), + (object(), {}, None), + (BasicModel(id=5), None, "5"), + (BasicModel(id=9), {}, "9"), + (None, {"id": 11}, "11"), + (object(), {"pk": 11}, None), + (BasicModel(id=6), {"id": 11}, "11"), + ], +) +def test_get_resource_id(resource_instance, resource, expected): + assert get_resource_id(resource_instance, resource) == expected + + @pytest.mark.parametrize( "message,pointer,response,result", [ diff --git a/tests/test_views.py b/tests/test_views.py index 42680d6a..47fec02a 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -183,6 +183,50 @@ def test_patch(self, client): } } + @pytest.mark.urls(__name__) + def test_post_with_missing_id(self, client): + data = { + "data": { + "id": None, + "type": "custom", + "attributes": {"body": "hello"}, + } + } + + url = reverse("custom") + + response = client.post(url, data=data) + assert response.status_code == status.HTTP_200_OK + assert response.json() == { + "data": { + "type": "custom", + "id": None, + "attributes": {"body": "hello"}, + } + } + + @pytest.mark.urls(__name__) + def test_patch_with_custom_id(self, client): + data = { + "data": { + "id": 2_193_102, + "type": "custom", + "attributes": {"body": "hello"}, + } + } + + url = reverse("custom-id") + + response = client.patch(url, data=data) + assert response.status_code == status.HTTP_200_OK + assert response.json() == { + "data": { + "type": "custom", + "id": "2176ce", # get_id() -> hex + "attributes": {"body": "hello"}, + } + } + # Routing setup @@ -202,6 +246,14 @@ class CustomModelSerializer(serializers.Serializer): id = serializers.IntegerField() +class CustomIdModelSerializer(serializers.Serializer): + id = serializers.SerializerMethodField() + body = serializers.CharField() + + def get_id(self, obj): + return hex(obj.id)[2:] + + class CustomAPIView(APIView): parser_classes = [JSONParser] renderer_classes = [JSONRenderer] @@ -211,11 +263,26 @@ def patch(self, request, *args, **kwargs): serializer = CustomModelSerializer(CustomModel(request.data)) return Response(status=status.HTTP_200_OK, data=serializer.data) + def post(self, request, *args, **kwargs): + serializer = CustomModelSerializer(request.data) + return Response(status=status.HTTP_200_OK, data=serializer.data) + + +class CustomIdAPIView(APIView): + parser_classes = [JSONParser] + renderer_classes = [JSONRenderer] + resource_name = "custom" + + def patch(self, request, *args, **kwargs): + serializer = CustomIdModelSerializer(CustomModel(request.data)) + return Response(status=status.HTTP_200_OK, data=serializer.data) + router = SimpleRouter() router.register(r"basic_models", BasicModelViewSet, basename="basic-model") urlpatterns = [ path("custom", CustomAPIView.as_view(), name="custom"), + path("custom-id", CustomIdAPIView.as_view(), name="custom-id"), ] urlpatterns += router.urls