Skip to content

Commit

Permalink
Merge pull request #341 from ankith26/more-typing
Browse files Browse the repository at this point in the history
Typing: Add typehinting for `Rod` modules
  • Loading branch information
skim0119 authored May 20, 2024
2 parents 4d05804 + 856e901 commit 03c00f4
Show file tree
Hide file tree
Showing 22 changed files with 1,118 additions and 750 deletions.
32 changes: 22 additions & 10 deletions elastica/_calculus.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__doc__ = """ Quadrature and difference kernels """
from typing import Any, Union
import numpy as np
from numpy import zeros, empty
from numpy.typing import NDArray
from numba import njit
from elastica.reset_functions_for_block_structure._reset_ghost_vector_or_scalar import (
_reset_vector_ghost,
Expand All @@ -9,15 +11,17 @@


@functools.lru_cache(maxsize=2)
def _get_zero_array(dim, ndim):
def _get_zero_array(dim: int, ndim: int) -> Union[float, NDArray[np.floating], None]:
if ndim == 1:
return 0.0
if ndim == 2:
return np.zeros((dim, 1))

return None


@njit(cache=True)
def _trapezoidal(array_collection):
def _trapezoidal(array_collection: NDArray[np.floating]) -> NDArray[np.floating]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way
Expand Down Expand Up @@ -63,7 +67,9 @@ def _trapezoidal(array_collection):


@njit(cache=True)
def _trapezoidal_for_block_structure(array_collection, ghost_idx):
def _trapezoidal_for_block_structure(
array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer]
) -> NDArray[np.floating]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way. This form
specifically for the block structure implementation and there is a reset function call, to reset
Expand Down Expand Up @@ -115,7 +121,9 @@ def _trapezoidal_for_block_structure(array_collection, ghost_idx):


@njit(cache=True)
def _two_point_difference(array_collection):
def _two_point_difference(
array_collection: NDArray[np.floating],
) -> NDArray[np.floating]:
"""
This function does differentiation.
Expand Down Expand Up @@ -156,7 +164,9 @@ def _two_point_difference(array_collection):


@njit(cache=True)
def _two_point_difference_for_block_structure(array_collection, ghost_idx):
def _two_point_difference_for_block_structure(
array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer]
) -> NDArray[np.floating]:
"""
This function does the differentiation, for Cosserat rod model equations. This form
specifically for the block structure implementation and there is a reset function call, to
Expand Down Expand Up @@ -207,7 +217,7 @@ def _two_point_difference_for_block_structure(array_collection, ghost_idx):


