diff --git a/dfdx/src/nn/layers/mamba_minimal.rs b/dfdx/src/nn/layers/mamba_minimal.rs index 59e81fa3..465aa97f 100644 --- a/dfdx/src/nn/layers/mamba_minimal.rs +++ b/dfdx/src/nn/layers/mamba_minimal.rs @@ -673,7 +673,7 @@ pub mod stateless { Err(_delta_a) => unreachable!(), }; let (delta_a, _delta_a_tape): (Vec>, _) = - delta_a.try_unstack()?; + delta_a.try_contiguous()?.try_unstack()?; // // delta B let delta_bu: Tensor<(usize, Batch, DInner, DState), _, _, _> = match delta_bu.try_realize() @@ -682,7 +682,7 @@ pub mod stateless { Err(_delta_bu) => unreachable!(), }; let (delta_bu, _delta_bu_tape): (Vec>, _) = - delta_bu.try_unstack()?; + delta_bu.try_contiguous()?.try_unstack()?; // // C let c: Tensor<(usize, Batch, DState, C1), _, _, _> = match c