diff --git a/.github/workflows/ci_v2.yaml b/.github/workflows/ci_v2.yaml new file mode 100644 index 0000000..be86ccf --- /dev/null +++ b/.github/workflows/ci_v2.yaml @@ -0,0 +1,92 @@ +name: L4CasADi v2 + +on: + push: + branches: [ v2 ] + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - uses: actions/checkout@v3 + with: + ref: 'v2' + - name: Run mypy + run: | + pip install mypy + mypy . --ignore-missing-imports --exclude examples + - name: Run flake8 + run: | + pip install flake8 + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + + tests: + runs-on: ${{ matrix.runs-on }} + needs: [ lint ] + timeout-minutes: 60 + strategy: + fail-fast: false + matrix: + runs-on: [ubuntu-latest, ubuntu-20.04, macos-latest] + + name: Tests on ${{ matrix.runs-on }} + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + ref: 'v2' + fetch-depth: 0 + + - name: Install Python + uses: actions/setup-python@v4 + with: + python-version: '>=3.9 <3.12' + + - name: Install L4CasADi + run: | + python -m pip install --upgrade pip + pip install torch --index-url https://download.pytorch.org/whl/cpu # Ensure CPU torch version + pip install -r requirements_build.txt + pip install . -v --no-build-isolation + + - name: Test with pytest + working-directory: ./tests + run: | + pip install pytest + pytest . + + test-on-aarch: + runs-on: ubuntu-latest + needs: [ lint ] + timeout-minutes: 60 + + name: Tests on aarch64 + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + ref: 'v2' + fetch-depth: 0 + - uses: uraimo/run-on-arch-action@v2 + name: Install and Test + with: + arch: aarch64 + distro: ubuntu20.04 + install: | + apt-get update + apt-get install -y --no-install-recommends python3.9 python3-pip python-is-python3 + pip install -U pip + apt-get install -y build-essential + + run: | + python -m pip install --upgrade pip + pip install torch --index-url https://download.pytorch.org/whl/cpu # Ensure CPU torch version + pip install -r requirements_build.txt + pip install . -v --no-build-isolation + # pip install pytest + # pytest . \ No newline at end of file diff --git a/README.md b/README.md index 4634a22..4385472 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,32 @@ arXiv: [Learning for CasADi: Data-driven Models in Numerical Optimization](https Talk: [Youtube](https://youtu.be/UYdkRnGr8eM?si=KEPcFEL9b7Vk2juI&t=3348) +## L4CasADi v2 Breaking Changes +After feedback from first use-cases L4CasADi v2 is designed with efficiency and simplicity in mind. + +This leads to the following breaking changes: + +- L4CasADi v2 can leverage PyTorch's batching capabilities for increased efficiency. When passing `batched=True`, +L4CasADi will understand the **first** input dimension as batch dimension. Thus, first and second-order derivatives +across elements of this dimension are assumed to be **sparse-zero**. To make use of this, instead of having multiple calls to a L4CasADi function in +your CasADi program, batch all inputs together and have a single L4CasADi call. An example of this can be seen when +comparing the [non-batched NeRF example](examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py) with the +[batched NeRF example](examples/nerf_trajectory_optimization/nerf_trajectory_optimization_batched.py) which is faster by +a factor of 5-10x. +- L4CasADi v2 will not change the shape of an input anymore as this was a source of confusion. The tensor forwarded to +the PyTorch model will resemble the **exact dimension** of the input variable by CasADi. You are responsible to make +sure that the PyTorch model handles a **two-dimensional** input matrix! Accordingly, the parameter +`model_expects_batch_dim` is removed. +- By default, L4CasADi v2 will not provide the Hessian, but the Jacobian of the Adjoint. This is sufficient for most +many optimization problems. However, you can explicitly request the generation of the Hessian by passing +`generate_jac_jac=True`. + +[//]: # (L4CasADi v2 can use the new **torch compile** functionality starting from PyTorch 2.4. By passing `scripting=False`. This +will lead to a longer compile time on first L4CasADi function call but will lead to a overall faster +execution. However, currently this functionality is experimental and not fully stable across all models. In the long +term there is a good chance this will become the default over scripting once the functionality is stabilized by the +Torch developers.) + ## Table of Content - [Projects using L4CasADi](#projects-using-l4casadi) - [Installation](#installation) @@ -205,14 +231,6 @@ https://github.com/Tim-Salzmann/l4casadi/blob/421de6ef408267eed0fd2519248b2152b6 ## FYIs -### Batch Dimension - -If your PyTorch model expects a batch dimension as first dimension (which most models do) you should pass -`model_expects_batch_dim=True` to the `L4CasADi` constructor. The `MX` input to the L4CasADi component is then expected -to be a vector of shape `[X, 1]`. L4CasADi will add a batch dimension of `1` automatically such that the input to the -underlying PyTorch model is of shape `[1, X]`. - ---- ### Warm Up diff --git a/examples/acados.py b/examples/acados.py index d58ad03..25ba6f8 100644 --- a/examples/acados.py +++ b/examples/acados.py @@ -128,7 +128,7 @@ def ocp(self): ocp.cost.W = np.array([[1.]]) # Trivial PyTorch index 0 - l4c_y_expr = l4c.L4CasADi(lambda x: x[0], name='y_expr', model_expects_batch_dim=False) + l4c_y_expr = l4c.L4CasADi(lambda x: x[0], name='y_expr') ocp.model.cost_y_expr = l4c_y_expr(x) ocp.model.cost_y_expr_e = x[0] diff --git a/examples/cpp_usage/generate.py b/examples/cpp_usage/generate.py index fe4f41c..9748edb 100644 --- a/examples/cpp_usage/generate.py +++ b/examples/cpp_usage/generate.py @@ -10,7 +10,7 @@ def forward(self, x): def generate(): - l4casadi_model = l4c.L4CasADi(TorchModel(), model_expects_batch_dim=False, name='sin_l4c') + l4casadi_model = l4c.L4CasADi(TorchModel(), name='sin_l4c') sym_in = cs.MX.sym('x', 1, 1) diff --git a/examples/fish_turbulent_flow/utils.py b/examples/fish_turbulent_flow/utils.py index b69f1c4..5f9f4dd 100644 --- a/examples/fish_turbulent_flow/utils.py +++ b/examples/fish_turbulent_flow/utils.py @@ -266,7 +266,7 @@ def import_l4casadi_model(device): x = cs.MX.sym("x", 3) xn = (x - meanX) / stdX - y = l4c.L4CasADi(model, name="turbulent_model", model_expects_batch_dim=True)(xn) + y = l4c.L4CasADi(model, name="turbulent_model", generate_adj1=False, generate_jac_jac=True)(xn.T).T y = y * stdY + meanY fU = cs.Function("fU", [x], [y[0]]) fV = cs.Function("fV", [x], [y[1]]) diff --git a/examples/matlab/export.py b/examples/matlab/export.py index 66549f8..6350d9d 100644 --- a/examples/matlab/export.py +++ b/examples/matlab/export.py @@ -10,7 +10,7 @@ def forward(self, x): def generate(): - l4casadi_model = l4c.L4CasADi(TorchModel(), model_expects_batch_dim=False, name='sin_l4c') + l4casadi_model = l4c.L4CasADi(TorchModel(), name='sin_l4c') sym_in = cs.MX.sym('x', 1, 1) l4casadi_model.build(sym_in) return diff --git a/examples/naive/readme.py b/examples/naive/readme.py index 7a9b1b6..131f95b 100644 --- a/examples/naive/readme.py +++ b/examples/naive/readme.py @@ -3,15 +3,15 @@ naive_mlp = l4c.naive.MultiLayerPerceptron(2, 128, 1, 2, 'Tanh') -l4c_model = l4c.L4CasADi(naive_mlp, model_expects_batch_dim=True) +l4c_model = l4c.L4CasADi(naive_mlp) -x_sym = cs.MX.sym('x', 2, 1) +x_sym = cs.MX.sym('x', 3, 2) y_sym = l4c_model(x_sym) f = cs.Function('y', [x_sym], [y_sym]) df = cs.Function('dy', [x_sym], [cs.jacobian(y_sym, x_sym)]) -ddf = cs.Function('ddy', [x_sym], [cs.hessian(y_sym, x_sym)[0]]) +ddf = cs.Function('ddy', [x_sym], [cs.jacobian(cs.jacobian(y_sym, x_sym), x_sym)]) -x = cs.DM([[0.], [2.]]) +x = cs.DM([[0., 2.], [0., 2.], [0., 2.]]) print(l4c_model(x)) print(f(x)) print(df(x)) diff --git a/examples/nerf_trajectory_optimization/density_nerf.py b/examples/nerf_trajectory_optimization/density_nerf.py index 7b16e06..bf1265d 100644 --- a/examples/nerf_trajectory_optimization/density_nerf.py +++ b/examples/nerf_trajectory_optimization/density_nerf.py @@ -46,7 +46,7 @@ def __init__(self): [nn.Linear(self.input_ch, W)] + [ nn.Linear(W, W) - if i not in self.skips + if i != 4 else nn.Linear(W + self.input_ch, W) for i in range(D - 1) ] @@ -60,7 +60,7 @@ def forward(self, x): for i, l in enumerate(self.pts_linears): h = self.pts_linears[i](h) h = F.relu(h) - if i in self.skips: + if i == 4: h = torch.cat([input_pts, h], -1) alpha = self.alpha_linear(h) diff --git a/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py index 3fd11f0..d4cd455 100644 --- a/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py +++ b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization.py @@ -10,6 +10,7 @@ CASE = 1 +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' def polynomial(n, n_eval): """Generates a symbolic function for a polynomial of degree n-1""" @@ -86,7 +87,7 @@ def trajectory_generator_solver(n, n_eval, L, warmup, threshold): f += cs.sum2(sk**2) # While having a maximum density (1.) of the NeRF as constraint. - lk = L(pk.T) + lk = L(pk) g = cs.horzcat(g, lk) lbg = cs.horzcat(lbg, cs.DM([-10e32]).T) ubg = cs.horzcat(ubg, cs.DM([threshold]).T) @@ -175,7 +176,7 @@ def main(): strict=False, ) # -------------------------- Create L4CasADi Module -------------------------- # - l4c_nerf = l4c.L4CasADi(model) + l4c_nerf = l4c.L4CasADi(model, scripting=False) # ---------------------------------------------------------------------------- # # NLP warmup # diff --git a/examples/nerf_trajectory_optimization/nerf_trajectory_optimization_batched.py b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization_batched.py new file mode 100644 index 0000000..c77a5c4 --- /dev/null +++ b/examples/nerf_trajectory_optimization/nerf_trajectory_optimization_batched.py @@ -0,0 +1,305 @@ +import os + +import casadi as cs +import matplotlib.pyplot as plt +import numpy as np +import torch + +import l4casadi as l4c +from density_nerf import DensityNeRF + +import os + +CASE = 1 + +DEVICE = 'cpu' + +if DEVICE != 'cpu': + torch.jit.set_fusion_strategy([('STATIC', 0)]) + + +def polynomial(n, n_eval): + """Generates a symbolic function for a polynomial of degree n-1""" + + # Polynomial symbolic function + coeffs = cs.MX.sym("coeffs", n, 3) + xi = cs.MX.sym("xi") + p = cs.MX.zeros(1, 3) + for k in range(n): + p += coeffs[k, :] * xi**k + + v = cs.jacobian(p, xi).T + a = cs.jacobian(v, xi).T + j = cs.jacobian(a, xi).T + s = cs.jacobian(j, xi).T + + f = cs.Function( + "f_poly", + [coeffs, xi], + [p, v, a, j, s], + ["coeffs", "xi"], + ["p", "v", "a", "j", "s"], + ) + + # evaluation function + p_eval = cs.MX.zeros(n_eval, 3) + v_eval = cs.MX.zeros(n_eval, 3) + a_eval = cs.MX.zeros(n_eval, 3) + j_eval = cs.MX.zeros(n_eval, 3) + s_eval = cs.MX.zeros(n_eval, 3) + xi_eval = np.linspace(0, 1, n_eval) + for k in range(n_eval): + p_eval[k, :] = f(coeffs=coeffs, xi=xi_eval[k])["p"] + v_eval[k, :] = f(coeffs=coeffs, xi=xi_eval[k])["v"] + a_eval[k, :] = f(coeffs=coeffs, xi=xi_eval[k])["a"] + j_eval[k, :] = f(coeffs=coeffs, xi=xi_eval[k])["j"] + s_eval[k, :] = f(coeffs=coeffs, xi=xi_eval[k])["s"] + + f_eval = cs.Function( + "f_eval", + [coeffs], + [p_eval, v_eval, a_eval, j_eval, s_eval], + ["coeffs"], + ["p", "v", "a", "j", "s"], + ) + + return f, f_eval + + +def trajectory_generator_solver(n, n_eval, L, warmup, threshold): + # Decision variables and parameters + f_poly, f_eval = polynomial(n, n_eval) + x = cs.MX.sym("x", n, 2) + X = cs.horzcat(cs.MX.zeros(n), x) + params = cs.MX.sym("params", n_eval, 3) + x_init = params[0, :] + x_end = params[-1, :] + + # Define NLP + f = 0 + g = [] + lbg = [] + ubg = [] + + ps = [] + ss = [] + for k in range(n_eval): + poly = f_poly(coeffs=X, xi=k / (n_eval - 1)) + ps.append(poly["p"]) + ss.append(poly["s"]) + + if not warmup: + ls = L(cs.vcat(ps)) + + for k in range(n_eval): + pk = ps[k] + sk = ss[k] + + if warmup: + f += cs.sum2((pk - params[k, :]) ** 2) + else: + # Optimize for minimum Snap. + f += cs.sum2(sk**2) + + # While having a maximum density (1.) of the NeRF as constraint. + lk = ls[k]#L(pk.T) + g = cs.horzcat(g, lk) + lbg = cs.horzcat(lbg, cs.DM([-10e32]).T) + ubg = cs.horzcat(ubg, cs.DM([threshold]).T) + + # Spatial bounds + g = cs.horzcat(g, pk[1:]) + lbg = cs.horzcat(lbg, cs.DM([-1, -0.3]).T) + ubg = cs.horzcat(ubg, cs.DM([1.2, 1.0]).T) + + # Initial and final states + eps = 0 + for key, init, end in zip( + ["p"], + [x_init], + [x_end], + ): + g = cs.horzcat(g, f_poly(coeffs=X, xi=0)[key] - init) + lbg = cs.horzcat(lbg, -cs.DM([eps, eps, eps]).T) + ubg = cs.horzcat(ubg, cs.DM([eps, eps, eps]).T) + + g = cs.horzcat(g, f_poly(coeffs=X, xi=1)[key] - end) + lbg = cs.horzcat(lbg, -cs.DM([eps, eps, eps]).T) + ubg = cs.horzcat(ubg, cs.DM([eps, eps, eps]).T) + + # Generate solver + x_nlp = cs.reshape(x, n * 2, 1) + p_nlp = cs.reshape(params, n_eval * 3, 1) + nlp_dict = { + "x": x_nlp, + "f": f, + "g": g, + "p": p_nlp, + } + + if warmup: + nlp_opts = { + "ipopt.linear_solver": "mumps", + "ipopt.sb": "yes", + "ipopt.max_iter": 100, + "ipopt.print_level": 5, + "print_time": False, + } + else: + nlp_opts = { + # High barrier parameter to adhere to warmstart. + "ipopt.mu_init": 1e-4, + "ipopt.barrier_tol_factor": 1e6, + + "ipopt.linear_solver": "mumps", + "ipopt.sb": "yes", + "ipopt.max_iter": 100, + "ipopt.print_level": 5, + "print_time": False, + } + + nlp_solver = cs.nlpsol("nerf_trajectory_optimizer", "ipopt", nlp_dict, nlp_opts) + + solver = {"solver": nlp_solver, "lbg": lbg, "ubg": ubg} + + return solver + + +def main(): + n = 9 + n_eval = 150 + optimization_threshold = 1. + viz_threshold = 10. + + if CASE == 1: # case 1 + p_start = np.array([0.0, -0.8, -0.2]) + p_goal = np.array([-0.0, 1.2, 0.8]) + elif CASE == 2: # case 2 + p_start = np.array([0.0, -0.8, -0.2]) + p_goal = np.array([-0.0, 1.2, -0.2]) + elif CASE == 3: # case 3 + p_start = np.array([0.0, -1, 1]) + p_goal = np.array([-0.0, 1.2, -0.2]) + else: + raise ValueError("Invalid case.") + + # --------------------------------- Load NERF -------------------------------- # + model = DensityNeRF() + model_path = os.path.join(os.path.dirname(__file__), "nerf_model.tar") + model.load_state_dict( + torch.load(model_path, map_location="cpu")["network_fn_state_dict"], + strict=False, + ) + # -------------------------- Create L4CasADi Module -------------------------- # + l4c_nerf = l4c.L4CasADi(model, device=DEVICE, batched=True) + + # ---------------------------------------------------------------------------- # + # NLP warmup # + # ---------------------------------------------------------------------------- # + + # --------------------------- Piecewise linear path -------------------------- # + + if CASE == 1: + points = np.array( + [[0, -0.8, -0.2], [0, -0.5, 0.4], [0, 0, 0.8], [0, 0.75, 0.3], [0, 1.2, 0.8]] + ) + elif CASE == 2: + points = np.array( + [[0, -0.8, -0.2], [0, -0.5, 0.4], [0, 0, 0.8], [0, 0.75, 0.4], [0, 1.2, -0.2]] + ) + elif CASE == 3: + points = np.array( + [[0.0, -1, 1], [0, -0.85, 0.4], [0, 0, 0.7], [0, 0.75, 0.45], [0, 1.2, -0.2]] + ) + else: + raise ValueError("Invalid case") + + dists = np.linalg.norm(np.diff(points, axis=0), axis=1) + n_eval_points = np.squeeze(dists / np.sum(dists) * n_eval).astype(int) + if np.sum(n_eval_points) != n_eval: + n_eval_points[-1] += n_eval - np.sum(n_eval_points) + piecewise_points = np.zeros((n_eval, 3)) + for k in range(len(points) - 1): + piecewise_points[ + np.sum(n_eval_points[:k]) : np.sum(n_eval_points[: k + 1]), : + ] = np.linspace(points[k], points[k + 1], n_eval_points[k] + 1)[:-1, :] + + # --------------------------------- Solve NLP -------------------------------- # + # Load solver + nlp_warm = trajectory_generator_solver(n, n_eval, l4c_nerf, warmup=True, threshold=optimization_threshold) + + # solve nlp + params_flat = piecewise_points.T.flatten() # update nlp to take this as input! + sol = nlp_warm["solver"](p=params_flat, lbg=nlp_warm["lbg"], ubg=nlp_warm["ubg"]) + + # --------------------------------- Evaluate --------------------------------- # + # Extract and evaluate solution + coeffs_warm = np.squeeze(sol["x"]).reshape(2, n).T + coeffs_warm = np.hstack([np.zeros((n, 1)), coeffs_warm]) + _, f_eval = polynomial(n, n_eval) + + # ---------------------------------------------------------------------------- # + # Collision free NLP # + # ---------------------------------------------------------------------------- # + # Load solver + nlp = trajectory_generator_solver(n, n_eval, l4c_nerf, warmup=False, threshold=optimization_threshold) + + # Solve nlp + x_init = coeffs_warm[:, 1:].T.flatten() + import time + t = time.time() + sol = nlp["solver"](x0=x_init, p=params_flat, lbg=nlp["lbg"], ubg=nlp["ubg"]) + print(time.time()-t) + + # --------------------------------- Evaluate --------------------------------- # + # Extract and evaluate solution + coeffs_sol = np.squeeze(sol["x"]).reshape(2, n).T + coeffs_sol = np.hstack([np.zeros((n, 1)), coeffs_sol]) + + _, f_eval = polynomial(n, n_eval) + p_eval = np.squeeze(f_eval(coeffs=coeffs_sol)["p"]) + + p_sol = p_eval.copy() + + # ---------------------------------------------------------------------------- # + # Visualize # + # ---------------------------------------------------------------------------- # + + meshgrid = torch.meshgrid( + torch.linspace(0, 0, 1), + torch.linspace(-1.0, 1.2, 200), + torch.linspace(-0.5, 1, 200), + indexing='ij' + ) + + points = torch.stack(meshgrid, dim=-1).reshape(-1, 3).to(DEVICE) + with torch.no_grad(): + density = model(points).detach()[..., 0].cpu().numpy() + points = points.cpu().numpy() + + with torch.no_grad(): + density_sol = model(torch.tensor(p_sol, dtype=torch.float32).to(DEVICE)).detach().cpu()[..., 0] + + print(f"Maximum Density in Solution: {density_sol.max()} < Threshold {optimization_threshold:.2f}") + + ax = plt.figure().add_subplot(111) + ax.plot(p_sol[:, 1], p_sol[:, 2], "-", color=(0.8, 0.12, 0.12), linewidth=3) + g = ax.scatter( + points[density > viz_threshold][:, 1], + points[density > viz_threshold][:, 2], + cmap="jet", + c=density[density > viz_threshold], + s=0.5, + ) + cb = plt.colorbar(g, ax=ax) + ax.scatter(p_start[1], p_start[2], color=(0.12, 0.12, 0.8), s=50., zorder=10) + ax.scatter(p_goal[1], p_goal[2], color=(0.12, 0.8, 0.12), s=50., zorder=10) + cb.set_label('NeRF Density') + plt.xticks([], []) + plt.yticks([], []) + plt.tight_layout() + plt.show() + + +if __name__ == '__main__': + main() diff --git a/examples/readme.py b/examples/readme.py index 7bf012b..f68af36 100644 --- a/examples/readme.py +++ b/examples/readme.py @@ -25,9 +25,9 @@ def forward(self, x): pyTorch_model = MultiLayerPerceptron() -l4c_model = l4c.L4CasADi(pyTorch_model, model_expects_batch_dim=True, device='cpu') # device='cuda' for GPU +l4c_model = l4c.L4CasADi(pyTorch_model, device='cpu') # device='cuda' for GPU -x_sym = cs.MX.sym('x', 2, 1) +x_sym = cs.MX.sym('x', 1, 2) y_sym = l4c_model(x_sym) f = cs.Function('y', [x_sym], [y_sym]) df = cs.Function('dy', [x_sym], [cs.jacobian(y_sym, x_sym)]) diff --git a/examples/simple_nlp.py b/examples/simple_nlp.py index d8923be..f74eaf2 100644 --- a/examples/simple_nlp.py +++ b/examples/simple_nlp.py @@ -14,7 +14,7 @@ def forward(self, input): f = PyTorchObjectiveModel() # objective -f = l4c.L4CasADi(f, name='f', model_expects_batch_dim=False)(x) +f = l4c.L4CasADi(f, name='f')(x) class PyTorchConstraintModel(torch.nn.Module): @@ -23,7 +23,7 @@ def forward(self, input): g = PyTorchConstraintModel() # constraint -g = l4c.L4CasADi(g, name='g', model_expects_batch_dim=False)(x) +g = l4c.L4CasADi(g, name='g')(x) nlp = {'x': x, 'f': f, 'g': g} diff --git a/l4casadi/l4casadi.py b/l4casadi/l4casadi.py index 455a65a..a4497c6 100644 --- a/l4casadi/l4casadi.py +++ b/l4casadi/l4casadi.py @@ -3,20 +3,22 @@ import platform import shutil import time +import warnings try: from importlib.resources import files except ImportError: - from importlib_resources import files # type: ignore[no-redef] + from importlib_resources import files # type: ignore[no-redef] from typing import Union, Optional, Callable, Text, Tuple import casadi as cs import torch + try: - from torch.func import jacrev, jacfwd, functionalize + from torch.func import jacrev, jacfwd, functionalize, vjp except ImportError: - from functorch import jacrev, jacfwd, functionalize + from functorch import jacrev, jacfwd, functionalize, vjp from l4casadi.ts_compiler import ts_compile from torch.fx.experimental.proxy_tensor import make_fx @@ -37,18 +39,21 @@ def dynamic_lib_file_ending(): class L4CasADi(object): def __init__(self, model: Callable[[torch.Tensor], torch.Tensor], - model_expects_batch_dim: bool = True, + batched: bool = False, device: Union[torch.device, Text] = 'cpu', name: Text = 'l4casadi_f', build_dir: Text = './_l4c_generated', model_search_path: Optional[Text] = None, - with_jacobian: bool = True, - with_hessian: bool = True, + generate_jac: bool = True, + generate_adj1: bool = True, + generate_jac_adj1: bool = True, + generate_jac_jac: bool = False, + scripting: bool = True, mutable: bool = False): """ :param model: PyTorch model. - :param model_expects_batch_dim: True if the PyTorch model expects a batch dimension. This is commonly True - for trained PyTorch models. + :param batched: If True, the first dimension of the two expected input dimension is assumed to be a batch + dimension. This can lead to speedups as sensitivities across this dimension can be neglected. :param device: Device on which the PyTorch model is executed. :param name: Unique name of the generated L4CasADi model. This name is used for autogenerated files. Creating two L4CasADi models with the same name will result in overwriting the files of the first model. @@ -56,10 +61,27 @@ def __init__(self, the absolute path to the `build_dir` where the model traces are exported to. This parameter can become useful if the created L4CasADi dynamic library and the exported PyTorch Models are expected to be moved to a different folder (or another device). - :param with_jacobian: If True, the Jacobian of the model is exported. - :param with_hessian: If True, the Hessian of the model is exported. + :param build_dir: Directory where the L4CasADi library is built. + :param generate_jac: If True, the Jacobian of the model is tried to be generated. + :param generate_adj1: If True, the Adjoint of the model is tried to be generated. + :param generate_jac_adj1: If True, the Jacobain of the Adjoint of the model is tried to be generated. + :param generate_jac_jac: If True, the Hessian of the model is tried to be generated. + :param scripting: If True, the model is traced using TorchScript. If False, the model is compiled. :param mutable: If True, enables updating the model online via the update method. """ + if platform.system() == "Windows": + warnings.warn("L4CasADi is currently not supported for Windows.") + + if not scripting: + warnings.warn("L4CasADi with Torch AOT compilation is experimental at this point and might not work as " + "expected.") + raise RuntimeError("PyTorch compile is not supported yet as it does not seem stable.") + if torch.__version__ < torch.torch_version.TorchVersion('2.4.0'): + raise RuntimeError("For PyTorch versions < 2.4.0 L4CasADi only supports jit scripting. Please pass " + "scripting=True.") + import torch._inductor.config as config + config.freezing = True + self.model = model self.naive = False if isinstance(self.model, NaiveL4CasADiModule): @@ -69,7 +91,7 @@ def __init__(self, for parameters in self.model.parameters(): parameters.requires_grad = False self.name = name - self.has_batch = model_expects_batch_dim + self.batched = batched self.device = device if isinstance(device, str) else f'{device.type}:{device.index}' self.build_dir = pathlib.Path(build_dir) @@ -79,12 +101,17 @@ def __init__(self, self._cs_fun: Optional[cs.Function] = None self._built = False - self._with_jacobian = with_jacobian - self._with_hessian = with_hessian + self._generate_jac = generate_jac + self._generate_adj1 = generate_adj1 + self._generate_jac_adj1 = generate_jac_adj1 + self._generate_jac_jac = generate_jac_jac + + self._scripting = scripting self._mutable = mutable - self._input_shape: Optional[Tuple[int, int]] = None + self._input_shape: Tuple[int, int] = (-1, -1) + self._output_shape: Tuple[int, int] = (-1, -1) def update(self, model: Optional[Callable[[torch.Tensor], torch.Tensor]] = None) -> None: """ @@ -105,7 +132,7 @@ def update(self, model: Optional[Callable[[torch.Tensor], torch.Tensor]] = None) for parameters in self.model.parameters(): parameters.requires_grad = False - self.export_torch_traces(*self._input_shape) # type: ignore[misc] + self.export_torch_traces() # type: ignore[misc] time.sleep(0.2) @@ -124,10 +151,6 @@ def shared_lib_dir(self): return self.build_dir.absolute().as_posix() def forward(self, inp: Union[cs.MX, cs.SX, cs.DM]): - if self.has_batch: - if not inp.shape[-1] == 1: # type: ignore[attr-defined] - raise ValueError("For batched PyTorch models only vector inputs are allowed.") - if self.naive: out = self.model(inp) else: @@ -171,16 +194,36 @@ def build(self, inp: Union[cs.MX, cs.SX, cs.DM]) -> None: self._built = True + def _verify_input_output(self): + if len(self._output_shape) != 2: + raise ValueError(f"""L4CasADi requires the model output to be a matrix (2 dimensions) but has + {len(self._output_shape)} dimensions. Please add a extra dimension of size 1. + For models which expects a batch dimension, the output should be a matrix of [1, d].""") + + if self.batched: + if self._input_shape[0] != self._output_shape[0]: + raise ValueError(f"""When the model is batched the first dimension of input and output (batch dimension) + has to be the same.""") + def generate(self, inp: Union[cs.MX, cs.SX, cs.DM]) -> None: - rows, cols = inp.shape # type: ignore[attr-defined] - has_jac, has_hess = self.export_torch_traces(rows, cols) - if not has_jac and self._with_jacobian: - print('Jacobian trace could not be generated.' - ' First-order sensitivities will not be available in CasADi.') - if not has_hess and self._with_hessian: - print('Hessian trace could not be generated.' - ' Second-order sensitivities will not be available in CasADi.') - self._generate_cpp_function_template(rows, cols, has_jac, has_hess) + self._input_shape = inp.shape # type: ignore[attr-defined] + self._output_shape = self.model(torch.zeros(*self._input_shape).to(self.device)).shape + self._verify_input_output() + + has_jac, has_adj1, has_jac_adj1, has_jac_jac = self.export_torch_traces() + if not has_jac and self._generate_jac: + warnings.warn('Jacobian trace could not be generated.' + ' First-order sensitivities will not be available in CasADi.') + if not has_adj1 and self._generate_adj1: + warnings.warn('Adjoint trace could not be generated.' + ' First-order sensitivities will not be available in CasADi.') + if not has_jac_adj1 and self._generate_jac_adj1: + warnings.warn('Jacobian Adjoint trace could not be generated.' + ' Second-order sensitivities will not be available in CasADi.') + if not has_jac_jac and self._generate_jac_jac: + warnings.warn('Hessian trace could not be generated.' + ' Second-order sensitivities will not be available in CasADi.') + self._generate_cpp_function_template(has_jac, has_adj1, has_jac_adj1, has_jac_jac) def _load_built_library_as_external_cs_fun(self): if not self._built: @@ -190,39 +233,80 @@ def _load_built_library_as_external_cs_fun(self): f"{self.build_dir / f'lib{self.name}'}{dynamic_lib_file_ending()}" ) - def _generate_cpp_function_template(self, rows: int, cols: int, has_jac: bool, has_hess: bool): - if self.has_batch: - out_shape = self.model(torch.zeros(1, rows).to(self.device)).shape - rows_out = out_shape[-1] - cols_out = 1 - else: - out_shape = self.model(torch.zeros(rows, cols).to(self.device)).shape - if len(out_shape) == 1: - rows_out = out_shape[0] - cols_out = 1 - else: - rows_out, cols_out = out_shape[-2:] - if len(out_shape) != 2: - raise ValueError(f"""L4CasADi requires the model output to be a matrix (2 dimensions) but has - {len(out_shape)} dimensions. For models which expects a batch dimension, - the output should be a matrix of [1, d].""") + @staticmethod + def generate_block_diagonal_ccs(batch_size, input_size, output_size): + """ + https://de.wikipedia.org/wiki/Harwell-Boeing-Format + :param batch_size: Size of batch dimension. + :param input_size: Size of input vector. + :param output_size: Size of output vector. + :return: + jac_ccs, hess_ccs + """ + # Jacobian dimensions [batch_size * output_size, batch_size * input_size] + col_ptr = list(range(0, batch_size * input_size * output_size, output_size)) + [ + batch_size * input_size * output_size] + row_ind = [] + for _ in range(input_size): + for batch_idx in range(batch_size): + row_ind += list(range(batch_idx, batch_idx + batch_size * output_size, batch_size)) + + jac_ccs = [batch_size * output_size, batch_size * input_size] + col_ptr + row_ind + + # Hessian dimensions [batch_size * output_size * batch_size * input_size, batch_size * input_size] + col_ptr = list(range(0, batch_size * input_size * output_size * input_size, input_size * output_size)) + [ + batch_size * input_size * output_size * input_size] + row_ind = [] + for _ in range(input_size): + for batch_idx in range(batch_size): + for jacobian_idx in range(0, batch_size * output_size * batch_size * input_size, + output_size * batch_size * batch_size): + row_ind += list(range(jacobian_idx + batch_idx * batch_size * output_size + batch_idx, + (jacobian_idx + batch_idx * batch_size * output_size + + batch_idx + batch_size * output_size), + batch_size)) + + hess_ccs = [batch_size * output_size * batch_size * input_size, batch_size * input_size] + col_ptr + row_ind + + return jac_ccs, hess_ccs + + def _generate_cpp_function_template(self, has_jac: bool, has_adj1: bool, has_jac_adj1: bool, has_jac_jac: bool): model_path = (self.build_dir.absolute().as_posix() if self._model_search_path is None else self._model_search_path) + if self.batched: + jac_ccs, jac_jac_ccs = self.generate_block_diagonal_ccs(self._input_shape[0], + self._input_shape[1], + self._output_shape[1]) + jac_adj_css, _ = self.generate_block_diagonal_ccs(self._input_shape[0], + self._input_shape[1], + self._input_shape[1]) + else: + jac_ccs, jac_adj_css, jac_jac_ccs = None, None, None + gen_params = { 'model_path': model_path, 'device': self.device, 'name': self.name, - 'rows_in': rows, - 'cols_in': cols, - 'rows_out': rows_out, - 'cols_out': cols_out, + 'rows_in': self._input_shape[0], + 'cols_in': self._input_shape[1], + 'rows_out': self._output_shape[0], + 'cols_out': self._output_shape[1], 'has_jac': 'true' if has_jac else 'false', - 'has_hess': 'true' if has_hess else 'false', - 'model_expects_batch_dim': 'true' if self.has_batch else 'false', + 'has_adj1': 'true' if has_adj1 else 'false', + 'has_jac_adj1': 'true' if has_jac_adj1 else 'false', + 'has_jac_jac': 'true' if has_jac_jac else 'false', + 'scripting': 'true' if self._scripting else 'false', 'model_is_mutable': 'true' if self._mutable else 'false', + 'batched': 'true' if self.batched else 'false', + 'jac_ccs_len': len(jac_ccs) if self.batched else 0, + 'jac_ccs': ', '.join(str(e) for e in jac_ccs) if self.batched else '', + 'jac_adj_ccs_len': len(jac_adj_css) if self.batched else 0, + 'jac_adj_ccs': ', '.join(str(e) for e in jac_adj_css) if self.batched else '', + 'jac_jac_ccs_len': len(jac_jac_ccs) if self.batched else 0, + 'jac_jac_ccs': ', '.join(str(e) for e in jac_jac_ccs) if self.batched else '', } render_casadi_c_template( @@ -254,40 +338,89 @@ def compile(self): raise Exception(f'Compilation failed!\n\nAttempted to execute OS command:\n{os_cmd}\n\n') def _trace_jac_model(self, inp): + if self.batched: + def with_batch_dim(x): + return torch.func.vmap(jacrev(self.model))(x[:, None])[:, 0].permute(1, 0, 2, 3) + + return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(inp) return make_fx(functionalize(jacrev(self.model), remove='mutations_and_views'))(inp) - def _trace_hess_model(self, inp): - return make_fx(functionalize(jacrev(jacrev(self.model)), remove='mutations_and_views'))(inp) + def _trace_adj1_model(self): + p_d = torch.zeros(self._input_shape).to(self.device) + t_d = torch.zeros(self._output_shape).to(self.device) - def export_torch_traces(self, rows: int, cols: int) -> Tuple[bool, bool]: - if self.has_batch: - d_inp = torch.zeros((1, rows)) - else: - d_inp = torch.zeros((rows, cols)) + def _vjp(p, x): + return vjp(self.model, p)[1](x)[0] + + return make_fx(functionalize(_vjp, remove='mutations_and_views'))(p_d, t_d) - # Save input shape for online update. - self._input_shape = (rows, cols) + def _trace_jac_adj1_model(self): + p_d = torch.zeros(self._input_shape).to(self.device) + t_d = torch.zeros(self._output_shape).to(self.device) + def _vjp(p, x): + return vjp(self.model, p)[1](x)[0] + + # TODO: replace jacfwd with jacrev depending on answer in https://github.com/pytorch/pytorch/issues/130735 + if self.batched: + def with_batch_dim(p, x): + return torch.func.vmap(jacfwd(_vjp))(p[:, None], x[:, None])[:, 0].permute(3, 2, 0, 1) + + return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(p_d, t_d) + return make_fx(functionalize(jacfwd(_vjp), remove='mutations_and_views'))(p_d, t_d) + + def _trace_hess_model(self, inp): + if self.batched: + def with_batch_dim(x): + # Permutation is trial and error + return torch.func.vmap(jacrev(jacrev(self.model)))(x[:, None])[:, 0].permute(1, 3, 2, 0, 4, 5) + + return make_fx(functionalize(with_batch_dim, remove='mutations_and_views'))(inp) + return make_fx(functionalize(jacrev(jacrev(self.model)), remove='mutations_and_views'))(inp) + + def export_torch_traces(self) -> Tuple[bool, bool, bool, bool]: + d_inp = torch.zeros(self._input_shape) d_inp = d_inp.to(self.device) + d_out = torch.zeros(self._output_shape) + d_out = d_out.to(self.device) + out_folder = self.build_dir - self._jit_compile_and_save(make_fx(functionalize(self.model, remove='mutations_and_views'))(d_inp), - (out_folder / f'{self.name}_forward.pt').as_posix(), - d_inp) + self.model_compile(make_fx(functionalize(self.model, remove='mutations_and_views'))(d_inp), + (out_folder / f'{self.name}.pt').as_posix(), + (d_inp,)) - exported_jacrev = False - if self._with_jacobian: + exported_jac = False + if self._generate_jac: jac_model = self._trace_jac_model(d_inp) - exported_jacrev = self._jit_compile_and_save( + exported_jac = self.model_compile( jac_model, - (out_folder / f'{self.name}_jacrev.pt').as_posix(), - d_inp + (out_folder / f'jac_{self.name}.pt').as_posix(), + (d_inp,) + ) + + exported_adj1 = False + if self._generate_adj1: + adj1_model = self._trace_adj1_model() + exported_adj1 = self.model_compile( + adj1_model, + (out_folder / f'adj1_{self.name}.pt').as_posix(), + (d_inp, d_out) + ) + + exported_jac_adj1 = False + if self._generate_jac_adj1: + jac_adj1_model = self._trace_jac_adj1_model() + exported_jac_adj1 = self.model_compile( + jac_adj1_model, + (out_folder / f'jac_adj1_{self.name}.pt').as_posix(), + (d_inp, d_out) ) exported_hess = False - if self._with_hessian: + if self._generate_jac_jac: hess_model = None try: hess_model = self._trace_hess_model(d_inp) @@ -295,17 +428,35 @@ def export_torch_traces(self, rows: int, cols: int) -> Tuple[bool, bool]: pass if hess_model is not None: - exported_hess = self._jit_compile_and_save( + exported_hess = self.model_compile( hess_model, - (out_folder / f'{self.name}_hess.pt').as_posix(), - d_inp + (out_folder / f'jac_jac_{self.name}.pt').as_posix(), + (d_inp,) ) - return exported_jacrev, exported_hess + return exported_jac, exported_adj1, exported_jac_adj1, exported_hess + + def model_compile(self, model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]): + if self._scripting: + return self._jit_compile_and_save(model, file_path, dummy_inp) + else: + return self._aot_compile_and_save(model, file_path, dummy_inp) + + @staticmethod + def _aot_compile_and_save(model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]): + try: + with torch.no_grad(): + torch._export.aot_compile( + model, + dummy_inp, + options={"aot_inductor.output_path": file_path[:-2] + 'so'}, + ) + return True + except: # noqa + return False @staticmethod - def _jit_compile_and_save(model, file_path: str, dummy_inp: torch.Tensor): - # TODO: Could switch to torch export https://pytorch.org/docs/stable/export.html + def _jit_compile_and_save(model, file_path: str, dummy_inp: Tuple[torch.Tensor, ...]): try: # Try scripting ts_compile(model).save(file_path) diff --git a/l4casadi/naive/nn/linear.py b/l4casadi/naive/nn/linear.py index 27c9020..da0ea11 100644 --- a/l4casadi/naive/nn/linear.py +++ b/l4casadi/naive/nn/linear.py @@ -6,8 +6,7 @@ class Linear(NaiveL4CasADiModule, torch.nn.Linear): def cs_forward(self, x): - assert x.shape[1] == 1, 'Casadi can not handle batches.' - y = cs.mtimes(self.weight.detach().numpy(), x) + y = cs.mtimes(x, self.weight.transpose(1, 0).detach().numpy()) if self.bias is not None: - y = y + self.bias.detach().numpy() + y = y + self.bias[None].repeat((x.shape[0], 1)).detach().numpy() return y diff --git a/l4casadi/realtime/realtime_l4casadi.py b/l4casadi/realtime/realtime_l4casadi.py index 84e644f..984fd0f 100644 --- a/l4casadi/realtime/realtime_l4casadi.py +++ b/l4casadi/realtime/realtime_l4casadi.py @@ -22,7 +22,7 @@ def __init__(self, :param name: Unique name of the generated L4CasADi model. This name is used for autogenerated files. Creating two L4CasADi models with the same name will result in overwriting the files of the first model. """ - super().__init__(model, model_expects_batch_dim=True, device=device, name=name) + super().__init__(model, device=device, name=name) if approximation_order > 2 or approximation_order < 1: raise ValueError("Taylor approximation order must be 1 or 2.") diff --git a/l4casadi/template_generation/templates/casadi_function.in.cpp b/l4casadi/template_generation/templates/casadi_function.in.cpp index c2774ee..e4d5f87 100644 --- a/l4casadi/template_generation/templates/casadi_function.in.cpp +++ b/l4casadi/template_generation/templates/casadi_function.in.cpp @@ -1,6 +1,6 @@ #include -L4CasADi l4casadi("{{ model_path }}", "{{ name }}", {{ model_expects_batch_dim }}, "{{ device }}", {{ has_jac }}, {{ has_hess }}, {{ model_is_mutable }}); +L4CasADi l4casadi("{{ model_path }}", "{{ name }}", {{ rows_in }}, {{ cols_in }}, {{ rows_out }}, {{ cols_out }}, "{{ device }}", {{ has_jac }}, {{ has_adj1 }}, {{ has_jac_adj1 }}, {{ has_jac_jac }}, {{ scripting }}, {{ model_is_mutable }}); #ifdef __cplusplus extern "C" { @@ -31,48 +31,143 @@ extern "C" { #endif #endif +// Function {{ name }} -static const casadi_int casadi_s_in0[3] = { {{ rows_in }}, {{ cols_in }}, 1}; -static const casadi_int casadi_s_out0[3] = { {{ rows_out }}, {{ cols_out }}, 1}; +static const casadi_int {{ name }}_s_in0[3] = { {{ rows_in }}, {{ cols_in }}, 1}; +static const casadi_int {{ name }}_s_out0[3] = { {{ rows_out }}, {{ cols_out }}, 1}; +// Only single input, single output is supported at the moment +CASADI_SYMBOL_EXPORT casadi_int {{ name }}_n_in(void) { return 1;} +CASADI_SYMBOL_EXPORT casadi_int {{ name }}_n_out(void) { return 1;} + +CASADI_SYMBOL_EXPORT const casadi_int* {{ name }}_sparsity_in(casadi_int i) { + switch (i) { + case 0: return {{ name }}_s_in0; + default: return 0; + } +} + +CASADI_SYMBOL_EXPORT const casadi_int* {{ name }}_sparsity_out(casadi_int i) { + switch (i) { + case 0: return {{ name }}_s_out0; + default: return 0; + } +} CASADI_SYMBOL_EXPORT int {{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ - l4casadi.forward(arg[0], {{ rows_in }}, {{ cols_in }}, res[0]); + l4casadi.forward(arg[0], res[0]); return 0; } -{% if has_jac %} +{% if has_jac == "true" %} +// Jacobian {{ name }} + +CASADI_SYMBOL_EXPORT casadi_int jac_{{ name }}_n_in(void) { return 2;} +CASADI_SYMBOL_EXPORT casadi_int jac_{{ name }}_n_out(void) { return 1;} + CASADI_SYMBOL_EXPORT int jac_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ - l4casadi.jac(arg[0], {{ rows_in }}, {{ cols_in }}, res[0]); + l4casadi.jac(arg[0], res[0]); return 0; } + +{% if batched == "true" %} +// Sparse output if batched. +static const casadi_int jac_{{ name }}_s_out0[{{jac_ccs_len}}] = { {{ jac_ccs }}}; + +CASADI_SYMBOL_EXPORT const casadi_int* jac_{{ name }}_sparsity_out(casadi_int i) { + switch (i) { + case 0: return jac_{{ name }}_s_out0; + default: return 0; + } +} +{% endif %} {% endif %} -{% if has_hess %} -CASADI_SYMBOL_EXPORT int jac_jac_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ - l4casadi.hess(arg[0], {{ rows_in }}, {{ cols_in }}, res[0]); + +{% if has_adj1 == "true" %} +// adj1 {{ name }} + +CASADI_SYMBOL_EXPORT casadi_int adj1_{{ name }}_n_in(void) { return 3;} +CASADI_SYMBOL_EXPORT casadi_int adj1_{{ name }}_n_out(void) { return 1;} + +CASADI_SYMBOL_EXPORT int adj1_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ + // adj1 [i0, out_o0, adj_o0] -> [out_adj_i0] + l4casadi.adj1(arg[0], arg[2], res[0]); return 0; } {% endif %} -// Only single input, single output is supported at the moment -CASADI_SYMBOL_EXPORT casadi_int {{ name }}_n_in(void) { return 1;} -CASADI_SYMBOL_EXPORT casadi_int {{ name }}_n_out(void) { return 1;} +{% if has_jac_adj1 == "true" %} +// jac_adj1 {{ name }} -CASADI_SYMBOL_EXPORT const casadi_int* {{ name }}_sparsity_in(casadi_int i) { +CASADI_SYMBOL_EXPORT casadi_int jac_adj1_{{ name }}_n_in(void) { return 4;} +CASADI_SYMBOL_EXPORT casadi_int jac_adj1_{{ name }}_n_out(void) { return 3;} + +CASADI_SYMBOL_EXPORT int jac_adj1_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ + // jac_adj1 [i0, out_o0, adj_o0, out_adj_i0] -> [jac_adj_i0_i0, jac_adj_i0_out_o0, jac_adj_i0_adj_o0] + if (res[1] != NULL) { + l4casadi.invalid_argument("jac_adj_i0_out_o0 is not provided by L4CasADi. If you need this feature, please contact the L4CasADi developer."); + } + if (res[2] != NULL) { + l4casadi.invalid_argument("jac_adj_i0_adj_o0 is not provided by L4CasADi. If you need this feature, please contact the L4CasADi developer."); + } + if (res[0] == NULL) { + l4casadi.invalid_argument("L4CasADi can only provide jac_adj_i0_i0 for jac_adj1_{{ name }} function. If you need this feature, please contact the L4CasADi developer."); + } + l4casadi.jac_adj1(arg[0], arg[2], res[0]); + return 0; +} + +{% if batched == "true" %} +// Sparse output if batched. +static const casadi_int jac_adj1_{{ name }}_s_out0[{{jac_adj_ccs_len}}] = { {{ jac_adj_ccs }}}; +static const casadi_int jac_adj1_{{ name }}_s_out23[3] = { {{ rows_in }} * {{ cols_in }}, {{ rows_out }} * {{ cols_out }}, 1}; + +CASADI_SYMBOL_EXPORT const casadi_int* jac_adj1_{{ name }}_sparsity_out(casadi_int i) { switch (i) { - case 0: return casadi_s_in0; + case 0: return jac_adj1_{{ name }}_s_out0; + case 1: return jac_adj1_{{ name }}_s_out23; + case 2: return jac_adj1_{{ name }}_s_out23; default: return 0; } } +{% endif %} +{% endif %} -CASADI_SYMBOL_EXPORT const casadi_int* {{ name }}_sparsity_out(casadi_int i) { + +{% if has_jac_jac == "true" %} +// jac_jac {{ name }} + +CASADI_SYMBOL_EXPORT int jac_jac_{{ name }}(const casadi_real** arg, casadi_real** res, casadi_int* iw, casadi_real* w, int mem){ + // [i0, out_o0, out_jac_o0_i0] -> [jac_jac_o0_i0_i0, jac_jac_o0_i0_out_o0] + if (res[1] != NULL) { + l4casadi.invalid_argument("jac_jac_o0_i0_out_o0 is not provided by L4CasADi. If you need this feature, please contact the L4CasADi developer."); + } + if (res[0] == NULL) { + l4casadi.invalid_argument("L4CasADi can only provide jac_jac_o0_i0_i0 for jac_jac_{{ name }} function. If you need this feature, please contact the L4CasADi developer."); + } + l4casadi.jac_jac(arg[0], res[0]); + return 0; +} + +{% if batched == "true" %} +// jac_jac {{ name }} + +static const casadi_int jac_jac_{{ name }}_s_out0[{{jac_jac_ccs_len}}] = { {{ jac_jac_ccs }}}; +static const casadi_int jac_jac_{{ name }}_s_out1[3] = { {{ rows_in }} * {{ cols_in }} * {{ rows_out }} * {{ cols_out }}, {{ rows_out }} * {{ cols_out }}, 1}; +CASADI_SYMBOL_EXPORT const casadi_int* jac_jac_{{ name }}_sparsity_out(casadi_int i) { switch (i) { - case 0: return casadi_s_out0; + case 0: return jac_jac_{{ name }}_s_out0; + case 1: return jac_jac_{{ name }}_s_out1; default: return 0; } } +CASADI_SYMBOL_EXPORT casadi_int jac_jac_{{ name }}_n_in(void) { return 3;} + +CASADI_SYMBOL_EXPORT casadi_int jac_jac_{{ name }}_n_out(void) { return 2;} +{% endif %} +{% endif %} #ifdef __cplusplus } /* extern "C" */ diff --git a/libl4casadi/CMakeLists.txt b/libl4casadi/CMakeLists.txt index 4bfa39d..674d250 100644 --- a/libl4casadi/CMakeLists.txt +++ b/libl4casadi/CMakeLists.txt @@ -1,9 +1,7 @@ cmake_minimum_required(VERSION 3.0 FATAL_ERROR) project(L4CasADi) -# Load CUDA if it is installed -find_package(CUDAToolkit) -find_package(CUDA) +set(CMAKE_COMPILE_WARNING_AS_ERROR ON) if (WIN32) set (CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS TRUE) @@ -14,6 +12,22 @@ endif () set(CMAKE_PREFIX_PATH ${CMAKE_TORCH_PATH}) find_package(Torch REQUIRED) + +# Load CUDA if it is installed +find_package(CUDAToolkit) +find_package(CUDA) + +add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) +add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) +add_definitions(-DTORCH_VERSION_PATCH=${Torch_VERSION_PATCH}) + +if (Torch_VERSION_MAJOR GREATER_EQUAL 1 AND Torch_VERSION_MINOR GREATER_EQUAL 4) + # add_definitions(-DENABLE_TORCH_COMPILE) +endif () +if (USE_CUDA) + add_definitions(-DUSE_CUDA) +endif () + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") add_library(l4casadi SHARED src/l4casadi.cpp include/l4casadi.hpp) diff --git a/libl4casadi/include/l4casadi.hpp b/libl4casadi/include/l4casadi.hpp index 2885432..b0a92ec 100644 --- a/libl4casadi/include/l4casadi.hpp +++ b/libl4casadi/include/l4casadi.hpp @@ -7,16 +7,27 @@ class L4CasADi { private: - bool model_expects_batch_dim; + int rows_in; + int cols_in; + + int rows_out; + int cols_out; public: - L4CasADi(std::string, std::string, bool = false, std::string = "cpu", bool = false, bool = false, bool = false); + L4CasADi(std::string, std::string, int, int, int, int, std::string = "cpu", bool = false, bool = false, bool = false, bool = false, + bool = false,bool = false); ~L4CasADi(); - void forward(const double*, int, int, double*); - void jac(const double*, int, int, double*); - void hess(const double*, int, int, double*); + void forward(const double*, double*); + void jac(const double*, double*); + void adj1(const double*, const double*, double*); + void jac_adj1(const double*, const double*, double*); + void jac_jac(const double*, double*); + + void invalid_argument(std::string); // PImpl Idiom class L4CasADiImpl; + class L4CasADiScriptedImpl; + class L4CasADiCompiledImpl; std::unique_ptr pImpl; }; diff --git a/libl4casadi/src/l4casadi.cpp b/libl4casadi/src/l4casadi.cpp index 9fe719c..3dc586c 100644 --- a/libl4casadi/src/l4casadi.cpp +++ b/libl4casadi/src/l4casadi.cpp @@ -1,37 +1,43 @@ #include +#include #include #include #include //#include +#if ENABLE_TORCH_COMPILE +#include + +#if USE_CUDA +#include +#endif +#endif + #include "l4casadi.hpp" torch::Device cpu(torch::kCPU); class L4CasADi::L4CasADiImpl { - std::string model_path; - std::string model_prefix; +protected: + std::string path; + std::string function_name; bool has_jac; + bool has_adj1; + bool has_jac_adj1; bool has_hess; - torch::jit::script::Module forward_model; - torch::jit::script::Module jac_model; - torch::jit::script::Module hess_model; + bool is_mutable; torch::Device device; - std::thread online_model_reloader_thread; - std::mutex model_update_mutex; - std::atomic reload_model_loop_running = false; - public: - L4CasADiImpl(std::string model_path, std::string model_prefix, std::string device, bool has_jac, bool has_hess, - bool model_is_mutable): device{torch::kCPU}, model_path{model_path}, model_prefix{model_prefix}, - has_jac{has_jac}, has_hess{has_hess} { - + L4CasADiImpl(std::string path, std::string function_name, std::string device, bool has_jac, bool has_adj1, + bool has_jac_adj1, bool has_hess, bool is_mutable): device{torch::kCPU}, path{path}, + function_name{function_name}, has_jac{has_jac}, has_adj1{has_adj1}, has_jac_adj1{has_jac_adj1}, + has_hess{has_hess}, is_mutable(is_mutable) { if (torch::cuda::is_available() && device.compare("cpu")) { std::cout << "CUDA is available! Using GPU " << device << "." << std::endl; this->device = torch::Device(device); @@ -44,16 +50,183 @@ class L4CasADi::L4CasADiImpl } else { this->device = torch::Device(device); } + } + virtual torch::Tensor forward(torch::Tensor) = 0; + virtual torch::Tensor jac(torch::Tensor) = 0; + virtual torch::Tensor adj1(torch::Tensor, torch::Tensor) = 0; + virtual torch::Tensor jac_adj1(torch::Tensor, torch::Tensor) = 0; + virtual torch::Tensor hess(torch::Tensor) = 0; + + virtual ~L4CasADiImpl() = default; +}; + +#if ENABLE_TORCH_COMPILE +class L4CasADi::L4CasADiCompiledImpl : public L4CasADi::L4CasADiImpl +{ + std::unique_ptr forward_model; + std::unique_ptr jac_model; + std::unique_ptr adj1_model; + std::unique_ptr jac_adj1_model; + std::unique_ptr hess_model; + + std::mutex model_update_mutex; + +public: + L4CasADiCompiledImpl(std::string path, std::string function_name, std::string device, bool has_jac, bool has_adj1, + bool has_jac_adj1, bool has_hess, bool is_mutable): L4CasADiImpl(path, function_name, device, has_jac, + has_adj1, has_jac_adj1, has_hess, is_mutable) { this->load_model_from_disk(); - if (model_is_mutable) { + if (is_mutable) { + throw std::invalid_argument("Mutable functions are not yet supported for compiled models."); + } + } + + ~L4CasADiCompiledImpl() = default; + + void load_model_from_disk() { + std::filesystem::path dir (this->path); + std::filesystem::path forward_model_file (this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->forward_model = std::make_unique((dir / forward_model_file).generic_string()); + } + else { + this->forward_model = std::make_unique((dir / forward_model_file).generic_string()); + } +#else + this->forward_model = std::make_unique((dir / forward_model_file).generic_string()); +#endif + if (this->has_adj1) { + std::filesystem::path adj1_model_file ("adj1_" + this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->adj1_model = std::make_unique((dir / adj1_model_file).generic_string()); + } + else { + this->adj1_model = std::make_unique((dir / adj1_model_file).generic_string()); + } +#else + this->adj1_model = std::make_unique((dir / adj1_model_file).generic_string()); +#endif + } + + if (this->has_jac_adj1) { + std::filesystem::path jac_adj1_model_file ("jac_adj1_" + this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->jac_adj1_model = std::make_unique((dir / jac_adj1_model_file).generic_string()); + } + else { + this->jac_adj1_model = std::make_unique((dir / jac_adj1_model_file).generic_string()); + } +#else + this->jac_adj1_model = std::make_unique((dir / jac_adj1_model_file).generic_string()); +#endif + } + + if (this->has_jac) { + std::filesystem::path jac_model_file ("jac_" + this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->jac_model = std::make_unique((dir / jac_model_file).generic_string()); + } + else { + this->jac_model = std::make_unique((dir / jac_model_file).generic_string()); + } +#else + this->jac_model = std::make_unique((dir / jac_model_file).generic_string()); +#endif + } + + if (this->has_hess) { + std::filesystem::path hess_model_file ("jac_jac_" + this->function_name + ".so"); +#if USE_CUDA + if (this-> device == cpu) { + this->hess_model = std::make_unique((dir / hess_model_file).generic_string()); + } + else { + this->hess_model = std::make_unique((dir / hess_model_file).generic_string()); + } +#else + this->hess_model = std::make_unique((dir / hess_model_file).generic_string()); +#endif + } + } + + torch::Tensor forward(torch::Tensor x) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(x); + auto out = this->forward_model->run(inputs)[0].to(cpu); + return out; + } + + torch::Tensor jac(torch::Tensor x) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(x.to(this->device)); + return this->jac_model->run(inputs)[0].to(cpu); + } + + torch::Tensor adj1(torch::Tensor primal, torch::Tensor tangent) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(primal.to(this->device)); + inputs.push_back(tangent.to(this->device)); + return this->adj1_model->run(inputs)[0].to(cpu); + } + + torch::Tensor jac_adj1(torch::Tensor primal, torch::Tensor tangent){ + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(primal.to(this->device)); + inputs.push_back(tangent.to(this->device)); + return this->jac_adj1_model->run(inputs)[0].to(cpu); + } + + torch::Tensor hess(torch::Tensor x) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(x.to(this->device)); + return this->hess_model->run(inputs)[0].to(cpu); + } + +}; +#endif + +class L4CasADi::L4CasADiScriptedImpl : public L4CasADi::L4CasADiImpl +{ + torch::jit::script::Module adj1_model; + torch::jit::script::Module forward_model; + torch::jit::script::Module jac_model; + torch::jit::script::Module jac_adj1_model; + torch::jit::script::Module hess_model; + + std::thread online_model_reloader_thread; + std::mutex model_update_mutex; + std::atomic reload_model_loop_running = false; + +public: + L4CasADiScriptedImpl(std::string path, std::string function_name, std::string device, bool has_jac, bool has_adj1, + bool has_jac_adj1, bool has_hess, bool is_mutable): L4CasADiImpl(path, function_name, device, has_jac, + has_adj1, has_jac_adj1, has_hess, is_mutable) { + + this->load_model_from_disk(); + + if (is_mutable) { this->reload_model_loop_running = true; - this->online_model_reloader_thread = std::thread(&L4CasADiImpl::reload_runner, this); + this->online_model_reloader_thread = std::thread(&L4CasADiScriptedImpl::reload_runner, this); } } - ~ L4CasADiImpl() { + ~ L4CasADiScriptedImpl() { if (this->reload_model_loop_running == true) { this->reload_model_loop_running = false; this->online_model_reloader_thread.join(); @@ -61,8 +234,8 @@ class L4CasADi::L4CasADiImpl } void reload_runner() { - std::filesystem::path dir (this->model_path); - std::filesystem::path reload_file (this->model_prefix + ".reload"); + std::filesystem::path dir (this->path); + std::filesystem::path reload_file (this->function_name + ".reload"); while(this->reload_model_loop_running) { std::this_thread::sleep_for(std::chrono::milliseconds(200)); @@ -75,15 +248,31 @@ class L4CasADi::L4CasADiImpl } void load_model_from_disk() { - std::filesystem::path dir (this->model_path); - std::filesystem::path forward_model_file (this->model_prefix + "_forward.pt"); + std::filesystem::path dir (this->path); + std::filesystem::path forward_model_file (this->function_name + ".pt"); this->forward_model = torch::jit::load((dir / forward_model_file).generic_string()); this->forward_model.to(this->device); this->forward_model.eval(); this->forward_model = torch::jit::optimize_for_inference(this->forward_model); + if (this->has_adj1) { + std::filesystem::path adj1_model_file ("adj1_" + this->function_name + ".pt"); + this->adj1_model = torch::jit::load((dir / adj1_model_file).generic_string()); + this->adj1_model.to(this->device); + this->adj1_model.eval(); + this->adj1_model = torch::jit::optimize_for_inference(this->adj1_model); + } + + if (this->has_jac_adj1) { + std::filesystem::path jac_adj1_model_file ("jac_adj1_" + this->function_name + ".pt"); + this->jac_adj1_model = torch::jit::load((dir / jac_adj1_model_file).generic_string()); + this->jac_adj1_model.to(this->device); + this->jac_adj1_model.eval(); + this->jac_adj1_model = torch::jit::optimize_for_inference(this->jac_adj1_model); + } + if (this->has_jac) { - std::filesystem::path jac_model_file (this->model_prefix + "_jacrev.pt"); + std::filesystem::path jac_model_file ("jac_" + this->function_name + ".pt"); this->jac_model = torch::jit::load((dir / jac_model_file).generic_string()); this->jac_model.to(this->device); this->jac_model.eval(); @@ -91,7 +280,7 @@ class L4CasADi::L4CasADiImpl } if (this->has_hess) { - std::filesystem::path hess_model_file (this->model_prefix + "_hess.pt"); + std::filesystem::path hess_model_file ("jac_jac_" + this->function_name + ".pt"); this->hess_model = torch::jit::load((dir / hess_model_file).generic_string()); this->hess_model.to(this->device); this->hess_model.eval(); @@ -99,71 +288,113 @@ class L4CasADi::L4CasADiImpl } } - torch::Tensor forward(torch::Tensor input) { + torch::Tensor forward(torch::Tensor x) { std::unique_lock lock(this->model_update_mutex); c10::InferenceMode guard; std::vector inputs; - inputs.push_back(input.to(this->device)); + inputs.push_back(x.to(this->device)); return this->forward_model.forward(inputs).toTensor().to(cpu); } - torch::Tensor jac(torch::Tensor input) { + torch::Tensor jac(torch::Tensor x) { std::unique_lock lock(this->model_update_mutex); c10::InferenceMode guard; std::vector inputs; - inputs.push_back(input.to(this->device)); + inputs.push_back(x.to(this->device)); return this->jac_model.forward(inputs).toTensor().to(cpu); } - torch::Tensor hess(torch::Tensor input) { + torch::Tensor adj1(torch::Tensor primal, torch::Tensor tangent) { std::unique_lock lock(this->model_update_mutex); c10::InferenceMode guard; std::vector inputs; - inputs.push_back(input.to(this->device)); + inputs.push_back(primal.to(this->device)); + inputs.push_back(tangent.to(this->device)); + + return this->adj1_model.forward(inputs).toTensor().to(cpu); + } + + torch::Tensor jac_adj1(torch::Tensor primal, torch::Tensor tangent){ + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(primal.to(this->device)); + inputs.push_back(tangent.to(this->device)); + + return this->jac_adj1_model.forward(inputs).toTensor().to(cpu); + } + + torch::Tensor hess(torch::Tensor x) { + std::unique_lock lock(this->model_update_mutex); + c10::InferenceMode guard; + std::vector inputs; + inputs.push_back(x.to(this->device)); return this->hess_model.forward(inputs).toTensor().to(cpu); } }; -L4CasADi::L4CasADi(std::string model_path, std::string model_prefix, bool model_expects_batch_dim, std::string device, - bool has_jac, bool has_hess, bool model_is_mutable): - pImpl{std::make_unique(model_path, model_prefix, device, has_jac, has_hess, model_is_mutable)}, - model_expects_batch_dim{model_expects_batch_dim} {} - -void L4CasADi::forward(const double* in, int rows, int cols, double* out) { - torch::Tensor in_tensor; - if (this->model_expects_batch_dim) { - in_tensor = torch::from_blob(( void * )in, {1, rows}, at::kDouble).to(torch::kFloat); +L4CasADi::L4CasADi(std::string path, std::string function_name, int rows_in, int cols_in, int rows_out, int cols_out, + std::string device, bool has_jac, bool has_adj1, bool has_jac_adj1, bool has_hess, bool scripting, bool is_mutable): + rows_in{rows_in}, cols_in{cols_in}, rows_out{rows_out}, cols_out{cols_out} { +#if ENABLE_TORCH_COMPILE + if (scripting == true) { + this->pImpl = std::make_unique(path, function_name, device, has_jac, has_adj1, has_jac_adj1, has_hess, is_mutable); } else { - in_tensor = torch::from_blob(( void * )in, {cols, rows}, at::kDouble).to(torch::kFloat).permute({1, 0}); - } + this->pImpl = std::make_unique(path, function_name, device, has_jac, has_adj1, has_jac_adj1, has_hess, is_mutable); + } +#else + this->pImpl = std::make_unique(path, function_name, device, has_jac, has_adj1, has_jac_adj1, has_hess, is_mutable); +#endif +} - torch::Tensor out_tensor = this->pImpl->forward(in_tensor).to(torch::kDouble).permute({1, 0}).contiguous(); +void L4CasADi::forward(const double* x, double* out) { + torch::Tensor x_tensor; + x_tensor = torch::from_blob(( void * )x, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); + torch::Tensor out_tensor = this->pImpl->forward(x_tensor).to(torch::kDouble).permute({1, 0}).contiguous(); std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); } -void L4CasADi::jac(const double* in, int rows, int cols, double* out) { - torch::Tensor in_tensor; - if (this->model_expects_batch_dim) { - in_tensor = torch::from_blob(( void * )in, {1, rows}, at::kDouble).to(torch::kFloat); - } else { - in_tensor = torch::from_blob(( void * )in, {cols, rows}, at::kDouble).to(torch::kFloat).permute({1, 0}); - } +void L4CasADi::jac(const double* x, double* out) { + torch::Tensor x_tensor; + x_tensor = torch::from_blob(( void * )x, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); // CasADi expects the return in Fortran order -> Transpose last two dimensions - torch::Tensor out_tensor = this->pImpl->jac(in_tensor).to(torch::kDouble).permute({3, 2, 1, 0}).contiguous(); + torch::Tensor out_tensor = this->pImpl->jac(x_tensor).to(torch::kDouble).permute({3, 2, 1, 0}).contiguous(); std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); } -void L4CasADi::hess(const double* in, int rows, int cols, double* out) { - torch::Tensor in_tensor; - if (this->model_expects_batch_dim) { - in_tensor = torch::from_blob(( void * )in, {1, rows}, at::kDouble).to(torch::kFloat); - } else { - in_tensor = torch::from_blob(( void * )in, {cols, rows}, at::kDouble).to(torch::kFloat).permute({1, 0}); - } +void L4CasADi::adj1(const double* p, const double* t, double* out) { + // adj1 [i0, out_o0, adj_o0] -> [out_adj_i0] + torch::Tensor p_tensor, t_tensor; + p_tensor = torch::from_blob(( void * )p, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); + t_tensor = torch::from_blob(( void * )t, {this->cols_out, this->rows_out}, at::kDouble).to(torch::kFloat).permute({1, 0}); + + // CasADi expects the return in Fortran order -> Transpose last two dimensions + torch::Tensor out_tensor = this->pImpl->adj1(p_tensor, t_tensor).to(torch::kDouble).permute({1, 0}).contiguous(); + std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); +} + +void L4CasADi::jac_adj1(const double* p, const double* t, double* out) { + // jac_adj1 [i0, out_o0, adj_o0, out_adj_i0] -> [jac_adj_i0_i0, jac_adj_i0_out_o0, jac_adj_i0_adj_o0] + torch::Tensor p_tensor, t_tensor; + p_tensor = torch::from_blob(( void * )p, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); + t_tensor = torch::from_blob(( void * )t, {this->cols_out, this->rows_out}, at::kDouble).to(torch::kFloat).permute({1, 0}); // CasADi expects the return in Fortran order -> Transpose last two dimensions - torch::Tensor out_tensor = this->pImpl->hess(in_tensor).to(torch::kDouble).permute({5, 4, 3, 2, 1, 0}).contiguous(); + torch::Tensor out_tensor = this->pImpl->jac_adj1(p_tensor, t_tensor).to(torch::kDouble).permute({3, 2, 1, 0}).contiguous(); std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); } +void L4CasADi::jac_jac(const double* x, double* out) { + torch::Tensor x_tensor; + x_tensor = torch::from_blob(( void * )x, {this->cols_in, this->rows_in}, at::kDouble).to(torch::kFloat).permute({1, 0}); + + // CasADi expects the return in Fortran order -> Transpose last two dimensions + torch::Tensor out_tensor = this->pImpl->hess(x_tensor).to(torch::kDouble).permute({5, 4, 3, 2, 1, 0}).contiguous(); + std::memcpy(out, out_tensor.data_ptr(), out_tensor.numel() * sizeof(double)); +} + +void L4CasADi::invalid_argument(std::string error_msg) { + throw std::invalid_argument(error_msg); +} + L4CasADi::~L4CasADi() = default; diff --git a/pyproject.toml b/pyproject.toml index 04d70b7..9b903d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "l4casadi" -version = "1.4.1" +version = "2.0.0" authors = [ { name="Tim Salzmann", email="Tim.Salzmann@tum.de" }, ] diff --git a/setup.py b/setup.py index 6f1938a..6e61f82 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ def compile_hook(manifest): setup( cmake_process_manifest_hook=compile_hook, cmake_source_dir='libl4casadi', - cmake_args=[f'-DCMAKE_TORCH_PATH={os.path.dirname(os.path.abspath(torch.__file__))}'], + cmake_args=['-DCMAKE_BUILD_TYPE=Release', f'-DCMAKE_TORCH_PATH={os.path.dirname(os.path.abspath(torch.__file__))}'], include_package_data=True, package_data={'': [ 'lib/**.dylib', diff --git a/tests/test_batching.py b/tests/test_batching.py new file mode 100644 index 0000000..ee11abb --- /dev/null +++ b/tests/test_batching.py @@ -0,0 +1,88 @@ +import pytest +import torch +import l4casadi as l4c +import casadi as cs +import numpy as np + + +class TestL4CasADiBatching: + @pytest.mark.parametrize("batch_size,input_size,output_size,jac_ccs_target,hess_ccs_target", [ + (10, 3, 2, [20, 30, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 0, 10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19, 0, 10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19, 0, 10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19], [600, 30, 0, 6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96, 102, 108, 114, 120, 126, 132, 138, 144, 150, 156, 162, 168, 174, 180, 0, 10, 200, 210, 400, 410, 21, 31, 221, 231, 421, 431, 42, 52, 242, 252, 442, 452, 63, 73, 263, 273, 463, 473, 84, 94, 284, 294, 484, 494, 105, 115, 305, 315, 505, 515, 126, 136, 326, 336, 526, 536, 147, 157, 347, 357, 547, 557, 168, 178, 368, 378, 568, 578, 189, 199, 389, 399, 589, 599, 0, 10, 200, 210, 400, 410, 21, 31, 221, 231, 421, 431, 42, 52, 242, 252, 442, 452, 63, 73, 263, 273, 463, 473, 84, 94, 284, 294, 484, 494, 105, 115, 305, 315, 505, 515, 126, 136, 326, 336, 526, 536, 147, 157, 347, 357, 547, 557, 168, 178, 368, 378, 568, 578, 189, 199, 389, 399, 589, 599, 0, 10, 200, 210, 400, 410, 21, 31, 221, 231, 421, 431, 42, 52, 242, 252, 442, 452, 63, 73, 263, 273, 463, 473, 84, 94, 284, 294, 484, 494, 105, 115, 305, 315, 505, 515, 126, 136, 326, 336, 526, 536, 147, 157, 347, 357, 547, 557, 168, 178, 368, 378, 568, 578, 189, 199, 389, 399, 589, 599]), + (3, 4, 3, [9, 12, 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 0, 3, 6, 1, 4, 7, 2, 5, 8, 0, 3, 6, 1, 4, 7, 2, 5, 8, 0, 3, 6, 1, 4, 7, 2, 5, 8, 0, 3, 6, 1, 4, 7, 2, 5, 8], [108, 12, 0, 12, 24, 36, 48, 60, 72, 84, 96, 108, 120, 132, 144, 0, 3, 6, 27, 30, 33, 54, 57, 60, 81, 84, 87, 10, 13, 16, 37, 40, 43, 64, 67, 70, 91, 94, 97, 20, 23, 26, 47, 50, 53, 74, 77, 80, 101, 104, 107, 0, 3, 6, 27, 30, 33, 54, 57, 60, 81, 84, 87, 10, 13, 16, 37, 40, 43, 64, 67, 70, 91, 94, 97, 20, 23, 26, 47, 50, 53, 74, 77, 80, 101, 104, 107, 0, 3, 6, 27, 30, 33, 54, 57, 60, 81, 84, 87, 10, 13, 16, 37, 40, 43, 64, 67, 70, 91, 94, 97, 20, 23, 26, 47, 50, 53, 74, 77, 80, 101, 104, 107, 0, 3, 6, 27, 30, 33, 54, 57, 60, 81, 84, 87, 10, 13, 16, 37, 40, 43, 64, 67, 70, 91, 94, 97, 20, 23, 26, 47, 50, 53, 74, 77, 80, 101, 104, 107]) + ]) + def test_ccs(self, batch_size, input_size, output_size, jac_ccs_target, hess_ccs_target): + jac_ccs, hess_ccs = l4c.L4CasADi.generate_block_diagonal_ccs(batch_size, input_size, output_size) + + assert jac_ccs == jac_ccs_target + assert hess_ccs == hess_ccs_target + + def test_l4casadi_sparse_out(self): + def model(x): + return torch.stack([(x[:, 0]**2 * x[:, 1]**2 * x[:, 2]**2), - (x[:, 0]**2 * x[:, 1]**2)], dim=-1) + + def model_cs(x): + return cs.hcat([(x[:, 0]**2 * x[:, 1]**2 * x[:, 2]**2), - (x[:, 0]**2 * x[:, 1]**2)]) + + inp = np.ones((5, 3)) + inp_sym = cs.MX.sym('x', 5, 3) + + jac_func_cs = cs.Function('f', [inp_sym], [cs.jacobian(model_cs(inp_sym), inp_sym)]) + jac_sparse_cs = jac_func_cs(inp) + + hess_func_cs = cs.Function('f', [inp_sym], [cs.jacobian(cs.jacobian(model_cs(inp_sym), inp_sym), inp_sym)]) + hess_sparse_cs = hess_func_cs(inp) + + l4c_model = l4c.L4CasADi(model, batched=True, generate_jac_jac=True) + + jac_func = cs.Function('f', [inp_sym], [cs.jacobian(l4c_model(inp_sym), inp_sym)]) + jac_sparse = jac_func(inp) + + hess_func = cs.Function('f', [inp_sym], [cs.jacobian(cs.jacobian(l4c_model(inp_sym), inp_sym), inp_sym)]) + hess_sparse = hess_func(inp) + + assert np.allclose(np.array(jac_sparse), np.array(jac_sparse_cs)) + assert np.allclose(np.array(hess_sparse), np.array(hess_sparse_cs)) + + def test_l4casadi_sparse_out_adj1(self): + def model(x): + return torch.stack([(x[:, 0] ** 2 * x[:, 1] ** 2 * x[:, 2] ** 2), - (x[:, 0] ** 2 * x[:, 1] ** 2)], dim=-1) + + def model_cs(x): + return cs.hcat([(x[:, 0] ** 2 * x[:, 1] ** 2 * x[:, 2] ** 2), -(x[:, 0] ** 2 * x[:, 1] ** 2)]) + + inp = np.ones((5, 3)) + tangent = np.zeros((5, 2)) + tangent[:, 0] = 1. + + inp_sym = cs.MX.sym('x', 5, 3) + tangent_sym = cs.MX.sym('x', 5, 2) + + func_cs = cs.Function('f', [inp_sym], [model_cs(inp_sym)]) + adj1_func_cs = func_cs.reverse(1) + + out_sym = func_cs(inp_sym) + out_cs = func_cs(inp) + adj1_out_cs = adj1_func_cs(inp, out_cs, tangent) + + + l4c_model = l4c.L4CasADi(model, batched=True) + y = l4c_model(inp_sym) + + func_t = l4c_model._cs_fun + adj1_func_t = func_t.reverse(1) + + out_t = func_t(inp) + adj1_out_t = adj1_func_t(inp, out_t, tangent) + + assert (np.array(adj1_out_cs) == np.array(adj1_out_t)).all() + + jac_adj1_func_cs = cs.Function('jac_adj1_f', [inp_sym, tangent_sym], + [cs.jacobian(adj1_func_cs(inp_sym, out_sym, tangent_sym), inp_sym)]) + jac_adj1_cs = jac_adj1_func_cs(inp, tangent) + + jac_adj1_func_t = cs.Function('jac_adj1_ft', [inp_sym, tangent_sym], + [cs.jacobian(adj1_func_t(inp_sym, func_t(inp_sym), tangent_sym), inp_sym)]) + jac_adj1_t = jac_adj1_func_t(inp, tangent) + + assert (np.array(jac_adj1_cs) == np.array(jac_adj1_t)).all() + diff --git a/tests/test_l4casadi.py b/tests/test_l4casadi.py index 16a06dd..4d1642c 100644 --- a/tests/test_l4casadi.py +++ b/tests/test_l4casadi.py @@ -50,15 +50,15 @@ def test_l4casadi_deep_model(self, deep_model): rand_inp = torch.rand((1, deep_model.input_layer.in_features)) torch_out = deep_model(rand_inp) - l4c_out = l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = l4c.L4CasADi(deep_model, batched=True)(rand_inp.detach().numpy()) - assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy(), atol=1e-6) + assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) def test_l4casadi_triag_model(self, triag_model): rand_inp = torch.rand((12, 12)) torch_out = triag_model(rand_inp) - l4c_out = l4c.L4CasADi(triag_model, model_expects_batch_dim=False)(rand_inp.detach().numpy()) + l4c_out = l4c.L4CasADi(triag_model)(rand_inp.detach().numpy()) assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) @@ -70,7 +70,7 @@ def test_l4casadi_triag_model_jac(self, triag_model): jac_fun = cs.Function('f_jac', [mx_inp], - [cs.jacobian(l4c.L4CasADi(triag_model, model_expects_batch_dim=False)(mx_inp), mx_inp)]) + [cs.jacobian(l4c.L4CasADi(triag_model)(mx_inp), mx_inp)]) l4c_out = jac_fun(rand_inp.detach().numpy()) @@ -88,7 +88,7 @@ def test_l4casadi_triag_model_hess_double_jac(self, triag_model): [mx_inp], [cs.jacobian( cs.jacobian( - l4c.L4CasADi(triag_model, model_expects_batch_dim=False)(mx_inp), mx_inp + l4c.L4CasADi(triag_model, generate_jac_jac=True)(mx_inp), mx_inp )[0, 0], mx_inp)]) l4c_out = hess_fun(rand_inp.transpose(-2, -1).detach().numpy()) @@ -100,28 +100,43 @@ def test_l4casadi_deep_model_jac(self, deep_model): rand_inp = torch.rand((1, deep_model.input_layer.in_features)) torch_out = torch.func.vmap(torch.func.jacrev(deep_model))(rand_inp)[0] - mx_inp = cs.MX.sym('x', deep_model.input_layer.in_features, 1) + mx_inp = cs.MX.sym('x', 1, deep_model.input_layer.in_features) jac_fun = cs.Function('f_jac', [mx_inp], - [cs.jacobian(l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(mx_inp), mx_inp)]) + [cs.jacobian(l4c.L4CasADi(deep_model)(mx_inp), mx_inp)]) - l4c_out = jac_fun(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = jac_fun(rand_inp.detach().numpy()) assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) - def test_l4casadi_deep_model_hess(self): + def test_l4casadi_deep_model_hess_with_jac_adj(self): deep_model = DeepModel(4, 1) rand_inp = torch.rand((1, deep_model.input_layer.in_features)) torch_out = torch.func.vmap(torch.func.hessian(deep_model))(rand_inp)[0] - mx_inp = cs.MX.sym('x', deep_model.input_layer.in_features, 1) + mx_inp = cs.MX.sym('x', 1, deep_model.input_layer.in_features) hess_fun = cs.Function('f_hess', [mx_inp], - [cs.hessian(l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(mx_inp), mx_inp)[0]]) + [cs.hessian(l4c.L4CasADi(deep_model, generate_adj1=True, generate_jac_jac=False)(mx_inp), mx_inp)[0]]) - l4c_out = hess_fun(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = hess_fun(rand_inp.detach().numpy()) + + assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) + + def test_l4casadi_deep_model_hess_with_jac_jac(self): + deep_model = DeepModel(4, 1) + rand_inp = torch.rand((1, deep_model.input_layer.in_features)) + torch_out = torch.func.vmap(torch.func.hessian(deep_model))(rand_inp)[0] + + mx_inp = cs.MX.sym('x', 1, deep_model.input_layer.in_features) + + hess_fun = cs.Function('f_hess', + [mx_inp], + [cs.hessian(l4c.L4CasADi(deep_model, generate_adj1=False, generate_jac_jac=True)(mx_inp), mx_inp)[0]]) + + l4c_out = hess_fun(rand_inp.detach().numpy()) assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) @@ -130,25 +145,25 @@ def test_l4casadi_deep_model_hess_double_jac(self): rand_inp = torch.rand((1, deep_model.input_layer.in_features)) torch_out = torch.func.vmap(torch.func.hessian(deep_model))(rand_inp)[0] - mx_inp = cs.MX.sym('x', deep_model.input_layer.in_features, 1) + mx_inp = cs.MX.sym('x', 1, deep_model.input_layer.in_features) hess_fun = cs.Function('f_hess_double_jac', [mx_inp], [cs.jacobian( cs.jacobian( - l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(mx_inp), mx_inp + l4c.L4CasADi(deep_model, generate_jac_jac=True)(mx_inp), mx_inp )[0], mx_inp)]) - l4c_out = hess_fun(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = hess_fun(rand_inp.detach().numpy()) assert np.allclose(l4c_out, torch_out[0, 0].detach().numpy(), atol=1e-6) def test_l4casadi_deep_model_online_update(self, deep_model): rand_inp = torch.rand((1, deep_model.input_layer.in_features)) - l4c_model = l4c.L4CasADi(deep_model, model_expects_batch_dim=True, mutable=True) + l4c_model = l4c.L4CasADi(deep_model, mutable=True) - l4c_out_old = l4c_model(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out_old = l4c_model(rand_inp.detach().numpy()) # Change model and online update L4CasADi deep_model.input_layer.reset_parameters() @@ -156,7 +171,7 @@ def test_l4casadi_deep_model_online_update(self, deep_model): torch_out = deep_model(rand_inp) - l4c_out = l4c_model(rand_inp.transpose(-2, -1).detach().numpy()) + l4c_out = l4c_model(rand_inp.detach().numpy()) - assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy(), atol=1e-6) - assert not np.allclose(l4c_out_old, torch_out.transpose(-2, -1).detach().numpy(), atol=1e-6) + assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6) + assert not np.allclose(l4c_out_old, torch_out.detach().numpy(), atol=1e-6) diff --git a/tests/test_naive_l4casadi.py b/tests/test_naive_l4casadi.py index 76f2a87..3198e4a 100644 --- a/tests/test_naive_l4casadi.py +++ b/tests/test_naive_l4casadi.py @@ -12,8 +12,8 @@ def test_naive_l4casadi_mlp(self): rand_inp = torch.rand((1, 2)) torch_out = naive_mlp(rand_inp) - cs_inp = cs.DM(rand_inp.transpose(-2, -1).detach().numpy()) + cs_inp = cs.DM(rand_inp.detach().numpy()) - l4c_out = l4c.L4CasADi(naive_mlp, model_expects_batch_dim=True)(cs_inp) + l4c_out = l4c.L4CasADi(naive_mlp)(cs_inp) - assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy(), atol=1e-6) + assert np.allclose(l4c_out, torch_out.detach().numpy(), atol=1e-6)