-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bug fix: DLModel prediction #2194
bug fix: DLModel prediction #2194
Conversation
Make sure DLModel.train=False when predicting in pipeline API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for finding the issue. This is a regression issue.
@@ -391,7 +391,8 @@ class DLModel[@specialized(Float, Double) T: ClassTag]( | |||
val localBatchSize = $(batchSize) | |||
|
|||
val resultRDD = dataFrame.rdd.mapPartitions { rowIter => | |||
val localModel = modelBroadCast.value() | |||
// call the evaluate method to enable DLModel.train=False during the predict process | |||
val localModel = modelBroadCast.value().evaluate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better if we move evaluate before model broadcasting?
.setBatchSize(nRecords) | ||
.setOptimMethod(new LBFGS[Float]()) | ||
.setLearningRate(0.1) | ||
.setMaxEpoch(maxEpoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's set this to 1.
@@ -91,6 +91,27 @@ class DLEstimatorSpec extends FlatSpec with Matchers with BeforeAndAfter { | |||
assert(correct > nRecords * 0.8) | |||
} | |||
|
|||
"An DLEstimator" should "throws exception when DLModel is predicting with DLModel.train=True" in { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we use intercept exception in this unit test? The content and the title does not quite match.
Hi @hhbyyh, |
Thanks for the update. |
jenkins passed |
* bug fix: DLModel prediction (#4) Make sure DLModel.train=False when predicting in pipeline API * 1. broadcast transformer in DLModel.transform ; 2. remove useless ut
* rebase * revert automl readme
* rebase * revert automl readme
* rebase * revert automl readme
* rebase * revert automl readme
What changes were proposed in this pull request?
How was this patch tested?
manual test