Skip to content

Commit

Permalink
Add unit test for record manifest writer
Browse files Browse the repository at this point in the history
  • Loading branch information
coufon committed Dec 22, 2023
1 parent 7c3128b commit 243ce7e
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 14 deletions.
3 changes: 1 addition & 2 deletions python/src/space/core/manifests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@


def write_parquet_file(file_path: str, schema: pa.Schema,
data: pa.Table) -> str:
data: pa.Table) -> None:
"""Materialize a single Parquet file."""
# TODO: currently assume this file is small, so always write a single file.
writer = pq.ParquetWriter(file_path, schema)
writer.write_table(data)

writer.close()
return file_path
13 changes: 5 additions & 8 deletions python/tests/core/manifests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ def test_write_all_types(self, tmp_path):
schema=schema,
primary_keys=["int64", "float64", "bool", "string"])

file_path = str(data_dir / "file0")
# TODO: the test should cover all types supported by column stats.
manifest_writer.write(
file_path,
_write_parquet_file(file_path, schema, [{
"data/file0",
_write_parquet_file(str(data_dir / "file0"), schema, [{
"int64": [1, 2, 3],
"float64": [0.1, 0.2, 0.3],
"bool": [True, False, False],
Expand All @@ -68,10 +67,9 @@ def test_write_all_types(self, tmp_path):
"bool": [False, False],
"string": ["A", "z"]
}]))
file_path = str(data_dir / "file1")
manifest_writer.write(
file_path,
_write_parquet_file(file_path, schema, [{
"data/file1",
_write_parquet_file(str(data_dir / "file1"), schema, [{
"int64": [1000, 1000000],
"float64": [-0.001, 0.001],
"bool": [False, False],
Expand All @@ -80,10 +78,9 @@ def test_write_all_types(self, tmp_path):

manifest_path = manifest_writer.finish()

data_dir_str = str(data_dir)
assert manifest_path is not None
assert pq.read_table(manifest_path).to_pydict() == {
"_FILE": [f"{data_dir_str}/file0", f"{data_dir_str}/file1"],
"_FILE": ["data/file0", "data/file1"],
"_INDEX_COMPRESSED_BYTES": [645, 334],
"_INDEX_UNCOMPRESSED_BYTES": [624, 320],
"_NUM_ROWS": [5, 2],
Expand Down
57 changes: 57 additions & 0 deletions python/tests/core/manifests/test_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pyarrow.parquet as pq

from space.core.manifests import RecordManifestWriter
import space.core.proto.metadata_pb2 as meta


class TestRecordManifestWriter:

def test_write(self, tmp_path):
data_dir = tmp_path / "dataset" / "data"
metadata_dir = tmp_path / "dataset" / "metadata"
metadata_dir.mkdir(parents=True)

manifest_writer = RecordManifestWriter(metadata_dir=str(metadata_dir))

manifest_writer.write(
"data/file0.arrayrecord", 0,
meta.StorageStatistics(num_rows=123,
index_compressed_bytes=10,
index_uncompressed_bytes=20,
record_uncompressed_bytes=30))
manifest_writer.write(
"data/file1.arrayrecord", 1,
meta.StorageStatistics(num_rows=456,
index_compressed_bytes=10,
index_uncompressed_bytes=20,
record_uncompressed_bytes=100))

manifest_path = manifest_writer.finish()

assert manifest_path is not None
assert pq.read_table(manifest_path).to_pydict() == {
"_FILE": ["data/file0.arrayrecord", "data/file1.arrayrecord"],
"_FIELD_ID": [0, 1],
"_NUM_ROWS": [123, 456],
"_UNCOMPRESSED_BYTES": [30, 100]
}

def test_empty_manifest_should_return_none(self, tmp_path):
metadata_dir = tmp_path / "dataset" / "metadata"
manifest_writer = RecordManifestWriter(metadata_dir=str(metadata_dir))

assert manifest_writer.finish() is None
6 changes: 2 additions & 4 deletions python/tests/core/manifests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def test_write_parquet_file(tmp_path):
data_dir = tmp_path / "file.parquet"

file_path = str(data_dir)
returned_path = utils.write_parquet_file(
file_path, pa.schema([("int64", pa.int64())]),
pa.Table.from_pydict({"int64": [1, 2]}))
utils.write_parquet_file(file_path, pa.schema([("int64", pa.int64())]),
pa.Table.from_pydict({"int64": [1, 2]}))

assert data_dir.exists()
assert returned_path == file_path
32 changes: 32 additions & 0 deletions python/tests/core/ops/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,35 @@ def test_update_index_storage_stats_negative():
assert base == meta.StorageStatistics(num_rows=90,
index_compressed_bytes=180,
index_uncompressed_bytes=270)


def test_update_record_stats_bytes():
base = meta.StorageStatistics(num_rows=100,
index_compressed_bytes=200,
index_uncompressed_bytes=300,
record_uncompressed_bytes=1000)

utils.update_record_stats_bytes(
base,
meta.StorageStatistics(num_rows=-10,
index_compressed_bytes=-20,
record_uncompressed_bytes=-100))
assert base == meta.StorageStatistics(num_rows=100,
index_compressed_bytes=200,
index_uncompressed_bytes=300,
record_uncompressed_bytes=900)


def test_address_column():
result = [{
"_FILE": "data/file.arrayrecord",
"_ROW_ID": 2
}, {
"_FILE": "data/file.arrayrecord",
"_ROW_ID": 3
}, {
"_FILE": "data/file.arrayrecord",
"_ROW_ID": 4
}]
assert utils.address_column("data/file.arrayrecord", 2,
3).to_pylist() == result
32 changes: 32 additions & 0 deletions python/tests/core/schema/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,35 @@ def test_field_id_to_column_id_dict(sample_arrow_schema):
220: 3,
260: 4
}


def test_classify_fields(sample_arrow_schema):
index_fields, record_fields = arrow.classify_fields(sample_arrow_schema,
["float32", "list"])

assert index_fields == [
arrow.Field("struct", 150),
arrow.Field("list_struct", 220),
arrow.Field("struct_list", 260)
]
assert record_fields == [
arrow.Field("float32", 100),
arrow.Field("list", 120)
]


def test_classify_fields_with_selected_fields(sample_arrow_schema):
index_fields, record_fields = arrow.classify_fields(sample_arrow_schema,
["float32", "list"],
["list", "struct"])

assert index_fields == [arrow.Field("struct", 150)]
assert record_fields == [arrow.Field("list", 120)]


def test_field_names():
assert arrow.field_names([
arrow.Field("struct", 150),
arrow.Field("list_struct", 220),
arrow.Field("struct_list", 260)
]) == ["struct", "list_struct", "struct_list"]

0 comments on commit 243ce7e

Please sign in to comment.