diff --git a/src/common/py-serde/src/python.rs b/src/common/py-serde/src/python.rs index e634743f2d..10e467e931 100644 --- a/src/common/py-serde/src/python.rs +++ b/src/common/py-serde/src/python.rs @@ -49,12 +49,19 @@ impl<'de> Visitor<'de> for PyObjectVisitor { where E: DeError, { - Python::with_gil(|py| { - py.import_bound(pyo3::intern!(py, "daft.pickle")) - .and_then(|m| m.getattr(pyo3::intern!(py, "loads"))) - .and_then(|f| Ok(f.call1((v,))?.into())) - .map_err(|e| DeError::custom(e.to_string())) - }) + self.visit_bytes(&v) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut v: Vec = Vec::with_capacity(seq.size_hint().unwrap_or_default()); + while let Some(elem) = seq.next_element()? { + v.push(elem); + } + + self.visit_bytes(&v) } } diff --git a/tests/io/test_s3_credentials_refresh.py b/tests/io/test_s3_credentials_refresh.py index 16a98fadf0..1b9aeccc8e 100644 --- a/tests/io/test_s3_credentials_refresh.py +++ b/tests/io/test_s3_credentials_refresh.py @@ -34,9 +34,8 @@ def test_s3_credentials_refresh(aws_log_file: io.IOBase): server_url = f"http://{host}:{port}" bucket_name = "mybucket" - file_name = "test.parquet" - - s3_file_path = f"s3://{bucket_name}/{file_name}" + input_file_path = f"s3://{bucket_name}/input.parquet" + output_file_path = f"s3://{bucket_name}/output.parquet" old_env = os.environ.copy() # Set required AWS environment variables before starting server. @@ -98,21 +97,28 @@ def get_credentials(): ) df = daft.from_pydict({"a": [1, 2, 3]}) - df.write_parquet(s3_file_path, io_config=static_config) + df.write_parquet(input_file_path, io_config=static_config) - df = daft.read_parquet(s3_file_path, io_config=dynamic_config) + df = daft.read_parquet(input_file_path, io_config=dynamic_config) assert count_get_credentials == 1 df.collect() assert count_get_credentials == 1 - df = daft.read_parquet(s3_file_path, io_config=dynamic_config) + df = daft.read_parquet(input_file_path, io_config=dynamic_config) assert count_get_credentials == 1 time.sleep(1) df.collect() assert count_get_credentials == 2 + df.write_parquet(output_file_path, io_config=dynamic_config) + assert count_get_credentials == 2 + + df2 = daft.read_parquet(output_file_path, io_config=static_config) + + assert df.to_arrow() == df2.to_arrow() + # Shutdown moto server. stop_process(process) # Restore old set of environment variables.