Skip to content

Commit

Permalink
Fix bug with calculating n_joints
Browse files Browse the repository at this point in the history
  • Loading branch information
karfly authored Sep 1, 2020
1 parent 4ada88f commit acee6d2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def one_epoch(model, criterion, opt, config, dataloader, device, epoch, n_iters_
keypoints_3d_pred, heatmaps_pred, volumes_pred, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred = model(images_batch, proj_matricies_batch, batch)

batch_size, n_views, image_shape = images_batch.shape[0], images_batch.shape[1], tuple(images_batch.shape[3:])
n_joints = keypoints_3d_pred[0].shape[1]
n_joints = keypoints_3d_pred.shape[1]

keypoints_3d_binary_validity_gt = (keypoints_3d_validity_gt > 0.0).type(torch.float32)

Expand Down

0 comments on commit acee6d2

Please sign in to comment.