diff --git a/.gitignore b/.gitignore index 54643682..db7d4292 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ backend/**/__pycache__/** backend/**/migrations/** backend/media backend/data/* +backend/log/* backend/training/* backend/.env backend/config.txt diff --git a/backend/core/models.py b/backend/core/models.py index 4c7a27f4..b4ed7435 100644 --- a/backend/core/models.py +++ b/backend/core/models.py @@ -1,5 +1,6 @@ from django.contrib.gis.db import models as geomodels from django.contrib.postgres.fields import ArrayField +from django.core.validators import MaxValueValidator, MinValueValidator from django.db import models from login.models import OsmUser @@ -84,3 +85,23 @@ class Training(models.Model): accuracy = models.FloatField(null=True, blank=True) epochs = models.PositiveIntegerField() batch_size = models.PositiveIntegerField() + freeze_layers = models.BooleanField(default=False) + + +class Feedback(models.Model): + ACTION_TYPE = ( + ("CREATE", "CREATE"), + ("MODIFY", "MODIFY"), + ("ACCEPT", "ACCEPT"), + ("INITIAL", "INITIAL"), + ) + geom = geomodels.GeometryField(srid=4326) + training = models.ForeignKey(Training, to_field="id", on_delete=models.CASCADE) + created_at = models.DateTimeField(auto_now_add=True) + zoom_level = models.PositiveIntegerField( + validators=[MinValueValidator(18), MaxValueValidator(23)] + ) + action = models.CharField(choices=ACTION_TYPE, max_length=10) + last_modified = models.DateTimeField(auto_now=True) + user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) + validated = models.BooleanField(default=False) diff --git a/backend/core/serializers.py b/backend/core/serializers.py index 3d9cb7bf..35cf792e 100644 --- a/backend/core/serializers.py +++ b/backend/core/serializers.py @@ -62,6 +62,25 @@ class Meta: ) +class FeedbackSerializer(GeoFeatureModelSerializer): + class Meta: + model = Feedback + geo_field = "geom" + fields = "__all__" + read_only_fields = ("created_at", "last_modified", "user") + partial = True + + def create(self, validated_data): + user = self.context["request"].user + validated_data["user"] = user + return super().create(validated_data) + + def to_representation(self, instance): + ret = super().to_representation(instance) + ret["properties"]["id"] = instance.id + return ret + + class LabelSerializer( GeoFeatureModelSerializer ): # serializers are used to translate models objects to api @@ -84,9 +103,14 @@ class Meta: model = Label geo_field = "geom" # this will be used as geometry in order to create geojson api , geofeatureserializer will let you create api in geojson # auto_bbox = True - fields = ( - "osm_id", - ) # defining all the fields to be included in curd for now , we can restrict few if we want + fields = ("osm_id",) + + +class FeedbackFileSerializer(GeoFeatureModelSerializer): + class Meta: + fields = ("training",) + model = Feedback + geo_field = "geom" class ImageDownloadSerializer(serializers.Serializer): @@ -106,6 +130,12 @@ def validate(self, data): return data +class FeedbackParamSerializer(serializers.Serializer): + training_id = serializers.IntegerField(required=True) + epochs = serializers.IntegerField(required=False) + batch_size = serializers.IntegerField(required=False) + freeze_layers = serializers.BooleanField(required=False) + class PredictionParamSerializer(serializers.Serializer): bbox = serializers.ListField(child=serializers.FloatField(), required=True) model_id = serializers.IntegerField(required=True) diff --git a/backend/core/tasks.py b/backend/core/tasks.py index 836d1b1c..8f020b81 100644 --- a/backend/core/tasks.py +++ b/backend/core/tasks.py @@ -10,13 +10,16 @@ import ramp.utils import tensorflow as tf from celery import shared_task -from core.models import AOI, Label, Training -from core.serializers import LabelFileSerializer +from core.models import AOI, Feedback, Label, Training +from core.serializers import FeedbackFileSerializer, LabelFileSerializer from core.utils import bbox, download_imagery, get_start_end_download_coords from django.conf import settings +from django.contrib.gis.db.models.aggregates import Extent +from django.contrib.gis.geos import GEOSGeometry from django.shortcuts import get_object_or_404 from django.utils import timezone from hot_fair_utilities import preprocess, train +from hot_fair_utilities.training import run_feedback logger = logging.getLogger(__name__) @@ -28,9 +31,15 @@ @shared_task def train_model( - dataset_id, training_id, epochs, batch_size, zoom_level, source_imagery + dataset_id, + training_id, + epochs, + batch_size, + zoom_level, + source_imagery, + feedback=None, + freeze_layers=False, ): - training_instance = get_object_or_404(Training, id=training_id) training_instance.status = "RUNNING" training_instance.started_at = timezone.now() @@ -46,12 +55,6 @@ def train_model( with open(log_file, "w") as f: # redirect stdout to the log file sys.stdout = f - try: - aois = AOI.objects.filter(dataset=dataset_id) - except AOI.DoesNotExist: - raise ValueError( - f"No AOI is attached with supplied dataset id:{dataset_id}, Create AOI first", - ) training_input_base_path = os.path.join( settings.TRAINING_WORKSPACE, f"dataset_{dataset_id}" ) @@ -61,16 +64,35 @@ def train_model( if os.path.exists(training_input_image_source): # always build dataset shutil.rmtree(training_input_image_source) os.makedirs(training_input_image_source) - for obj in aois: + if feedback: + feedback_objects = Feedback.objects.filter( + training__id=feedback, + validated=True, + ) + bbox_feedback = feedback_objects.aggregate(Extent("geom"))[ + "geom__extent" + ] + bbox_geo = GEOSGeometry( + f"POLYGON(({bbox_feedback[0]} {bbox_feedback[1]},{bbox_feedback[2]} {bbox_feedback[1]},{bbox_feedback[2]} {bbox_feedback[3]},{bbox_feedback[0]} {bbox_feedback[3]},{bbox_feedback[0]} {bbox_feedback[1]}))" + ) + print(training_input_image_source) + print(bbox_feedback) + with open( + os.path.join(training_input_image_source, "labels_bbox.geojson"), + "w", + encoding="utf-8", + ) as f: + f.write(bbox_geo.geojson) + for z in zoom_level: zm_level = z print( f"""Running Download process for - aoi : {obj.id} - dataset : {dataset_id} , zoom : {zm_level}""" + feedback {training_id} - dataset : {dataset_id} , zoom : {zm_level}""" ) try: tile_size = DEFAULT_TILE_SIZE # by default - bbox_coords = bbox(obj.geom.coords[0]) + bbox_coords = list(bbox_feedback) start, end = get_start_end_download_coords( bbox_coords, zm_level, tile_size ) @@ -85,19 +107,58 @@ def train_model( except Exception as ex: raise ex + else: + try: + aois = AOI.objects.filter(dataset=dataset_id) + except AOI.DoesNotExist: + raise ValueError( + f"No AOI is attached with supplied dataset id:{dataset_id}, Create AOI first", + ) + + for obj in aois: + bbox_coords = bbox(obj.geom.coords[0]) + for z in zoom_level: + zm_level = z + print( + f"""Running Download process for + aoi : {obj.id} - dataset : {dataset_id} , zoom : {zm_level}""" + ) + try: + tile_size = DEFAULT_TILE_SIZE # by default + + start, end = get_start_end_download_coords( + bbox_coords, zm_level, tile_size + ) + # start downloading + download_imagery( + start, + end, + zm_level, + base_path=training_input_image_source, + source=source_imagery, + ) + except Exception as ex: + raise ex + ## -----------LABEL GENERATOR--------- - aoi_list = [r.id for r in aois] - label = Label.objects.filter(aoi__in=aoi_list).values() - serialized_field = LabelFileSerializer(data=list(label), many=True) + logging.debug("Label Generator started") + if feedback: + feedback_objects = Feedback.objects.filter( + training__id=feedback, + validated=True, + ) + serialized_field = FeedbackFileSerializer(feedback_objects, many=True) + else: + aoi_list = [r.id for r in aois] + label = Label.objects.filter(aoi__in=aoi_list) + serialized_field = LabelFileSerializer(label, many=True) - if serialized_field.is_valid(raise_exception=True): - with open( - os.path.join(training_input_image_source, "labels.geojson"), - "w", - encoding="utf-8", - ) as f: - f.write(json.dumps(serialized_field.data)) - f.close() + with open( + os.path.join(training_input_image_source, "labels.geojson"), + "w", + encoding="utf-8", + ) as f: + f.write(json.dumps(serialized_field.data)) ## --------- Data Preparation ---------- base_path = os.path.join(settings.RAMP_HOME, "ramp-data", str(dataset_id)) @@ -130,14 +191,32 @@ def train_model( # train train_output = f"{base_path}/train" - final_accuracy, final_model_path = train( - input_path=preprocess_output, - output_path=train_output, - epoch_size=epochs, - batch_size=batch_size, - model="ramp", - model_home=os.environ["RAMP_HOME"], - ) + if feedback: + final_accuracy, final_model_path = run_feedback( + input_path=preprocess_output, + output_path=train_output, + feedback_base_model=os.path.join( + settings.TRAINING_WORKSPACE, + f"dataset_{dataset_id}", + "output", + f"training_{feedback}", + "checkpoint.tf", + ), + model_home=os.environ["RAMP_HOME"], + epoch_size=epochs, + batch_size=batch_size, + freeze_layers=freeze_layers, + ) + else: + final_accuracy, final_model_path = train( + input_path=preprocess_output, + output_path=train_output, + epoch_size=epochs, + batch_size=batch_size, + model="ramp", + model_home=os.environ["RAMP_HOME"], + freeze_layers=freeze_layers, + ) # copy final model to output output_path = os.path.join( diff --git a/backend/core/urls.py b/backend/core/urls.py index 4c17d37e..8abbea60 100644 --- a/backend/core/urls.py +++ b/backend/core/urls.py @@ -7,6 +7,8 @@ AOIViewSet, APIStatus, DatasetViewSet, + FeedbackView, + FeedbackViewset, GenerateGpxView, LabelViewSet, ModelViewSet, @@ -28,6 +30,7 @@ router.register(r"label", LabelViewSet) router.register(r"training", TrainingViewSet) router.register(r"model", ModelViewSet) +router.register(r"feedback", FeedbackViewset) urlpatterns = [ @@ -37,6 +40,7 @@ path("training/status//", run_task_status), path("training/publish//", publish_training), path("prediction/", PredictionView.as_view()), + path("apply/feedback/", FeedbackView.as_view()), path("status/", APIStatus.as_view()), path("geojson2osm/", geojson2osmconverter, name="geojson2osmconverter"), path("aoi/gpx//", GenerateGpxView.as_view()), diff --git a/backend/core/views.py b/backend/core/views.py index d1c3bc72..b6c9e5bb 100644 --- a/backend/core/views.py +++ b/backend/core/views.py @@ -38,10 +38,13 @@ from rest_framework.views import APIView from rest_framework_gis.filters import InBBoxFilter, TMSTileFilter -from .models import AOI, Dataset, Label, Model, Training +from .models import AOI, Dataset, Feedback, Label, Model, Training from .serializers import ( AOISerializer, DatasetSerializer, + FeedbackFileSerializer, + FeedbackParamSerializer, + FeedbackSerializer, LabelSerializer, ModelSerializer, PredictionParamSerializer, @@ -118,6 +121,7 @@ def create(self, validated_data): zoom_level=instance.zoom_level, source_imagery=instance.source_imagery or instance.model.dataset.source_imagery, + freeze_layers=instance.freeze_layers, ) if not instance.source_imagery: instance.source_imagery = instance.model.dataset.source_imagery @@ -139,6 +143,16 @@ class TrainingViewSet( filterset_fields = ["model", "status"] +class FeedbackViewset(viewsets.ModelViewSet): + authentication_classes = [OsmAuthentication] + permission_classes = [IsOsmAuthenticated] + permission_allowed_methods = ["GET"] + queryset = Feedback.objects.all() + http_method_names = ["get", "post", "patch", "delete"] + serializer_class = FeedbackSerializer # connecting serializer + filterset_fields = ["training", "user", "action", "validated"] + + class ModelViewSet( viewsets.ModelViewSet ): # This is ModelViewSet , will be tightly coupled with the models @@ -305,7 +319,56 @@ def run_task_status(request, run_id: str): return Response(result) -import multiprocessing +class FeedbackView(APIView): + authentication_classes = [OsmAuthentication] + permission_classes = [IsOsmAuthenticated] + + @swagger_auto_schema( + request_body=FeedbackParamSerializer, responses={status.HTTP_200_OK: "ok"} + ) + def post(self, request, *args, **kwargs): + res_serializer = FeedbackParamSerializer(data=request.data) + if res_serializer.is_valid(raise_exception=True): + deserialized_data = res_serializer.data + training_id = deserialized_data["training_id"] + training_instance = Training.objects.get(id=training_id) + + unique_zoom_levels = ( + Feedback.objects.filter(training__id=training_id, validated=True) + .values("zoom_level") + .distinct() + ) + zoom_level = [z["zoom_level"] for z in unique_zoom_levels] + epochs = deserialized_data.get("epochs", 20) + batch_size = deserialized_data.get("batch_size", 8) + instance = Training.objects.create( + model=training_instance.model, + status="SUBMITTED", + description=f"Feedback of Training {training_id}", + created_by=self.request.user, + zoom_level=zoom_level, + epochs=epochs, + batch_size=batch_size, + source_imagery=training_instance.source_imagery, + ) + + task = train_model.delay( + dataset_id=instance.model.dataset.id, + training_id=instance.id, + epochs=instance.epochs, + batch_size=instance.batch_size, + zoom_level=instance.zoom_level, + source_imagery=instance.source_imagery, + feedback=training_id, + freeze_layers=instance.freeze_layers, + ) + if not instance.source_imagery: + instance.source_imagery = instance.model.dataset.source_imagery + instance.task_id = task.id + instance.save() + print(f"Saved Feedback train model request to queue with id {task.id}") + return HttpResponse(status=200) + DEFAULT_TILE_SIZE = 256 @@ -421,6 +484,11 @@ def post(self, request, *args, **kwargs): ## TODO : can send osm xml format from here as well using geojson2osm return Response(geojson_data, status=status.HTTP_201_CREATED) + except ValueError as e: + if str(e) == "No Features Found": + return Response("No features found", status=204) + else: + return Response(str(e), status=500) except Exception as ex: print(ex) shutil.rmtree(temp_path) diff --git a/frontend/package.json b/frontend/package.json index 38faf73c..4604544b 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -8,6 +8,7 @@ "dependencies": { "@emotion/react": "^11.9.0", "@emotion/styled": "^11.8.1", + "@geoman-io/leaflet-geoman-free": "^2.14.2", "@id-sdk/util": "^3.0.0-pre.10", "@material-ui/core": "^4.12.4", "@material-ui/icons": "^4.11.3", @@ -15,6 +16,7 @@ "@mui/icons-material": "^5.6.1", "@mui/lab": "^5.0.0-alpha.77", "@mui/material": "^5.6.1", + "@mui/styles": "^5.12.0", "@mui/x-data-grid": "^5.17.12", "@testing-library/jest-dom": "^5.16.4", "@testing-library/react": "^12.1.4", @@ -23,15 +25,17 @@ "axios": "^0.26.1", "leaflet": "^1.7.1", "leaflet-draw": "^1.0.4", + "ol": "^7.3.0", "react": "^17.0.1", "react-dom": "^17.0.1", - "react-https-redirect": "^1.0.5", "react-dotenv": "^0.1.3", + "react-https-redirect": "^1.0.5", "react-leaflet": "3.2.1", "react-leaflet-draw": "0.19.8", "react-query": "^3.34.19", "react-router-dom": "^6.4.3", "react-scripts": "5.0.0", + "react-toastify": "^9.1.2", "uninstall": "^0.0.0", "web-vitals": "^2.1.4" }, @@ -59,4 +63,4 @@ "last 1 safari version" ] } -} \ No newline at end of file +} diff --git a/frontend/src/components/Layout/AIModels/AIModelEditor/AIModelEditor.js b/frontend/src/components/Layout/AIModels/AIModelEditor/AIModelEditor.js index d5a535a9..fd7d0af5 100644 --- a/frontend/src/components/Layout/AIModels/AIModelEditor/AIModelEditor.js +++ b/frontend/src/components/Layout/AIModels/AIModelEditor/AIModelEditor.js @@ -6,7 +6,7 @@ import { TextField, Typography, } from "@mui/material"; -import React, { useContext, useState } from "react"; +import React, { useContext, useState, useEffect } from "react"; import { useNavigate, useParams } from "react-router-dom"; import { modelStatus } from "../../../../utils"; import axios from "../../../../axios"; @@ -21,18 +21,27 @@ import { FormControl, FormLabel } from "@material-ui/core"; import AuthContext from "../../../../Context/AuthContext"; import Trainings from "./Trainings"; import DatasetCurrent from "./DatasetCurrent"; +import FeedbackToast from "./FeedbackToast"; +import FeedbackPopup from "./FeedbackPopup"; +import FormGroup from "@mui/material/FormGroup"; + const AIModelEditor = (props) => { let { id } = useParams(); const [error, setError] = useState(null); const [epochs, setEpochs] = useState(20); - const [zoomLevel, setZoomLevel] = useState([19]); + const [zoomLevel, setZoomLevel] = useState([19, 20]); const [popupOpen, setPopupOpen] = useState(false); - const [popupRowData, setPopupRowData] = useState(null); + const [sourceImagery, setSourceImagery] = React.useState(null); + const [freezeLayers, setFreezeLayers] = useState(false); + const [popupRowData, setPopupRowData] = useState(null); + const [feedbackCount, setFeedbackCount] = useState(0); + const [feedbackData, setFeedbackData] = useState(null); const [random, setRandom] = useState(Math.random()); const [batchSize, setBatchSize] = useState(8); const [description, setDescription] = useState(""); + const [feedbackPopupOpen, setFeedbackPopupOpen] = React.useState(false); const { accessToken } = useContext(AuthContext); const zoomLevels = [19, 20, 21]; const getModelById = async () => { @@ -66,12 +75,43 @@ const AIModelEditor = (props) => { const { data, isLoading, refetch } = useQuery("getModelById", getModelById, { refetchInterval: 60000, }); + const getFeedbackCount = async () => { + try { + const response = await axios.get( + `/feedback/?training=${data.published_training}` + ); + setFeedbackData(response.data); + const feedbackCount = response.data.features.length; + setFeedbackCount(feedbackCount); + } catch (error) { + console.error("Error fetching feedback information:", error); + } + }; + useEffect(() => { + if (data?.published_training) { + getFeedbackCount(); + } + }, [data]); + + const handleFeedbackClick = async (trainingId) => { + getFeedbackCount(); + if (sourceImagery === null) { + try { + const response = await axios.get(`/training/${trainingId}/`); + setSourceImagery(response.data.source_imagery); + } catch (error) { + console.error(error); + } + } + setFeedbackPopupOpen(true); + }; const saveTraining = async () => { try { const body = { epochs: epochs, batch_size: batchSize, + freeze_layers:freezeLayers, model: id, zoom_level: zoomLevel, description: description, @@ -84,8 +124,8 @@ const AIModelEditor = (props) => { if (res.error) { setError( res.error.response.statusText + - " / " + - JSON.stringify(res.error.response.data) + " / " + + JSON.stringify(res.error.response.data) ); return; } @@ -107,6 +147,13 @@ const AIModelEditor = (props) => { {data && ( + {data.published_training && ( + + )} Model ID: {data.id} @@ -217,33 +264,36 @@ const AIModelEditor = (props) => { /> - + Zoom Levels - {zoomLevels.map((level) => ( - { - if (e.target.checked) { - console.log(e.target.value); - console.log(level); - setZoomLevel([...zoomLevel, level]); - } else { - setZoomLevel(zoomLevel.filter((l) => l !== level)); - } - }} - name={`zoom-level-${level}`} - /> - } - label={`Zoom ${level}`} - /> - ))} + + {zoomLevels.map((level) => ( + { + if (e.target.checked) { + setZoomLevel([...zoomLevel, level]); + } else { + setZoomLevel(zoomLevel.filter((l) => l !== level)); + } + }} + name={`zoom-level-${level}`} + /> + } + label={`Zoom ${level}`} + /> + ))} + + + + { }} /> + + + Freeze Layers + + setFreezeLayers(e.target.checked)} + name="freeze-layers" + /> + } + label="Freeze Layers" + /> + + + - - + + + + + + + {error && ( @@ -302,6 +387,15 @@ const AIModelEditor = (props) => { row={popupRowData} /> )} + {feedbackPopupOpen && data.published_training && ( + setFeedbackPopupOpen(false)} + /> + )} ); }; diff --git a/frontend/src/components/Layout/AIModels/AIModelEditor/DatasetCurrent.js b/frontend/src/components/Layout/AIModels/AIModelEditor/DatasetCurrent.js index 5aca4ac2..91d16529 100644 --- a/frontend/src/components/Layout/AIModels/AIModelEditor/DatasetCurrent.js +++ b/frontend/src/components/Layout/AIModels/AIModelEditor/DatasetCurrent.js @@ -13,7 +13,6 @@ const DatasetCurrent = (props) => { if (res.error) { // setMapError(res.error.response.statusText); } else { - console.log("DatasetCurrent ", res.data); return res.data; } } catch (e) { diff --git a/frontend/src/components/Layout/AIModels/AIModelEditor/FeedbackMap.js b/frontend/src/components/Layout/AIModels/AIModelEditor/FeedbackMap.js new file mode 100644 index 00000000..050489bb --- /dev/null +++ b/frontend/src/components/Layout/AIModels/AIModelEditor/FeedbackMap.js @@ -0,0 +1,127 @@ +import React, { useState, useEffect, useContext } from "react"; +import { MapContainer, TileLayer, GeoJSON, useMap } from "react-leaflet"; +import "leaflet/dist/leaflet.css"; +import L from "leaflet"; +import axios from "../../../../axios"; +import AuthContext from "../../../../Context/AuthContext"; + +const FeedbackMap = ({ feedbackData, sourceImagery }) => { + const { accessToken } = useContext(AuthContext); + const [mapData, setMapData] = useState(feedbackData); + + useEffect(() => { + if (feedbackData?.features?.length > 0) { + setMapData(feedbackData); + } + }, [feedbackData]); + + useEffect(() => { + if (mapData?.features?.length > 0) { + const geoJSON = new L.GeoJSON(mapData, { + onEachFeature: (feature, layer) => { + const validated = feature.properties.validated || false; + const color = validated ? "green" : "red"; + layer.setStyle({ + color: color, + }); + layer.bindPopup(` + Action: ${feature.properties.action}
+ Created at: ${new Date( + feature.properties.created_at + ).toLocaleString()}
+ + + `); + + layer.on("popupopen", () => { + const validateButtonElement = document.getElementById( + `validate-${feature.properties.id}` + ); + validateButtonElement.addEventListener("click", () => { + const id = feature.properties.id; + const newValidated = !validated; + axios + .patch( + `/feedback/${id}/`, + { validated: newValidated }, + { headers: { "access-token": accessToken } } + ) + .then(() => { + feature.properties.validated = newValidated; + layer.setStyle({ color: newValidated ? "green" : "red" }); + validateButtonElement.innerHTML = newValidated + ? "Invalidate" + : "Validate"; + }) + .catch((error) => console.error(error)); + }); + + const deleteButtonElement = document.getElementById( + `discard-${feature.properties.id}` + ); + deleteButtonElement.addEventListener("click", () => { + const id = feature.properties.id; + axios + .delete(`/feedback/${id}/`, { + headers: { "access-token": accessToken }, + }) + .then(() => { + const filteredFeatures = mapData.features.filter( + (f) => f.properties.id !== id + ); + const newMapData = { + type: "FeatureCollection", + features: filteredFeatures, + }; + setMapData(newMapData); + }) + .catch((error) => console.error(error)); + }); + }); + }, + style: (feature) => ({ + color: feature.properties.validated ? "green" : "red", + }), + }); + if (geoJSONLayer) { + geoJSONLayer.remove(); + } + + setGeoJSONLayer(geoJSON); + } + }, [mapData]); + + const [geoJSONLayer, setGeoJSONLayer] = useState(null); + const ChangeMapView = ({ geoJSONLayer }) => { + const map = useMap(); + useEffect(() => { + if (geoJSONLayer) { + const bounds = geoJSONLayer.getBounds(); + if (bounds.isValid()) { + map.fitBounds(bounds); + } + geoJSONLayer.addTo(map); + } + }, [map, geoJSONLayer]); + return null; + }; + + return ( + + + + + ); +}; + +export default FeedbackMap; diff --git a/frontend/src/components/Layout/AIModels/AIModelEditor/FeedbackPopup.js b/frontend/src/components/Layout/AIModels/AIModelEditor/FeedbackPopup.js new file mode 100644 index 00000000..89fb4e15 --- /dev/null +++ b/frontend/src/components/Layout/AIModels/AIModelEditor/FeedbackPopup.js @@ -0,0 +1,260 @@ +import React, { useState, useContext } from "react"; +import { makeStyles } from "@mui/styles"; +import Dialog from "@mui/material/Dialog"; +import DialogContent from "@mui/material/DialogContent"; +import DialogTitle from "@mui/material/DialogTitle"; +import FeedbackMap from "./FeedbackMap"; +import LoadingButton from "@mui/lab/LoadingButton"; +import axios from "../../../../axios"; +import AuthContext from "../../../../Context/AuthContext"; +import Typography from "@mui/material/Typography"; +import Grid from "@mui/material/Grid"; +import TextField from "@mui/material/TextField"; +import { FormControl, FormLabel } from "@material-ui/core"; +import FormGroup from "@mui/material/FormGroup"; +import { Checkbox, FormControlLabel } from "@mui/material"; + + +const useStyles = makeStyles((theme) => ({ + content: { + padding: theme.spacing(2), + }, +})); + +const FeedbackPopup = ({ + feedbackData, + onClose, + sourceImagery, + trainingId, +}) => { + const classes = useStyles(); + const [open, setOpen] = useState(true); + const [loading, setLoading] = useState(false); + const [freezeLayers, setFreezeLayers] = useState(false); + + const { accessToken } = useContext(AuthContext); + const actionCounts = { + CREATE: 0, + MODIFY: 0, + ACCEPT: 0, + }; + const [epochs, setEpochs] = useState(2); + const [batchSize, setBatchSize] = useState(1); + + feedbackData.features.forEach((feature) => { + switch (feature.properties.action) { + case "CREATE": + actionCounts.CREATE++; + break; + case "MODIFY": + actionCounts.MODIFY++; + break; + case "ACCEPT": + actionCounts.ACCEPT++; + break; + default: + break; + } + }); + + const handleClose = () => { + setOpen(false); + onClose(); + }; + const handleApplyFeedback = () => { + setLoading(true); + axios + .post( + "/apply/feedback/", + { training_id: trainingId, epochs: epochs, batch_size: batchSize, freeze_layers: freezeLayers }, + { headers: { "access-token": accessToken } } + ) + .then((response) => { + console.log(response.data); + setLoading(false); + handleClose(); + }) + .catch((error) => { + console.error(error); + setLoading(false); + }); + }; + + return ( + + + {" "} + Published Model Feedbacks + + + + + + Feedback Information + + + Total Feedbacks: {feedbackData.features.length} + + + Validated:{" "} + { + feedbackData.features.filter( + (feature) => feature.properties.validated + ).length + } + + + Need Validation:{" "} + { + feedbackData.features.filter( + (feature) => !feature.properties.validated + ).length + } + + + + + Feedback Summary + + + Created / Modified : {actionCounts.CREATE} / {actionCounts.MODIFY} + + + Accepted: {actionCounts.ACCEPT} + + !feature.properties.validated + ).length < 1 + } + onClick={() => { + setLoading(true); + const feedbackIds = feedbackData.features.map( + (feature) => feature.properties.id + ); + Promise.all( + feedbackIds.map((id) => + axios.patch( + `/feedback/${id}/`, + { validated: true }, + { headers: { "access-token": accessToken } } + ) + ) + ) + .then(() => { + setLoading(false); + console.log("All feedback validated successfully!"); + }) + .catch((error) => { + setLoading(false); + console.error(error); + }); + }} + loading={loading} + > + Validate All + + !feature.properties.validated + ).length < 1 + } + onClick={() => { + setLoading(true); + const feedbackIds = feedbackData.features.map( + (feature) => feature.properties.id + ); + Promise.all( + feedbackIds.map((id) => + axios.delete(`/feedback/${id}/`, { + headers: { "access-token": accessToken }, + }) + ) + ) + .then(() => { + setLoading(false); + console.log("All feedback deleted successfully!"); + }) + .catch((error) => { + setLoading(false); + console.error(error); + }); + }} + loading={loading} + > + Discard All + + + + + setEpochs(Math.max(0, parseInt(e.target.value)))} + inputProps={{ min: 1, step: 1 }} + fullWidth + margin="normal" + /> + setBatchSize(Math.max(0, parseInt(e.target.value)))} + inputProps={{ min: 1, step: 1 }} + fullWidth + margin="normal" + /> + + + Freeze Layers + + setFreezeLayers(e.target.checked)} + name="freeze-layers" + /> + } + label="Freeze Layers" + /> + + + + + feature.properties.validated + ).length < 5 + } + > + Apply Validated Feedback to Model + + + + ); +}; + +export default FeedbackPopup; diff --git a/frontend/src/components/Layout/AIModels/AIModelEditor/FeedbackToast.js b/frontend/src/components/Layout/AIModels/AIModelEditor/FeedbackToast.js new file mode 100644 index 00000000..c767956e --- /dev/null +++ b/frontend/src/components/Layout/AIModels/AIModelEditor/FeedbackToast.js @@ -0,0 +1,89 @@ +import React from "react"; +import { makeStyles } from "@mui/styles"; +import Snackbar from "@mui/material/Snackbar"; +import IconButton from "@mui/material/IconButton"; +import CloseIcon from "@mui/icons-material/Close"; +import FeedbackPopup from "./FeedbackPopup"; +import { useNavigate } from "react-router-dom"; +import axios from "../../../../axios"; + +const useStyles = makeStyles((theme) => ({ + close: { + padding: theme.spacing(0.5), + }, +})); + +const FeedbackToast = ({ count, feedbackData, trainingId }) => { + const classes = useStyles(); + const navigate = useNavigate(); + const [open, setOpen] = React.useState(false); + const [popupOpen, setPopupOpen] = React.useState(false); + const [sourceImagery, setSourceImagery] = React.useState(null); + + const handleClick = async () => { + if (sourceImagery === null) { + try { + const response = await axios.get(`/training/${trainingId}/`); + setSourceImagery(response.data.source_imagery); + } catch (error) { + console.error(error); + } + } + setOpen(false); + setPopupOpen(true); + }; + + const handleClose = (event, reason) => { + if (reason === "clickaway") { + return; + } + + setOpen(false); + }; + + React.useEffect(() => { + if (count > 0) { + setOpen(true); + } + }, [count]); + + return ( + <> + + + + + + } + onClick={handleClick} + /> + {popupOpen && ( + setPopupOpen(false)} + /> + )} + + ); +}; + +export default FeedbackToast; diff --git a/frontend/src/components/Layout/AIModels/AIModelEditor/Popup.js b/frontend/src/components/Layout/AIModels/AIModelEditor/Popup.js index 5c93c4af..01f66c0a 100644 --- a/frontend/src/components/Layout/AIModels/AIModelEditor/Popup.js +++ b/frontend/src/components/Layout/AIModels/AIModelEditor/Popup.js @@ -104,7 +104,14 @@ const Popup = ({ open, handleClose, row }) => { useEffect(() => { setLoading(true); if (row.status === "FAILED" || row.status === "RUNNING") { + // Call getTrainingStatus every 3 seconds + const intervalId = setInterval(() => { + getTrainingStatus(row.task_id); + }, 3000); + getTrainingStatus(row.task_id).finally(() => setLoading(false)); + + return () => clearInterval(intervalId); } else if (row.status === "FINISHED") { getDatasetId(row.model).finally(() => setLoading(false)); } else { @@ -143,6 +150,9 @@ const Popup = ({ open, handleClose, row }) => {

Status: {row.status}

+

+ Freeze Layers: {row.freeze_layers} +

+
+ `; + const popup = L.popup() + .setLatLng(e.latlng) + .setContent(popupContent) + .openOn(e.target._map); + const popupElement = popup.getElement(); + popupElement + .querySelector("#rightButton") + .addEventListener("click", () => { + feature.properties.action = "ACCEPT"; + const bounds = layer.getBounds(); + const corners = [bounds.getSouthWest(), bounds.getNorthEast()]; + for (const corner of corners) { + const [tileX, tileY] = deg2tile( + corner.lat, + corner.lng, + predictionZoomlevel + ); + addTileBoundaryLayer( + mapref, + addedTiles, + tileX, + tileY, + predictionZoomlevel, + setAddedTiles + ); + } + }); + } + }); + }; + + useEffect(() => { + const map = mapref; + + map.pm.addControls({ + position: "topleft", + drawMarker: false, + drawPolygon: true, + drawCircleMarker: false, + drawCircle: false, + drawPolyline: false, + drawText: false, + editMode: true, + dragMode: true, + cutPolygon: false, + tooltips: true, + removalMode: true, + oneBlock: true, + allowSelfIntersection: false, + }); + map.on("pm:create", onPMCreate); + }, [mapref]); + + return ( + + ); +}; + +export default EditableGeoJSON; diff --git a/frontend/src/components/Layout/Start/Prediction/Prediction.js b/frontend/src/components/Layout/Start/Prediction/Prediction.js index 259329f7..36e7aa7b 100644 --- a/frontend/src/components/Layout/Start/Prediction/Prediction.js +++ b/frontend/src/components/Layout/Start/Prediction/Prediction.js @@ -13,6 +13,7 @@ import { InputLabel, Tooltip, MenuItem, + Link, Select, } from "@mui/material"; @@ -24,22 +25,42 @@ import { TileLayer, useMapEvents, } from "react-leaflet"; +import L from "leaflet"; import { useMutation, useQuery } from "react-query"; -import { useParams } from "react-router-dom"; +import { useNavigate, useParams } from "react-router-dom"; import axios from "../../../../axios"; import AuthContext from "../../../../Context/AuthContext"; -import { GeoJSON } from "react-leaflet"; +import Snackbar from "@mui/material/Snackbar"; + +import "@geoman-io/leaflet-geoman-free"; +import "@geoman-io/leaflet-geoman-free/dist/leaflet-geoman.css"; +import EditableGeoJSON from "./EditableGeoJSON"; const Prediction = () => { const { id } = useParams(); + const [feedbackSubmitted, setFeedbackSubmitted] = useState(false); + const [snackbarOpen, setSnackbarOpen] = useState(false); + const [predictions, setPredictions] = useState(null); + const [feedbackSubmittedCount, setFeedbackSubmittedCount] = useState(0); + const [addedTiles, setAddedTiles] = useState(new Set()); + + const handleCloseSnackbar = () => { + setSnackbarOpen(false); + }; + let tileBoundaryLayer = null; const [error, setError] = useState(false); const [josmLoading, setJosmLoading] = useState(false); + const [feedbackLoading, setFeedbackLoading] = useState(false); + const [predictionZoomlevel, setpredictionZoomlevel] = useState(null); const [apiCallInProgress, setApiCallInProgress] = useState(false); - const [confidence, setConfidence] = useState(50); - + const [confidence, setConfidence] = useState(90); + const [totalPredictionsCount, settotalPredictionsCount] = useState(0); + const [DeletedCount, setDeletedCount] = useState(0); + const [CreatedCount, setCreatedCount] = useState(0); + const [ModifiedCount, setModifiedCount] = useState(0); const [map, setMap] = useState(null); - const [zoom, setZoom] = useState(0); + const [zoom, setZoom] = useState(15); const [responseTime, setResponseTime] = useState(0); const [bounds, setBounds] = useState({}); @@ -168,11 +189,12 @@ const Prediction = () => { const { mutate: callPredict, - data: predictions, + data: returnedpredictions, isLoading: predictionLoading, } = useMutation(async () => { setApiCallInProgress(true); setResponseTime(0); + const headers = { "access-token": accessToken, }; @@ -193,6 +215,11 @@ const Prediction = () => { const endTime = new Date().getTime(); // measure end time setResponseTime(((endTime - startTime) / 1000).toFixed(0)); // calculate and store response time in seconds setApiCallInProgress(false); + if (res.status === 204) { + // Add this if statement + setError("No features found on requested bbox"); + return; + } if (res.error) { setError( `${res.error.response.statusText}, ${JSON.stringify( @@ -201,9 +228,56 @@ const Prediction = () => { ); return; } - return res.data; + setpredictionZoomlevel(zoom); + const updatedPredictions = addIdsToPredictions(res.data); + setPredictions(updatedPredictions); + settotalPredictionsCount(updatedPredictions.features.length); + setCreatedCount(0); + setModifiedCount(0); + setDeletedCount(0); + if (addedTiles.size > 0) { + console.log("Map has tileboundarylayer"); + } + return updatedPredictions; }); + const handleSubmitFeedback = async () => { + setFeedbackLoading(true); + console.log(predictions.features.length); + let count = 0; + try { + for (let i = 0; i < predictions.features.length; i++) { + console.log(predictions.features[i]); + const { geometry } = predictions.features[i]; + const { action } = predictions.features[i].properties; + if (action !== "INITIAL") { + // Add this if statement + const body = { + geom: geometry, + action, + training: modelInfo.trainingId, + zoom_level: predictionZoomlevel, + }; + console.log(body); + const headers = { + "access-token": accessToken, + Authorization: `Bearer ${accessToken}`, + }; + await axios.post("/feedback/", body, { headers }); + count = count + 1; + } + } + setFeedbackLoading(false); + setFeedbackSubmittedCount(count); + setSnackbarOpen(true); + setFeedbackSubmitted(true); + } catch (error) { + console.error(error); + setFeedbackLoading(false); + window.alert("An error occurred while submitting feedback."); + } + }; + async function openWithJosm() { setJosmLoading(true); if (!predictions) { @@ -211,9 +285,21 @@ const Prediction = () => { return; } + // Remove the "id" and "featuretype" properties from each feature in the "features" array + const modifiedPredictions = { + ...predictions, + features: predictions.features.map((feature) => { + const { id, action, ...newProps } = feature.properties; + return { + ...feature, + properties: newProps, + }; + }), + }; + try { const response = await axios.post("/geojson2osm/", { - geojson: predictions, + geojson: modifiedPredictions, }); if (response.status === 200) { const osmUrl = new URL("http://127.0.0.1:8111/load_data"); @@ -237,7 +323,7 @@ const Prediction = () => { setError("OSM XML conversion failed"); } } catch (error) { - setError(error.message); + setError("Couldn't Open JOSM , Check if JOSM is Open"); } finally { setJosmLoading(false); } @@ -256,6 +342,24 @@ const Prediction = () => { return null; } + function addIdsToPredictions(predictions) { + const features = predictions.features.map((feature, index) => { + return { + ...feature, + properties: { + ...feature.properties, + id: index, + action: "INITIAL", + }, + }; + }); + + settotalPredictionsCount(features.length); + + return { ...predictions, features }; + } + const navigate = useNavigate(); + return ( <> @@ -272,37 +376,28 @@ const Prediction = () => { whenCreated={setMap} > - - {/* code for TileLayer components */} - {oamImagery && dataset && ( - - - - )} - + {oamImagery && dataset && ( + + )} {predictions && ( - ({ - color: "darkred", - weight: 6, - fillPattern: { - weight: 1, - opacity: 1, - pattern: /\/\/\/\//, - strokeOpacity: 0.5, - strokeWeight: 1, - strokeColor: "red", - }, - })} + setPredictions={setPredictions} + mapref={map} + predictionZoomlevel={predictionZoomlevel} + addedTiles={addedTiles} + setAddedTiles={setAddedTiles} + setCreatedCount={setCreatedCount} + setModifiedCount={setModifiedCount} + setDeletedCount={setDeletedCount} + tileBoundaryLayer={tileBoundaryLayer} /> )} @@ -312,7 +407,7 @@ const Prediction = () => { 22} + disabled={zoom < 19 || !zoom || zoom > 22} loading={predictionLoading} onClick={() => { setError(false); @@ -321,77 +416,130 @@ const Prediction = () => { > Run Prediction - - - - Confidence: - - - - - - {map && ( - - Current Zoom: {JSON.stringify(zoom)} - - - Response: {responseTime} sec - - + + + + + Confidence: + + + + + + + + Current Zoom: {JSON.stringify(zoom)} + + {predictions && ( + + Predicted on: {predictionZoomlevel} Zoom + + )} + + Response: {responseTime} sec + + + {predictions && ( + + + Feedback + + + Initial Predictions: + {totalPredictionsCount} + + {CreatedCount > 0 && ( + + Total Created: + {CreatedCount} + + )} + {ModifiedCount > 0 && ( + + Total Modified: + {ModifiedCount} + + )} + {DeletedCount > 0 && ( + + Total Deleted: + {DeletedCount} + + )} + {CreatedCount + ModifiedCount + DeletedCount > 1 && + !feedbackSubmitted && ( + + Submit my feedback + + )} + + )} {loading ? ( ) : ( modelInfo && ( - - + + Loaded Model - - Model ID: {modelInfo.id} - - - Model Name: {modelInfo.name} + + { + e.preventDefault(); + navigate("/ai-models/" + modelInfo.id); + }} + color="inherit" + > + ID: {modelInfo.id} + + + Name: {modelInfo.name} - - Model Last Modified:{" "} + + Last Modified:{" "} {new Date(modelInfo.lastModified).toLocaleString()} - + Published Training - - Training ID: {modelInfo.trainingId} + + ID: {modelInfo.trainingId} - - Training Description:{" "} - {modelInfo.trainingDescription} + + Description: {modelInfo.trainingDescription} - - Training Zoom Level:{" "} - {modelInfo.trainingZoomLevel} + + Zoom Level: {modelInfo.trainingZoomLevel} - - Training Accuracy:{" "} - {modelInfo.trainingAccuracy} % + + Accuracy: {modelInfo.trainingAccuracy} % - - Model Size: {modelInfo.modelSize} MB + + Model Size: {modelInfo.modelSize} MB ) @@ -405,11 +553,23 @@ const Prediction = () => { color="secondary" onClick={openWithJosm} loading={josmLoading} + size="small" + sx={{ mt: 1 }} > Open Results with JOSM )} +
); diff --git a/frontend/src/components/Layout/TrainingDS/DatasetNew/DatasetNew.js b/frontend/src/components/Layout/TrainingDS/DatasetNew/DatasetNew.js index 0b14527d..d914bb24 100644 --- a/frontend/src/components/Layout/TrainingDS/DatasetNew/DatasetNew.js +++ b/frontend/src/components/Layout/TrainingDS/DatasetNew/DatasetNew.js @@ -1,109 +1,138 @@ -import Alert from '@material-ui/lab/Alert' -import { Button, Grid, TextField, Typography } from '@mui/material' -import React, { useContext, useState } from 'react' -import SaveIcon from '@material-ui/icons/Save'; -import { useMutation } from 'react-query'; -import axios from "../../../../axios" -import { useNavigate } from 'react-router-dom'; -import AuthContext from '../../../../Context/AuthContext'; -const DatasetNew = props => { - - const [error, setError] = useState(null) - const [DSName, setDSName] = useState("") +import Alert from "@material-ui/lab/Alert"; +import { Button, Grid, TextField, Typography } from "@mui/material"; +import React, { useContext, useState } from "react"; +import SaveIcon from "@material-ui/icons/Save"; +import { useMutation } from "react-query"; +import axios from "../../../../axios"; +import { useNavigate } from "react-router-dom"; +import AuthContext from "../../../../Context/AuthContext"; +const DatasetNew = (props) => { + const [error, setError] = useState(null); + const [DSName, setDSName] = useState(""); const navigate = useNavigate(); - const { accessToken } = useContext(AuthContext) + const [oamURL, setOAMURL] = useState(); + const { accessToken } = useContext(AuthContext); const saveDataset = async () => { try { - const body = { - "name": DSName, - "status": 0 - } + name: DSName, + status: 0, + source_imagery: oamURL, + }; const headers = { - "access-token": accessToken - } + "access-token": accessToken, + }; const res = await axios.post("/dataset/", body, { headers }); - if (res.error) - setError(res.error.response.statusText); + if (res.error) setError(res.error.response.statusText); - console.log("/training-datasets/", res) - navigate(`/training-datasets/${res.data.id}`) + console.log("/training-datasets/", res); + navigate(`/training-datasets/${res.data.id}`); return res.data; } catch (e) { console.log("isError"); setError(JSON.stringify(e)); } finally { - } }; const { mutate, isLoading } = useMutation(saveDataset); + return ( + <> +
+ + + + Create New Training Dataset + + - return <> -
- - - - Create New Training Dataset - - - - - - Enter a name that represents the training datasets. - It is recommedned to use a local name of the area where the training dataset exists - - - - - { - setDSName(e.target.value) - - }} - error={DSName.trim() === ""} - /> - - - - + + + Enter a name that represents the training datasets. It is + recommedned to use a local name of the area where the training + dataset exists + + - + + + + - + {error && ( + + {error} + + )} - {error && - - {error} - - } - - -
- - -} +
+ + ); +}; -export default DatasetNew; \ No newline at end of file +export default DatasetNew; diff --git a/frontend/src/index.css b/frontend/src/index.css index 6c6072e5..82e48df5 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -78,6 +78,26 @@ code { } +.feedback-button { + background-color: #4caf50; + margin-top: 2px; + color: #fff; + border: none; + border-radius: 4px; + padding: 8px 12px; + cursor: pointer; + outline: none; + font-weight: bold; + font-size: 1.0em; +} + +.feedback-button:last-child { + background-color: #f44336; + margin-left:2px; +} + + + .logo:hover { cursor: pointer;