Skip to content

Commit

Permalink
Implement support for base experiments
Browse files Browse the repository at this point in the history
Base experiment basically means continuing from where previous experiment left: using the same model and potentially the same model configuration. Previous time series data will also be loaded and new data accumulated on top of it.
  • Loading branch information
Lehdari committed Apr 27, 2023
1 parent b380fa5 commit e2834c7
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 77 deletions.
8 changes: 7 additions & 1 deletion include/App.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ class App {
gui::Gui _gui;

void trainingControl();
void resetExperiment(); // shouldn't be called when training thread is running

// For updating possible changes made in GUI
void updateExperimentConfig(nlohmann::json& experimentConfig);

// Reset the whole experiment, create new config and instantiate new model.
// Shouldn't be called when training thread is running.
void resetExperiment();
};


Expand Down
8 changes: 5 additions & 3 deletions include/gui/State.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ struct State {
STOPPED = 0,
ONGOING = 1,
PAUSED = 2
} trainingStatus {TrainingStatus::STOPPED};
} trainingStatus {TrainingStatus::STOPPED};

char experimentName[256] {"ex_{time}_{version}"};
std::string modelTypeName {"AutoEncoderModel"}; // type name of the model to be trained
char experimentName[256] {"ex_{time}_{version}"};
std::string experimentBase {""};
nlohmann::json baseExperimentConfig;
std::string modelTypeName {"AutoEncoderModel"}; // type name of the model to be trained
};

} // namespace gui
3 changes: 2 additions & 1 deletion include/ml/Trainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Trainer {
void quit();

void configureExperiment(nlohmann::json&& experimentConfig);
void setupExperiment(); // called before starting the training loop
void saveExperiment();

// Access the model that is being trained
Expand Down Expand Up @@ -81,8 +82,8 @@ class Trainer {

bool startRecording();
void nextMap(size_t newBatchEntryId = 0); // proceed to next map
void setupExperiment(); // called before starting the training
void createExperimentDirectories() const;
void loadBaseExperimentTrainingInfo();
};

} // namespace ml
37 changes: 29 additions & 8 deletions src/App.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "gvizdoom/DoomGame.hpp"
#include "glad/glad.h"
#include "util/ExperimentUtils.hpp"

#include <chrono>
#include <filesystem>
Expand Down Expand Up @@ -79,7 +80,7 @@ App::App(Trainer* trainer) :

// Initialize gui
_gui.init(_window, &_glContext);
_gui.setCallback("newModelTypeSelected", [&](const gui::State& guiState){ resetExperiment(); });
_gui.setCallback("resetExperiment", [&](const gui::State& guiState){ resetExperiment(); });
_gui.update(_trainer->getTrainingInfo());

// Read gui layout from the layout file
Expand Down Expand Up @@ -165,8 +166,9 @@ void App::trainingControl()
break;

// no trainer thread running, launch it
(*_trainer->getExperimentConfig())["experiment_name"] = _gui.getState().experimentName;
(*_trainer->getExperimentConfig()).erase("experiment_root");
updateExperimentConfig(*_trainer->getExperimentConfig());
_trainer->setupExperiment(); // needed for updated training info
_gui.update(_trainer->getTrainingInfo()); // communicate potential changes in training info to gui
_trainerThread = std::thread(&ml::Trainer::loop, _trainer);
} break;
case gui::State::TrainingStatus::PAUSED: {
Expand All @@ -175,6 +177,19 @@ void App::trainingControl()
}
}

void App::updateExperimentConfig(nlohmann::json& experimentConfig)
{
// TODO This is here so that when using macros in the experiment name a new directory is generated.
// TODO It is planned be overridable from the experiment configuration GUI section.
if (experimentConfig.contains("experiment_root"))
experimentConfig.erase("experiment_root");

experimentConfig["experiment_name"] = _gui.getState().experimentName;
if (!_gui.getState().experimentBase.empty())
experimentConfig["experiment_base_root"] = experimentRootFromString(_gui.getState().experimentBase);
experimentConfig["software_version"] = GIT_VERSION;
}

