Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I reorder tensors in place like x[:,:,:,[2,0,1]] In PyTorch? #2647

Closed
i10416 opened this issue Jun 11, 2023 · 4 comments
Closed

How can I reorder tensors in place like x[:,:,:,[2,0,1]] In PyTorch? #2647

i10416 opened this issue Jun 11, 2023 · 4 comments
Assignees

Comments

@i10416
Copy link
Contributor

i10416 commented Jun 11, 2023

Given x of shape (1, 10, 10, 3),
I can reorder( and replace in place) tensors in the specific axis by, for example, x[:,:,:,[2,0,1]] In PyTorch.
How can I achieve the same result using DJL NDArray API?

batch_size = 1
height = 10
width = 10
channel = 3
x = torch.randn((batch_size, height, width, channel))
x
tensor([[[[ 1.2450,  0.2395, -1.4496],
          [-1.4505,  0.8022, -0.8087],
          [-0.8357, -0.5123,  1.1846],
          [-1.1332,  0.0763,  1.2089],
          [-1.0103, -2.2320,  0.1810],
          [ 0.7712,  0.5609, -0.2574],
          [ 0.3336,  0.4204, -0.4664],
          [ 1.8834,  0.3339, -1.4987],
          [-1.5052,  0.1414,  2.9350],
          [ 0.3335,  0.3214,  1.6047]],

         [[-0.3892,  0.4478, -0.4097],
          [ 1.2167, -1.5380, -0.1554],
          [-2.2246,  0.2458, -0.3464],
          [-1.2612,  0.4891, -1.4027],
          [ 1.6989, -0.1904, -1.4988],
          [ 1.2409,  0.8922,  1.4012],
....
x[:,:,:,[2,0,1]]
tensor([[[[-1.4496,  1.2450,  0.2395],
          [-0.8087, -1.4505,  0.8022],
          [ 1.1846, -0.8357, -0.5123],
          [ 1.2089, -1.1332,  0.0763],
          [ 0.1810, -1.0103, -2.2320],
          [-0.2574,  0.7712,  0.5609],
          [-0.4664,  0.3336,  0.4204],
          [-1.4987,  1.8834,  0.3339],
          [ 2.9350, -1.5052,  0.1414],
          [ 1.6047,  0.3335,  0.3214]],

         [[-0.4097, -0.3892,  0.4478],
          [-0.1554,  1.2167, -1.5380],
          [-0.3464, -2.2246,  0.2458],
          [-1.4027, -1.2612,  0.4891],
          [-1.4988,  1.6989, -0.1904],
          [ 1.4012,  1.2409,  0.8922],
          [ 0.7958,  0.1829,  0.7539],
          [-0.1230,  0.8494, -1.2449],

For now, I use the following code.

//> using scala "3.3.0"
//> using dep "ai.djl:api:0.22.1"
//> using dep "ai.djl:basicdataset:0.22.1"
//> using dep "org.slf4j:slf4j-simple:2.0.7"
//> using dep "ai.djl.pytorch:pytorch-engine:0.22.1"

import ai.djl.*
import ai.djl.ndarray.*
import ai.djl.ndarray.types.*
import ai.djl.ndarray.index.NDIndex

import scala.util.chaining.*
import scala.jdk.CollectionConverters.*

val mg = NDManager.newBaseManager()
val sample = mg.randomNormal(new Shape(1, 10, 10, 3))

val a0 = sample.get(NDIndex(":,:,:,0"))
val a0clone = mg.create(a0.getShape())
a0.copyTo(a0clone)
val a1 = sample.get(NDIndex(":,:,:,1"))
val a1clone = mg.create(a1.getShape())
a1.copyTo(a1clone)
val a2 = sample.get(NDIndex(":,:,:,2"))
val a2clone = mg.create(a2.getShape())
a2.copyTo(a2clone)

sample.set(NDIndex(":,:,:,0"), a2clone)
sample.set(NDIndex(":,:,:,1"), a0clone)
sample.set(NDIndex(":,:,:,2"), a1clone)
@KexinFeng
Copy link
Contributor

We already have full support of pytorch indexing: #1719 and #1755 . See the demos therein.

To your use case, it can be easily done like the following:

val mg = NDManager.newBaseManager()
val sample = mg.randomNormal(new Shape(1, 10, 10, 3))
val indexArray = mg.create(new int {0, 2, 1});

val newArray = sample.get(new NDIndex(":, :, :, {}"), indexArray));
sample.set(new NDIndex(":, :, :, :"), newArray) // if needed

@i10416
Copy link
Contributor Author

i10416 commented Jun 14, 2023

Ah, I didn't know those PRs. Thanks a lot!

@i10416 i10416 closed this as completed Jun 14, 2023
@i10416
Copy link
Contributor Author

i10416 commented Jun 14, 2023

By the way, is there a correspondence table of PyTorch tensor API vs djl tensor API similar to that of numpy vs breeze(see https://github.com/scalanlp/breeze/wiki/Linear-Algebra-Cheat-Sheet)?

And if there isn't, is it helpful to write such comparison as a Wiki or a document?

@zachgk
Copy link
Contributor

zachgk commented Jun 14, 2023

@i10416 We don't have a corresponding table like that. I could see it being useful though! If you are interested in writing one, you can put it in a markdown document PR and we can add it to our docs near http://docs.djl.ai/master/engines/pytorch/index.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants