Skip to content

Commit

Permalink
Remove dead code comment
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Mar 19, 2024
1 parent 1cb8f5c commit 7bff1e5
Showing 1 changed file with 0 additions and 107 deletions.
107 changes: 0 additions & 107 deletions yolox-burn/src/model/boxes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,110 +139,3 @@ pub fn non_maximum_suppression(bboxes: &mut [Vec<BoundingBox>], threshold: f32)
bboxes_for_class.truncate(current_index);
}
}

// pub fn nms<B: Backend>(
// boxes: Tensor<B, 3>,
// scores: Tensor<B, 3>,
// iou_threshold: f64,
// score_threshold: f64,
// ) -> () {
// // Tensor<B, 2>
// let [_, num_boxes, num_classes] = scores.shape().dims;
// let device = boxes.device();
// let is_valid = Tensor::<B, 1, Bool>::from_data([true; 1], &device);
// // let values: Tensor<B, 1, _> = Tensor::from_data([1, 0, 3, 4, -1], &device);

// let x = Tensor::arange(0..6 as i64, &device).reshape(Shape::new([2, 3]));
// let mask = x.greater_elem(1);

// let selected_indices = boxes
// .iter_dim(0)
// .enumerate()
// // Per-batch NMS
// .map(|(batch_idx, candidate_boxes)| {
// // Per-class filtering
// (0..num_classes).map(|cls_idx| {
// scores.select(
// 0,
// Tensor::<B, 1, Int>::full(Shape::new([1]), batch_idx as i64, &device),
// );
// // [batch_size, num_boxes, 1]
// let scores_mask = scores.greater_equal_elem(score_threshold).any_dim(2);
// // 1. `non_zero` to get the mask's valid boxes indices
// // 2. `select` to return only valid boxes (by score threshold)
// // 3. `argsort` to get the indices ranked by classification score
// // NOTE: we could do it all on CPU (preds to_vec) to iterate and compare as done in
// // [candle](https://github.com/huggingface/candle/blob/main/candle-examples/examples/yolo-v3/main.rs#L38)
// // ([nms](https://github.com/huggingface/candle/blob/main/candle-transformers/src/object_detection.rs#L31))
// // but this is not so nice
// let box_idx = scores_mask
// .iter_dim(1)
// .filter_map(|box_idx| {
// // NOTE: bool tensors do not support `select` and `into_scalar()`
// if scores_mask
// .slice([batch_idx..batch_idx + 1])
// .any()
// .int()
// .into_scalar() as bool
// {
// Some(box_idx)
// } else {
// None
// }
// })
// .collect();
// // scores[batch_idx, ]
// // [batch_size, num_boxes, 1]
// let scores_max = scores.gather(dim, indices);
// })
// });
// }

// std::vector<SelectedIndex> selected_indices;
// std::vector<BoxInfoPtr> selected_boxes_inside_class;
// selected_boxes_inside_class.reserve(std::min<size_t>(static_cast<size_t>(max_output_boxes_per_class), pc.num_boxes_));

// for (int64_t batch_index = 0; batch_index < pc.num_batches_; ++batch_index) {
// for (int64_t class_index = 0; class_index < pc.num_classes_; ++class_index) {
// int64_t box_score_offset = (batch_index * pc.num_classes_ + class_index) * pc.num_boxes_;
// const float* batch_boxes = boxes_data + (batch_index * pc.num_boxes_ * 4);
// std::vector<BoxInfoPtr> candidate_boxes;
// candidate_boxes.reserve(pc.num_boxes_);

// // Filter by score_threshold_
// const auto* class_scores = scores_data + box_score_offset;
// if (pc.score_threshold_ != nullptr) {
// for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
// if (*class_scores > score_threshold) {
// candidate_boxes.emplace_back(*class_scores, box_index);
// }
// }
// } else {
// for (int64_t box_index = 0; box_index < pc.num_boxes_; ++box_index, ++class_scores) {
// candidate_boxes.emplace_back(*class_scores, box_index);
// }
// }
// std::priority_queue<BoxInfoPtr, std::vector<BoxInfoPtr>> sorted_boxes(std::less<BoxInfoPtr>(), std::move(candidate_boxes));

// selected_boxes_inside_class.clear();
// // Get the next box with top score, filter by iou_threshold
// while (!sorted_boxes.empty() && static_cast<int64_t>(selected_boxes_inside_class.size()) < max_output_boxes_per_class) {
// const BoxInfoPtr& next_top_score = sorted_boxes.top();

// bool selected = true;
// // Check with existing selected boxes for this class, suppress if exceed the IOU (Intersection Over Union) threshold
// for (const auto& selected_index : selected_boxes_inside_class) {
// if (SuppressByIOU(batch_boxes, next_top_score.index_, selected_index.index_, center_point_box, iou_threshold)) {
// selected = false;
// break;
// }
// }

// if (selected) {
// selected_boxes_inside_class.push_back(next_top_score);
// selected_indices.emplace_back(batch_index, class_index, next_top_score.index_);
// }
// sorted_boxes.pop();
// } // while
// } // for class_index
// } // for batch_index

0 comments on commit 7bff1e5

Please sign in to comment.