From b4975799aaa18905e704182cead4f990aaad896a Mon Sep 17 00:00:00 2001 From: Peder Hovdan Andresen <107681714+pederhan@users.noreply.github.com> Date: Tue, 28 May 2024 14:04:31 +0200 Subject: [PATCH] Handle JSON in Zone POST requests (#538) * Handle JSON in Zone POST requests * Invert content type checks * Refactor nameserver extraction to helper function * Fix using builtin as annotation --- mreg/api/v1/views_zones.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/mreg/api/v1/views_zones.py b/mreg/api/v1/views_zones.py index ac3bac6b..faa7b222 100644 --- a/mreg/api/v1/views_zones.py +++ b/mreg/api/v1/views_zones.py @@ -1,3 +1,4 @@ +from typing import List import django.core.exceptions from django.db import transaction @@ -7,6 +8,7 @@ from rest_framework import (generics, renderers, status) from rest_framework.decorators import (api_view, renderer_classes) from rest_framework.exceptions import ParseError +from rest_framework.request import Request from rest_framework.response import Response from mreg.models.base import NameServer @@ -55,6 +57,13 @@ def _validate_nameservers(names): done.add(name) +def _get_request_nameservers(request: Request, field: str = "primary_ns") -> List[str]: + """Extract nameservers from the request data.""" + if request.content_type == "application/json": + return request.data.get(field, []) + return request.data.getlist(field, []) + + class ZoneList(generics.ListCreateAPIView): """ get: @@ -72,13 +81,13 @@ def get_queryset(self): qs = super().get_queryset() return self.filterset(data=self.request.GET, queryset=qs).qs - def post(self, request, *args, **kwargs): + def post(self, request: Request, *args, **kwargs): qs = self.get_queryset() if qs.filter(name=request.data["name"]).exists(): content = {'ERROR': 'Zone name already in use'} return Response(content, status=status.HTTP_409_CONFLICT) # A copy is required since the original is immutable - nameservers = request.data.getlist('primary_ns') + nameservers = _get_request_nameservers(request) _validate_nameservers(nameservers) data = request.data.copy() data['primary_ns'] = nameservers[0] @@ -121,13 +130,12 @@ def get_queryset(self): self.queryset = self.parentzone.delegations.all().order_by('id') return self.filterset(data=self.request.GET, queryset=self.queryset).qs - def post(self, request, *args, **kwargs): + def post(self, request: Request, *args, **kwargs): qs = self.get_queryset() if qs.filter(name=request.data[self.lookup_field]).exists(): content = {'ERROR': 'Zone name already in use'} return Response(content, status=status.HTTP_409_CONFLICT) - - nameservers = request.data.getlist('nameservers') + nameservers = _get_request_nameservers(request, "nameservers") _validate_nameservers(nameservers) data = request.data.copy() data['zone'] = self.parentzone.pk @@ -292,14 +300,14 @@ def get(self, request, *args, **kwargs): zone = self.get_object() return Response([ns.name for ns in zone.nameservers.all()], status=status.HTTP_200_OK) - def patch(self, request, *args, **kwargs): + def patch(self, request: Request, *args, **kwargs): if 'primary_ns' not in request.data: return Response({'ERROR': 'No nameserver found in body'}, status=status.HTTP_400_BAD_REQUEST) zone = self.get_object() - nameservers = request.data.getlist('primary_ns') + nameservers = _get_request_nameservers(request) _validate_nameservers(nameservers) zone.update_nameservers(nameservers) - zone.primary_ns = request.data.getlist('primary_ns')[0] + zone.primary_ns = nameservers[0] zone.updated = True self.perform_update(zone) return Response(status=status.HTTP_204_NO_CONTENT, headers={'Location': request.path})