void App::resetExperiment()
{
if (_gui.getState().trainingStatus != gui::State::TrainingStatus::STOPPED) {
Expand All @@ -184,12 +199,18 @@ void App::resetExperiment()

// Setup experiment config
nlohmann::json experimentConfig;
experimentConfig["experiment_name"] = _gui.getState().experimentName;
updateExperimentConfig(experimentConfig);
experimentConfig["model_type"] = _gui.getState().modelTypeName;
modelTypeNameCallback(_gui.getState().modelTypeName, [&]<typename T_Model>(){ // load the default model config
experimentConfig["model_config"] = getDefaultModelConfig<T_Model>();
});
experimentConfig["software_version"] = GIT_VERSION;
if (_gui.getState().baseExperimentConfig.contains("model_config")) {
experimentConfig["model_config"] = _gui.getState().baseExperimentConfig["model_config"];
// model config may change, store also the original
experimentConfig["base_model_config"] = _gui.getState().baseExperimentConfig["model_config"];
}
else {
modelTypeNameCallback(_gui.getState().modelTypeName, [&]<typename T_Model>() { // load the default model config
experimentConfig["model_config"] = getDefaultModelConfig<T_Model>();
});
}

_trainer->configureExperiment(std::move(experimentConfig));
_gui.update(_trainer->getTrainingInfo());
Expand Down
94 changes: 65 additions & 29 deletions src/gui/TrainingWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
#include "ml/Models.hpp"
#include "ml/ModelTypeUtils.hpp"
#include "ml/Trainer.hpp"
#include "util/ExperimentUtils.hpp"

#include "imgui.h"
#include "misc/cpp/imgui_stdlib.h"


namespace fs = std::filesystem;


gui::TrainingWindow::TrainingWindow(std::set<int>* activeIds, State* guiState, int id) :
Window(this, guiState, activeIds, id)
{
Expand All @@ -36,39 +40,70 @@ void gui::TrainingWindow::render(ml::Trainer* trainer)

// Flag for disabling all settings when training's in progress
bool trainingInProgress = _guiState->trainingStatus != State::TrainingStatus::STOPPED;
ImGui::BeginDisabled(trainingInProgress);

// Experiment name input
ImGui::Text("Experiment name:");
ImGui::SetNextItemWidth(windowSize.x);
if (ImGui::InputText("##ExperimentName", _guiState->experimentName, 255)) {
if (_guiState->callbacks.contains("newModelTypeSelected"))
_guiState->callbacks["newModelTypeSelected"](*_guiState);
}

// Model select
ImGui::Text("Model: ");
ImGui::SameLine();
ImGui::SetNextItemWidth(windowSize.x - fontSize*6.35f);
if (ImGui::BeginCombo("##ModelSelector", _guiState->modelTypeName.c_str())) {
ml::modelForEachTypeCallback([&]<typename T_Model>() {
constexpr auto name = ml::ModelTypeInfo<T_Model>::name;
bool isSelected = (_guiState->modelTypeName == name);
if (ImGui::Selectable(name, isSelected)) {
// Call the callback function for new model type selection (in case it's defined)
if (_guiState->modelTypeName != name && _guiState->callbacks.contains("newModelTypeSelected")) {
_guiState->modelTypeName = name;
_guiState->callbacks["newModelTypeSelected"](*_guiState);
if (ImGui::CollapsingHeader("Experiment configuration")) {
ImGui::BeginDisabled(trainingInProgress);

// Experiment name input
ImGui::Text("Experiment name:");
ImGui::SetNextItemWidth(windowSize.x - fontSize * 2.0f);
ImGui::InputText("##ExperimentName", _guiState->experimentName, 255);

// Experiment base input
ImGui::Text("Experiment base:");
ImGui::SetNextItemWidth(windowSize.x - fontSize * 2.0f);
if (ImGui::InputText("##ExperimentBase", &_guiState->experimentBase) &&
!_guiState->experimentBase.empty()) {
auto baseRoot = experimentRootFromString(_guiState->experimentBase);
if (fs::exists(baseRoot)) { // valid experiment root passed, let's check if there's a config
fs::path baseExperimentConfigFilename = baseRoot / "experiment_config.json";
if (fs::exists(baseExperimentConfigFilename)) { // config found, parse it
std::ifstream baseExperimentConfigFile(baseExperimentConfigFilename);
_guiState->baseExperimentConfig = nlohmann::json::parse(baseExperimentConfigFile);
if (_guiState->baseExperimentConfig.contains("model_type")) { // reset the model
_guiState->modelTypeName = _guiState->baseExperimentConfig["model_type"];
if (_guiState->callbacks.contains("resetExperiment"))
_guiState->callbacks["resetExperiment"](*_guiState);
}
else
printf("WARNING: No model_type specified in the base experiment config\n"); // TODO logging
}
else
printf("WARNING: No experiment config from %s found!\n", baseExperimentConfigFilename.c_str());
}
if (isSelected)
ImGui::SetItemDefaultFocus();
});
}
if (ImGui::Button("Reload model config")) { // reload model config from the base experiment
if (_guiState->baseExperimentConfig.contains("model_config"))
(*trainer->getExperimentConfig())["model_config"] = _guiState->baseExperimentConfig["model_config"];
else
printf("WARNING: No base experiment model config found\n"); // TODO logging
}

ImGui::EndCombo();
}
// Model select
ImGui::BeginDisabled(!_guiState->experimentBase.empty()); // Same model type forced when using a base experiment
ImGui::Text("Model:");
ImGui::SetNextItemWidth(windowSize.x - fontSize * 2.0f);
if (ImGui::BeginCombo("##ModelSelector", _guiState->modelTypeName.c_str())) {
ml::modelForEachTypeCallback([&]<typename T_Model>() {
constexpr auto name = ml::ModelTypeInfo<T_Model>::name;
bool isSelected = (_guiState->modelTypeName == name);
if (ImGui::Selectable(name, isSelected)) {
// Call the callback function for new model type selection (in case it's defined)
if (_guiState->modelTypeName != name && _guiState->callbacks.contains("resetExperiment")) {
_guiState->modelTypeName = name;
_guiState->callbacks["resetExperiment"](*_guiState);
}
}
if (isSelected)
ImGui::SetItemDefaultFocus();
});

ImGui::EndDisabled();
ImGui::EndCombo();
}
ImGui::EndDisabled();

ImGui::EndDisabled();
}

// Model config table
if (ImGui::CollapsingHeader("Model configuration")) {
Expand Down Expand Up @@ -98,7 +133,8 @@ void gui::TrainingWindow::render(ml::Trainer* trainer)
ImGui::Checkbox("##value", &v);
paramValue = v;
} break;
case nlohmann::json::value_t::number_integer: {
case nlohmann::json::value_t::number_integer:
case nlohmann::json::value_t::number_unsigned: {
int v = paramValue.get<int>();
ImGui::SetNextItemWidth(columnWidth);
ImGui::InputInt("##value", &v);
Expand Down
81 changes: 54 additions & 27 deletions src/ml/Trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ void Trainer::loop()
if (_agentModel == nullptr)
throw std::runtime_error("Agent model must not be nullptr");

setupExperiment();

auto& doomGame = DoomGame::instance();
_playerInitPos = doomGame.getGameState<GameState::PlayerPos>();

Expand Down Expand Up @@ -184,22 +182,48 @@ void Trainer::configureExperiment(nlohmann::json&& experimentConfig)
if (!experimentConfig.contains("model_config"))
throw std::runtime_error("Experiment config does not contain mandatory entry \"model_config\"");

std::string previousModelType;
if (_experimentConfig.contains("model_type"))
previousModelType = _experimentConfig["model_type"];

_experimentConfig = std::move(experimentConfig);

// Reset training info
_trainingInfo.reset(); // TODO load past data if basis experiment is used
// Reset training info, load base experiment data if one is specified
loadBaseExperimentTrainingInfo();

// delete the previous model
_model.reset();
// Check if the model type has changed, reinstantiate the model in that case
if (previousModelType != _experimentConfig["model_type"].get<std::string>()) {
// delete the previous model
_model.reset();

// instantiate new model of desired type using the config above
ml::modelTypeNameCallback(_experimentConfig["model_type"], [&]<typename T_Model>(){
_model = std::make_unique<T_Model>();
});
// instantiate new model of desired type using the config above
ml::modelTypeNameCallback(_experimentConfig["model_type"], [&]<typename T_Model>() {
_model = std::make_unique<T_Model>();
});
}

_model->setTrainingInfo(&_trainingInfo);
}

void Trainer::setupExperiment()
{
// Format the experiment name
_experimentConfig["experiment_name"] = formatExperimentName(_experimentConfig["experiment_name"],
_experimentConfig["model_config"]);

if (!_experimentConfig.contains("experiment_root")) {
printf("INFO: No \"experiment_root\" defined. Using experiment_name: \"%s\"\n",
_experimentConfig["experiment_name"].get<std::string>().c_str());
_experimentConfig["experiment_root"] = _experimentConfig["experiment_name"];
}

createExperimentDirectories();
loadBaseExperimentTrainingInfo();

_model->setTrainingInfo(&_trainingInfo);
_model->init(_experimentConfig);
}

void Trainer::saveExperiment()
{
printf("Saving the experiment\n"); // TODO logging
Expand Down Expand Up @@ -278,23 +302,6 @@ void Trainer::nextMap(size_t newBatchEntryId)
_encoderModel->reset();
}

void Trainer::setupExperiment()
{
// Format the experiment name
_experimentConfig["experiment_name"] = formatExperimentName(_experimentConfig["experiment_name"],
_experimentConfig["model_config"]);

if (!_experimentConfig.contains("experiment_root")) {
printf("INFO: No \"experiment_root\" defined. Using experiment_name: \"%s\"\n",
_experimentConfig["experiment_name"].get<std::string>().c_str());
_experimentConfig["experiment_root"] = _experimentConfig["experiment_name"];
}

createExperimentDirectories();

_model->init(_experimentConfig);
}

void Trainer::createExperimentDirectories() const
{
fs::path experimentDir = doot2::experimentsDirectory / _experimentConfig["experiment_root"];
Expand All @@ -304,3 +311,23 @@ void Trainer::createExperimentDirectories() const
throw std::runtime_error("Could not create the directory \""+experimentDir.string()+"\"");
}
}

void Trainer::loadBaseExperimentTrainingInfo()
{
_trainingInfo.reset();

if (!_experimentConfig.contains("experiment_base_root"))
return; // no base experiment specified, return

fs::path timeSeriesPath = _experimentConfig["experiment_base_root"].get<fs::path>() / "time_series.json";
if (!fs::exists(timeSeriesPath)) {
printf("WARNING: Base experiment specified in config but time series data was not found (%s).",
timeSeriesPath.c_str()); // TODO logging
return;
}

// TODO handle potential exceptions
std::ifstream timeSeriesFile(timeSeriesPath);
auto timeSeriesJson = nlohmann::json::parse(timeSeriesFile);
_trainingInfo.timeSeries.write()->fromJson<double>(timeSeriesJson);
}
Loading

0 comments on commit e2834c7

Please sign in to comment.