Skip to content

Commit

Permalink
Handle JSON in Zone POST requests (#538)
Browse files Browse the repository at this point in the history
* Handle JSON in Zone POST requests

* Invert content type checks

* Refactor nameserver extraction to helper function

* Fix using builtin as annotation
  • Loading branch information
pederhan authored May 28, 2024
1 parent fa6ca20 commit b497579
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions mreg/api/v1/views_zones.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
import django.core.exceptions

from django.db import transaction
Expand All @@ -7,6 +8,7 @@
from rest_framework import (generics, renderers, status)
from rest_framework.decorators import (api_view, renderer_classes)
from rest_framework.exceptions import ParseError
from rest_framework.request import Request
from rest_framework.response import Response

from mreg.models.base import NameServer
Expand Down Expand Up @@ -55,6 +57,13 @@ def _validate_nameservers(names):
done.add(name)


def _get_request_nameservers(request: Request, field: str = "primary_ns") -> List[str]:
"""Extract nameservers from the request data."""
if request.content_type == "application/json":
return request.data.get(field, [])
return request.data.getlist(field, [])


class ZoneList(generics.ListCreateAPIView):
"""
get:
Expand All @@ -72,13 +81,13 @@ def get_queryset(self):
qs = super().get_queryset()
return self.filterset(data=self.request.GET, queryset=qs).qs

def post(self, request, *args, **kwargs):
def post(self, request: Request, *args, **kwargs):
qs = self.get_queryset()
if qs.filter(name=request.data["name"]).exists():
content = {'ERROR': 'Zone name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)
# A copy is required since the original is immutable
nameservers = request.data.getlist('primary_ns')
nameservers = _get_request_nameservers(request)
_validate_nameservers(nameservers)
data = request.data.copy()
data['primary_ns'] = nameservers[0]
Expand Down Expand Up @@ -121,13 +130,12 @@ def get_queryset(self):
self.queryset = self.parentzone.delegations.all().order_by('id')
return self.filterset(data=self.request.GET, queryset=self.queryset).qs

def post(self, request, *args, **kwargs):
def post(self, request: Request, *args, **kwargs):
qs = self.get_queryset()
if qs.filter(name=request.data[self.lookup_field]).exists():
content = {'ERROR': 'Zone name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)

nameservers = request.data.getlist('nameservers')
nameservers = _get_request_nameservers(request, "nameservers")
_validate_nameservers(nameservers)
data = request.data.copy()
data['zone'] = self.parentzone.pk
Expand Down Expand Up @@ -292,14 +300,14 @@ def get(self, request, *args, **kwargs):
zone = self.get_object()
return Response([ns.name for ns in zone.nameservers.all()], status=status.HTTP_200_OK)

def patch(self, request, *args, **kwargs):
def patch(self, request: Request, *args, **kwargs):
if 'primary_ns' not in request.data:
return Response({'ERROR': 'No nameserver found in body'}, status=status.HTTP_400_BAD_REQUEST)
zone = self.get_object()
nameservers = request.data.getlist('primary_ns')
nameservers = _get_request_nameservers(request)
_validate_nameservers(nameservers)
zone.update_nameservers(nameservers)
zone.primary_ns = request.data.getlist('primary_ns')[0]
zone.primary_ns = nameservers[0]
zone.updated = True
self.perform_update(zone)
return Response(status=status.HTTP_204_NO_CONTENT, headers={'Location': request.path})
Expand Down

0 comments on commit b497579

Please sign in to comment.