-
Notifications
You must be signed in to change notification settings - Fork 733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Orca spark estimator #3124
Orca spark estimator #3124
Changes from 13 commits
8355158
c8443cc
f8a6127
cc5767d
dee20d3
6a39160
a1223d5
86c84f1
58c0e5c
266d3e2
74f897c
ae9370a
0b05a7b
5bb79e5
b705a08
cd7c790
4fdb108
7c3bfc9
b402c9f
51d821a
0a3d07b
093a9ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# | ||
# Copyright 2018 Analytics Zoo Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
|
||
class BaseEstimator(ABC): | ||
@abstractmethod | ||
def fit(self, data, epochs, **kwargs): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def predict(self, data, **kwargs): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def evaluate(self, data, **kwargs): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_model(self): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def save(self, model_path): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def load(self, checkpoint, **kwargs): | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,11 +15,10 @@ | |
# | ||
from zoo.pipeline.estimator.estimator import Estimator as SparkEstimator | ||
from zoo.orca.learn.pytorch.training_operator import TrainingOperator | ||
from zoo.orca.learn.spark_estimator import Estimator as OrcaSparkEstimator | ||
from zoo.orca.data import SparkXShards | ||
from bigdl.optim.optimizer import MaxEpoch | ||
from zoo.feature.common import FeatureSet | ||
|
||
import torch | ||
from torch.optim.optimizer import Optimizer as TorchOptimizer | ||
from torch.utils.data import DataLoader | ||
|
||
|
@@ -195,7 +194,7 @@ def shutdown(self, force=False): | |
return self.estimator.shutdown(force=force) | ||
|
||
|
||
class PytorchSparkEstimatorWrapper(Estimator): | ||
class PytorchSparkEstimatorWrapper(OrcaSparkEstimator): | ||
def __init__(self, model, loss, optimizer, model_dir=None, bigdl_type="float"): | ||
from zoo.pipeline.api.torch import TorchModel, TorchLoss, TorchOptim | ||
self.loss = loss | ||
|
@@ -208,6 +207,8 @@ def __init__(self, model, loss, optimizer, model_dir=None, bigdl_type="float"): | |
optimizer = SGD() | ||
elif isinstance(optimizer, TorchOptimizer): | ||
optimizer = TorchOptim.from_pytorch(optimizer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to make sure |
||
self.log_dir = None | ||
self.app_name = None | ||
Comment on lines
+215
to
+216
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there anywhere to set these two inputs? Seems they are always None and set_tensorboard will never be triggered... @cyita |
||
self.model_dir = model_dir | ||
self.model = TorchModel.from_pytorch(model) | ||
self.estimator = SparkEstimator(self.model, optimizer, model_dir, bigdl_type=bigdl_type) | ||
|
@@ -223,6 +224,9 @@ def fit(self, data, epochs=1, batch_size=32, validation_data=None, validation_me | |
validation_methods = Metrics.convert_metrics_list(validation_methods) | ||
checkpoint_trigger = Trigger.convert_trigger(checkpoint_trigger) | ||
|
||
if self.log_dir is not None and self.app_name is not None: | ||
self.estimator.set_tensorboad(self.log_dir, self.app_name) | ||
|
||
if isinstance(data, SparkXShards): | ||
train_rdd = data.rdd.flatMap(to_sample) | ||
train_feature_set = FeatureSet.sample_rdd(train_rdd) | ||
|
@@ -274,20 +278,25 @@ def evaluate(self, data, validation_methods=None, batch_size=32): | |
def get_model(self): | ||
return self.model.to_pytorch() | ||
|
||
def save(self, checkpoint): | ||
def save(self, model_path): | ||
pass | ||
|
||
def load(self, checkpoint, loss=None): | ||
from zoo.orca.learn.utils import find_latest_checkpoint | ||
from bigdl.nn.layer import Model | ||
from bigdl.optim.optimizer import OptimMethod | ||
import os | ||
if loss is not None: | ||
from zoo.pipeline.api.torch import TorchLoss | ||
self.loss = TorchLoss.from_pytorch(loss) | ||
path, prefix, version = find_latest_checkpoint(checkpoint, model_type="pytorch") | ||
if path is None: | ||
raise ValueError("Cannot find PyTorch checkpoint, please check your checkpoint path.") | ||
self.load_orca_checkpoint(path, version=version, prefix=prefix) | ||
|
||
def load_orca_checkpoint(self, path, version, prefix=None): | ||
import os | ||
from bigdl.nn.layer import Model | ||
from bigdl.optim.optimizer import OptimMethod | ||
assert prefix is not None, "You should provide optimMethod prefix, " \ | ||
"for example 'optimMethod-TorchModelf53bddcc'" | ||
try: | ||
self.model = Model.load(os.path.join(path, "model.{}".format(version))) | ||
optimizer = OptimMethod.load(os.path.join(path, "{}.{}".format(prefix, version))) | ||
|
@@ -296,8 +305,14 @@ def load(self, checkpoint, loss=None): | |
"and checkpoint type.") | ||
self.estimator = SparkEstimator(self.model, optimizer, self.model_dir) | ||
|
||
def shutdown(self, force=False): | ||
pass | ||
def load_latest_orca_checkpoint(self, path): | ||
self.load(checkpoint=path) | ||
|
||
def get_train_summary(self, tag=None): | ||
return self.estimator.get_train_summary(tag=tag) | ||
|
||
def get_validation_summary(self, tag=None): | ||
return self.estimator.get_validation_summary(tag=tag) | ||
|
||
def clear_gradient_clipping(self): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass
will do as the subclass needs to implement all of itsabstractmethod
:-)