Skip to content
This repository has been archived by the owner on Feb 14, 2024. It is now read-only.

Commit

Permalink
Enable model upload via REPL #27
Browse files Browse the repository at this point in the history
  • Loading branch information
SichangHe authored Jul 19, 2023
2 parents 51943fa + d893bcd commit 51bd753
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 1 deletion.
7 changes: 7 additions & 0 deletions backend/train/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,10 @@ class Meta:
class PostServerDataSerializer(serializers.Serializer):
id = serializers.IntegerField()
start_fresh = serializers.BooleanField(required=False, default=False) # type: ignore


# Always change together with `upload` in `fed_kit.py`.
class UploadDataSerializer(serializers.Serializer):
name = serializers.CharField(max_length=256)
layers_sizes = serializers.ListField(child=serializers.IntegerField(min_value=0))
data_type = serializers.CharField(max_length=256)
6 changes: 5 additions & 1 deletion backend/train/urls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from django.urls import path
from train.views import *

urlpatterns = [path("advertised", advertise_model), path("server", request_server)]
urlpatterns = [
path("advertised", advertise_model),
path("server", request_server),
path("upload", upload_file),
]
56 changes: 56 additions & 0 deletions backend/train/views.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import logging
from typing import OrderedDict

from django.core.files.uploadedfile import UploadedFile
from rest_framework import permissions
from rest_framework.decorators import api_view, permission_classes
from rest_framework.request import MultiValueDict
from rest_framework.response import Response
from rest_framework.status import HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND
from rest_framework.views import Request
from train.models import TFLiteModel, TrainingDataType
from train.scheduler import server
from train.serializers import *

from backend.settings import BASE_DIR

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -53,3 +57,55 @@ def request_server(request: Request):
return Response("Model not found", HTTP_404_NOT_FOUND)
response = server(model, data["start_fresh"])
return Response(response.__dict__)


def file_in_request(request: Request):
files = request.FILES
if isinstance(files, MultiValueDict):
file = files.get("file")
if isinstance(file, UploadedFile):
return file


@api_view(["POST"])
@permission_classes((permissions.AllowAny,))
def upload_file(request: Request):
# Deserialize request data.
serializer = UploadDataSerializer(data=request.data) # type: ignore
if not serializer.is_valid():
logger.error(serializer.errors)
return Response(serializer.errors, HTTP_400_BAD_REQUEST)
data: OrderedDict = serializer.validated_data # type: ignore
name = data["name"]
data_type_name = data["data_type"]
# Validate unique file name.
try:
model = TFLiteModel.objects.get(name=data["name"])
return Response("Model name used", HTTP_400_BAD_REQUEST)
except TFLiteModel.DoesNotExist:
pass
# Get model file.
file = file_in_request(request)
if file is None:
return Response("No file in request.", HTTP_400_BAD_REQUEST)
# Get `data_type`.
try:
data_type = TrainingDataType.objects.get(name=data_type_name)
except TrainingDataType.DoesNotExist:
logger.warn(f"upload: Creating new data_type `{data_type_name}`.")
data_type = TrainingDataType(name=data_type_name)
data_type.save()
# Save model file.
path = f"static/{name}--{file.name}" # Guaranteed unique.
with open(BASE_DIR / path, "wb") as fd:
fd.write(file.file.read())
# Save model.
model = TFLiteModel(
name=name,
file_path=f"/{path}",
layers_sizes=data["layers_sizes"],
data_type=data_type,
)
model.save()

return Response("ok")
33 changes: 33 additions & 0 deletions fed_kit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""FedKit backend operations package.
Use this package from an interactive python shell:
```python
import fed_kit
response = fed_kit.upload("test.tflite", "test_model", [100, 200, 300], "test_type")
print(response)
print(response.text)
```
"""
import requests

DEFAULT_URL = "http://localhost:8000/"


# Always change together with `UploadDataSerializer` in `train.serializers`.
def upload(
file: str,
name: str,
layers_sizes: list[int],
data_type: str,
base: str = DEFAULT_URL,
):
"""Upload model `file` and store it as `name` on the backend."""
url = base + "/train/upload"
files = {"file": open(file, "rb")}
data = {
"name": name,
"layers_sizes": layers_sizes,
"data_type": data_type,
}
return requests.post(url, data=data, files=files)

0 comments on commit 51bd753

Please sign in to comment.