Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
aapatni committed Mar 25, 2023
1 parent 559bf3b commit 5e5b331
Show file tree
Hide file tree
Showing 6 changed files with 2 additions and 13 deletions.
2 changes: 0 additions & 2 deletions cross_view_transformer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,12 @@ def setup_data_module(cfg: DictConfig) -> DataModule:


def setup_viz(cfg: DictConfig) -> Callable:
# return None
return instantiate(cfg.visualization)


def setup_experiment(cfg: DictConfig) -> Tuple[ModelModule, DataModule, Callable]:
model_module = setup_model_module(cfg)
data_module = setup_data_module(cfg)
viz_fn = None
viz_fn = setup_viz(cfg)

return model_module, data_module, viz_fn
Expand Down
2 changes: 1 addition & 1 deletion cross_view_transformer/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_split(self, split, loader=True, shuffle=False):
if loader_config['num_workers'] == 0:
loader_config['prefetch_factor'] = 2

return torch.utils.data.DataLoader(dataset, shuffle=shuffle, num_workers=1)
return torch.utils.data.DataLoader(dataset, shuffle=shuffle, **loader_config)

def train_dataloader(self, shuffle=True):
return self.get_split('train', loader=True, shuffle=shuffle)
Expand Down
4 changes: 1 addition & 3 deletions cross_view_transformer/data/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .common import INTERPOLATION, get_view_matrix, get_pose, get_split
from .transforms import Sample, SaveDataTransform

import ipdb

STATIC = ['lane', 'road_segment']
DIVIDER = ['road_divider', 'lane_divider']
Expand Down Expand Up @@ -209,8 +208,7 @@ def get_dynamic_objects(self, sample, annotations):
visibility = np.full((h, w), 255, dtype=np.uint8)

coords = np.stack(np.meshgrid(np.arange(w), np.arange(h)), -1).astype(np.float32)
# import ipdb
# ipdb.sset_trace()

for ann, p in zip(annotations, self.convert_to_box(sample, annotations)):
box = p[:2, :4]
center = p[:2, 4]
Expand Down
2 changes: 0 additions & 2 deletions cross_view_transformer/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def __init__(

def forward(self, pred, batch):
pred = pred['center']
# import pdb
# pdb.set_trace()
label = batch['center']
loss = super().forward(pred, label)

Expand Down
3 changes: 0 additions & 3 deletions cross_view_transformer/model/cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,3 @@ def forward(self, batch):
z = self.to_logits(y)

return {k: z[:, start:stop] for k, (start, stop) in self.outputs.items()}

# model = CrossViewTransformer()
# num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
2 changes: 0 additions & 2 deletions cross_view_transformer/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,6 @@ def __init__(
self.layers = nn.ModuleList(layers)

def forward(self, batch):
# print(batch.keys())
# print(batch)
b, n, _, _, _ = batch['image'].shape

image = batch['image'].flatten(0, 1) # b n c h w
Expand Down

0 comments on commit 5e5b331

Please sign in to comment.