From bba0dfc81c7282906ff5564bf71f94e4eb3c02be Mon Sep 17 00:00:00 2001 From: kshitijrajsharma Date: Tue, 3 Dec 2024 12:18:13 +0000 Subject: [PATCH] fix(log-production): fixes bug on epoch limit and log production --- backend/aiproject/settings.py | 11 ++++++-- backend/core/tasks.py | 49 ++++++++++++++++++++++++----------- backend/core/views.py | 33 +++++++++++++++-------- 3 files changed, 65 insertions(+), 28 deletions(-) diff --git a/backend/aiproject/settings.py b/backend/aiproject/settings.py index 163780da..4ecea1b9 100644 --- a/backend/aiproject/settings.py +++ b/backend/aiproject/settings.py @@ -56,8 +56,15 @@ # Limiter -EPOCHS_LIMIT = env("EPOCHS_LIMIT", default=30) -BATCH_SIZE_LIMIT = env("BATCH_SIZE_LIMIT", default=8) + +## YOLO +YOLO_EPOCHS_LIMIT = env("YOLO_EPOCHS_LIMIT", default=200) +YOLO_BATCH_SIZE_LIMIT = env("YOLO_BATCH_SIZE_LIMIT", default=8) + +## RAMP +RAMP_EPOCHS_LIMIT = env("RAMP_EPOCHS_LIMIT", default=40) +RAMP_BATCH_SIZE_LIMIT = env("RAMP_BATCH_SIZE_LIMIT", default=8) + TRAINING_WORKSPACE_DOWNLOAD_LIMIT = env( "TRAINING_WORKSPACE_DOWNLOAD_LIMIT", default=200 ) diff --git a/backend/core/tasks.py b/backend/core/tasks.py index d20fbaf7..95308f72 100644 --- a/backend/core/tasks.py +++ b/backend/core/tasks.py @@ -35,7 +35,14 @@ from django.utils import timezone from predictor import download_imagery, get_start_end_download_coords +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) + + logger = logging.getLogger(__name__) +logger.propagate = False + # from core.serializers import LabelFileSerializer @@ -363,15 +370,27 @@ def yolo_model_training( os.makedirs(output_path) shutil.copyfile(output_model_path, os.path.join(output_path, "checkpoint.pt")) - shutil.copyfile(os.path.join(os.path.dirname(output_model_path),'best.onnx'), os.path.join(output_path, "checkpoint.onnx")) + shutil.copyfile( + os.path.join(os.path.dirname(output_model_path), "best.onnx"), + os.path.join(output_path, "checkpoint.onnx"), + ) # shutil.copyfile(os.path.dirname(output_model_path,'checkpoint.tflite'), os.path.join(output_path, "checkpoint.tflite")) - + shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed")) - os.makedirs(os.path.join(output_path,model),exist_ok=True) + os.makedirs(os.path.join(output_path, model), exist_ok=True) - shutil.copytree(os.path.join(yolo_data_dir,'images'), os.path.join(output_path,model, "images")) - shutil.copytree(os.path.join(yolo_data_dir,'labels'), os.path.join(output_path,model, "labels")) - shutil.copyfile(os.path.join(yolo_data_dir,'yolo_dataset.yaml'), os.path.join(output_path,model, "yolo_dataset.yaml")) + shutil.copytree( + os.path.join(yolo_data_dir, "images"), + os.path.join(output_path, model, "images"), + ) + shutil.copytree( + os.path.join(yolo_data_dir, "labels"), + os.path.join(output_path, model, "labels"), + ) + shutil.copyfile( + os.path.join(yolo_data_dir, "yolo_dataset.yaml"), + os.path.join(output_path, model, "yolo_dataset.yaml"), + ) graph_output_path = os.path.join( pathlib.Path(os.path.dirname(output_model_path)).parent, "iou_chart.png" @@ -454,10 +473,7 @@ def train_model( if training_instance.task_id is None or training_instance.task_id.strip() == "": training_instance.task_id = train_model.request.id training_instance.save() - log_file = os.path.join( - settings.LOG_PATH, f"run_{train_model.request.id}_log.txt" - ) - + log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}.log") if model_instance.base_model == "YOLO_V8_V1" and settings.YOLO_HOME is None: raise ValueError("YOLO Home is not configured") @@ -465,11 +481,14 @@ def train_model( raise ValueError("Ramp Home is not configured") try: - with open(log_file, "w") as f: - # redirect stdout to the log file + with open(log_file, "a") as f: + # redirect stdout to the log file sys.stdout = f - training_input_image_source, aoi_serializer, serialized_field = prepare_data( - training_instance, dataset_id, feedback, zoom_level, source_imagery + logging.info("Training Started") + training_input_image_source, aoi_serializer, serialized_field = ( + prepare_data( + training_instance, dataset_id, feedback, zoom_level, source_imagery + ) ) if model_instance.base_model in ("YOLO_V8_V1", "YOLO_V8_V2"): @@ -499,7 +518,7 @@ def train_model( input_boundary_width, ) - logger.info(f"Training task {training_id} completed successfully") + logging.info(f"Training task {training_id} completed successfully") return response except Exception as ex: diff --git a/backend/core/views.py b/backend/core/views.py index 4d48e600..a34005e3 100644 --- a/backend/core/views.py +++ b/backend/core/views.py @@ -150,16 +150,25 @@ def create(self, validated_data): epochs = validated_data["epochs"] batch_size = validated_data["batch_size"] + if model.base_model == "RAMP": + if epochs > settings.RAMP_EPOCHS_LIMIT: + raise ValidationError( + f"Epochs can't be greater than {settings.RAMP_EPOCHS_LIMIT} on this server" + ) + if batch_size > settings.RAMP_BATCH_SIZE_LIMIT: + raise ValidationError( + f"Batch size can't be greater than {settings.RAMP_BATCH_SIZE_LIMIT} on this server" + ) + if model.base_model in ["YOLO_V8_V1","YOLO_V8_V2"]: - if epochs > settings.EPOCHS_LIMIT: - raise ValidationError( - f"Epochs can't be greater than {settings.EPOCHS_LIMIT} on this server" - ) - if batch_size > settings.BATCH_SIZE_LIMIT: - raise ValidationError( - f"Batch size can't be greater than {settings.BATCH_SIZE_LIMIT} on this server" - ) - + if epochs > settings.YOLO_EPOCHS_LIMIT: + raise ValidationError( + f"Epochs can't be greater than {settings.YOLO_EPOCHS_LIMIT} on this server" + ) + if batch_size > settings.YOLO_BATCH_SIZE_LIMIT: + raise ValidationError( + f"Batch size can't be greater than {settings.YOLO_BATCH_SIZE_LIMIT} on this server" + ) user = self.context["request"].user validated_data["user"] = user # create the model instance @@ -553,11 +562,13 @@ def run_task_status(request, run_id: str): } ) elif task_result.state == "PENDING" or task_result.state == "STARTED": - log_file = os.path.join(settings.LOG_PATH, f"run_{run_id}_log.txt") + log_file = os.path.join(settings.LOG_PATH, f"run_{run_id}.log") try: # read the last 10 lines of the log file + cmd = ["tail", "-n", str(settings.LOG_LINE_STREAM_TRUNCATE_VALUE), log_file] + # print(cmd) output = subprocess.check_output( - ["tail", "-n", settings.LOG_LINE_STREAM_TRUNCATE_VALUE, log_file] + cmd ).decode("utf-8") except Exception as e: output = str(e)