Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-29235][ML][PYSPARK] Support avgMetrics in read/write of CrossV…
…alidatorModel ### What changes were proposed in this pull request? Currently pyspark doesn't write/read `avgMetrics` in `CrossValidatorModel`, whereas scala supports it. ### Why are the changes needed? Test step to reproduce it: ``` dataset = spark.createDataFrame([(Vectors.dense([0.0]), 0.0), (Vectors.dense([0.4]), 1.0), (Vectors.dense([0.5]), 0.0), (Vectors.dense([0.6]), 1.0), (Vectors.dense([1.0]), 1.0)] * 10, ["features", "label"]) lr = LogisticRegression() grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() evaluator = BinaryClassificationEvaluator() cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,parallelism=2) cvModel = cv.fit(dataset) cvModel.write().save("/tmp/model") cvModel2 = CrossValidatorModel.read().load("/tmp/model") print(cvModel.avgMetrics) # prints non empty result as expected print(cvModel2.avgMetrics) # Bug: prints an empty result. ``` ### Does this PR introduce any user-facing change? No ### How was this patch tested? Manually tested Before patch: ``` >>> cvModel.write().save("/tmp/model_0") >>> cvModel2 = CrossValidatorModel.read().load("/tmp/model_0") >>> print(cvModel2.avgMetrics) [] ``` After patch: ``` >>> cvModel2 = CrossValidatorModel.read().load("/tmp/model_2") >>> print(cvModel2.avgMetrics[0]) 0.5 ``` Closes #26038 from shahidki31/avgMetrics. Authored-by: shahid <[email protected]> Signed-off-by: Sean Owen <[email protected]>
- Loading branch information