@njit(cache=True)
def _difference(vector):
def _difference(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes difference between elements of a batch vector.
Expand Down Expand Up @@ -238,7 +248,7 @@ def _difference(vector):


@njit(cache=True)
def _average(vector):
def _average(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes the average between elements of a vector.
Expand Down Expand Up @@ -268,7 +278,9 @@ def _average(vector):


@njit(cache=True)
def _clip_array(input_array, vmin, vmax):
def _clip_array(
input_array: NDArray[np.floating], vmin: np.floating, vmax: np.floating
) -> NDArray[np.floating]:
"""
This function clips an array values
between user defined minimum and maximum
Expand Down Expand Up @@ -304,7 +316,7 @@ def _clip_array(input_array, vmin, vmax):


@njit(cache=True)
def _isnan_check(array):
def _isnan_check(array: NDArray[Any]) -> bool:
"""
This function checks if there is any nan inside the array.
If there is nan, it returns True boolean.
Expand All @@ -324,7 +336,7 @@ def _isnan_check(array):
Python version: 2.24 µs ± 96.1 ns per loop
This version: 479 ns ± 6.49 ns per loop
"""
return np.isnan(array).any()
return bool(np.isnan(array).any())


position_difference_kernel = _difference
Expand Down
213 changes: 107 additions & 106 deletions elastica/_contact_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,29 @@
)
import numba
import numpy as np
from numpy.typing import NDArray


@numba.njit(cache=True)
def _calculate_contact_forces_rod_cylinder(
x_collection_rod,
edge_collection_rod,
x_cylinder_center,
x_cylinder_tip,
edge_cylinder,
radii_sum,
length_sum,
internal_forces_rod,
external_forces_rod,
external_forces_cylinder,
external_torques_cylinder,
cylinder_director_collection,
velocity_rod,
velocity_cylinder,
contact_k,
contact_nu,
velocity_damping_coefficient,
friction_coefficient,
x_collection_rod: NDArray[np.floating],
edge_collection_rod: NDArray[np.floating],
x_cylinder_center: NDArray[np.floating],
x_cylinder_tip: NDArray[np.floating],
edge_cylinder: NDArray[np.floating],
radii_sum: NDArray[np.floating],
length_sum: NDArray[np.floating],
internal_forces_rod: NDArray[np.floating],
external_forces_rod: NDArray[np.floating],
external_forces_cylinder: NDArray[np.floating],
external_torques_cylinder: NDArray[np.floating],
cylinder_director_collection: NDArray[np.floating],
velocity_rod: NDArray[np.floating],
velocity_cylinder: NDArray[np.floating],
contact_k: np.floating,
contact_nu: np.floating,
velocity_damping_coefficient: np.floating,
friction_coefficient: np.floating,
) -> None:
# We already pass in only the first n_elem x
n_points = x_collection_rod.shape[1]
Expand Down Expand Up @@ -155,22 +156,22 @@ def _calculate_contact_forces_rod_cylinder(

@numba.njit(cache=True)
def _calculate_contact_forces_rod_rod(
x_collection_rod_one,
radius_rod_one,
length_rod_one,
tangent_rod_one,
velocity_rod_one,
internal_forces_rod_one,
external_forces_rod_one,
x_collection_rod_two,
radius_rod_two,
length_rod_two,
tangent_rod_two,
velocity_rod_two,
internal_forces_rod_two,
external_forces_rod_two,
contact_k,
contact_nu,
x_collection_rod_one: NDArray[np.floating],
radius_rod_one: NDArray[np.floating],
length_rod_one: NDArray[np.floating],
tangent_rod_one: NDArray[np.floating],
velocity_rod_one: NDArray[np.floating],
internal_forces_rod_one: NDArray[np.floating],
external_forces_rod_one: NDArray[np.floating],
x_collection_rod_two: NDArray[np.floating],
radius_rod_two: NDArray[np.floating],
length_rod_two: NDArray[np.floating],
tangent_rod_two: NDArray[np.floating],
velocity_rod_two: NDArray[np.floating],
internal_forces_rod_two: NDArray[np.floating],
external_forces_rod_two: NDArray[np.floating],
contact_k: np.floating,
contact_nu: np.floating,
) -> None:
# We already pass in only the first n_elem x
n_points_rod_one = x_collection_rod_one.shape[1]
Expand Down Expand Up @@ -272,14 +273,14 @@ def _calculate_contact_forces_rod_rod(

@numba.njit(cache=True)
def _calculate_contact_forces_self_rod(
x_collection_rod,
radius_rod,
length_rod,
tangent_rod,
velocity_rod,
external_forces_rod,
contact_k,
contact_nu,
x_collection_rod: NDArray[np.floating],
radius_rod: NDArray[np.floating],
length_rod: NDArray[np.floating],
tangent_rod: NDArray[np.floating],
velocity_rod: NDArray[np.floating],
external_forces_rod: NDArray[np.floating],
contact_k: np.floating,
contact_nu: np.floating,
) -> None:
# We already pass in only the first n_elem x
n_points_rod = x_collection_rod.shape[1]
Expand Down Expand Up @@ -360,24 +361,24 @@ def _calculate_contact_forces_self_rod(

@numba.njit(cache=True)
def _calculate_contact_forces_rod_sphere(
x_collection_rod,
edge_collection_rod,
x_sphere_center,
x_sphere_tip,
edge_sphere,
radii_sum,
length_sum,
internal_forces_rod,
external_forces_rod,
external_forces_sphere,
external_torques_sphere,
sphere_director_collection,
velocity_rod,
velocity_sphere,
contact_k,
contact_nu,
velocity_damping_coefficient,
friction_coefficient,
x_collection_rod: NDArray[np.floating],
edge_collection_rod: NDArray[np.floating],
x_sphere_center: NDArray[np.floating],
x_sphere_tip: NDArray[np.floating],
edge_sphere: NDArray[np.floating],
radii_sum: NDArray[np.floating],
length_sum: NDArray[np.floating],
internal_forces_rod: NDArray[np.floating],
external_forces_rod: NDArray[np.floating],
external_forces_sphere: NDArray[np.floating],
external_torques_sphere: NDArray[np.floating],
sphere_director_collection: NDArray[np.floating],
velocity_rod: NDArray[np.floating],
velocity_sphere: NDArray[np.floating],
contact_k: np.floating,
contact_nu: np.floating,
velocity_damping_coefficient: np.floating,
friction_coefficient: np.floating,
) -> None:
# We already pass in only the first n_elem x
n_points = x_collection_rod.shape[1]
Expand Down Expand Up @@ -486,18 +487,18 @@ def _calculate_contact_forces_rod_sphere(

@numba.njit(cache=True)
def _calculate_contact_forces_rod_plane(
plane_origin,
plane_normal,
surface_tol,
k,
nu,
radius,
mass,
position_collection,
velocity_collection,
internal_forces,
external_forces,
):
plane_origin: NDArray[np.floating],
plane_normal: NDArray[np.floating],
surface_tol: np.floating,
k: np.floating,
nu: np.floating,
radius: NDArray[np.floating],
mass: NDArray[np.floating],
position_collection: NDArray[np.floating],
velocity_collection: NDArray[np.floating],
internal_forces: NDArray[np.floating],
external_forces: NDArray[np.floating],
) -> tuple[NDArray[np.floating], NDArray[np.intp]]:
"""
This function computes the plane force response on the element, in the
case of contact. Contact model given in Eqn 4.8 Gazzola et. al. RSoS 2018 paper
Expand Down Expand Up @@ -571,30 +572,30 @@ def _calculate_contact_forces_rod_plane(

@numba.njit(cache=True)
def _calculate_contact_forces_rod_plane_with_anisotropic_friction(
plane_origin,
plane_normal,
surface_tol,
slip_velocity_tol,
k,
nu,
kinetic_mu_forward,
kinetic_mu_backward,
kinetic_mu_sideways,
static_mu_forward,
static_mu_backward,
static_mu_sideways,
radius,
mass,
tangents,
position_collection,
director_collection,
velocity_collection,
omega_collection,
internal_forces,
external_forces,
internal_torques,
external_torques,
):
plane_origin: NDArray[np.floating],
plane_normal: NDArray[np.floating],
surface_tol: np.floating,
slip_velocity_tol: np.floating,
k: np.floating,
nu: np.floating,
kinetic_mu_forward: np.floating,
kinetic_mu_backward: np.floating,
kinetic_mu_sideways: np.floating,
static_mu_forward: np.floating,
static_mu_backward: np.floating,
static_mu_sideways: np.floating,
radius: NDArray[np.floating],
mass: NDArray[np.floating],
tangents: NDArray[np.floating],
position_collection: NDArray[np.floating],
director_collection: NDArray[np.floating],
velocity_collection: NDArray[np.floating],
omega_collection: NDArray[np.floating],
internal_forces: NDArray[np.floating],
external_forces: NDArray[np.floating],
internal_torques: NDArray[np.floating],
external_torques: NDArray[np.floating],
) -> None:
(
plane_response_force_mag,
no_contact_point_idx,
Expand Down Expand Up @@ -784,16 +785,16 @@ def _calculate_contact_forces_rod_plane_with_anisotropic_friction(

@numba.njit(cache=True)
def _calculate_contact_forces_cylinder_plane(
plane_origin,
plane_normal,
surface_tol,
k,
nu,
length,
position_collection,
velocity_collection,
external_forces,
):
plane_origin: NDArray[np.floating],
plane_normal: NDArray[np.floating],
surface_tol: np.floating,
k: np.floating,
nu: np.floating,
length: NDArray[np.floating],
position_collection: NDArray[np.floating],
velocity_collection: NDArray[np.floating],
external_forces: NDArray[np.floating],
) -> tuple[NDArray[np.floating], NDArray[np.intp]]:

# Compute plane response force
# total_forces = system.internal_forces + system.external_forces
Expand Down
Loading

0 comments on commit 03c00f4

Please sign in to comment.