From dcd83336b68049763973709733bf2721a687507d Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Thu, 17 Oct 2024 16:30:45 +0530 Subject: [PATCH] Testcases (#2567) --- candle-core/src/tensor.rs | 7 +- candle-core/tests/tensor_tests.rs | 274 ++++++++++++++++++++++++++++++ 2 files changed, 278 insertions(+), 3 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 7dd24abf9..e7355aadc 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1520,14 +1520,15 @@ impl Tensor { /// # Arguments /// /// * `self` - The input tensor. - /// * `indexes` - The indices of elements to gather, this should have the same shape as `self` - /// but can have a different number of elements on the target dimension. + /// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self` + /// and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim /// * `dim` - the target dimension. /// /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on /// dimension `dim` by the values in `indexes`. pub fn gather(&self, indexes: &Self, dim: D) -> Result { let dim = dim.to_index(self.shape(), "gather")?; + let self_dims = self.dims(); let indexes_dims = indexes.dims(); let mismatch = if indexes_dims.len() != self_dims.len() { @@ -1535,7 +1536,7 @@ impl Tensor { } else { let mut mismatch = false; for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() { - if i != dim && d1 != d2 { + if i != dim && d1 < d2 { mismatch = true; break; } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index e0cea15c6..e3246a33a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1047,6 +1047,280 @@ fn gather(device: &Device) -> Result<()> { let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?; let hs = t.gather(&ids, 0)?; assert_eq!(hs.to_vec2::()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]); + + // Random data + + // Dim: 0 + let t = Tensor::new( + &[ + [ + [108_f32, -47., 16., -56., -83., -130., 210.], + [253., 95., 151., 228., -210., -123., -127.], + [-9., -217., 2., -78., 163., 245., -204.], + [-246., 79., -238., 88., -226., -184., 171.], + [8., -48., -153., 234., -34., 166., -153.], + [124., 0., -10., -61., -242., -15., -238.], + ], + [ + [12., -64., -199., 244., -240., 156., -128.], + [173., -57., 4., -198., 233., -110., 238.], + [95., 82., 0., 240., 53., -211., 209.], + [-122., 167., -212., 227., -144., 61., 118.], + [-63., -146., 200., 244., 168., -167., 116.], + [-125., -147., 110., -253., -178., -250., -18.], + ], + [ + [57., 86., -50., 56., 92., 205., -78.], + [-137., -156., -18., 248., -61., -239., 14.], + [-248., -30., -50., -70., -251., 250., -83.], + [-221., 67., 72., 59., -24., -154., 232.], + [-144., -23., -74., 5., 93., 171., 205.], + [46., -77., -38., -226., 246., 161., -17.], + ], + [ + [-153., -231., -236., 161., 126., 2., -22.], + [-229., -41., 209., 164., 234., 160., 57.], + [223., 254., -186., -162., -46., -160., -102.], + [65., 30., 213., -253., 59., 224., -154.], + [-82., -203., -177., 17., 31., -256., -246.], + [176., -135., -65., 54., -56., 210., 76.], + ], + [ + [-10., -245., 168., 124., -14., -33., -178.], + [25., -43., -39., 132., -89., 169., 179.], + [187., -215., 32., -133., 87., -7., -168.], + [-224., -215., -5., -230., -58., -162., 128.], + [158., -137., -122., -100., -202., -83., 136.], + [30., -185., -144., 250., 209., -40., 127.], + ], + [ + [-196., 108., -245., 122., 146., -228., 62.], + [-1., -66., 160., 137., 13., -172., -21.], + [244., 199., -164., 28., 119., -175., 198.], + [-62., 253., -162., 195., -95., -230., -211.], + [123., -72., -26., -107., -139., 64., 245.], + [11., -126., -182., 108., -12., 184., -127.], + ], + [ + [-159., 126., 176., 161., 73., -111., -138.], + [-187., 214., -217., -33., -223., -201., -212.], + [-61., -120., -166., -172., -95., 53., 196.], + [-33., 86., 134., -152., 154., -53., 74.], + [186., -28., -154., -174., 141., -109., 217.], + [82., 35., 252., 145., 181., 74., -87.], + ], + ], + device, + )?; + + let ids = Tensor::new( + &[ + [ + [6_u32, 6, 4, 3, 4, 4, 6], + [3, 3, 2, 4, 4, 4, 6], + [3, 3, 0, 2, 4, 6, 4], + [2, 5, 1, 2, 6, 6, 1], + [2, 1, 6, 5, 3, 2, 3], + [6, 1, 0, 1, 0, 2, 6], + ], + [ + [4, 6, 4, 3, 3, 3, 2], + [4, 3, 2, 4, 4, 4, 6], + [2, 3, 0, 2, 4, 6, 4], + [6, 5, 1, 2, 6, 6, 1], + [4, 1, 6, 5, 3, 2, 3], + [1, 1, 0, 1, 0, 2, 6], + ], + [ + [3, 6, 4, 3, 3, 3, 2], + [2, 3, 2, 4, 4, 4, 6], + [4, 3, 0, 2, 4, 6, 4], + [0, 5, 1, 2, 6, 6, 1], + [6, 1, 6, 5, 3, 2, 3], + [4, 1, 0, 1, 0, 2, 6], + ], + [ + [0, 6, 4, 3, 3, 3, 2], + [5, 3, 2, 4, 4, 4, 6], + [0, 3, 0, 2, 4, 6, 4], + [3, 5, 1, 2, 6, 6, 1], + [0, 1, 6, 5, 3, 2, 3], + [3, 1, 0, 1, 0, 2, 6], + ], + ], + device, + )?; + + let hs = t.gather(&ids, 0)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [ + [-159_f32, 126., 168., 161., -14., -33., -138.], + [-229., -41., -18., 132., -89., 169., -212.], + [223., 254., 2., -70., 87., 53., -168.], + [-221., 253., -212., 59., 154., -53., 118.], + [-144., -146., -154., -107., 31., 171., -246.], + [82., -147., -10., -253., -242., 161., -87.] + ], + [ + [-10., 126., 168., 161., 126., 2., -78.], + [25., -41., -18., 132., -89., 169., -212.], + [-248., 254., 2., -70., 87., 53., -168.], + [-33., 253., -212., 59., 154., -53., 118.], + [158., -146., -154., -107., 31., 171., -246.], + [-125., -147., -10., -253., -242., 161., -87.] + ], + [ + [-153., 126., 168., 161., 126., 2., -78.], + [-137., -41., -18., 132., -89., 169., -212.], + [187., 254., 2., -70., 87., 53., -168.], + [-246., 253., -212., 59., 154., -53., 118.], + [186., -146., -154., -107., 31., 171., -246.], + [30., -147., -10., -253., -242., 161., -87.] + ], + [ + [108., 126., 168., 161., 126., 2., -78.], + [-1., -41., -18., 132., -89., 169., -212.], + [-9., 254., 2., -70., 87., 53., -168.], + [65., 253., -212., 59., 154., -53., 118.], + [8., -146., -154., -107., 31., 171., -246.], + [176., -147., -10., -253., -242., 161., -87.] + ] + ] + ); + + // Dim: 1 + let t = Tensor::new( + &[ + [ + [-117_f32, -175., 69., -163.], + [200., 242., -21., -67.], + [179., 150., -126., -75.], + [-118., 38., -138., -13.], + [-221., 136., -185., 180.], + [58., 182., -204., -149.], + ], + [ + [3., -148., -58., -154.], + [-43., 45., -108., 4.], + [-69., -249., -71., -21.], + [80., 110., -152., -235.], + [-88., 7., 92., -250.], + [-186., 207., -242., 98.], + ], + [ + [238., 19., 64., -242.], + [-150., -97., 218., 58.], + [111., -233., 204., -212.], + [-242., -232., 83., 42.], + [153., 62., -251., 219.], + [-117., 36., -119., 10.], + ], + [ + [215., 159., -169., -27.], + [-83., 101., -88., 169.], + [-205., 93., 225., -64.], + [-162., 240., 214., 23.], + [-112., 6., 21., 245.], + [-38., 113., 93., 215.], + ], + [ + [91., -188., -148., 101.], + [74., 203., -35., 55.], + [-116., -130., -153., -96.], + [58., 22., -45., -194.], + [-221., -134., 73., 159.], + [-203., -254., 31., 235.], + ], + [ + [105., -53., 61., 186.], + [-195., 234., 75., -1.], + [51., 139., 160., -108.], + [-173., -167., 161., 19.], + [83., -246., 156., -222.], + [109., 39., -149., 137.], + ], + ], + device, + )?; + + let ids = Tensor::new( + &[ + [[4_u32, 4, 4, 2]], + [[0, 4, 4, 3]], + [[1, 5, 3, 4]], + [[0, 3, 3, 2]], + [[1, 1, 5, 2]], + [[1, 4, 5, 4]], + ], + device, + )?; + + let hs = t.gather(&ids, 1)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [[-221., 136., -185., -75.]], + [[3., 7., 92., -235.]], + [[-150., 36., 83., 219.]], + [[215., 240., 214., -64.]], + [[74., 203., 31., -96.]], + [[-195., -246., -149., -222.]] + ] + ); + + // Dim: 2 + let t = Tensor::new( + &[ + [[-162_f32, 202.], [-126., -39.], [35., -65.], [1., 80.]], + [[37., 248.], [-191., 89.], [117., -40.], [-217., 220.]], + ], + device, + )?; + + let ids = Tensor::new(&[[[1_u32], [0], [1], [1]], [[0], [1], [0], [1]]], device)?; + + let hs = t.gather(&ids, 2)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [[202.], [-126.], [-65.], [80.]], + [[37.], [89.], [117.], [220.]] + ] + ); + + let t = Tensor::new( + &[ + [[-21_f32, -197.], [194., 122.]], + [[255., -106.], [-191., 250.]], + [[33., -117.], [43., 10.]], + [[-130., 238.], [-217., -92.]], + ], + device, + )?; + + let ids = Tensor::new( + &[ + [[0_u32, 1], [1, 0]], + [[1, 0], [0, 1]], + [[0, 1], [0, 1]], + [[1, 0], [1, 0]], + ], + device, + )?; + + let hs = t.gather(&ids, 2)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [[-21., -197.], [122., 194.]], + [[-106., 255.], [-191., 250.]], + [[33., -117.], [43., 10.]], + [[238., -130.], [-92., -217.]] + ] + ); + Ok(()) }