Skip to content

Commit

Permalink
Add typing fixes. (#239)
Browse files Browse the repository at this point in the history
* Add typing fixes.

* Revert ignore comments.

* Fix rust lint.

* Fix rust lint.

* Update environment.

* Add libtiff.

* Update channels.

* Try torch assert fix.

* Update python versions.

* Fix 3.11 noop.

* Fix 3.11 ci.

* Remove excess dep.

* Remove excess dep.

* Fix constraint.

* Add pytorch channel.

* Remove redundant python req + uncomment code.
  • Loading branch information
benjaminrwilson authored Dec 12, 2023
1 parent 046f01f commit b2810d8
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 273 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
matrix:
os: [macos-latest, ubuntu-latest]
python_version:
["3.8", "3.9", "3.10"]
["3.9", "3.10", "3.11"]
defaults:
run:
shell: bash -l {0}
Expand Down
2 changes: 1 addition & 1 deletion conda/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ dependencies:
- tqdm
- universal_pathlib
- pip:
- git+https://github.com/JonathonLuiten/TrackEval.git
- git+https://github.com/JonathonLuiten/TrackEval.git
4 changes: 2 additions & 2 deletions rust/src/data_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl DataLoader {
pub fn get(&self, index: usize) -> Sweep {
let row = self.file_index.0.get_row(index).unwrap().0;
let (log_id, timestamp_ns) = (
row.get(0).unwrap().get_str().unwrap(),
row.first().unwrap().get_str().unwrap(),
row.get(1).unwrap().try_extract::<u64>().unwrap(),
);

Expand Down Expand Up @@ -313,7 +313,7 @@ impl DataLoader {
pub fn get_synchronized_images(&self, index: usize) -> Vec<TimeStampedImage> {
let row = self.file_index.0.get_row(index).unwrap().0;
let (log_id, _) = (
row.get(0).unwrap().get_str().unwrap(),
row.first().unwrap().get_str().unwrap(),
row.get(1).unwrap().try_extract::<u64>().unwrap(),
);

Expand Down
2 changes: 1 addition & 1 deletion src/av2/torch/structures/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def from_sweep_pair(cls, sweeps: Tuple[Sweep, Sweep]) -> Flow:
c1_SE3_c0 = c1.dst_SE3_object.compose(c0.dst_SE3_object.inverse())
flow[obj_mask] = (
torch.as_tensor(
c1_SE3_c0.transform_point_cloud(obj_pts.numpy()),
c1_SE3_c0.transform_point_cloud(obj_pts_npy),
dtype=torch.float32,
)
- obj_pts
Expand Down
6 changes: 3 additions & 3 deletions src/av2/utils/dense_grid_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ def interp_dense_grid_from_sparse(
"""
if interp_method not in ["linear", "nearest"]:
raise ValueError("Unknown interpolation method.")
if grid_img.dtype != values.dtype:
raise ValueError("Grid and values should be the same datatype.")

if points.shape[0] < MIN_REQUIRED_POINTS_SIMPLEX:
# return the empty grid, since we can't interpolate.
return grid_img

output_dtype = values.dtype

# get (x,y) tuples back
grid_coords = mesh_grid.get_mesh_grid_as_point_cloud(
min_x=0, max_x=grid_w - 1, min_y=0, max_y=grid_h - 1
Expand All @@ -65,4 +65,4 @@ def interp_dense_grid_from_sparse(
v = grid_coords[:, 1].astype(np.int32)
# Now index in at (y,x) locations
grid_img[v, u] = interp_vals
return grid_img.astype(output_dtype)
return grid_img
6 changes: 4 additions & 2 deletions src/av2/utils/synchronization_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,14 @@ def __init__(
self.per_log_cam_timestamps_index[log_id] = {}
for cam_name in list(RingCameras) + list(StereoCameras):
sensor_folder_wildcard = (
f"{dataset_dir}/{log_id}/sensors/cameras/{cam_name}/*.jpg"
f"{dataset_dir}/{log_id}/sensors/cameras/{cam_name.value}/*.jpg"
)
cam_timestamps = get_timestamps_from_sensor_folder(
sensor_folder_wildcard
)
self.per_log_cam_timestamps_index[log_id][cam_name] = cam_timestamps
self.per_log_cam_timestamps_index[log_id][
cam_name.value
] = cam_timestamps

sensor_folder_wildcard = f"{dataset_dir}/{log_id}/sensors/lidar/*.feather"
lidar_timestamps = get_timestamps_from_sensor_folder(sensor_folder_wildcard)
Expand Down
54 changes: 27 additions & 27 deletions tests/unit/datasets/sensor/test_sensor_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,33 @@

def _create_dummy_sensor_dataloader(log_id: str) -> SensorDataloader:
"""Create a dummy sensor dataloader."""
with Path(tempfile.TemporaryDirectory().name) as sensor_dataset_dir:
for sensor_name, timestamps_ms in SENSOR_TIMESTAMPS_MS_DICT.items():
for t in timestamps_ms:
if "ring" in sensor_name:
fpath = Path(
sensor_dataset_dir,
"dummy",
log_id,
"sensors",
"cameras",
sensor_name,
f"{int(t*1e6)}.jpg",
)
Path(fpath).parent.mkdir(exist_ok=True, parents=True)
fpath.open("w").close()
elif "lidar" in sensor_name:
fpath = Path(
sensor_dataset_dir,
"dummy",
log_id,
"sensors",
sensor_name,
f"{int(t*1e6)}.feather",
)
Path(fpath).parent.mkdir(exist_ok=True, parents=True)
fpath.open("w").close()
return SensorDataloader(dataset_dir=sensor_dataset_dir, with_cache=False)
sensor_dataset_dir = Path(tempfile.TemporaryDirectory().name)
for sensor_name, timestamps_ms in SENSOR_TIMESTAMPS_MS_DICT.items():
for t in timestamps_ms:
if "ring" in sensor_name:
fpath = Path(
sensor_dataset_dir,
"dummy",
log_id,
"sensors",
"cameras",
sensor_name,
f"{int(t*1e6)}.jpg",
)
Path(fpath).parent.mkdir(exist_ok=True, parents=True)
fpath.open("w").close()
elif "lidar" in sensor_name:
fpath = Path(
sensor_dataset_dir,
"dummy",
log_id,
"sensors",
sensor_name,
f"{int(t*1e6)}.feather",
)
Path(fpath).parent.mkdir(exist_ok=True, parents=True)
fpath.open("w").close()
return SensorDataloader(dataset_dir=sensor_dataset_dir, with_cache=False)


def test_sensor_data_loader_milliseconds() -> None:
Expand Down
106 changes: 52 additions & 54 deletions tests/unit/evaluation/scene_flow/test_sf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,67 +271,65 @@ def test_average_metrics() -> None:
Verify that the weighted average metric breakdown has the correct subset counts and values.
"""
with Path(tempfile.TemporaryDirectory().name) as test_dir:
test_dir.mkdir()
anno_dir = test_dir / "annotations"
anno_dir.mkdir()

pred_dir = test_dir / "predictions"
pred_dir.mkdir()

timestamp_ns_1 = 111111111111111111
timestamp_ns_2 = 111111111111111112

write_annotation(
gts_classes,
gts_close,
gts_dynamic,
gts_valid,
gts,
("log", timestamp_ns_1),
anno_dir,
)
write_annotation(
gts_classes,
gts_close,
gts_dynamic,
gts_valid,
gts,
("log", timestamp_ns_2),
anno_dir,
)
test_dir = Path(tempfile.TemporaryDirectory().name)
test_dir.mkdir()
anno_dir = test_dir / "annotations"
anno_dir.mkdir()

write_annotation(
gts_classes,
gts_close,
gts_dynamic,
gts_valid,
gts,
("log_missing", timestamp_ns_1),
anno_dir,
)
pred_dir = test_dir / "predictions"
pred_dir.mkdir()

write_output_file(dts_perfect, dts_dynamic, ("log", timestamp_ns_1), pred_dir)
write_output_file(dts_perfect, dts_dynamic, ("log", timestamp_ns_2), pred_dir)
timestamp_ns_1 = 111111111111111111
timestamp_ns_2 = 111111111111111112

results_df = eval.evaluate_directories(anno_dir, pred_dir)
write_annotation(
gts_classes,
gts_close,
gts_dynamic,
gts_valid,
gts,
("log", timestamp_ns_1),
anno_dir,
)
write_annotation(
gts_classes,
gts_close,
gts_dynamic,
gts_valid,
gts,
("log", timestamp_ns_2),
anno_dir,
)

write_annotation(
gts_classes,
gts_close,
gts_dynamic,
gts_valid,
gts,
("log_missing", timestamp_ns_1),
anno_dir,
)

assert len(results_df) == 16
assert results_df.Count.sum() == 18
write_output_file(dts_perfect, dts_dynamic, ("log", timestamp_ns_1), pred_dir)
write_output_file(dts_perfect, dts_dynamic, ("log", timestamp_ns_2), pred_dir)

assert np.allclose(results_df.EPE.mean(), 0.0)
assert np.allclose(results_df["ACCURACY_STRICT"].mean(), 1.0)
assert np.allclose(results_df["ACCURACY_RELAX"].mean(), 1.0)
assert np.allclose(results_df["ANGLE_ERROR"].mean(), 0.0)
results_df = eval.evaluate_directories(anno_dir, pred_dir)

assert results_df.TP.sum() == 2 * 2
assert results_df.TN.sum() == 2 * 2 # First true negative marked invalid
assert results_df.FP.sum() == 3 * 2
assert results_df.FN.sum() == 2 * 2
assert len(results_df) == 16
assert results_df.Count.sum() == 18

assert (
results_df.groupby(["Class", "Motion"]).Count.sum().Background.Dynamic == 0
)
assert np.allclose(results_df.EPE.mean(), 0.0)
assert np.allclose(results_df["ACCURACY_STRICT"].mean(), 1.0)
assert np.allclose(results_df["ACCURACY_RELAX"].mean(), 1.0)
assert np.allclose(results_df["ANGLE_ERROR"].mean(), 0.0)

assert results_df.TP.sum() == 2 * 2
assert results_df.TN.sum() == 2 * 2 # First true negative marked invalid
assert results_df.FP.sum() == 3 * 2
assert results_df.FN.sum() == 2 * 2

assert results_df.groupby(["Class", "Motion"]).Count.sum().Background.Dynamic == 0
results_dict = eval.results_to_dict(results_df)

assert len(results_dict) == 38
Expand Down
Loading

0 comments on commit b2810d8

Please sign in to comment.