Skip to content

Commit

Permalink
fix(//cpp/api): Better inital condition for the dataloader iterator to
Browse files Browse the repository at this point in the history
address datarace issue

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Apr 24, 2020
1 parent 5c0d737 commit 8d22bdd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cpp/api/include/trtorch/ptq.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Int8Calibrator : Algorithm {
using Batch = typename DataLoader::super::BatchType;
public:
Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache)
: dataloader_(dataloader.get()), it_(dataloader_->begin()), cache_file_path_(cache_file_path), use_cache_(use_cache) {}
: dataloader_(dataloader.get()), it_(dataloader_->end()), cache_file_path_(cache_file_path), use_cache_(use_cache) {}

int getBatchSize() const override {
// HACK: TRTorch only uses explict batch sizing, INT8 Calibrator does not
Expand Down
7 changes: 7 additions & 0 deletions cpp/ptq/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ int main(int argc, const char* argv[]) {
auto execution_timer = timers::PreciseCPUTimer();
auto images = (*(*eval_dataloader).begin()).data.to(torch::kCUDA);

execution_timer.start();
mod.forward({images});
execution_timer.stop();
std::cout << "Latency of JIT model FP32 (Batch Size 32): " << execution_timer.milliseconds() << "ms" << std::endl;

execution_timer.reset();

execution_timer.start();
trt_mod.forward({images});
execution_timer.stop();
Expand Down

0 comments on commit 8d22bdd

Please sign in to comment.