diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index e7a68de47034d..b97af4fc04925 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -48,12 +48,12 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed): >>> centers = model.clusterCenters() >>> len(centers) 2 - >>> transformed = model.transform(df) - >>> (transformed.columns)[0] == 'features' + >>> transformed = model.transform(df).select("features", "prediction") + >>> "features" in transformed.columns True - >>> (transformed.columns)[1] == 'prediction' + >>> "prediction" in transformed.columns True - >>> rows = sorted(transformed.collect(), key = lambda r: r[0]) + >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction True >>> rows[2].prediction == rows[3].prediction