diff --git a/iRPE/DETR-with-iRPE/README.md b/iRPE/DETR-with-iRPE/README.md index 654ade04..8e3405ce 100644 --- a/iRPE/DETR-with-iRPE/README.md +++ b/iRPE/DETR-with-iRPE/README.md @@ -37,7 +37,7 @@ pip install -r ./requirements.txt Although iRPE can be implemented by PyTorch native functions, the backward speed of PyTorch index function is very slow. We implement CUDA operators for more efficient training and recommend to build it. `nvcc` is necessary to build CUDA operators. ```bash -cd rpe_ops/ +cd models/rpe_attention/rpe_ops/ python setup.py install --user ``` @@ -132,12 +132,12 @@ If we want a image relative position encoding with contextual product shared-hea ## Training - Train a DETR-ResNet50 with iRPE (contextual product shared-head `9 x 9` buckets) for **150 epochs**: ```bash -python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --lr_drop 100 --epochs 150 --coco_path ./coco_data --enc_rpe2d rpe-2.0-product-ctx-1-k --output_dir ./output' +torchrun --nproc_per_node=8 main.py --lr_drop 100 --epochs 150 --coco_path ./coco_data --enc_rpe2d rpe-2.0-product-ctx-1-k --output_dir ./output' ``` - Train a DETR-ResNet50 with iRPE (contextual product shared-head `9 x 9` buckets) for **300 epochs**: ```bash -python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --lr_drop 200 --epochs 300 --coco_path ./coco_data --enc_rpe2d rpe-2.0-product-ctx-1-k --output_dir ./output' +torchrun --nproc_per_node=8 main.py --lr_drop 200 --epochs 300 --coco_path ./coco_data --enc_rpe2d rpe-2.0-product-ctx-1-k --output_dir ./output' ``` where `--nproc_per_node 8` means using 8 GPUs to train the model. `/coco_data` is the dataset folder, and `./output` is the model checkpoint folder. @@ -145,7 +145,7 @@ where `--nproc_per_node 8` means using 8 GPUs to train the model. `/coco_data` i ## Evaluation The step is similar to training. Add the checkpoint path and the flag `--eval --resume `. ```bash -python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --lr_drop 100 --epochs 150 --coco_path ./coco_data --enc_rpe2d rpe-2.0-product-ctx-1-k --output_dir ./output --eval --resume rpe-2.0-product-ctx-1-k.pth' +torchrun --nproc_per_node=8 main.py --lr_drop 100 --epochs 150 --coco_path ./coco_data --enc_rpe2d rpe-2.0-product-ctx-1-k --output_dir ./output --eval --resume rpe-2.0-product-ctx-1-k.pth' ``` ## Code Structure @@ -157,7 +157,7 @@ File | Description [`models/rpe_attention/irpe.py`](./models/rpe_attention/irpe.py) | The implementation of image relative position encoding [`models/rpe_attention/multi_head_attention.py`](./models/rpe_attention/multi_head_attention.py) | The nn.Module `MultiheadAttention` with iRPE [`models/rpe_attention/rpe_attention_function.py`](./models/rpe_attention/rpe_attention_function.py) | The function `rpe_multi_head_attention_forward` with iRPE -[`rpe_ops`](./rpe_ops) | The CUDA implementation of iRPE operators for efficient training +[`models/rpe_attention/rpe_ops`](./models/rpe_attention/rpe_ops) | The CUDA implementation of iRPE operators for efficient training # Citing iRPE If this project is helpful for you, please cite it. Thank you! : ) diff --git a/iRPE/DETR-with-iRPE/models/rpe_attention/irpe.py b/iRPE/DETR-with-iRPE/models/rpe_attention/irpe.py index 85e7d5bf..065148fd 100644 --- a/iRPE/DETR-with-iRPE/models/rpe_attention/irpe.py +++ b/iRPE/DETR-with-iRPE/models/rpe_attention/irpe.py @@ -1,17 +1,21 @@ """The implementation of iRPE (image relative position encoding).""" from easydict import EasyDict as edict +import os +import sys import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +sys.path.append(os.path.dirname(__file__)) try: from rpe_ops.rpe_index import RPEIndexFunction -except ImportError: +except ImportError as e: RPEIndexFunction = None import warnings RED_STR = "\033[91m{}\033[00m" - warnings.warn(RED_STR.format("[WARNING] The module `rpe_ops` is not built. \ + warnings.warn(RED_STR.format("[WARNING] {e}. \ +The module `rpe_ops` is not built. \ For better training performance, please build `rpe_ops`."),) diff --git a/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/README.md b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/README.md new file mode 100644 index 00000000..e2952ef8 --- /dev/null +++ b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/README.md @@ -0,0 +1,35 @@ +# 2D RPE Operators + +## Build iRPE operators implemented by CUDA. +Although iRPE can be implemented by PyTorch native functions, the backward speed of PyTorch index function is very slow. We implement CUDA operators for more efficient training and recommend to build it. `nvcc` is necessary to build CUDA operators. +```bash +cd rpe_ops/ +python setup.py install --user +``` + +## rpe\_index +The function [`rpe_index`](./rpe_index.py#L5) is equal to +```python +def rpe_index(input, index): + '''Y[b, h, i, j] = input[b, h, i, index[i, j]] + + Parameters + ---------- + input: torch.Tensor, float32 + The shape is (B, H, L_query, num_buckets) + index: torch.Tensor, int32 + The shape is (L_query, L_key) + + where B is the batch size, and H is the number of attention heads. + + Returns + ------- + Y: torch.Tensor, float32 + The shape is (B, H, L_query, L_key) + ''' + L_query, L_key = index.shape + num_buckets = input.size(-1) + B = len(input) + offset = torch.arange(0, L_query * num_buckets, num_buckets).view(-1, 1) + return input.flatten(2)[:, :, (index + offset).flatten()].view(B, -1, L_query, L_key) +``` diff --git a/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/rpe_index.cpp b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/rpe_index.cpp new file mode 100644 index 00000000..766142bd --- /dev/null +++ b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/rpe_index.cpp @@ -0,0 +1,142 @@ +#include + +#include +#include + +using index_t = int; + +at::Tensor rpe_index_forward_cpu(torch::Tensor input, torch::Tensor index) { + /* + - Inputs + input: float32 (B, H, L_query, num_buckets) + index: index_t (L_query, L_key) + - Outputs + Y: float32 (B, H, L_query, L_key) + */ + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(index.device().is_cpu(), "index must be a CPU tensor"); + AT_ASSERTM(input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + const index_t B = input.size(0); + const index_t H = input.size(1); + const index_t num_buckets = input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + at::Tensor Y = at::empty({B, H, L_query, L_key}, input.options()); + auto input_ = input.contiguous(); + auto index_ = index.contiguous(); + const index_t grain_size = 3000; + const index_t numel = Y.numel(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "rpe_index_forward_cpu", [&] { + const scalar_t *p_input = input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + scalar_t *p_Y = Y.data_ptr(); + at::parallel_for(0, numel, grain_size, [&](index_t begin, index_t end) { + /* + // we optimize the following function to + // reduce the number of operators, namely divide and multiply. + for (index_t i = begin; i < end; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + */ + + index_t aligned_begin = (begin + L_qk - 1) / L_qk * L_qk; + if (aligned_begin > end) aligned_begin = end; + index_t aligned_end = end / L_qk * L_qk; + for (index_t i = begin; i < aligned_begin; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + + // [aligned_begin, aligned_end) + // where aligned_begin % L_qk == 0, aligned_end % L_qk == 0 + index_t base = aligned_begin / L_key * num_buckets; + const index_t base_end = aligned_end / L_key * num_buckets; + index_t i = aligned_begin; + while (base < base_end) { + for (index_t q = 0, j = 0; q < L_query; ++q) { + for (index_t k = 0; k < L_key; ++k) { + p_Y[i++] = p_input[base + p_index[j++]]; + } + base += num_buckets; + } + } + + for (index_t i = aligned_end; i < end; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + }); + }); + return Y; +} + +template +inline scalar_t cpuAtomicAdd(scalar_t *address, const scalar_t val) { +#pragma omp critical + *address += val; + return *address; +} + +void rpe_index_backward_cpu(torch::Tensor grad_input, torch::Tensor grad_output, + torch::Tensor index) { + /* + - Inputs + grad_output: float32 (B, H, L_query, L_key) + index: index_t (L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + */ + AT_ASSERTM(grad_input.device().is_cpu(), "grad_input must be a CPU tensor"); + AT_ASSERTM(grad_output.device().is_cpu(), "grad_output must be a CPU tensor"); + AT_ASSERTM(index.device().is_cpu(), "grad_index must be a CPU tensor"); + AT_ASSERTM(grad_input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(grad_output.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + + const index_t num_buckets = grad_input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + + auto grad_input_ = grad_input.contiguous(); + auto grad_output_ = grad_output.contiguous(); + auto index_ = index.contiguous(); + + const index_t grain_size = 3000; + const index_t numel = grad_output.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_input.scalar_type(), "rpe_index_backward_atomic_cpu", [&] { + scalar_t *p_grad_input = grad_input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + const scalar_t *p_grad_output = grad_output_.data_ptr(); + at::parallel_for(0, numel, grain_size, [&](index_t begin, index_t end) { + for (index_t i = begin; i < end; ++i) { + const index_t input_i = i / L_key * num_buckets + p_index[i % L_qk]; + const scalar_t v = p_grad_output[i]; + cpuAtomicAdd(p_grad_input + input_i, v); + } + }); + }); +} + +std::string version() { + return "1.2.0"; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("version", &version, "The version of the package `rpe_index_cpp`"); + m.def("forward_cpu", &rpe_index_forward_cpu, "2D RPE Index Forward (CPU)"); + m.def("backward_cpu", &rpe_index_backward_cpu, "2D RPE Index Backward (CPU)"); + +#if defined(WITH_CUDA) + at::Tensor rpe_index_forward_gpu(torch::Tensor input, torch::Tensor index); + void rpe_index_backward_gpu(torch::Tensor grad_input, + torch::Tensor grad_output, torch::Tensor index); + m.def("forward_gpu", &rpe_index_forward_gpu, "2D RPE Index Forward (GPU)"); + m.def("backward_gpu", &rpe_index_backward_gpu, "2D RPE Index Backward (GPU)"); +#endif +} diff --git a/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/rpe_index.py b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/rpe_index.py new file mode 100644 index 00000000..1d915e1a --- /dev/null +++ b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/rpe_index.py @@ -0,0 +1,100 @@ +import torch +import rpe_index_cpp + + +EXPECTED_VERSION = "1.2.0" +assert rpe_index_cpp.version() == EXPECTED_VERSION, \ + f"""Unmatched `rpe_index_cpp` version: {rpe_index_cpp.version()}, expected version: {EXPECTED_VERSION} +Please re-build the package `rpe_ops`.""" + + +class RPEIndexFunction(torch.autograd.Function): + '''Y[b, h, i, j] = input[b, h, i, index[i, j]]''' + @staticmethod + def forward(ctx, input, index): + ''' + Y[b, h, i, j] = input[b, h, i, index[i, j]] + + Parameters + ---------- + input: torch.Tensor, float32 + The shape is (B, H, L_query, num_buckets) + index: torch.Tensor, int32 + The shape is (L_query, L_key) + + where B is the batch size, and H is the number of attention heads. + + Returns + ------- + Y: torch.Tensor, float32 + The shape is (B, H, L_query, L_key) + ''' + + num_buckets = input.size(-1) + ctx.save_for_backward(index) + ctx.input_shape = input.shape + forward_fn = rpe_index_cpp.forward_cpu if \ + input.device.type == 'cpu' else rpe_index_cpp.forward_gpu + output = forward_fn(input, index) + return output + + @staticmethod + def backward(ctx, grad_output): + ''' + - Inputs + grad_output: float32 (B, H, L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + ''' + index = ctx.saved_tensors[0] + if ctx.needs_input_grad[0]: + grad_input = grad_output.new_zeros(ctx.input_shape) + backward_fn = rpe_index_cpp.backward_cpu if \ + grad_output.device.type == 'cpu' else rpe_index_cpp.backward_gpu + backward_fn(grad_input, grad_output, index) + return grad_input, None + return None, None + + +if __name__ == '__main__': + import numpy as np + import time + B = 128 + H = 32 + L_query = 50 + L_key = L_query + num_buckets = 50 + + x = torch.randn(B, H, L_query, num_buckets) + + index = torch.randint(low=0, high=num_buckets, size=(L_query, L_key)) + index = index.to(torch.int) + offset = torch.arange(0, L_query * num_buckets, num_buckets).view(-1, 1) + + def test(x, index, offset): + tic = time.time() + x1 = x.clone() + x1.requires_grad = True + x2 = x.clone() + x2.requires_grad = True + + y = RPEIndexFunction.apply(x1, index) + gt_y = x2.flatten(2)[:, :, (index + offset).flatten() + ].view(B, H, L_query, L_key) + + np.testing.assert_almost_equal( + gt_y.detach().cpu().numpy(), y.detach().cpu().numpy()) + + mask = torch.randn(gt_y.shape, device=x.device) + (gt_y * mask).sum().backward() + (y * mask).sum().backward() + + print("X1:", x1.grad.cpu().numpy().flatten().sum()) + print("X2:", x2.grad.cpu().numpy().flatten().sum()) + np.testing.assert_almost_equal( + x1.grad.cpu().numpy(), x2.grad.cpu().numpy(), decimal=5) + print("Test over", x.device) + print("Cost:", time.time() - tic) + test(x, index, offset) + if torch.cuda.is_available(): + test(x.cuda(), index.cuda(), offset.cuda()) diff --git a/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/rpe_index_cuda.cu b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/rpe_index_cuda.cu new file mode 100644 index 00000000..3ddb5315 --- /dev/null +++ b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/rpe_index_cuda.cu @@ -0,0 +1,140 @@ +#include +#include + +#include +#include + +using index_t = int; + +const int HIP_MAX_GRID_NUM = 65535; +const int HIP_MAX_NUM_THREADS = 512; + +inline int HIP_GET_NUM_THREADS(const int n) { + return std::min(HIP_MAX_NUM_THREADS, ((n + 31) / 32) * 32); +} + +inline int HIP_GET_BLOCKS(const int n, const int num_threads) { + return std::min(HIP_MAX_GRID_NUM, n + num_threads - 1) / num_threads; +} + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void rpe_index_forward_gpu_kernel( + index_t n, scalar_t *p_Y, const scalar_t *__restrict__ p_input, + const index_t *__restrict__ p_index, index_t num_buckets, index_t H, + index_t L_query, index_t L_key, index_t L_qk, index_t s0, index_t s1, + index_t s2, index_t s3) { + CUDA_KERNEL_LOOP(i, n) { + index_t gi = i / L_key; + const index_t qi = gi % L_query; + gi /= L_query; + const index_t hi = gi % H; + gi /= H; + const index_t bi = gi; + const index_t ind = bi * s0 + hi * s1 + qi * s2 + p_index[i % L_qk] * s3; + p_Y[i] = __ldg(&p_input[ind]); + } +} + +template +__global__ void rpe_index_backward_gpu_kernel( + index_t n, scalar_t *p_grad_input, const index_t *__restrict__ p_index, + const scalar_t *__restrict__ p_grad_output, index_t num_buckets, + index_t L_key, index_t L_qk) { + CUDA_KERNEL_LOOP(i, n) { + const index_t input_i = i / L_key * num_buckets + p_index[i % L_qk]; + const scalar_t v = p_grad_output[i]; + gpuAtomicAdd(p_grad_input + input_i, v); + } +} + +at::Tensor rpe_index_forward_gpu(torch::Tensor input, torch::Tensor index) { + /* + - Inputs + input: float32 (B, H, L_query, num_buckets) + index: index_t (L_query, L_key) + - Outputs + Y: float32 (B, H, L_query, L_key) + */ + AT_ASSERTM(input.device().is_cuda(), "input must be a GPU tensor"); + AT_ASSERTM(index.device().is_cuda(), "index must be a GPU tensor"); + AT_ASSERTM(input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + AT_ASSERTM(index.is_contiguous(), "index should be contiguous"); + const index_t B = input.size(0); + const index_t H = input.size(1); + const index_t num_buckets = input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + at::Tensor Y = at::empty({B, H, L_query, L_key}, input.options()); + const index_t numel = Y.numel(); + const at::IntArrayRef strides = input.strides(); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(numel); + const int blocks = HIP_GET_BLOCKS(numel, threadsPerBlock); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "rpe_index_forward_gpu", [&] { + const scalar_t *p_input = input.data_ptr(); + const index_t *p_index = index.data_ptr(); + scalar_t *p_Y = Y.data_ptr(); + rpe_index_forward_gpu_kernel<<>>( + numel, p_Y, p_input, p_index, num_buckets, H, L_query, L_key, L_qk, + strides[0], strides[1], strides[2], strides[3]); + }); + return Y; +} + +void rpe_index_backward_gpu(torch::Tensor grad_input, torch::Tensor grad_output, + torch::Tensor index) { + /* + - Inputs + grad_output: float32 (B, H, L_query, L_key) + index: index_t (L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + */ + AT_ASSERTM(grad_input.device().is_cuda(), "grad_input must be a GPU tensor"); + AT_ASSERTM(grad_output.device().is_cuda(), + "grad_output must be a GPU tensor"); + AT_ASSERTM(index.device().is_cuda(), "grad_index must be a GPU tensor"); + AT_ASSERTM(grad_input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(grad_output.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + + const index_t num_buckets = grad_input.size(3); + const index_t L_query = grad_output.size(2); + const index_t L_key = grad_output.size(3); + const index_t L_qk = L_query * L_key; + + auto grad_input_ = grad_input.contiguous(); + auto grad_output_ = grad_output.contiguous(); + auto index_ = index.contiguous(); + + const index_t numel = grad_output.numel(); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(numel); + const int blocks = HIP_GET_BLOCKS(numel, threadsPerBlock); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "rpe_index_backward_gpu", [&] { + scalar_t *p_grad_input = grad_input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + const scalar_t *p_grad_output = grad_output_.data_ptr(); + rpe_index_backward_gpu_kernel<<>>( + numel, p_grad_input, p_index, p_grad_output, num_buckets, L_key, + L_qk); + }); +} diff --git a/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/setup.py b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/setup.py new file mode 100644 index 00000000..4664b905 --- /dev/null +++ b/iRPE/DETR-with-iRPE/models/rpe_attention/rpe_ops/setup.py @@ -0,0 +1,25 @@ +"""Build iRPE (image RPE) Functions""" +from setuptools import setup, Extension +import torch +from torch.utils import cpp_extension + +ext_t = cpp_extension.CppExtension +ext_fnames = ['rpe_index.cpp'] +define_macros = [] +extra_compile_args = dict(cxx=['-fopenmp', '-O3'], + nvcc=['-O3']) + +if torch.cuda.is_available(): + ext_t = cpp_extension.CUDAExtension + ext_fnames.append('rpe_index_cuda.cu') + define_macros.append(('WITH_CUDA', None)) + +setup(name='rpe_index', + version="1.2.0", + ext_modules=[ext_t( + 'rpe_index_cpp', + ext_fnames, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + )], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/iRPE/DeiT-with-iRPE/README.md b/iRPE/DeiT-with-iRPE/README.md index 67a44d11..103e4f14 100644 --- a/iRPE/DeiT-with-iRPE/README.md +++ b/iRPE/DeiT-with-iRPE/README.md @@ -72,7 +72,7 @@ For example, we train DeiT-S with contextual product relative position encoding Run the following command: ```bash -python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model deit_small_patch16_224_ctx_product_50_shared_k --batch-size 128 --data-path ./ImageNet/ --output_dir ./outputs/ --load-tar +torchrun --nproc_per_node=8 main.py --model deit_small_patch16_224_ctx_product_50_shared_k --batch-size 128 --data-path ./ImageNet/ --output_dir ./outputs/ --load-tar ``` You can remove the flag `--load-tar` if storing images as individual files : ) @@ -80,7 +80,7 @@ You can remove the flag `--load-tar` if storing images as individual files : ) ## Evaluation The step is similar to training. Add `--eval --resume `. ```bash -python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model deit_small_patch16_224_ctx_product_50_shared_k --batch-size 128 --data-path ./ImageNet/ --output_dir ./outputs/ --load-tar --eval --resume deit_small_patch16_224_ctx_product_50_shared_k.pth +torchrun --nproc_per_node=8 main.py --model deit_small_patch16_224_ctx_product_50_shared_k --batch-size 128 --data-path ./ImageNet/ --output_dir ./outputs/ --load-tar --eval --resume deit_small_patch16_224_ctx_product_50_shared_k.pth ``` `--resume ` can be replaced by `--pretrained`, then the checkpoint will be downloaded automatically. The download directory is usually `$HOME/.cache/torch/hub/checkpoints`. diff --git a/iRPE/DeiT-with-iRPE/irpe.py b/iRPE/DeiT-with-iRPE/irpe.py index 85e7d5bf..065148fd 100644 --- a/iRPE/DeiT-with-iRPE/irpe.py +++ b/iRPE/DeiT-with-iRPE/irpe.py @@ -1,17 +1,21 @@ """The implementation of iRPE (image relative position encoding).""" from easydict import EasyDict as edict +import os +import sys import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +sys.path.append(os.path.dirname(__file__)) try: from rpe_ops.rpe_index import RPEIndexFunction -except ImportError: +except ImportError as e: RPEIndexFunction = None import warnings RED_STR = "\033[91m{}\033[00m" - warnings.warn(RED_STR.format("[WARNING] The module `rpe_ops` is not built. \ + warnings.warn(RED_STR.format("[WARNING] {e}. \ +The module `rpe_ops` is not built. \ For better training performance, please build `rpe_ops`."),)