Skip to content

Commit

Permalink
add GPU number and lazy load img to GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
sword4869 committed Nov 8, 2023
1 parent 2eee0e2 commit dfd866c
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 6 deletions.
2 changes: 2 additions & 0 deletions arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(self, parser, sentinel=False):
self._images = "images"
self._resolution = -1
self._white_background = False
self._device_number = 0
self._lazy_load = False
self.data_device = "cuda"
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)
Expand Down
2 changes: 1 addition & 1 deletion render.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam
print("Rendering " + args.model_path)

# Initialize system state (RNG)
safe_state(args.quiet)
safe_state(args.quiet, args.device_number)

render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
5 changes: 4 additions & 1 deletion scene/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", lazy_load=False
):
super(Camera, self).__init__()

Expand All @@ -36,6 +36,9 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")

if lazy_load:
self.data_device = torch.device("cpu")

self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
print("Optimizing " + args.model_path)

# Initialize system state (RNG)
safe_state(args.quiet)
safe_state(args.quiet, args.device_number)

# Start GUI server, configure and run training
network_gui.init(args.ip, args.port)
Expand Down
2 changes: 1 addition & 1 deletion utils/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
image_name=cam_info.image_name, uid=id, data_device=args.data_device, lazy_load=args.lazy_load)

def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = []
Expand Down
4 changes: 2 additions & 2 deletions utils/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def build_scaling_rotation(s, r):
L = R @ L
return L

def safe_state(silent):
def safe_state(silent, device_number):
old_f = sys.stdout
class F:
def __init__(self, silent):
Expand All @@ -130,4 +130,4 @@ def flush(self):
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))
torch.cuda.set_device(torch.device(f"cuda:{device_number}"))

0 comments on commit dfd866c

Please sign in to comment.