Skip to content

Commit

Permalink
Update typing. (#142)
Browse files Browse the repository at this point in the history
* Update typing.

* Update utils.py

* Add pytorch channel.

* Updat CI.

* Fix new lines.

* Update deps.

* Update CI.

* Move module metadata.

* Fix env name.

* Update CI.
  • Loading branch information
benjaminrwilson authored Mar 29, 2023
1 parent 6149f8a commit ccd0a6e
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 11 deletions.
28 changes: 23 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,35 @@ jobs:
- uses: actions/checkout@v2
- uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: av2
environment-file: conda/environment.yml
mamba-version: "*"
miniforge-version: latest
python-version: ${{ matrix.python_version }}

- name: Install prerequisites.
- name: Build `av2`.
run: |
mamba install -y nox pip pyyaml
maturin develop --extras test
- name: Run nox with ${{ matrix.venv_backend }}.
run: | # Cache the environments (-r), fail on missing python interpreters (--error-on-missing-interpreters), set dependency resolver backend.
python -m nox -r --error-on-missing-interpreters --python ${{ matrix.python_version }} --default-venv-backend ${{ matrix.venv_backend }}
- name: Run `black`.
run: |
black .
- name: Run `isort`.
run: |
isort .
- name: Run `flake8`.
run: |
flake8 .
- name: Run `flake8`.
run: |
mypy .
- name: Run `pytest`.
run: |
pytest tests --cov src/av2
- name: Install pypa/build.
run: >-
Expand Down
1 change: 1 addition & 0 deletions conda/environment.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
name: av2
channels:
- pytorch
- conda-forge
dependencies:
- av
Expand Down
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,31 @@ dynamic = [
"version"
]

[project.optional-dependencies]
test = [
"black[jupyter]",
"flake8",
"flake8-annotations",
"flake8-black",
"flake8-bugbear",
"flake8-docstrings",
"flake8-import-order",
"darglint",
"isort",
"mypy",
"types-pyyaml",
"pytest",
"pytest-benchmark",
"pytest-cov",
]

[project.urls]
homepage = "argoverse.org"
repository = "https://github.com/argoverse/av2-api"

[tool.maturin]
features = ["pyo3/extension-module"]
module-name = "av2._r"

[tool.black]
line-length = 120
Expand Down
3 changes: 0 additions & 3 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ name = "av2"
version = "0.3.0"
edition = "2021"

[package.metadata.maturin]
name = "av2._r"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "av2"
Expand Down
6 changes: 3 additions & 3 deletions src/av2/torch/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def SE3_from_frame(frame: pd.DataFrame) -> Se3:
SE(3) object representing a (N,4,4) tensor of homogeneous transformations.
"""
quaternion_npy = frame.loc[0, list(QWXYZ_COLUMNS)].to_numpy().astype(float)
quat_wxyz = Quaternion(torch.as_tensor(quaternion_npy, dtype=torch.float32))
quat_wxyz = Quaternion(torch.as_tensor(quaternion_npy, dtype=torch.float32)[None])
rotation = So3(quat_wxyz)

translation_npy = frame.loc[0, list(TRANSLATION_COLUMNS)].to_numpy().astype(np.float32)
translation = torch.as_tensor(translation_npy, dtype=torch.float32)
dst_SE3_src = Se3(rotation[None], translation[None])
translation = torch.as_tensor(translation_npy, dtype=torch.float32)[None]
dst_SE3_src = Se3(rotation, translation)
dst_SE3_src.rotation._q.requires_grad_(False)
dst_SE3_src.translation.requires_grad_(False)
return dst_SE3_src

0 comments on commit ccd0a6e

Please sign in to comment.