Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TMP] Reward model #13

Draft
wants to merge 12 commits into
base: develop
Choose a base branch
from
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
.idea/
cmake-build-*/
*.code-workspace
*.pk3
src/gitinfo.h
*.pt
models/
build/
8 changes: 5 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ project(DooT2)


# required for env with a RTX 4090
set(CMAKE_CUDA_ARCHITECTURES 89)
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
set(CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda-11.8)
# set(CMAKE_CUDA_ARCHITECTURES 89)
# set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
# set(CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda-11.8)


# Add external dependencies
Expand Down Expand Up @@ -33,6 +33,8 @@ set(DOOT2_SOURCES
src/HeatmapActionModule.cpp
src/ModelProto.cpp
src/ResNeXtModule.cpp
src/RewardModel.cpp
src/RewardModelTrainer.cpp
src/SequenceStorage.cpp
)

Expand Down
20 changes: 15 additions & 5 deletions include/App.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@


#include "ActionManager.hpp"
#include "HeatmapActionModule.hpp"
#include "DoorTraversalActionModule.hpp"
#include "SequenceStorage.hpp"
#include "HeatmapActionModule.hpp"
#include "ModelProto.hpp"
#include "RewardModelTrainer.hpp"
#include "SequenceStorage.hpp"

#include <SDL.h>
#include <opencv2/core/mat.hpp>
Expand All @@ -27,7 +28,11 @@
class App {
public:
App();
// TODO RO5
App(const App&) = delete;
App(App&&) = delete;
App& operator=(const App&) = delete;
App& operator=(App&&) = delete;

~App();

void loop();
Expand All @@ -53,8 +58,13 @@ class App {
size_t _batchEntryId;
bool _newPatchReady;

ModelProto _model;

ModelProto _modelEdec;
RewardModelTrainer _modelReward;

torch::Device _torchDevice;
FrameEncoder _frameEncoder;
FrameDecoder _frameDecoder;
bool _trainRewardModel;

void nextMap(); // proceed to next map
};
2 changes: 2 additions & 0 deletions include/ModelProto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
#include "FrameDecoder.hpp"
#include "FlowDecoder.hpp"


#include <vector>
#include <memory>
#include <atomic>
#include <mutex>
#include <thread>
#include <torch/torch.h>


class SequenceStorage;
Expand Down
18 changes: 18 additions & 0 deletions include/RewardModel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include <torch/torch.h>
#include <torch/nn/modules/rnn.h>

class RewardModelImpl : public torch::nn::Module {
public:
RewardModelImpl();
torch::Tensor forward(
torch::Tensor encodings,
torch::Tensor actions,
torch::Tensor rewards);
private:
const int64_t _inputSize;
const int64_t _hiddenSize;
torch::nn::LSTM _lstm;
};
TORCH_MODULE(RewardModel);
24 changes: 24 additions & 0 deletions include/RewardModelTrainer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "ActionConverter.hpp"
#include "RewardModel.hpp"

#include <torch/torch.h>

class SequenceStorage;

class RewardModelTrainer {
public:
RewardModelTrainer();
RewardModelTrainer(const RewardModelTrainer&) = delete;
RewardModelTrainer(RewardModelTrainer&&) = delete;
RewardModelTrainer& operator=(const RewardModelTrainer&) = delete;
RewardModelTrainer& operator=(RewardModelTrainer&&) = delete;

void train(SequenceStorage& storage);
private:
RewardModel _rewardModel;
torch::optim::Adam _optimizer;
float _learningRate{1e-3};
ActionConverter<float> _actionConverter;
};
10 changes: 10 additions & 0 deletions include/SequenceStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class SequenceStorage {
// Map encoding data to a torch tensor (BW)
const torch::Tensor mapEncodingData();

// Map rewards to a torch tensor (B)
const torch::Tensor mapRewards();

friend class SequenceStorage;

private:
Expand All @@ -58,6 +61,7 @@ class SequenceStorage {

float* const _frameData;
float* const _encodingData;

const SequenceStorage::Settings& _settings;
};

Expand Down Expand Up @@ -95,8 +99,14 @@ class SequenceStorage {
// Map encoding data to a torch tensor (LBW)
const torch::Tensor mapEncodingData();

// Map rewards to a torch tensor (LB)
const torch::Tensor mapRewards();

const Settings& settings() const noexcept;

// Reinitialize all data to default values (0 and such)
void reset();

private:
Settings _settings;
uint64_t _frameSize; // size of a frame in elements
Expand Down
6 changes: 6 additions & 0 deletions include/TensorUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,9 @@ INLINE void copyToTensor(const std::vector<T_Data>& vector, torch::Tensor& tenso
memcpy(tensor.data_ptr<T_Data>(), vector.data(), vector.size()*sizeof(T_Data));
}
}

INLINE void printTensor(const torch::Tensor& t, const std::string& msg = std::string()) {
printf("%s: %ld %ld %ld %ld\n",
msg.c_str(),
t.sizes()[0], t.sizes()[1], t.sizes()[2], t.sizes()[3]);
}
112 changes: 96 additions & 16 deletions src/App.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
//

#include "App.hpp"

#include "Constants.hpp"
#include <filesystem>
#include "gvizdoom/DoomGame.hpp"

#include <opencv2/highgui.hpp>

#include "Constants.hpp"


using namespace doot2;
using namespace gvizdoom;
using namespace torch;
using namespace torch::indexing;
namespace fs = std::filesystem;


App::App() :
_rnd (1507715517),
Expand All @@ -27,12 +32,14 @@ App::App() :
_quit (false),
_heatmapActionModule (HeatmapActionModule::Settings{256, 32.0f}),
_doorTraversalActionModule (false),
_sequenceStorage (SequenceStorage::Settings{batchSize, sequenceLength, true, false, frameWidth, frameHeight, ImageFormat::BGRA}),
_sequenceStorage (SequenceStorage::Settings{batchSize, sequenceLength, false, true, 0, 0, ImageFormat::BGRA, encodingLength}),
_positionPlot (1024, 1024, CV_32FC3, cv::Scalar(0.0f)),
_initPlayerPos (0.0f, 0.0f),
_frameId (0),
_batchEntryId (0),
_newPatchReady (false)
_newPatchReady (false),
_torchDevice (torch::cuda::is_available() ? kCUDA : kCPU),
_trainRewardModel (true)
{
auto& doomGame = DoomGame::instance();

Expand Down Expand Up @@ -71,6 +78,38 @@ App::App() :
// Setup ActionManager
_actionManager.addModule(&_doorTraversalActionModule);
_actionManager.addModule(&_heatmapActionModule);

// Load frame encoder
if (fs::exists(frameEncoderFilename)) {
printf("App: Loading frame encoder model from %s\n", frameEncoderFilename); // TODO logging
serialize::InputArchive inputArchive;
inputArchive.load_from(frameEncoderFilename);
_frameEncoder->load(inputArchive);
// Use the inference mode
_frameEncoder->eval();
}
else {
printf("No %s found. Initializing a new frame encoder model.\n", frameEncoderFilename); // TODO logging
}

_frameEncoder->to(_torchDevice);

// Load frame decoder
if (fs::exists(frameDecoderFilename)) {
printf("App: Loading frame encoder model from %s\n", frameDecoderFilename); // TODO logging
serialize::InputArchive inputArchive;
inputArchive.load_from(frameDecoderFilename);
_frameDecoder->load(inputArchive);
// Use the inference mode
_frameDecoder->eval();
}
else {
printf("No %s found. Initializing a new frame encoder model.\n", frameDecoderFilename); // TODO logging
}

_frameDecoder->to(_torchDevice);


}

App::~App()
Expand All @@ -91,9 +130,12 @@ void App::loop()

Vec2f playerPosScreen(0.0f, 0.0f);

size_t recordBeginFrameId = 768+_rnd()%512;
size_t recordBeginFrameId = 16+_rnd()%16;
size_t recordEndFrameId = recordBeginFrameId+64;

// BHWC
torch::Tensor pixelBuffer{torch::zeros({1, 480, 640, 4})};

while (!_quit) {
while(SDL_PollEvent(&event)) {
if (event.type == SDL_QUIT ||
Expand All @@ -116,7 +158,7 @@ void App::loop()
// Update the game state, restart if required
if (_frameId >= recordEndFrameId || doomGame.update(action)) {
nextMap();
recordBeginFrameId = 768+_rnd()%512;
recordBeginFrameId = 16+_rnd()%16;
recordEndFrameId = recordBeginFrameId+64;
continue;
}
Expand All @@ -136,10 +178,44 @@ void App::loop()
auto recordFrameId = _frameId - recordBeginFrameId;
auto batch = _sequenceStorage[recordFrameId];
batch.actions[_batchEntryId] = action;
Image<uint8_t> frame(doomGame.getScreenWidth(), doomGame.getScreenHeight(), ImageFormat::BGRA);
frame.copyFrom(doomGame.getPixelsBGRA());
convertImage(frame, batch.frames[_batchEntryId]);
batch.rewards[_batchEntryId] = 0.0; // TODO no rewards for now

// Convert the game frame from uint8 to float
const auto imageFormat{ImageFormat::BGRA};
Image<uint8_t> frameUint8(doomGame.getScreenWidth(), doomGame.getScreenHeight(), imageFormat);
Image<float> frameFloat(doomGame.getScreenWidth(), doomGame.getScreenHeight(), imageFormat);
frameUint8.copyFrom(doomGame.getPixelsBGRA());
convertImage(frameUint8, frameFloat);

// Copy the float frame to a torch::Tensor
const auto nPixels = doomGame.getScreenWidth() * doomGame.getScreenHeight() * getImageFormatNChannels(imageFormat);
copyToTensor(frameFloat.data(), nPixels, pixelBuffer);

// upload to GPU and permute to BCHW
torch::Tensor pixelBufferGpu = pixelBuffer.to(_torchDevice);
pixelBufferGpu = pixelBufferGpu.permute({0,3,1,2});

// encode
torch::Tensor encoding = _frameEncoder(pixelBufferGpu);

// Check sanity with decoder
#if CHECK_SANITY_DECODER
torch::Tensor decoding = _frameDecoder(encoding);
decoding = decoding.permute({0,2,3,1}).contiguous();

cv::Mat decodingOpencv(480, 640, CV_32FC4);
copyFromTensor(decoding.to(torch::kCPU), (float*)decodingOpencv.ptr<float>(0), 640*480*4);

cv::imshow("app-decoding", decodingOpencv);
#endif
// store encoding to the sequence storage
copyFromTensor(encoding.to(torch::kCPU), batch.encodings[_batchEntryId], encodingLength);

// Update relative player position
playerPosRelative(0) = doomGame.getGameState<GameState::PlayerPos>()(0) - _initPlayerPos(0);
playerPosRelative(1) = _initPlayerPos(1) - doomGame.getGameState<GameState::PlayerPos>()(1); // invert y

// Reward is negative heatmap value
batch.rewards[_batchEntryId] = -_heatmapActionModule.sample(playerPosRelative, true);
}

// Render screen
Expand Down Expand Up @@ -191,12 +267,16 @@ void App::loop()

// Train
if (_newPatchReady) {
// Create copy of the sequence storage
auto sequenceStorageCopy(_sequenceStorage);

printf("Training...\n");
_model.waitForTrainingFinish();
_model.trainAsync(std::move(sequenceStorageCopy));
if (_trainRewardModel) {
_modelReward.train(_sequenceStorage);
} else {
// Create copy of the sequence storage
auto sequenceStorageCopy(_sequenceStorage);

printf("Training...\n");
_modelEdec.waitForTrainingFinish();
_modelEdec.trainAsync(std::move(sequenceStorageCopy));
}
_newPatchReady = false;
}

Expand Down
3 changes: 2 additions & 1 deletion src/ModelProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
//

#include "ModelProto.hpp"

#include "Constants.hpp"
#include "SequenceStorage.hpp"

#include <opencv2/core/mat.hpp> // TODO temp
Expand All @@ -23,7 +25,6 @@ static constexpr double learningRate = 1.0e-3; // TODO
static constexpr int64_t nTrainingEpochs = 10;

using namespace doot2;

using namespace torch;
namespace tf = torch::nn::functional;
namespace fs = std::filesystem;
Expand Down
Loading