Skip to content

Commit

Permalink
Torchoptim, wrap pytorch optimizer to bigdl's optimmethod (#2869)
Browse files Browse the repository at this point in the history
* TorchOptim

* add python api

* add scala test

* add python test

* add python test

* add python test

* fix on yarn

* clean up

* add test

* fix style check
  • Loading branch information
qiuxin2012 committed Sep 15, 2020
1 parent 66417fa commit 94c3f02
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/orca/example/torchmodel/train/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from zoo.pipeline.api.torch import TorchModel, TorchLoss
from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim
from zoo.pipeline.estimator import *
from bigdl.optim.optimizer import SGD, Adam
from zoo.common.nncontext import *
Expand Down Expand Up @@ -106,10 +106,11 @@ def main():
model.train()
criterion = nn.NLLLoss()

adam = Adam(args.lr)
adam = torch.optim.Adam(model.parameters(), lr=args.lr)
zoo_model = TorchModel.from_pytorch(model)
zoo_criterion = TorchLoss.from_pytorch(criterion)
zoo_estimator = Estimator(zoo_model, optim_methods=adam)
zoo_optim = TorchOptim.from_pytorch(adam)
zoo_estimator = Estimator(zoo_model, optim_methods=zoo_optim)
train_featureset = FeatureSet.pytorch_dataloader(train_loader)
test_featureset = FeatureSet.pytorch_dataloader(test_loader)
from bigdl.optim.optimizer import MaxEpoch, EveryEpoch
Expand Down

0 comments on commit 94c3f02

Please sign in to comment.