From 9e000facdec8c1f1fc7b57fb814720de261d0c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mamy=20Andr=C3=A9-Ratsimbazafy?= Date: Sun, 10 Jan 2021 17:12:26 +0100 Subject: [PATCH] I tried so hard... And got so far... - https://github.com/nim-lang/Nim/issues/16664 - https://github.com/nim-lang/Nim/issues/4687 --- flambeau/raw_bindings/data_api.nim | 38 +++++++++++++++++++++----- flambeau/raw_bindings/neural_nets.nim | 2 +- flambeau/raw_bindings/optimizers.nim | 1 + flambeau/raw_bindings/tensors.nim | 5 ++++ proof_of_concepts/poc09_end_to_end.nim | 37 ++++++++++++++++++++----- 5 files changed, 68 insertions(+), 15 deletions(-) diff --git a/flambeau/raw_bindings/data_api.nim b/flambeau/raw_bindings/data_api.nim index d0ba859..6b88028 100644 --- a/flambeau/raw_bindings/data_api.nim +++ b/flambeau/raw_bindings/data_api.nim @@ -40,7 +40,8 @@ type [Batch] = object func next*(it: var TorchDataIterator) {.importcpp: "(++#)".} -func get*[Batch](it: var TorchDataIterator[Batch]): lent Batch {.importcpp: "(*#)".} +func get*[Batch](it: var TorchDataIterator[Batch]): Batch {.importcpp: "(*#)".} + # TODO: this should be lent? func `==`*(it1, it2: TorchDataIterator): bool {.importcpp: "# == #".} # ####################################################################### @@ -247,13 +248,13 @@ type StatelessDataLoader* {.bycopy, pure, importcpp: "torch::data::StatelessDataLoader".} - [Dataset, Sampler] + [D, S] # Dataset, Sampler = object of DataLoaderBase StatefulDataLoader* {.bycopy, pure, importcpp: "torch::data::StatefulDataLoader".} - [Dataset] + [D] # Dataset = object of DataLoaderBase DataLoaderOptions* @@ -265,7 +266,7 @@ type # and https://github.com/nim-lang/Nim/issues/16655 # BatchDataset and Dataset have no generics attached # and so we can't infer their Iterator type :/ -func start*(dl: DataLoaderBase +func start*(dl: StatelessDataLoader ): TorchDataIterator[Example[Tensor, Tensor]] {.importcpp: "#.begin()".} ## Start an iterator @@ -274,13 +275,36 @@ func start*(dl: DataLoaderBase ## and so the output is fixed to Example[Tensor, Tensor] ## which is the output of the Stack transform -func stop*(dl: DataLoaderBase +func stop*(dl: StatelessDataLoader ): TorchDataIterator[Example[Tensor, Tensor]] {.importcpp: "#.end()".} ## Returns a sentinel value that denotes ## the end of an iterator -iterator items*(dl: DataLoaderBase): Example[Tensor, Tensor] = +func start*[D, S]( + dl: CppUniquePtr[StatelessDataLoader[D, S]] + ): TorchDataIterator[Example[Tensor, Tensor]] + {.importcpp: "#->begin()".} + ## Start an iterator + ## Note: due to compiler bugs with C++ interop + ## we can't attach the DataLoaderBase generic type, + ## and so the output is fixed to Example[Tensor, Tensor] + ## which is the output of the Stack transform + ## + ## Overload as StatelessDataLoader has no default constructors + ## So we don't want Nim to use temporaries + +func stop*[D, S]( + dl: CppUniquePtr[StatelessDataLoader[D, S]] + ): TorchDataIterator[Example[Tensor, Tensor]] + {.importcpp: "#->end()".} + ## Returns a sentinel value that denotes + ## the end of an iterator + ## + ## Overload as StatelessDataLoader has no default constructors + ## So we don't want Nim to use temporaries + +iterator items*(dl: StatelessDataLoader or CppUniquePtr[StatelessDataLoader]): Example[Tensor, Tensor] = # TODO: lent Example[Tensor, Tensor], # borrow checker complains about 'cur' escaping it's frame # but `cur.get()` already returns a borrowed view @@ -290,7 +314,7 @@ iterator items*(dl: DataLoaderBase): Example[Tensor, Tensor] = yield cur.get() cur.next() -iterator pairs*(dl: DataLoaderBase): tuple[index: int, value: Example[Tensor, Tensor]] = +iterator pairs*(dl: StatelessDataLoader or CppUniquePtr[StatelessDataLoader]): tuple[index: int, value: Example[Tensor, Tensor]] = # TODO: lent Example[Tensor, Tensor] # borrow checker complains about 'cur' escaping it's frame # but `cur.get()` already returns a borrowed view diff --git a/flambeau/raw_bindings/neural_nets.nim b/flambeau/raw_bindings/neural_nets.nim index 2035c40..565a2d6 100644 --- a/flambeau/raw_bindings/neural_nets.nim +++ b/flambeau/raw_bindings/neural_nets.nim @@ -168,7 +168,7 @@ func reset_parameters*(linear: Linear){.importcpp: "#.reset_parameters()".} # pretty_print -func forward*(linear: Linear, input: Tensor): Tensor {.importcpp: "#.forward(#)".} +func forward*(linear: Linear, input: Tensor): Tensor {.importcpp: "#->forward(#)".} ## Transforms the ``input`` tensor ## by multiplying with the ``weight`` ## and optionally adding the ``bias``, diff --git a/flambeau/raw_bindings/optimizers.nim b/flambeau/raw_bindings/optimizers.nim index 65ada82..52e8ac9 100644 --- a/flambeau/raw_bindings/optimizers.nim +++ b/flambeau/raw_bindings/optimizers.nim @@ -56,3 +56,4 @@ func init*( {.constructor, importcpp:"torch::optim::SGD(@)".} func step*(optim: var SGD){.importcpp: "#.step()".} +func zero_grad*(optim: var SGD){.importcpp: "#.zero_grad()".} diff --git a/flambeau/raw_bindings/tensors.nim b/flambeau/raw_bindings/tensors.nim index 8a2b392..1dd6ccf 100644 --- a/flambeau/raw_bindings/tensors.nim +++ b/flambeau/raw_bindings/tensors.nim @@ -81,6 +81,8 @@ const torchHeader* = torchHeadersPath / "torch/torch.h" {.push header: torchHeader.} +{.passC: "-Wfatal-errors".} # The default "-fmax-errors=3" is unreadable + # Assumptions # ----------------------------------------------------------------------- # @@ -322,6 +324,9 @@ func vulkan*(a: Tensor): Tensor {.importcpp: "#.vulkan()".} # libtorch/include/ATen/TensorIndexing.h # and https://pytorch.org/cppdocs/notes/tensor_indexing.html +func item*(a: Tensor, T: typedesc): T {.importcpp: "#.item<'0>()".} + ## Extract the scalar from a 0-dimensional tensor + # Unsure what those corresponds to in Python # func `[]`*(a: Tensor, index: Scalar): Tensor {.importcpp: "#[#]".} # func `[]`*(a: Tensor, index: Tensor): Tensor {.importcpp: "#[#]".} diff --git a/proof_of_concepts/poc09_end_to_end.nim b/proof_of_concepts/poc09_end_to_end.nim index 2b29b94..9615183 100644 --- a/proof_of_concepts/poc09_end_to_end.nim +++ b/proof_of_concepts/poc09_end_to_end.nim @@ -1,7 +1,9 @@ # This is a port of the C++ end-to-end example # at https://pytorch.org/cppdocs/frontend.html -import ../flambeau +import + ../flambeau, + std/[enumerate, strformat] # Argh, need Linear{nullptr} in the codegen # so we cheat by inlining C++ @@ -20,15 +22,16 @@ struct Net: public torch::nn::Module { }; """].} -type Net{.importcpp.} = object of Module +type Net{.pure, importcpp.} = object of Module fc1: Linear fc2: Linear fc3: Linear -proc init(T: type Net): Net = - result.fc1 = result.register_module("fc1", Linear.init(784, 64)) - result.fc2 = result.register_module("fc2", Linear.init(64, 32)) - result.fc3 = result.register_module("fc3", Linear.init(32, 10)) +proc init(net: var Net) = + # Note: PyTorch Model serialization requires shared_ptr + net.fc1 = net.register_module("fc1", Linear.init(784, 64)) + net.fc2 = net.register_module("fc2", Linear.init(64, 32)) + net.fc3 = net.register_module("fc3", Linear.init(32, 10)) func forward*(net: Net, x: Tensor): Tensor = var x = x @@ -39,7 +42,8 @@ func forward*(net: Net, x: Tensor): Tensor = return x proc main() = - let net = Net.init() # TODO: make_shared + let net = make_shared(Net) + net.init() let data_loader = make_data_loader( mnist("build/mnist").map(Stack[Example[Tensor, Tensor]].init()), @@ -51,4 +55,23 @@ proc main() = learning_rate = 0.01 ) + for epoch in 1 .. 10: + # Iterate the data loader to yield batches from the dataset. + for batch_index, batch in data_loader.pairs(): + # Reset gradients. + optimizer.zero_grad() + # Execute the model on the input data. + let prediction = net.forward(batch.data) + # Compute a loss value to judge the prediction of our model. + var loss = nll_loss(prediction, batch.target) + # Compute the gradients of the loss w.r.t. the parameters of our model. + loss.backward() + # Update the parameters based on the calculated gradients. + optimizer.step() + # output the loss and checkpoint every 100 batches. + if batch_index mod 100 == 0: + echo &"Epoch: {epoch} | Batch: {batch_index} | Loss: {loss.item(float32)}" + # Serialize your model periodically as a checkpoint. + net.save("net.pt") + main()