-
Notifications
You must be signed in to change notification settings - Fork 744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pytorch] How to use SequentialImpl StringAnyModuleDict AnyModuleVector these class? #1283
Comments
Yes, unfortunately, Sequential is not currently usable, see issue #623 (comment). |
Right. Don't try to use public class ConvBnRelu extends Module {
final Conv2dImpl conv;
final BatchNorm2dImpl bn;
final ReLUImpl relu;
ConvBnRelu(int inChannels, int outChannels, int kernelSize, int stride) {
Conv2dOptions convOpt = new Conv2dOptions(inChannels, outChannels, new ExpandingArray2(kernelSize));
convOpt.stride().put(new long[]{stride, stride});
convOpt.padding().put(new ExpandingArray2(kernelSize/2));
convOpt.bias().put(false);
conv = new Conv2dImpl(convOpt);
register_module("conv", conv);
BatchNormOptions bnOpt = new BatchNormOptions(outChannels);
bn = new BatchNorm2dImpl(bnOpt);
register_module("bn", bn);
ReLUOptions reluOpt = new ReLUOptions();
reluOpt.inplace().put(true);
relu = new ReLUImpl(reluOpt);
register_module("relu", relu);
}
Tensor forward(Tensor x) {
return relu.forward(bn.forward(conv.forward(x)));
}
} |
@HGuillemet very thank ,I just want to say does exist another way to do like module collect , thanks |
Sorry I don't understand your question. |
just instead of sequential |
If I create ConvBnRelu module layer block , then I found the model Net need a list of ConvBnRelu, if use Sequential or |
It also has not been done because it's easy, and more Java-ish, to do without Sequential. For instance to chain 2 public class Chain extends Module {
final ConvBnRelu block1 = new ConvBnRelu(3, 10, 3, 1);
final ConvBnRelu block2 = new ConvBnRelu(10, 20, 3, 3);
Chain() {
register_module("block1", block1);
register_module("block2", block2);
}
Tensor forward(Tensor x) {
return block2.forward(block1.forward(x));
}
} |
If that's the only issue, I think simply mapping the forward() function to a few overloads taking Tensor arguments as follows, which covers most use cases, should work: public native @ByVal Tensor forward(@Const @ByRef Tensor input0);
public native @ByVal Tensor forward(@Const @ByRef Tensor input0, @Const @ByRef Tensor input1);
public native @ByVal Tensor forward(@Const @ByRef Tensor input0, @Const @ByRef Tensor input1, @Const @ByRef Tensor input2);
public native @ByVal Tensor forward(@Const @ByRef Tensor input0, @Const @ByRef Tensor input1, @Const @ByRef Tensor input2, @Const @ByRef Tensor input3);
... |
First issue is to map As a side note: I don't think we need Issue 2 is more problematic : we should be able to use |
Now that you mention it, we should be able to map all instances of the AnyModule constructor for all concrete Module types, just like with register_module(). That should work, right?
I'd have to look at that bit more again, but there's a reason that's there. It doesn't work if we don't map all instances.
From what I understand looking at that again, we don't need to implement |
I think it will, but I'm not sure it's a good thing to do it if
You're right, a quick test give me a SIGSEGV when calling
The aim of |
It obviously doesn't work for modules implemented in Java, it's only going to work for the ones implemented in C++, but it's better than nothing :) |
I also know Sequential and anyModule are difficult implement in java ,thanks for the contributors do a lot of background work |
I have tried to map The preset can provide this helper class, but as @saudet often points it out, this is rather the role of an upper-level software. The lack of mapping for the native sequential module is not functionally blocking. |
Could you provide the compiler errors that you get when you try what I wrote above #1283 (comment)? |
I don't know how to translate this as public native @ByVal Tensor forward(@Const @ByRef Tensor input0); in Conv2dImpl impl = new Conv2dImpl(1, 1, new LongPointer(3, 3));
AnyModule any = new AnyModule(impl);
SequentialImpl seq = new SequentialImpl();
seq.push_back("conv", any);
Tensor in = torch.rand(1, 1, 10, 10);
seq.forward(in); However, I'm still unsure if we'd better map the native package org.bytedeco.pytorch;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
public class Sequential<O> extends Module {
private final Module[] modules;
private final Method[] forwardMethods;
public Sequential(Module... modules) {
this.modules = modules;
forwardMethods = new Method[modules.length];
MODULE:
for (int i = 0; i < modules.length; i++) {
for (Method method: modules[i].getClass().getMethods()) {
if (method.getName().equals("forward")) {
forwardMethods[i] = method;
register_module(Integer.toString(i), modules[i]);
continue MODULE;
}
}
throw new IllegalArgumentException("No forward method found for module "+i);
}
}
public O forward(java.lang.Object... inputs) {
java.lang.Object output;
for (int i = 0; ; i++) {
try {
output = forwardMethods[i].invoke(modules[i], inputs);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
if (i == modules.length - 1) break;
inputs = new java.lang.Object[] { output };
}
return (O) output;
}
} |
We can probably just pass that as Info.javaText for the forward() function.
We can add code like that as part of a helper class, sure. We can probably add all that we want as a subclass like org.bytedeco.pytorch.AbstractSequential and an overloaded constructor though. If possible, I wouldn't add another Sequential class just for that. |
I would either provide the native Sequential or a Java one, not both, nor one as a subclass of the other, since they do the same thing. My preference is for the Java one since it's compatible with custom Java modules. |
If it's not meant to support the C++ API, but provide a higher-level Java API on top of the C++ API, that should probably be done as part of another module, in another repository. We can add classes like that as part of JavaCV if you like?
I wasn't aware of any issues with that. You'll need to provide more details to go about it. I can't fix what isn't broken :) |
Subclassing Module is the way the API is supposed to be used, but the native Sequential cannot chain Java modules. So the Java Sequential is a a way to fix the native sequential. But I agree that it could belong to a higher level software, since it's nothing essential, like said before. |
I dug a bit further and here is what I understood: Module m = new MyModule(); // some Java subclass of Module
Conv2dImpl c = new Conv2dImpl(1, 1, new LongPointer(3, 3));
register_module("x", m); // => ok
register_module("y", c); // => ok
m = c;
register_module("z", m); // => SIGSEGV And that's an issue for any method taking a
Currently the presets also defines classes like My impression is that we should be able to find a more consistent mapping, and that we could hide the |
I'm not entirely sure what the purpose of ModuleHolder is supposed to be, but you can try to map register_module() for those and see if it behaves more like we'd expect it to when used from Java. |
Ok. I misundestood indeed. Each time a C++ function takes a Did I get it right this time ? |
Not quite. |
Not really, no. We need to add methods somewhere to call static_cast and/or dynamic_cast, manually. That's what the asModule() methods are for. |
Hi, Now in new version pytorch-java do we friendly support SequentialImpl ModuleDict ModuleList ? not only complex model need these in recommend system algorithm, we want to transfer python pytorch code to java code ,not only load python trained pytorch model to java |
Not in the 2.0.1 version currently online, but I'm working on a big overhaul of the presets and I'll try to add support for those. |
HI ,
I write python pytorch code ,use torch.nn.Sequential this object, but in javacpp pytorch use Sequential is complex, this object not have [forward] method ,and it constructor need StringAnyModuleDict ,but when foreach cannot get the Module element real type ,only see AnyModule, how to get real module type and as the real insert module index invoke the moudle element forward() method
in scala
The text was updated successfully, but these errors were encountered: