From e2834c7fb1177717bd7b2d1a73f4355c08992975 Mon Sep 17 00:00:00 2001 From: Lehdari Date: Thu, 27 Apr 2023 23:35:10 +0300 Subject: [PATCH] Implement support for base experiments Base experiment basically means continuing from where previous experiment left: using the same model and potentially the same model configuration. Previous time series data will also be loaded and new data accumulated on top of it. --- include/App.hpp | 8 +- include/gui/State.hpp | 8 +- include/ml/Trainer.hpp | 3 +- src/App.cpp | 37 ++++++-- src/gui/TrainingWindow.cpp | 94 ++++++++++++++------ src/ml/Trainer.cpp | 81 +++++++++++------ src/ml/models/MultiLevelAutoEncoderModel.cpp | 26 ++++-- 7 files changed, 180 insertions(+), 77 deletions(-) diff --git a/include/App.hpp b/include/App.hpp index b10be69..6cd2f3e 100644 --- a/include/App.hpp +++ b/include/App.hpp @@ -50,7 +50,13 @@ class App { gui::Gui _gui; void trainingControl(); - void resetExperiment(); // shouldn't be called when training thread is running + + // For updating possible changes made in GUI + void updateExperimentConfig(nlohmann::json& experimentConfig); + + // Reset the whole experiment, create new config and instantiate new model. + // Shouldn't be called when training thread is running. + void resetExperiment(); }; diff --git a/include/gui/State.hpp b/include/gui/State.hpp index dd841fe..d8c5090 100644 --- a/include/gui/State.hpp +++ b/include/gui/State.hpp @@ -44,10 +44,12 @@ struct State { STOPPED = 0, ONGOING = 1, PAUSED = 2 - } trainingStatus {TrainingStatus::STOPPED}; + } trainingStatus {TrainingStatus::STOPPED}; - char experimentName[256] {"ex_{time}_{version}"}; - std::string modelTypeName {"AutoEncoderModel"}; // type name of the model to be trained + char experimentName[256] {"ex_{time}_{version}"}; + std::string experimentBase {""}; + nlohmann::json baseExperimentConfig; + std::string modelTypeName {"AutoEncoderModel"}; // type name of the model to be trained }; } // namespace gui diff --git a/include/ml/Trainer.hpp b/include/ml/Trainer.hpp index e012e65..2040eef 100644 --- a/include/ml/Trainer.hpp +++ b/include/ml/Trainer.hpp @@ -50,6 +50,7 @@ class Trainer { void quit(); void configureExperiment(nlohmann::json&& experimentConfig); + void setupExperiment(); // called before starting the training loop void saveExperiment(); // Access the model that is being trained @@ -81,8 +82,8 @@ class Trainer { bool startRecording(); void nextMap(size_t newBatchEntryId = 0); // proceed to next map - void setupExperiment(); // called before starting the training void createExperimentDirectories() const; + void loadBaseExperimentTrainingInfo(); }; } // namespace ml diff --git a/src/App.cpp b/src/App.cpp index a275503..0f14b5c 100644 --- a/src/App.cpp +++ b/src/App.cpp @@ -16,6 +16,7 @@ #include "gvizdoom/DoomGame.hpp" #include "glad/glad.h" +#include "util/ExperimentUtils.hpp" #include #include @@ -79,7 +80,7 @@ App::App(Trainer* trainer) : // Initialize gui _gui.init(_window, &_glContext); - _gui.setCallback("newModelTypeSelected", [&](const gui::State& guiState){ resetExperiment(); }); + _gui.setCallback("resetExperiment", [&](const gui::State& guiState){ resetExperiment(); }); _gui.update(_trainer->getTrainingInfo()); // Read gui layout from the layout file @@ -165,8 +166,9 @@ void App::trainingControl() break; // no trainer thread running, launch it - (*_trainer->getExperimentConfig())["experiment_name"] = _gui.getState().experimentName; - (*_trainer->getExperimentConfig()).erase("experiment_root"); + updateExperimentConfig(*_trainer->getExperimentConfig()); + _trainer->setupExperiment(); // needed for updated training info + _gui.update(_trainer->getTrainingInfo()); // communicate potential changes in training info to gui _trainerThread = std::thread(&ml::Trainer::loop, _trainer); } break; case gui::State::TrainingStatus::PAUSED: { @@ -175,6 +177,19 @@ void App::trainingControl() } } +void App::updateExperimentConfig(nlohmann::json& experimentConfig) +{ + // TODO This is here so that when using macros in the experiment name a new directory is generated. + // TODO It is planned be overridable from the experiment configuration GUI section. + if (experimentConfig.contains("experiment_root")) + experimentConfig.erase("experiment_root"); + + experimentConfig["experiment_name"] = _gui.getState().experimentName; + if (!_gui.getState().experimentBase.empty()) + experimentConfig["experiment_base_root"] = experimentRootFromString(_gui.getState().experimentBase); + experimentConfig["software_version"] = GIT_VERSION; +} + void App::resetExperiment() { if (_gui.getState().trainingStatus != gui::State::TrainingStatus::STOPPED) { @@ -184,12 +199,18 @@ void App::resetExperiment() // Setup experiment config nlohmann::json experimentConfig; - experimentConfig["experiment_name"] = _gui.getState().experimentName; + updateExperimentConfig(experimentConfig); experimentConfig["model_type"] = _gui.getState().modelTypeName; - modelTypeNameCallback(_gui.getState().modelTypeName, [&](){ // load the default model config - experimentConfig["model_config"] = getDefaultModelConfig(); - }); - experimentConfig["software_version"] = GIT_VERSION; + if (_gui.getState().baseExperimentConfig.contains("model_config")) { + experimentConfig["model_config"] = _gui.getState().baseExperimentConfig["model_config"]; + // model config may change, store also the original + experimentConfig["base_model_config"] = _gui.getState().baseExperimentConfig["model_config"]; + } + else { + modelTypeNameCallback(_gui.getState().modelTypeName, [&]() { // load the default model config + experimentConfig["model_config"] = getDefaultModelConfig(); + }); + } _trainer->configureExperiment(std::move(experimentConfig)); _gui.update(_trainer->getTrainingInfo()); diff --git a/src/gui/TrainingWindow.cpp b/src/gui/TrainingWindow.cpp index 0a14247..076a932 100644 --- a/src/gui/TrainingWindow.cpp +++ b/src/gui/TrainingWindow.cpp @@ -13,11 +13,15 @@ #include "ml/Models.hpp" #include "ml/ModelTypeUtils.hpp" #include "ml/Trainer.hpp" +#include "util/ExperimentUtils.hpp" #include "imgui.h" #include "misc/cpp/imgui_stdlib.h" +namespace fs = std::filesystem; + + gui::TrainingWindow::TrainingWindow(std::set* activeIds, State* guiState, int id) : Window(this, guiState, activeIds, id) { @@ -36,39 +40,70 @@ void gui::TrainingWindow::render(ml::Trainer* trainer) // Flag for disabling all settings when training's in progress bool trainingInProgress = _guiState->trainingStatus != State::TrainingStatus::STOPPED; - ImGui::BeginDisabled(trainingInProgress); - - // Experiment name input - ImGui::Text("Experiment name:"); - ImGui::SetNextItemWidth(windowSize.x); - if (ImGui::InputText("##ExperimentName", _guiState->experimentName, 255)) { - if (_guiState->callbacks.contains("newModelTypeSelected")) - _guiState->callbacks["newModelTypeSelected"](*_guiState); - } - // Model select - ImGui::Text("Model: "); - ImGui::SameLine(); - ImGui::SetNextItemWidth(windowSize.x - fontSize*6.35f); - if (ImGui::BeginCombo("##ModelSelector", _guiState->modelTypeName.c_str())) { - ml::modelForEachTypeCallback([&]() { - constexpr auto name = ml::ModelTypeInfo::name; - bool isSelected = (_guiState->modelTypeName == name); - if (ImGui::Selectable(name, isSelected)) { - // Call the callback function for new model type selection (in case it's defined) - if (_guiState->modelTypeName != name && _guiState->callbacks.contains("newModelTypeSelected")) { - _guiState->modelTypeName = name; - _guiState->callbacks["newModelTypeSelected"](*_guiState); + if (ImGui::CollapsingHeader("Experiment configuration")) { + ImGui::BeginDisabled(trainingInProgress); + + // Experiment name input + ImGui::Text("Experiment name:"); + ImGui::SetNextItemWidth(windowSize.x - fontSize * 2.0f); + ImGui::InputText("##ExperimentName", _guiState->experimentName, 255); + + // Experiment base input + ImGui::Text("Experiment base:"); + ImGui::SetNextItemWidth(windowSize.x - fontSize * 2.0f); + if (ImGui::InputText("##ExperimentBase", &_guiState->experimentBase) && + !_guiState->experimentBase.empty()) { + auto baseRoot = experimentRootFromString(_guiState->experimentBase); + if (fs::exists(baseRoot)) { // valid experiment root passed, let's check if there's a config + fs::path baseExperimentConfigFilename = baseRoot / "experiment_config.json"; + if (fs::exists(baseExperimentConfigFilename)) { // config found, parse it + std::ifstream baseExperimentConfigFile(baseExperimentConfigFilename); + _guiState->baseExperimentConfig = nlohmann::json::parse(baseExperimentConfigFile); + if (_guiState->baseExperimentConfig.contains("model_type")) { // reset the model + _guiState->modelTypeName = _guiState->baseExperimentConfig["model_type"]; + if (_guiState->callbacks.contains("resetExperiment")) + _guiState->callbacks["resetExperiment"](*_guiState); + } + else + printf("WARNING: No model_type specified in the base experiment config\n"); // TODO logging } + else + printf("WARNING: No experiment config from %s found!\n", baseExperimentConfigFilename.c_str()); } - if (isSelected) - ImGui::SetItemDefaultFocus(); - }); + } + if (ImGui::Button("Reload model config")) { // reload model config from the base experiment + if (_guiState->baseExperimentConfig.contains("model_config")) + (*trainer->getExperimentConfig())["model_config"] = _guiState->baseExperimentConfig["model_config"]; + else + printf("WARNING: No base experiment model config found\n"); // TODO logging + } - ImGui::EndCombo(); - } + // Model select + ImGui::BeginDisabled(!_guiState->experimentBase.empty()); // Same model type forced when using a base experiment + ImGui::Text("Model:"); + ImGui::SetNextItemWidth(windowSize.x - fontSize * 2.0f); + if (ImGui::BeginCombo("##ModelSelector", _guiState->modelTypeName.c_str())) { + ml::modelForEachTypeCallback([&]() { + constexpr auto name = ml::ModelTypeInfo::name; + bool isSelected = (_guiState->modelTypeName == name); + if (ImGui::Selectable(name, isSelected)) { + // Call the callback function for new model type selection (in case it's defined) + if (_guiState->modelTypeName != name && _guiState->callbacks.contains("resetExperiment")) { + _guiState->modelTypeName = name; + _guiState->callbacks["resetExperiment"](*_guiState); + } + } + if (isSelected) + ImGui::SetItemDefaultFocus(); + }); - ImGui::EndDisabled(); + ImGui::EndCombo(); + } + ImGui::EndDisabled(); + + ImGui::EndDisabled(); + } // Model config table if (ImGui::CollapsingHeader("Model configuration")) { @@ -98,7 +133,8 @@ void gui::TrainingWindow::render(ml::Trainer* trainer) ImGui::Checkbox("##value", &v); paramValue = v; } break; - case nlohmann::json::value_t::number_integer: { + case nlohmann::json::value_t::number_integer: + case nlohmann::json::value_t::number_unsigned: { int v = paramValue.get(); ImGui::SetNextItemWidth(columnWidth); ImGui::InputInt("##value", &v); diff --git a/src/ml/Trainer.cpp b/src/ml/Trainer.cpp index 5ba86df..385a95a 100644 --- a/src/ml/Trainer.cpp +++ b/src/ml/Trainer.cpp @@ -73,8 +73,6 @@ void Trainer::loop() if (_agentModel == nullptr) throw std::runtime_error("Agent model must not be nullptr"); - setupExperiment(); - auto& doomGame = DoomGame::instance(); _playerInitPos = doomGame.getGameState(); @@ -184,22 +182,48 @@ void Trainer::configureExperiment(nlohmann::json&& experimentConfig) if (!experimentConfig.contains("model_config")) throw std::runtime_error("Experiment config does not contain mandatory entry \"model_config\""); + std::string previousModelType; + if (_experimentConfig.contains("model_type")) + previousModelType = _experimentConfig["model_type"]; + _experimentConfig = std::move(experimentConfig); - // Reset training info - _trainingInfo.reset(); // TODO load past data if basis experiment is used + // Reset training info, load base experiment data if one is specified + loadBaseExperimentTrainingInfo(); - // delete the previous model - _model.reset(); + // Check if the model type has changed, reinstantiate the model in that case + if (previousModelType != _experimentConfig["model_type"].get()) { + // delete the previous model + _model.reset(); - // instantiate new model of desired type using the config above - ml::modelTypeNameCallback(_experimentConfig["model_type"], [&](){ - _model = std::make_unique(); - }); + // instantiate new model of desired type using the config above + ml::modelTypeNameCallback(_experimentConfig["model_type"], [&]() { + _model = std::make_unique(); + }); + } _model->setTrainingInfo(&_trainingInfo); } +void Trainer::setupExperiment() +{ + // Format the experiment name + _experimentConfig["experiment_name"] = formatExperimentName(_experimentConfig["experiment_name"], + _experimentConfig["model_config"]); + + if (!_experimentConfig.contains("experiment_root")) { + printf("INFO: No \"experiment_root\" defined. Using experiment_name: \"%s\"\n", + _experimentConfig["experiment_name"].get().c_str()); + _experimentConfig["experiment_root"] = _experimentConfig["experiment_name"]; + } + + createExperimentDirectories(); + loadBaseExperimentTrainingInfo(); + + _model->setTrainingInfo(&_trainingInfo); + _model->init(_experimentConfig); +} + void Trainer::saveExperiment() { printf("Saving the experiment\n"); // TODO logging @@ -278,23 +302,6 @@ void Trainer::nextMap(size_t newBatchEntryId) _encoderModel->reset(); } -void Trainer::setupExperiment() -{ - // Format the experiment name - _experimentConfig["experiment_name"] = formatExperimentName(_experimentConfig["experiment_name"], - _experimentConfig["model_config"]); - - if (!_experimentConfig.contains("experiment_root")) { - printf("INFO: No \"experiment_root\" defined. Using experiment_name: \"%s\"\n", - _experimentConfig["experiment_name"].get().c_str()); - _experimentConfig["experiment_root"] = _experimentConfig["experiment_name"]; - } - - createExperimentDirectories(); - - _model->init(_experimentConfig); -} - void Trainer::createExperimentDirectories() const { fs::path experimentDir = doot2::experimentsDirectory / _experimentConfig["experiment_root"]; @@ -304,3 +311,23 @@ void Trainer::createExperimentDirectories() const throw std::runtime_error("Could not create the directory \""+experimentDir.string()+"\""); } } + +void Trainer::loadBaseExperimentTrainingInfo() +{ + _trainingInfo.reset(); + + if (!_experimentConfig.contains("experiment_base_root")) + return; // no base experiment specified, return + + fs::path timeSeriesPath = _experimentConfig["experiment_base_root"].get() / "time_series.json"; + if (!fs::exists(timeSeriesPath)) { + printf("WARNING: Base experiment specified in config but time series data was not found (%s).", + timeSeriesPath.c_str()); // TODO logging + return; + } + + // TODO handle potential exceptions + std::ifstream timeSeriesFile(timeSeriesPath); + auto timeSeriesJson = nlohmann::json::parse(timeSeriesFile); + _trainingInfo.timeSeries.write()->fromJson(timeSeriesJson); +} diff --git a/src/ml/models/MultiLevelAutoEncoderModel.cpp b/src/ml/models/MultiLevelAutoEncoderModel.cpp index 0b5e107..5e7356b 100644 --- a/src/ml/models/MultiLevelAutoEncoderModel.cpp +++ b/src/ml/models/MultiLevelAutoEncoderModel.cpp @@ -140,26 +140,36 @@ void MultiLevelAutoEncoderModel::init(const nlohmann::json& experimentConfig) if (modelConfig.contains("frame_decoder_filename")) _frameDecoderFilename = experimentRoot / modelConfig["frame_decoder_filename"].get(); + // Separate loading paths in case a base experiment is specified + fs::path frameEncoderFilename = _frameEncoderFilename; + fs::path frameDecoderFilename = _frameDecoderFilename; + if (experimentConfig.contains("experiment_base_root") && experimentConfig.contains("base_model_config")) { + frameEncoderFilename = experimentConfig["experiment_base_root"].get() / + experimentConfig["base_model_config"]["frame_encoder_filename"].get(); + frameDecoderFilename = experimentConfig["experiment_base_root"].get() / + experimentConfig["base_model_config"]["frame_decoder_filename"].get(); + } + // Load frame encoder - if (fs::exists(_frameEncoderFilename)) { - printf("Loading frame encoder model from %s\n", _frameEncoderFilename.c_str()); // TODO logging + if (fs::exists(frameEncoderFilename)) { + printf("Loading frame encoder model from %s\n", frameEncoderFilename.c_str()); // TODO logging serialize::InputArchive inputArchive; - inputArchive.load_from(_frameEncoderFilename); + inputArchive.load_from(frameEncoderFilename); _frameEncoder->load(inputArchive); } else { - printf("No %s found. Initializing new frame encoder model.\n", _frameEncoderFilename.c_str()); // TODO logging + printf("No %s found. Initializing new frame encoder model.\n", frameEncoderFilename.c_str()); // TODO logging } // Load frame decoder - if (fs::exists(_frameDecoderFilename)) { - printf("Loading frame decoder model from %s\n", _frameDecoderFilename.c_str()); // TODO logging + if (fs::exists(frameDecoderFilename)) { + printf("Loading frame decoder model from %s\n", frameDecoderFilename.c_str()); // TODO logging serialize::InputArchive inputArchive; - inputArchive.load_from(_frameDecoderFilename); + inputArchive.load_from(frameDecoderFilename); _frameDecoder->load(inputArchive); } else { - printf("No %s found. Initializing new frame decoder model.\n", _frameDecoderFilename.c_str()); // TODO logging + printf("No %s found. Initializing new frame decoder model.\n", frameDecoderFilename.c_str()); // TODO logging } // Setup hyperparameters