Skip to content

Commit

Permalink
Fix scaled masked softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos authored and kali committed Dec 13, 2024
1 parent d47b6f9 commit 0100dc0
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions metal/src/kernels/nn/scaled_masked_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,12 @@ mod tests {
where
F: Datum + Float + std::ops::AddAssign,
usize: AsPrimitive<F>,
f32: AsPrimitive<F>,
{
pub fn reference(&self) -> Result<Tensor> {
let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?;
let mask = Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?;
let scale: Arc<_> = tensor0(0.125f32).into();
let scale: Arc<_> = tensor0::<F>(0.125f32.as_()).into();

let cpu_output = BasicScaledMaskedSoftmax { scale }
.eval(tvec![a.into_tvalue(), mask.into_tvalue()])?[0]
Expand All @@ -237,7 +238,7 @@ mod tests {
let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_metal()?;
let mask =
Tensor::from_shape(self.mask_shape.as_slice(), &self.mask)?.into_metal()?;
let scale: Arc<_> = tensor0(0.125f32).into();
let scale: Arc<_> = tensor0::<F>(0.125f32.as_()).into();
let metal_output = ScaledMaskedSoftmax.eval(context, &a, &scale, &mask)?;
metal_output.to_cpu()
})
Expand Down

0 comments on commit 0100dc0

Please sign in to comment.