Skip to content

Commit

Permalink
add continuity requirement for unstack
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Feb 8, 2024
1 parent 5994ac5 commit 5ffff2d
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions dfdx-core/src/tensor_ops/unstack/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ impl<S: Shape, E: Dtype, D: UnstackKernel<E>, T, const N: usize> TryUnstack<Cons
for Tensor<S, E, D, T>
where
S: SubDim<Head = Const<N>>,
D: super::reshape_to::ReshapeKernel<E>,
T: Tape<E, D>,
{
type Unstacked = ([Tensor<S::Tail, E, D, T>; N], T);
Expand All @@ -57,6 +58,7 @@ where
impl<S: Shape, E: Dtype, D: UnstackKernel<E>, T> TryUnstack<usize> for Tensor<S, E, D, T>
where
S: SubDim<Head = usize>,
D: super::reshape_to::ReshapeKernel<E>,
T: Tape<E, D>,
{
type Unstacked = (Vec<Tensor<S::Tail, E, D, T>>, T);
Expand Down Expand Up @@ -136,6 +138,7 @@ fn try_unstack<OptionalItems, Items, S: Shape, E: Dtype, D: UnstackKernel<E>, T>
where
S: SubDim,
T: Tape<E, D>,
D: super::reshape_to::ReshapeKernel<E>,
OptionalItems: Array<Option<Tensor<S::Tail, E, D, NoneTape>>, Dim = S::Head>
+ std::ops::IndexMut<usize, Output = Option<Tensor<S::Tail, E, D, NoneTape>>>,
Items: Array<Tensor<S::Tail, E, D, T>, Dim = S::Head>,
Expand All @@ -144,10 +147,17 @@ where
let (head, _tail) = stack.shape().sub_dim();
let (stack, stack_tape) = stack.split_tape();

// TODO: remove this overhead, and panic on a non-contiguous condition
let stack = {
use super::reshape_to::ReshapeTo;
stack.try_contiguous()?
};

let stack_ghost = stack.ghost();

// list of optional tensors (all are Some)
let mut unstacks = device.forward::<_, OptionalItems>(stack)?;

let mut unstacks = UnstackKernel::forward::<_, OptionalItems>(&device, stack)?;

// tensors from unstacks must get tapes inserted into them.
// to do this, from_fn is re-utilized, but this time without optionals
Expand All @@ -163,7 +173,7 @@ where
grads.try_alloc_for(&stack_ghost)?;
grads.try_alloc_for(&unstack_ghost)?;
let (grad_stack, grad_unstack) = grads.mut_and_ref(&stack_ghost, &unstack_ghost);
device.backward(grad_stack, grad_unstack, i)
UnstackKernel::backward(&device, grad_stack, grad_unstack, i)
});
unstack.put_tape(unstack_tape)
},
Expand Down

0 comments on commit 5ffff2d

Please sign in to comment.