From 2b4e56e72bf3cd291349baf6feb197666d368b67 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Wed, 10 Aug 2016 10:39:18 +0200 Subject: [PATCH] Make RDD and DataFrame a context manager --- python/pyspark/rdd.py | 36 ++++++++++++++++++++++++++++++++- python/pyspark/sql/dataframe.py | 26 ++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 0508235c1c9ee..02c0a35938086 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -188,6 +188,12 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri self._id = jrdd.id() self.partitioner = None + def __enter__(self): + return self + + def __exit__(self, *args): + self.unpersist() + def _pickled(self): return self._reserialize(AutoBatchedSerializer(PickleSerializer())) @@ -221,6 +227,21 @@ def context(self): def cache(self): """ Persist this RDD with the default storage level (C{MEMORY_ONLY}). + + :py:meth:`cache` can be used in a 'with' statement. The RDD will be automatically + unpersisted once the 'with' block is exited. Note however that any actions on the RDD + that require the RDD to be cached, should be invoked inside the 'with' block; otherwise, + caching will have no effect. + + >>> rdd = sc.parallelize(["b", "a", "c"]) + >>> with rdd.cache() as cached: + ... print(cached.getStorageLevel()) + ... print(cached.count()) + ... + Memory Serialized 1x Replicated + 3 + >>> print(rdd.getStorageLevel()) + Serialized 1x Replicated """ self.is_cached = True self.persist(StorageLevel.MEMORY_ONLY) @@ -233,9 +254,22 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): a new storage level if the RDD does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_ONLY}). + :py:meth:`persist` can be used in a 'with' statement. The RDD will be automatically + unpersisted once the 'with' block is exited. Note however that any actions on the RDD + that require the RDD to be cached, should be invoked inside the 'with' block; otherwise, + caching will have no effect. + >>> rdd = sc.parallelize(["b", "a", "c"]) - >>> rdd.persist().is_cached + >>> with rdd.persist() as persisted: + ... print(persisted.getStorageLevel()) + ... print(persisted.is_cached) + ... print(persisted.count()) + ... + Memory Serialized 1x Replicated True + 3 + >>> print(rdd.getStorageLevel()) + Serialized 1x Replicated """ self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a986092f5d634..57dc0725b9404 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -75,6 +75,12 @@ def __init__(self, jdf, sql_ctx): self._schema = None # initialized lazily self._lazy_rdd = None + def __enter__(self): + return self + + def __exit__(self, *args): + self.unpersist() + @property @since(1.3) def rdd(self): @@ -390,6 +396,16 @@ def foreachPartition(self, f): @since(1.3) def cache(self): """ Persists with the default storage level (C{MEMORY_ONLY}). + + :py:meth:`cache` can be used in a 'with' statement. The DataFrame will be automatically + unpersisted once the 'with' block is exited. Note however that any actions on the DataFrame + that require the DataFrame to be cached, should be invoked inside the 'with' block; + otherwise, caching will have no effect. + + >>> with df.cache() as cached: + ... print(cached.count()) + ... + 2 """ self.is_cached = True self._jdf.cache() @@ -401,6 +417,16 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. If no storage level is specified defaults to (C{MEMORY_ONLY}). + + :py:meth:`persist` can be used in a 'with' statement. The DataFrame will be automatically + unpersisted once the 'with' block is exited. Note however that any actions on the DataFrame + that require the DataFrame to be cached, should be invoked inside the 'with' block; + otherwise, caching will have no effect. + + >>> with df.persist() as persisted: + ... print(persisted.count()) + ... + 2 """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)