From 7c3128ba2d53b600257b7de29aa775a53847c909 Mon Sep 17 00:00:00 2001 From: coufon Date: Thu, 21 Dec 2023 04:59:45 +0000 Subject: [PATCH 1/3] Add writing record fields in Append op. --- python/pyproject.toml | 24 +-- python/src/space/core/manifests/__init__.py | 4 + python/src/space/core/manifests/index.py | 23 +-- python/src/space/core/manifests/record.py | 77 +++++++ python/src/space/core/manifests/utils.py | 29 +++ python/src/space/core/ops/append.py | 118 ++++++++++- python/src/space/core/ops/utils.py | 23 ++- python/src/space/core/proto/metadata.proto | 8 +- python/src/space/core/proto/metadata_pb2.py | 12 +- python/src/space/core/proto/metadata_pb2.pyi | 15 +- python/src/space/core/proto/runtime.proto | 5 +- python/src/space/core/proto/runtime_pb2.py | 12 +- python/src/space/core/proto/runtime_pb2.pyi | 7 +- python/src/space/core/schema/arrow.py | 55 ++++- python/src/space/core/schema/constants.py | 7 + python/src/space/core/storage.py | 18 +- .../space/core/utils/lazy_imports_utils.py | 192 ++++++++++++++++++ python/tests/core/manifests/test_index.py | 2 +- python/tests/core/manifests/test_utils.py | 29 +++ python/tests/core/ops/test_append.py | 6 +- python/tests/core/ops/test_utils.py | 8 +- python/tests/core/test_storage.py | 9 +- 22 files changed, 607 insertions(+), 76 deletions(-) create mode 100644 python/src/space/core/manifests/record.py create mode 100644 python/src/space/core/manifests/utils.py create mode 100644 python/src/space/core/utils/lazy_imports_utils.py create mode 100644 python/tests/core/manifests/test_utils.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 9b80fbd..5ffe16e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,34 +1,29 @@ [project] name = "space" version = "0.0.1" -authors = [ - { name="Space team", email="no-reply@google.com" }, -] +authors = [{ name = "Space team", email = "no-reply@google.com" }] description = "A storage framework for machine learning datasets" -license = {text = "Apache-2.0"} +license = { text = "Apache-2.0" } classifiers = [ "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11" + "Programming Language :: Python :: 3.11", ] requires-python = ">=3.8" dependencies = [ + "array-record", "numpy", "protobuf", "pyarrow >= 14.0.0", "tensorflow_datasets", - "typing_extensions" + "typing_extensions", ] [project.optional-dependencies] -dev = [ - "pyarrow-stubs", - "tensorflow", - "types-protobuf" -] +dev = ["pyarrow-stubs", "tensorflow", "types-protobuf"] [project.urls] Homepage = "https://github.com/google/space" @@ -49,4 +44,9 @@ disable = ['fixme'] [tool.pylint.MAIN] ignore = 'space/core/proto' -ignored-modules = ['space.core.proto', 'google.protobuf', 'substrait'] +ignored-modules = [ + 'space.core.proto', + 'google.protobuf', + 'substrait', + 'array_record', +] diff --git a/python/src/space/core/manifests/__init__.py b/python/src/space/core/manifests/__init__.py index 36a92ed..ab96cad 100644 --- a/python/src/space/core/manifests/__init__.py +++ b/python/src/space/core/manifests/__init__.py @@ -12,3 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +"""Manifest files writer and reader implementation.""" + +from space.core.manifests.index import IndexManifestWriter +from space.core.manifests.record import RecordManifestWriter diff --git a/python/src/space/core/manifests/index.py b/python/src/space/core/manifests/index.py index 2ed7bea..46bfac0 100644 --- a/python/src/space/core/manifests/index.py +++ b/python/src/space/core/manifests/index.py @@ -19,13 +19,13 @@ import pyarrow as pa import pyarrow.parquet as pq +from space.core.manifests.utils import write_parquet_file import space.core.proto.metadata_pb2 as meta +from space.core.schema import constants from space.core.schema.arrow import field_id, field_id_to_column_id_dict from space.core.utils import paths # Manifest file fields. -_FILE_PATH_FIELD = '_FILE' -_NUM_ROWS_FIELD = '_NUM_ROWS' _INDEX_COMPRESSED_BYTES_FIELD = '_INDEX_COMPRESSED_BYTES' _INDEX_UNCOMPRESSED_BYTES_FIELD = '_INDEX_UNCOMPRESSED_BYTES' @@ -57,7 +57,8 @@ def _manifest_schema( """Build the index manifest file schema, based on storage schema.""" primary_keys_ = set(primary_keys) - fields = [(_FILE_PATH_FIELD, pa.utf8()), (_NUM_ROWS_FIELD, pa.int64()), + fields = [(constants.FILE_PATH_FIELD, pa.utf8()), + (constants.NUM_ROWS_FIELD, pa.int64()), (_INDEX_COMPRESSED_BYTES_FIELD, pa.int64()), (_INDEX_UNCOMPRESSED_BYTES_FIELD, pa.int64())] @@ -209,16 +210,6 @@ def finish(self) -> Optional[str]: if manifest_data.num_rows == 0: return None - return _write_index_manifest(self._metadata_dir, self._manifest_schema, - manifest_data) - - -def _write_index_manifest(metadata_dir: str, schema: pa.Schema, - data: pa.Table) -> str: - # TODO: currently assume this file is small, so always write a single file. - file_path = paths.new_index_manifest_path(metadata_dir) - writer = pq.ParquetWriter(file_path, schema) - writer.write_table(data) - - writer.close() - return file_path + file_path = paths.new_index_manifest_path(self._metadata_dir) + write_parquet_file(file_path, self._manifest_schema, manifest_data) + return file_path diff --git a/python/src/space/core/manifests/record.py b/python/src/space/core/manifests/record.py new file mode 100644 index 0000000..521260a --- /dev/null +++ b/python/src/space/core/manifests/record.py @@ -0,0 +1,77 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Record manifest files writer and reader implementation.""" + +from typing import List, Optional + +import pyarrow as pa + +from space.core.manifests.utils import write_parquet_file +import space.core.proto.metadata_pb2 as meta +from space.core.utils import paths +from space.core.schema import constants + + +def _manifest_schema() -> pa.Schema: + fields = [(constants.FILE_PATH_FIELD, pa.utf8()), + (constants.FIELD_ID_FIELD, pa.int32()), + (constants.NUM_ROWS_FIELD, pa.int64()), + (constants.UNCOMPRESSED_BYTES_FIELD, pa.int64())] + return pa.schema(fields) # type: ignore[arg-type] + + +class RecordManifestWriter: + """Writer of record manifest files.""" + + def __init__(self, metadata_dir: str): + self._metadata_dir = metadata_dir + self._manifest_schema = _manifest_schema() + + self._file_paths: List[str] = [] + self._field_ids: List[int] = [] + self._num_rows: List[int] = [] + self._uncompressed_bytes: List[int] = [] + + def write(self, file_path: str, field_id: int, + storage_statistics: meta.StorageStatistics) -> None: + """Write a new manifest row. + + Args: + file_path: a relative file path of the index file. + field_id: the field ID of the associated field for this ArrayRecord file. + storage_statistics: storage statistics of the file. + """ + self._file_paths.append(file_path) + self._field_ids.append(field_id) + self._num_rows.append(storage_statistics.num_rows) + self._uncompressed_bytes.append( + storage_statistics.record_uncompressed_bytes) + + def finish(self) -> Optional[str]: + """Materialize the manifest file and return the file path.""" + if not self._file_paths: + return None + + arrays = [ + self._file_paths, self._field_ids, self._num_rows, + self._uncompressed_bytes + ] + manifest_data = pa.Table.from_arrays( + arrays, # type: ignore[arg-type] + schema=self._manifest_schema) # type: ignore[call-arg] + + file_path = paths.new_record_manifest_path(self._metadata_dir) + write_parquet_file(file_path, self._manifest_schema, manifest_data) + return file_path diff --git a/python/src/space/core/manifests/utils.py b/python/src/space/core/manifests/utils.py new file mode 100644 index 0000000..cc99418 --- /dev/null +++ b/python/src/space/core/manifests/utils.py @@ -0,0 +1,29 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Manifest utilities.""" + +import pyarrow as pa +import pyarrow.parquet as pq + + +def write_parquet_file(file_path: str, schema: pa.Schema, + data: pa.Table) -> str: + """Materialize a single Parquet file.""" + # TODO: currently assume this file is small, so always write a single file. + writer = pq.ParquetWriter(file_path, schema) + writer.write_table(data) + + writer.close() + return file_path diff --git a/python/src/space/core/ops/append.py b/python/src/space/core/ops/append.py index 1ae8474..6c0f84c 100644 --- a/python/src/space/core/ops/append.py +++ b/python/src/space/core/ops/append.py @@ -16,19 +16,21 @@ from __future__ import annotations from abc import abstractmethod -from dataclasses import dataclass -from typing import Optional +from dataclasses import dataclass, field as dataclass_field +from typing import Dict, Optional, Tuple import pyarrow as pa import pyarrow.parquet as pq -from space.core.manifests.index import IndexManifestWriter +from space.core.manifests import IndexManifestWriter +from space.core.manifests import RecordManifestWriter from space.core.ops import utils from space.core.ops.base import BaseOp, InputData from space.core.proto import metadata_pb2 as meta from space.core.proto import runtime_pb2 as runtime -from space.core.schema.arrow import arrow_schema +from space.core.schema import arrow from space.core.utils import paths +from space.core.utils.lazy_imports_utils import array_record_module as ar from space.core.utils.paths import StoragePaths # TODO: to obtain the values from user provided options. @@ -52,11 +54,23 @@ def finish(self) -> Optional[runtime.Patch]: @dataclass class _IndexWriterInfo: - """Contain information of index file writer.""" + """Information of index file writer.""" writer: pq.ParquetWriter file_path: str +@dataclass +class _RecordWriterInfo: + """Information of record file writer.""" + writer: ar.ArrayRecordWriter + file_path: str + file_id: int = 0 + next_row_id: int = 0 + storage_statistics: meta.StorageStatistics = dataclass_field( + default_factory=meta.StorageStatistics) + + +# pylint: disable=too-many-instance-attributes class LocalAppendOp(BaseAppendOp, StoragePaths): """Append operation running locally. @@ -70,11 +84,19 @@ def __init__(self, location: str, metadata: meta.StorageMetadata): StoragePaths.__init__(self, location) self._metadata = metadata - self._schema = arrow_schema(self._metadata.schema.fields) + self._schema = arrow.arrow_schema(self._metadata.schema.fields) + + self._index_fields, self._record_fields = arrow.classify_fields( + self._schema, + set(self._metadata.schema.record_fields), + selected_fields=None) # Data file writers. self._index_writer_info: Optional[_IndexWriterInfo] = None + # Key is field name. + self._record_writers: Dict[str, _RecordWriterInfo] = {} + # Local runtime caches. self._cached_index_data: Optional[pa.Table] = None self._cached_index_file_bytes = 0 @@ -83,6 +105,7 @@ def __init__(self, location: str, metadata: meta.StorageMetadata): self._index_manifest_writer = IndexManifestWriter( self._metadata_dir, self._schema, self._metadata.schema.primary_keys) # type: ignore[arg-type] + self._record_manifest_writer = RecordManifestWriter(self._metadata_dir) self._patch = runtime.Patch() @@ -98,6 +121,11 @@ def finish(self) -> Optional[runtime.Patch]: Returns: A patch to the storage or None if no actual storage modification happens. """ + # Flush all cached record data. + for f in self._record_fields: + if f.name in self._record_writers: + self._finish_record_writer(f, self._record_writers[f.name]) + # Flush all cached index data. if self._cached_index_data is not None: self._maybe_create_index_writer() @@ -107,11 +135,17 @@ def finish(self) -> Optional[runtime.Patch]: if self._index_writer_info is not None: self._finish_index_writer() + # Write manifest files. index_manifest_full_path = self._index_manifest_writer.finish() if index_manifest_full_path is not None: self._patch.added_index_manifest_files.append( self.short_path(index_manifest_full_path)) + record_manifest_path = self._record_manifest_writer.finish() + if record_manifest_path: + self._patch.added_record_manifest_files.append( + self.short_path(record_manifest_path)) + if self._patch.storage_statistics_update.num_rows == 0: return None @@ -127,6 +161,21 @@ def _append_arrow(self, data: pa.Table) -> None: index_data = data self._maybe_create_index_writer() + index_data = data.select(arrow.field_names(self._index_fields)) + + # Write record fields into files. + # TODO: to parallelize it. + record_addresses = [ + self._write_record_column(f, data.column(f.name)) + for f in self._record_fields + ] + + # TODO: to preserve the field order in schema. + for field_name, address_column in record_addresses: + # TODO: the field/column added must have field ID. + index_data = index_data.append_column(field_name, address_column) + + # Write index fields into files. self._cached_index_file_bytes += index_data.nbytes if self._cached_index_data is None: @@ -154,7 +203,7 @@ def _maybe_create_index_writer(self) -> None: writer, self.short_path(full_file_path)) def _finish_index_writer(self) -> None: - """Materialize a new index file, update metadata and stats.""" + """Materialize a new index file (Parquet), update metadata and stats.""" if self._index_writer_info is None: return @@ -164,8 +213,61 @@ def _finish_index_writer(self) -> None: stats = self._index_manifest_writer.write( self._index_writer_info.file_path, self._index_writer_info.writer.writer.metadata) - utils.update_index_storage_statistics( + utils.update_index_storage_stats( base=self._patch.storage_statistics_update, update=stats) self._index_writer_info = None self._cached_index_file_bytes = 0 + + def _write_record_column( + self, field: arrow.Field, + column: pa.ChunkedArray) -> Tuple[str, pa.StructArray]: + """Write record field into files. + + Returns: + A tuple (field_name, address_column). + """ + field_name = field.name + + # TODO: this section needs to be locked when supporting threaded execution. + if field_name in self._record_writers: + writer_info = self._record_writers[field_name] + else: + file_path = paths.new_record_file_path(self._data_dir, field_name) + writer = ar.ArrayRecordWriter(file_path, options="") + writer_info = _RecordWriterInfo(writer, self.short_path(file_path)) + self._record_writers[field_name] = writer_info + + num_rows = column.length() + writer_info.storage_statistics.num_rows += num_rows + writer_info.storage_statistics.record_uncompressed_bytes += column.nbytes + + for chunk in column.chunks: + for v in chunk: + writer_info.writer.write(v.as_py()) + + # Generate record address field values to return. + next_row_id = writer_info.next_row_id + num_rows + address_column = utils.address_column(writer_info.file_path, + writer_info.next_row_id, num_rows) + writer_info.next_row_id = next_row_id + + # Materialize the file when size is over threshold. + if (writer_info.storage_statistics.record_uncompressed_bytes + > _MAX_ARRAY_RECORD_BYTES): + self._finish_record_writer(field, writer_info) + + return field_name, address_column + + def _finish_record_writer(self, field: arrow.Field, + writer_info: _RecordWriterInfo) -> None: + """Materialize a new record file (ArrayRecord), update metadata and + stats. + """ + writer_info.writer.close() + self._record_manifest_writer.write(writer_info.file_path, field.field_id, + writer_info.storage_statistics) + utils.update_record_stats_bytes(self._patch.storage_statistics_update, + writer_info.storage_statistics) + + del self._record_writers[field.name] diff --git a/python/src/space/core/ops/utils.py b/python/src/space/core/ops/utils.py index b619c49..50f9ae7 100644 --- a/python/src/space/core/ops/utils.py +++ b/python/src/space/core/ops/utils.py @@ -14,10 +14,14 @@ # """Utilities for operation classes.""" +import numpy as np +import pyarrow as pa + +from space.core.schema import arrow from space.core.proto import metadata_pb2 as meta -def update_index_storage_statistics( +def update_index_storage_stats( base: meta.StorageStatistics, update: meta.StorageStatistics, ) -> None: @@ -25,3 +29,20 @@ def update_index_storage_statistics( base.num_rows += update.num_rows base.index_compressed_bytes += update.index_compressed_bytes base.index_uncompressed_bytes += update.index_uncompressed_bytes + + +def update_record_stats_bytes(base: meta.StorageStatistics, + update: meta.StorageStatistics) -> None: + """Update record storage statistics.""" + base.record_uncompressed_bytes += update.record_uncompressed_bytes + + +def address_column(file_path: str, start_row: int, + num_rows: int) -> pa.StructArray: + """Construct an record address column by a file path and row ID range.""" + return pa.StructArray.from_arrays( + [ + [file_path] * num_rows, # type: ignore[arg-type] + np.arange(start_row, start_row + num_rows, dtype=np.int32) + ], + fields=arrow.record_address_types()) # type: ignore[arg-type] diff --git a/python/src/space/core/proto/metadata.proto b/python/src/space/core/proto/metadata.proto index 1215f95..801ed15 100644 --- a/python/src/space/core/proto/metadata.proto +++ b/python/src/space/core/proto/metadata.proto @@ -70,6 +70,9 @@ message Schema { // Primary key field names. Required but primary keys are un-enforced. repeated string primary_keys = 2; + + // Names of record fields that are stored in row formats (ArrayRecord). + repeated string record_fields = 3; } // Storage snapshot persisting physical metadata such as manifest file paths. @@ -84,7 +87,7 @@ message Snapshot { } // Statistics of storage data. -// NEXT_ID: 4 +// NEXT_ID: 5 message StorageStatistics { // Number of rows. int64 num_rows = 1; @@ -94,4 +97,7 @@ message StorageStatistics { // Uncompressed bytes of index data. int64 index_uncompressed_bytes = 3; + + // Uncompressed bytes of record data. + int64 record_uncompressed_bytes = 4; } diff --git a/python/src/space/core/proto/metadata_pb2.py b/python/src/space/core/proto/metadata_pb2.py index d6e061e..ec2af20 100644 --- a/python/src/space/core/proto/metadata_pb2.py +++ b/python/src/space/core/proto/metadata_pb2.py @@ -15,7 +15,7 @@ from substrait import type_pb2 as substrait_dot_type__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fspace/core/proto/metadata.proto\x12\x0bspace.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x14substrait/type.proto\"#\n\nEntryPoint\x12\x15\n\rmetadata_file\x18\x01 \x01(\t\"\xdb\x03\n\x0fStorageMetadata\x12/\n\x0b\x63reate_time\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x34\n\x10last_update_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12/\n\x04type\x18\x03 \x01(\x0e\x32!.space.proto.StorageMetadata.Type\x12#\n\x06schema\x18\x04 \x01(\x0b\x32\x13.space.proto.Schema\x12\x1b\n\x13\x63urrent_snapshot_id\x18\x05 \x01(\x03\x12>\n\tsnapshots\x18\x06 \x03(\x0b\x32+.space.proto.StorageMetadata.SnapshotsEntry\x12:\n\x12storage_statistics\x18\x07 \x01(\x0b\x32\x1e.space.proto.StorageStatistics\x1aG\n\x0eSnapshotsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.space.proto.Snapshot:\x02\x38\x01\")\n\x04Type\x12\x14\n\x10TYPE_UNSPECIFIED\x10\x00\x12\x0b\n\x07\x44\x41TASET\x10\x01\"F\n\x06Schema\x12&\n\x06\x66ields\x18\x01 \x01(\x0b\x32\x16.substrait.NamedStruct\x12\x14\n\x0cprimary_keys\x18\x02 \x03(\t\"P\n\x08Snapshot\x12\x13\n\x0bsnapshot_id\x18\x01 \x01(\x03\x12/\n\x0b\x63reate_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"g\n\x11StorageStatistics\x12\x10\n\x08num_rows\x18\x01 \x01(\x03\x12\x1e\n\x16index_compressed_bytes\x18\x02 \x01(\x03\x12 \n\x18index_uncompressed_bytes\x18\x03 \x01(\x03\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fspace/core/proto/metadata.proto\x12\x0bspace.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x14substrait/type.proto\"#\n\nEntryPoint\x12\x15\n\rmetadata_file\x18\x01 \x01(\t\"\xdb\x03\n\x0fStorageMetadata\x12/\n\x0b\x63reate_time\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x34\n\x10last_update_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12/\n\x04type\x18\x03 \x01(\x0e\x32!.space.proto.StorageMetadata.Type\x12#\n\x06schema\x18\x04 \x01(\x0b\x32\x13.space.proto.Schema\x12\x1b\n\x13\x63urrent_snapshot_id\x18\x05 \x01(\x03\x12>\n\tsnapshots\x18\x06 \x03(\x0b\x32+.space.proto.StorageMetadata.SnapshotsEntry\x12:\n\x12storage_statistics\x18\x07 \x01(\x0b\x32\x1e.space.proto.StorageStatistics\x1aG\n\x0eSnapshotsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.space.proto.Snapshot:\x02\x38\x01\")\n\x04Type\x12\x14\n\x10TYPE_UNSPECIFIED\x10\x00\x12\x0b\n\x07\x44\x41TASET\x10\x01\"]\n\x06Schema\x12&\n\x06\x66ields\x18\x01 \x01(\x0b\x32\x16.substrait.NamedStruct\x12\x14\n\x0cprimary_keys\x18\x02 \x03(\t\x12\x15\n\rrecord_fields\x18\x03 \x03(\t\"P\n\x08Snapshot\x12\x13\n\x0bsnapshot_id\x18\x01 \x01(\x03\x12/\n\x0b\x63reate_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"\x8a\x01\n\x11StorageStatistics\x12\x10\n\x08num_rows\x18\x01 \x01(\x03\x12\x1e\n\x16index_compressed_bytes\x18\x02 \x01(\x03\x12 \n\x18index_uncompressed_bytes\x18\x03 \x01(\x03\x12!\n\x19record_uncompressed_bytes\x18\x04 \x01(\x03\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'space.core.proto.metadata_pb2', globals()) @@ -33,9 +33,9 @@ _STORAGEMETADATA_TYPE._serialized_start=575 _STORAGEMETADATA_TYPE._serialized_end=616 _SCHEMA._serialized_start=618 - _SCHEMA._serialized_end=688 - _SNAPSHOT._serialized_start=690 - _SNAPSHOT._serialized_end=770 - _STORAGESTATISTICS._serialized_start=772 - _STORAGESTATISTICS._serialized_end=875 + _SCHEMA._serialized_end=711 + _SNAPSHOT._serialized_start=713 + _SNAPSHOT._serialized_end=793 + _STORAGESTATISTICS._serialized_start=796 + _STORAGESTATISTICS._serialized_end=934 # @@protoc_insertion_point(module_scope) diff --git a/python/src/space/core/proto/metadata_pb2.pyi b/python/src/space/core/proto/metadata_pb2.pyi index 949daaa..56fc217 100644 --- a/python/src/space/core/proto/metadata_pb2.pyi +++ b/python/src/space/core/proto/metadata_pb2.pyi @@ -153,20 +153,25 @@ class Schema(google.protobuf.message.Message): FIELDS_FIELD_NUMBER: builtins.int PRIMARY_KEYS_FIELD_NUMBER: builtins.int + RECORD_FIELDS_FIELD_NUMBER: builtins.int @property def fields(self) -> substrait.type_pb2.NamedStruct: """Fields persisted as Substrait named struct.""" @property def primary_keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: """Primary key field names. Required but primary keys are un-enforced.""" + @property + def record_fields(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Names of record fields that are stored in row formats (ArrayRecord).""" def __init__( self, *, fields: substrait.type_pb2.NamedStruct | None = ..., primary_keys: collections.abc.Iterable[builtins.str] | None = ..., + record_fields: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["fields", b"fields"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["fields", b"fields", "primary_keys", b"primary_keys"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["fields", b"fields", "primary_keys", b"primary_keys", "record_fields", b"record_fields"]) -> None: ... global___Schema = Schema @@ -200,7 +205,7 @@ global___Snapshot = Snapshot @typing_extensions.final class StorageStatistics(google.protobuf.message.Message): """Statistics of storage data. - NEXT_ID: 4 + NEXT_ID: 5 """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -208,19 +213,23 @@ class StorageStatistics(google.protobuf.message.Message): NUM_ROWS_FIELD_NUMBER: builtins.int INDEX_COMPRESSED_BYTES_FIELD_NUMBER: builtins.int INDEX_UNCOMPRESSED_BYTES_FIELD_NUMBER: builtins.int + RECORD_UNCOMPRESSED_BYTES_FIELD_NUMBER: builtins.int num_rows: builtins.int """Number of rows.""" index_compressed_bytes: builtins.int """Compressed bytes of index data.""" index_uncompressed_bytes: builtins.int """Uncompressed bytes of index data.""" + record_uncompressed_bytes: builtins.int + """Uncompressed bytes of record data.""" def __init__( self, *, num_rows: builtins.int = ..., index_compressed_bytes: builtins.int = ..., index_uncompressed_bytes: builtins.int = ..., + record_uncompressed_bytes: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["index_compressed_bytes", b"index_compressed_bytes", "index_uncompressed_bytes", b"index_uncompressed_bytes", "num_rows", b"num_rows"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["index_compressed_bytes", b"index_compressed_bytes", "index_uncompressed_bytes", b"index_uncompressed_bytes", "num_rows", b"num_rows", "record_uncompressed_bytes", b"record_uncompressed_bytes"]) -> None: ... global___StorageStatistics = StorageStatistics diff --git a/python/src/space/core/proto/runtime.proto b/python/src/space/core/proto/runtime.proto index 77624f2..5786096 100644 --- a/python/src/space/core/proto/runtime.proto +++ b/python/src/space/core/proto/runtime.proto @@ -27,8 +27,11 @@ message Patch { // Index manifest file paths to be removed from the storage. repeated string deleted_index_manifest_files = 2; + // Record manifest file paths newly added to the storage. + repeated string added_record_manifest_files = 3; + // The change of the storage statistics. - StorageStatistics storage_statistics_update = 3; + StorageStatistics storage_statistics_update = 4; } // Result of a job. diff --git a/python/src/space/core/proto/runtime_pb2.py b/python/src/space/core/proto/runtime_pb2.py index 97a3b99..fce8899 100644 --- a/python/src/space/core/proto/runtime_pb2.py +++ b/python/src/space/core/proto/runtime_pb2.py @@ -14,7 +14,7 @@ from space.core.proto import metadata_pb2 as space_dot_core_dot_proto_dot_metadata__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1espace/core/proto/runtime.proto\x12\x0bspace.proto\x1a\x1fspace/core/proto/metadata.proto\"\x94\x01\n\x05Patch\x12\"\n\x1a\x61\x64\x64\x65\x64_index_manifest_files\x18\x01 \x03(\t\x12$\n\x1c\x64\x65leted_index_manifest_files\x18\x02 \x03(\t\x12\x41\n\x19storage_statistics_update\x18\x03 \x01(\x0b\x32\x1e.space.proto.StorageStatistics\"\xc3\x01\n\tJobResult\x12+\n\x05state\x18\x01 \x01(\x0e\x32\x1c.space.proto.JobResult.State\x12\x41\n\x19storage_statistics_update\x18\x02 \x01(\x0b\x32\x1e.space.proto.StorageStatistics\"F\n\x05State\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\r\n\tSUCCEEDED\x10\x01\x12\n\n\x06\x46\x41ILED\x10\x02\x12\x0b\n\x07SKIPPED\x10\x03\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1espace/core/proto/runtime.proto\x12\x0bspace.proto\x1a\x1fspace/core/proto/metadata.proto\"\xb9\x01\n\x05Patch\x12\"\n\x1a\x61\x64\x64\x65\x64_index_manifest_files\x18\x01 \x03(\t\x12$\n\x1c\x64\x65leted_index_manifest_files\x18\x02 \x03(\t\x12#\n\x1b\x61\x64\x64\x65\x64_record_manifest_files\x18\x03 \x03(\t\x12\x41\n\x19storage_statistics_update\x18\x04 \x01(\x0b\x32\x1e.space.proto.StorageStatistics\"\xc3\x01\n\tJobResult\x12+\n\x05state\x18\x01 \x01(\x0e\x32\x1c.space.proto.JobResult.State\x12\x41\n\x19storage_statistics_update\x18\x02 \x01(\x0b\x32\x1e.space.proto.StorageStatistics\"F\n\x05State\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\r\n\tSUCCEEDED\x10\x01\x12\n\n\x06\x46\x41ILED\x10\x02\x12\x0b\n\x07SKIPPED\x10\x03\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'space.core.proto.runtime_pb2', globals()) @@ -22,9 +22,9 @@ DESCRIPTOR._options = None _PATCH._serialized_start=81 - _PATCH._serialized_end=229 - _JOBRESULT._serialized_start=232 - _JOBRESULT._serialized_end=427 - _JOBRESULT_STATE._serialized_start=357 - _JOBRESULT_STATE._serialized_end=427 + _PATCH._serialized_end=266 + _JOBRESULT._serialized_start=269 + _JOBRESULT._serialized_end=464 + _JOBRESULT_STATE._serialized_start=394 + _JOBRESULT_STATE._serialized_end=464 # @@protoc_insertion_point(module_scope) diff --git a/python/src/space/core/proto/runtime_pb2.pyi b/python/src/space/core/proto/runtime_pb2.pyi index b72be8e..4e809a8 100644 --- a/python/src/space/core/proto/runtime_pb2.pyi +++ b/python/src/space/core/proto/runtime_pb2.pyi @@ -42,6 +42,7 @@ class Patch(google.protobuf.message.Message): ADDED_INDEX_MANIFEST_FILES_FIELD_NUMBER: builtins.int DELETED_INDEX_MANIFEST_FILES_FIELD_NUMBER: builtins.int + ADDED_RECORD_MANIFEST_FILES_FIELD_NUMBER: builtins.int STORAGE_STATISTICS_UPDATE_FIELD_NUMBER: builtins.int @property def added_index_manifest_files(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: @@ -50,6 +51,9 @@ class Patch(google.protobuf.message.Message): def deleted_index_manifest_files(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: """Index manifest file paths to be removed from the storage.""" @property + def added_record_manifest_files(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Record manifest file paths newly added to the storage.""" + @property def storage_statistics_update(self) -> space.core.proto.metadata_pb2.StorageStatistics: """The change of the storage statistics.""" def __init__( @@ -57,10 +61,11 @@ class Patch(google.protobuf.message.Message): *, added_index_manifest_files: collections.abc.Iterable[builtins.str] | None = ..., deleted_index_manifest_files: collections.abc.Iterable[builtins.str] | None = ..., + added_record_manifest_files: collections.abc.Iterable[builtins.str] | None = ..., storage_statistics_update: space.core.proto.metadata_pb2.StorageStatistics | None = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["storage_statistics_update", b"storage_statistics_update"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["added_index_manifest_files", b"added_index_manifest_files", "deleted_index_manifest_files", b"deleted_index_manifest_files", "storage_statistics_update", b"storage_statistics_update"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["added_index_manifest_files", b"added_index_manifest_files", "added_record_manifest_files", b"added_record_manifest_files", "deleted_index_manifest_files", b"deleted_index_manifest_files", "storage_statistics_update", b"storage_statistics_update"]) -> None: ... global___Patch = Patch diff --git a/python/src/space/core/schema/arrow.py b/python/src/space/core/schema/arrow.py index da3f0ca..898c240 100644 --- a/python/src/space/core/schema/arrow.py +++ b/python/src/space/core/schema/arrow.py @@ -15,13 +15,13 @@ """Utilities for schemas in the Arrow format.""" from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set, Tuple import pyarrow as pa from substrait.type_pb2 import NamedStruct, Type from space.core.utils.constants import UTF_8 -from space.core.schema.constants import TF_FEATURES_TYPE +from space.core.schema import constants from space.core.schema.types import TfFeatures _PARQUET_FIELD_ID_KEY = b"PARQUET:field_id" @@ -121,7 +121,7 @@ def _user_defined_arrow_type(type_: Type) -> pa.ExtensionType: type_name = type_.user_defined.type_parameters[0].string serialized = type_.user_defined.type_parameters[1].string - if type_name == TF_FEATURES_TYPE: + if type_name == constants.TF_FEATURES_TYPE: return TfFeatures.__arrow_ext_deserialize__( None, serialized) # type: ignore[arg-type] @@ -140,3 +140,52 @@ def field_id_to_column_id_dict(schema: pa.Schema) -> Dict[int, int]: field_id_dict[name]: column_id for column_id, name in enumerate(schema.names) } + + +@dataclass +class Field: + """Information of a field.""" + name: str + field_id: int + + +def classify_fields( + schema: pa.Schema, + record_fields: Set[str], + selected_fields: Optional[Set[str]] = None +) -> Tuple[List[Field], List[Field]]: + """Classify fields into indexes and records. + + Args: + schema: storage logical or physical schema. + record_fields: names of record fields. + selected_fields: selected fields to be accessed. + + Returns: + A tuple (index_fields, record_fields). + """ + index_fields: List[Field] = [] + record_fields_: List[Field] = [] + + for f in schema: + if selected_fields is not None and f.name not in selected_fields: + continue + + field = Field(f.name, field_id(f)) + if f.name in record_fields: + record_fields_.append(field) + else: + index_fields.append(field) + + return index_fields, record_fields_ + + +def field_names(fields: List[Field]) -> List[str]: + """Extract field names from a list of fields.""" + return list(map(lambda f: f.name, fields)) + + +def record_address_types() -> List[Tuple[str, pa.DataType]]: + """Returns Arrow fields of record addresses.""" + return [(constants.FILE_PATH_FIELD, pa.string()), + (constants.ROW_ID_FIELD, pa.int32())] diff --git a/python/src/space/core/schema/constants.py b/python/src/space/core/schema/constants.py index b910a34..91daa6a 100644 --- a/python/src/space/core/schema/constants.py +++ b/python/src/space/core/schema/constants.py @@ -16,3 +16,10 @@ # Substrait type name of Arrow custom type TfFeatures. TF_FEATURES_TYPE = "TF_FEATURES" + +FILE_PATH_FIELD = "_FILE" +ROW_ID_FIELD = "_ROW_ID" +FIELD_ID_FIELD = "_FIELD_ID" + +NUM_ROWS_FIELD = "_NUM_ROWS" +UNCOMPRESSED_BYTES_FIELD = "_UNCOMPRESSED_BYTES" diff --git a/python/src/space/core/storage.py b/python/src/space/core/storage.py index 921b883..4008ecd 100644 --- a/python/src/space/core/storage.py +++ b/python/src/space/core/storage.py @@ -75,20 +75,19 @@ def snapshot(self, snapshot_id: Optional[int] = None) -> meta.Snapshot: @classmethod def create( - cls, - location: str, - schema: pa.Schema, - primary_keys: List[str], - ) -> Storage: # pylint: disable=unused-argument + cls, location: str, schema: pa.Schema, primary_keys: List[str], + record_fields: List[str]) -> Storage: # pylint: disable=unused-argument """Create a new empty storage. Args: location: the directory path to the storage. schema: the schema of the storage. primary_keys: un-enforced primary keys. + record_fields: fields stored in row format (ArrayRecord). """ # TODO: to verify that location is an empty directory. - # TODO: to verify primary key fields (and types) are valid. + # TODO: to verify primary key fields and record_fields (and types) are + # valid. field_id_mgr = FieldIdManager() schema = field_id_mgr.assign_field_ids(schema) @@ -98,8 +97,11 @@ def create( metadata = meta.StorageMetadata( create_time=now, last_update_time=now, - schema=meta.Schema(fields=substrait_schema.substrait_fields(schema), - primary_keys=primary_keys), + schema=meta.Schema( + fields=substrait_schema.substrait_fields(schema), + primary_keys=primary_keys, + # TODO: to optionally auto infer record fields. + record_fields=record_fields), current_snapshot_id=_INIT_SNAPSHOT_ID, type=meta.StorageMetadata.DATASET) diff --git a/python/src/space/core/utils/lazy_imports_utils.py b/python/src/space/core/utils/lazy_imports_utils.py new file mode 100644 index 0000000..f211e6e --- /dev/null +++ b/python/src/space/core/utils/lazy_imports_utils.py @@ -0,0 +1,192 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" Lazy import utilities. + +This file uses code from +https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/core/utils/lazy_imports_utils.py +""" + +from __future__ import annotations + +import builtins +import contextlib +import dataclasses +import functools +import importlib +import time +import types +from typing import Any, Callable, Iterator, Optional, Tuple + +Callback = Callable[..., None] + + +@dataclasses.dataclass +class LazyModule: + """Module loaded lazily during first call.""" + + module_name: str + module: Optional[types.ModuleType] = None + fromlist: Optional[Tuple[str, ...]] = () + error_callback: Optional[Callback] = None + success_callback: Optional[Callback] = None + + @classmethod + @functools.lru_cache(maxsize=None) + def from_cache(cls, **kwargs): + """Factory to cache all instances of module. + + Note: The cache is global to all instances of the + `lazy_imports` context manager. + + Args: + **kwargs: Init kwargs + + Returns: + New object + """ + return cls(**kwargs) + + def __getattr__(self, name: str) -> Any: + if self.fromlist and name in self.fromlist: + module_name = f"{self.module_name}.{name}" + return self.from_cache( + module_name=module_name, + module=self.module, + fromlist=self.fromlist, + error_callback=self.error_callback, + success_callback=self.success_callback, + ) + if self.module is None: # Load on first call + try: + start_import_time = time.perf_counter() + self.module = importlib.import_module(self.module_name) + import_time_ms = (time.perf_counter() - start_import_time) * 1000 + if self.success_callback is not None: + self.success_callback( + import_time_ms=import_time_ms, + module=self.module, + module_name=self.module_name, + ) + except ImportError as exception: + if self.error_callback is not None: + self.error_callback(exception=exception, + module_name=self.module_name) + raise exception + return getattr(self.module, name) + + +@contextlib.contextmanager +def lazy_imports( + error_callback: Optional[Callback] = None, + success_callback: Optional[Callback] = None, +) -> Iterator[None]: + """Context Manager which lazy loads packages. + + Their import is not executed immediately, but is postponed to the first + call of one of their attributes. + + Warning: mind current implementation's limitations: + + - You can only lazy load modules (`from x import y` will not work if `y` is a + constant or a function or a class). + - You cannot `import x.y` if `y` is not imported in the `x/__init__.py`. + + Usage: + + ```python + from tensorflow_datasets.core.utils.lazy_imports_utils import lazy_imports + + with lazy_imports(): + import tensorflow as tf + ``` + + Args: + error_callback: a callback to trigger when an import fails. The callback is + passed kwargs containing: 1) exception (ImportError): the exception that + was raised after the error; 2) module_name (str): the name of the imported + module. + success_callback: a callback to trigger when an import succeeds. The + callback is passed kwargs containing: 1) import_time_ms (float): the + import time (in milliseconds); 2) module (Any): the imported module; 3) + module_name (str): the name of the imported module. + + Yields: + None + """ + # Need to mock `__import__` (instead of `sys.meta_path`, as we do not want + # to modify the `sys.modules` cache in any way) + original_import = builtins.__import__ + try: + builtins.__import__ = functools.partial( + _lazy_import, + error_callback=error_callback, + success_callback=success_callback, + ) + yield + finally: + builtins.__import__ = original_import + + +# pylint: disable=too-many-arguments +def _lazy_import( + name: str, + globals_=None, + locals_=None, + fromlist: tuple[str, ...] = (), + level: int = 0, + *, + error_callback: Optional[Callback], + success_callback: Optional[Callback], +): + """Mock of `builtins.__import__`.""" + del globals_, locals_ # Unused + + if level: + raise ValueError(f"Relative import statements not supported ({name}).") + + if not fromlist: + # import x.y.z + # import x.y.z as z + # In that case, Python would usually import the entirety of `x` if each + # submodule is imported in its parent's `__init__.py`. So we do the same. + root_name = name.split(".")[0] + return LazyModule.from_cache( + module_name=root_name, + error_callback=error_callback, + success_callback=success_callback, + ) + # from x.y.z import a, b + return LazyModule.from_cache( + module_name=name, + fromlist=fromlist, + error_callback=error_callback, + success_callback=success_callback, + ) + + +def array_record_error_callback(**kwargs): + """Print error of ArrayRecord import error.""" + del kwargs + print("\n\n***************************************************************") + print( + "Failed to import ArrayRecord. This probably means that you are running" + " on macOS or Windows. ArrayRecord currently does not work for your" + " infrastructure, because it uses Python bindings in C++. We are actively" + " working on this issue. Thanks for your understanding.") + print("***************************************************************\n\n") + + +with lazy_imports(error_callback=array_record_error_callback): + from array_record.python import array_record_module # type: ignore[import-untyped] # pylint: disable=unused-import diff --git a/python/tests/core/manifests/test_index.py b/python/tests/core/manifests/test_index.py index fb821ed..fe6cfc8 100644 --- a/python/tests/core/manifests/test_index.py +++ b/python/tests/core/manifests/test_index.py @@ -17,7 +17,7 @@ import pyarrow as pa import pyarrow.parquet as pq -from space.core.manifests.index import IndexManifestWriter +from space.core.manifests import IndexManifestWriter from space.core.schema.arrow import field_metadata _SCHEMA = pa.schema([ diff --git a/python/tests/core/manifests/test_utils.py b/python/tests/core/manifests/test_utils.py new file mode 100644 index 0000000..030098c --- /dev/null +++ b/python/tests/core/manifests/test_utils.py @@ -0,0 +1,29 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa + +from space.core.manifests import utils + + +def test_write_parquet_file(tmp_path): + data_dir = tmp_path / "file.parquet" + + file_path = str(data_dir) + returned_path = utils.write_parquet_file( + file_path, pa.schema([("int64", pa.int64())]), + pa.Table.from_pydict({"int64": [1, 2]})) + + assert data_dir.exists() + assert returned_path == file_path diff --git a/python/tests/core/ops/test_append.py b/python/tests/core/ops/test_append.py index 23122cd..bd4b67b 100644 --- a/python/tests/core/ops/test_append.py +++ b/python/tests/core/ops/test_append.py @@ -32,7 +32,8 @@ def test_write_pydict_all_types(self, tmp_path): ]) storage = Storage.create(location=str(location), schema=schema, - primary_keys=["int64"]) + primary_keys=["int64"], + record_fields=[]) op = LocalAppendOp(str(location), storage.metadata) @@ -79,7 +80,8 @@ def test_empty_op_return_none(self, tmp_path): schema = pa.schema([pa.field("int64", pa.int64())]) storage = Storage.create(location=str(location), schema=schema, - primary_keys=["int64"]) + primary_keys=["int64"], + record_fields=[]) op = LocalAppendOp(str(location), storage.metadata) assert op.finish() is None diff --git a/python/tests/core/ops/test_utils.py b/python/tests/core/ops/test_utils.py index 54aa3c9..fdc1ba6 100644 --- a/python/tests/core/ops/test_utils.py +++ b/python/tests/core/ops/test_utils.py @@ -16,12 +16,12 @@ from space.core.proto import metadata_pb2 as meta -def test_update_index_storage_statistics_positive(): +def test_update_index_storage_stats_positive(): base = meta.StorageStatistics(num_rows=100, index_compressed_bytes=200, index_uncompressed_bytes=300) - utils.update_index_storage_statistics( + utils.update_index_storage_stats( base, meta.StorageStatistics(num_rows=10, index_compressed_bytes=20, @@ -31,12 +31,12 @@ def test_update_index_storage_statistics_positive(): index_uncompressed_bytes=330) -def test_update_index_storage_statistics_negative(): +def test_update_index_storage_stats_negative(): base = meta.StorageStatistics(num_rows=100, index_compressed_bytes=200, index_uncompressed_bytes=300) - utils.update_index_storage_statistics( + utils.update_index_storage_stats( base, meta.StorageStatistics(num_rows=-10, index_compressed_bytes=-20, diff --git a/python/tests/core/test_storage.py b/python/tests/core/test_storage.py index 90c7e85..eb6debf 100644 --- a/python/tests/core/test_storage.py +++ b/python/tests/core/test_storage.py @@ -68,7 +68,8 @@ def test_create_storage(self, tmp_path): location = tmp_path / "dataset" storage = Storage.create(location=str(location), schema=_SCHEMA, - primary_keys=["int64"]) + primary_keys=["int64"], + record_fields=["string"]) entry_point_file = location / "metadata" / _ENTRY_POINT_FILE assert entry_point_file.exists() @@ -87,13 +88,15 @@ def test_create_storage(self, tmp_path): Type(i64=Type.I64(type_variation_reference=0)), Type(string=Type.String(type_variation_reference=1)) ])), - primary_keys=["int64"]) + primary_keys=["int64"], + record_fields=["string"]) def test_load_storage(self, tmp_path): location = tmp_path / "dataset" storage = Storage.create(location=str(location), schema=_SCHEMA, - primary_keys=["int64"]) + primary_keys=["int64"], + record_fields=[]) loaded_storage = Storage.load(str(location)) assert loaded_storage.metadata == storage.metadata From 243ce7e2a76e34a931347cf3247b292163225eb6 Mon Sep 17 00:00:00 2001 From: coufon Date: Fri, 22 Dec 2023 05:48:36 +0000 Subject: [PATCH 2/3] Add unit test for record manifest writer --- python/src/space/core/manifests/utils.py | 3 +- python/tests/core/manifests/test_index.py | 13 ++--- python/tests/core/manifests/test_record.py | 57 ++++++++++++++++++++++ python/tests/core/manifests/test_utils.py | 6 +-- python/tests/core/ops/test_utils.py | 32 ++++++++++++ python/tests/core/schema/test_arrow.py | 32 ++++++++++++ 6 files changed, 129 insertions(+), 14 deletions(-) create mode 100644 python/tests/core/manifests/test_record.py diff --git a/python/src/space/core/manifests/utils.py b/python/src/space/core/manifests/utils.py index cc99418..5ff32c4 100644 --- a/python/src/space/core/manifests/utils.py +++ b/python/src/space/core/manifests/utils.py @@ -19,11 +19,10 @@ def write_parquet_file(file_path: str, schema: pa.Schema, - data: pa.Table) -> str: + data: pa.Table) -> None: """Materialize a single Parquet file.""" # TODO: currently assume this file is small, so always write a single file. writer = pq.ParquetWriter(file_path, schema) writer.write_table(data) writer.close() - return file_path diff --git a/python/tests/core/manifests/test_index.py b/python/tests/core/manifests/test_index.py index fe6cfc8..fc15891 100644 --- a/python/tests/core/manifests/test_index.py +++ b/python/tests/core/manifests/test_index.py @@ -53,11 +53,10 @@ def test_write_all_types(self, tmp_path): schema=schema, primary_keys=["int64", "float64", "bool", "string"]) - file_path = str(data_dir / "file0") # TODO: the test should cover all types supported by column stats. manifest_writer.write( - file_path, - _write_parquet_file(file_path, schema, [{ + "data/file0", + _write_parquet_file(str(data_dir / "file0"), schema, [{ "int64": [1, 2, 3], "float64": [0.1, 0.2, 0.3], "bool": [True, False, False], @@ -68,10 +67,9 @@ def test_write_all_types(self, tmp_path): "bool": [False, False], "string": ["A", "z"] }])) - file_path = str(data_dir / "file1") manifest_writer.write( - file_path, - _write_parquet_file(file_path, schema, [{ + "data/file1", + _write_parquet_file(str(data_dir / "file1"), schema, [{ "int64": [1000, 1000000], "float64": [-0.001, 0.001], "bool": [False, False], @@ -80,10 +78,9 @@ def test_write_all_types(self, tmp_path): manifest_path = manifest_writer.finish() - data_dir_str = str(data_dir) assert manifest_path is not None assert pq.read_table(manifest_path).to_pydict() == { - "_FILE": [f"{data_dir_str}/file0", f"{data_dir_str}/file1"], + "_FILE": ["data/file0", "data/file1"], "_INDEX_COMPRESSED_BYTES": [645, 334], "_INDEX_UNCOMPRESSED_BYTES": [624, 320], "_NUM_ROWS": [5, 2], diff --git a/python/tests/core/manifests/test_record.py b/python/tests/core/manifests/test_record.py new file mode 100644 index 0000000..b254915 --- /dev/null +++ b/python/tests/core/manifests/test_record.py @@ -0,0 +1,57 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow.parquet as pq + +from space.core.manifests import RecordManifestWriter +import space.core.proto.metadata_pb2 as meta + + +class TestRecordManifestWriter: + + def test_write(self, tmp_path): + data_dir = tmp_path / "dataset" / "data" + metadata_dir = tmp_path / "dataset" / "metadata" + metadata_dir.mkdir(parents=True) + + manifest_writer = RecordManifestWriter(metadata_dir=str(metadata_dir)) + + manifest_writer.write( + "data/file0.arrayrecord", 0, + meta.StorageStatistics(num_rows=123, + index_compressed_bytes=10, + index_uncompressed_bytes=20, + record_uncompressed_bytes=30)) + manifest_writer.write( + "data/file1.arrayrecord", 1, + meta.StorageStatistics(num_rows=456, + index_compressed_bytes=10, + index_uncompressed_bytes=20, + record_uncompressed_bytes=100)) + + manifest_path = manifest_writer.finish() + + assert manifest_path is not None + assert pq.read_table(manifest_path).to_pydict() == { + "_FILE": ["data/file0.arrayrecord", "data/file1.arrayrecord"], + "_FIELD_ID": [0, 1], + "_NUM_ROWS": [123, 456], + "_UNCOMPRESSED_BYTES": [30, 100] + } + + def test_empty_manifest_should_return_none(self, tmp_path): + metadata_dir = tmp_path / "dataset" / "metadata" + manifest_writer = RecordManifestWriter(metadata_dir=str(metadata_dir)) + + assert manifest_writer.finish() is None diff --git a/python/tests/core/manifests/test_utils.py b/python/tests/core/manifests/test_utils.py index 030098c..bb7a54b 100644 --- a/python/tests/core/manifests/test_utils.py +++ b/python/tests/core/manifests/test_utils.py @@ -21,9 +21,7 @@ def test_write_parquet_file(tmp_path): data_dir = tmp_path / "file.parquet" file_path = str(data_dir) - returned_path = utils.write_parquet_file( - file_path, pa.schema([("int64", pa.int64())]), - pa.Table.from_pydict({"int64": [1, 2]})) + utils.write_parquet_file(file_path, pa.schema([("int64", pa.int64())]), + pa.Table.from_pydict({"int64": [1, 2]})) assert data_dir.exists() - assert returned_path == file_path diff --git a/python/tests/core/ops/test_utils.py b/python/tests/core/ops/test_utils.py index fdc1ba6..ba716d1 100644 --- a/python/tests/core/ops/test_utils.py +++ b/python/tests/core/ops/test_utils.py @@ -44,3 +44,35 @@ def test_update_index_storage_stats_negative(): assert base == meta.StorageStatistics(num_rows=90, index_compressed_bytes=180, index_uncompressed_bytes=270) + + +def test_update_record_stats_bytes(): + base = meta.StorageStatistics(num_rows=100, + index_compressed_bytes=200, + index_uncompressed_bytes=300, + record_uncompressed_bytes=1000) + + utils.update_record_stats_bytes( + base, + meta.StorageStatistics(num_rows=-10, + index_compressed_bytes=-20, + record_uncompressed_bytes=-100)) + assert base == meta.StorageStatistics(num_rows=100, + index_compressed_bytes=200, + index_uncompressed_bytes=300, + record_uncompressed_bytes=900) + + +def test_address_column(): + result = [{ + "_FILE": "data/file.arrayrecord", + "_ROW_ID": 2 + }, { + "_FILE": "data/file.arrayrecord", + "_ROW_ID": 3 + }, { + "_FILE": "data/file.arrayrecord", + "_ROW_ID": 4 + }] + assert utils.address_column("data/file.arrayrecord", 2, + 3).to_pylist() == result diff --git a/python/tests/core/schema/test_arrow.py b/python/tests/core/schema/test_arrow.py index 64dbfdd..420c220 100644 --- a/python/tests/core/schema/test_arrow.py +++ b/python/tests/core/schema/test_arrow.py @@ -49,3 +49,35 @@ def test_field_id_to_column_id_dict(sample_arrow_schema): 220: 3, 260: 4 } + + +def test_classify_fields(sample_arrow_schema): + index_fields, record_fields = arrow.classify_fields(sample_arrow_schema, + ["float32", "list"]) + + assert index_fields == [ + arrow.Field("struct", 150), + arrow.Field("list_struct", 220), + arrow.Field("struct_list", 260) + ] + assert record_fields == [ + arrow.Field("float32", 100), + arrow.Field("list", 120) + ] + + +def test_classify_fields_with_selected_fields(sample_arrow_schema): + index_fields, record_fields = arrow.classify_fields(sample_arrow_schema, + ["float32", "list"], + ["list", "struct"]) + + assert index_fields == [arrow.Field("struct", 150)] + assert record_fields == [arrow.Field("list", 120)] + + +def test_field_names(): + assert arrow.field_names([ + arrow.Field("struct", 150), + arrow.Field("list_struct", 220), + arrow.Field("struct_list", 260) + ]) == ["struct", "list_struct", "struct_list"] From 612215618cf47128a82c9ebf30077f63eed2a5dd Mon Sep 17 00:00:00 2001 From: coufon Date: Fri, 22 Dec 2023 22:24:06 +0000 Subject: [PATCH 3/3] Add unit tests of appending records and schema methods --- python/src/space/core/ops/append.py | 10 +- python/src/space/core/schema/arrow.py | 29 ++++-- .../space/core/schema/types/tf_features.py | 11 ++- python/tests/core/manifests/test_record.py | 1 - python/tests/core/ops/test_append.py | 94 +++++++++++++++++++ python/tests/core/schema/conftest.py | 34 ++++++- python/tests/core/schema/test_arrow.py | 30 +++++- .../core/schema/types/test_tf_features.py | 6 ++ 8 files changed, 193 insertions(+), 22 deletions(-) diff --git a/python/src/space/core/ops/append.py b/python/src/space/core/ops/append.py index 6c0f84c..761cce6 100644 --- a/python/src/space/core/ops/append.py +++ b/python/src/space/core/ops/append.py @@ -84,12 +84,12 @@ def __init__(self, location: str, metadata: meta.StorageMetadata): StoragePaths.__init__(self, location) self._metadata = metadata - self._schema = arrow.arrow_schema(self._metadata.schema.fields) - + record_fields = set(self._metadata.schema.record_fields) + self._schema = arrow.arrow_schema(self._metadata.schema.fields, + record_fields, + physical=True) self._index_fields, self._record_fields = arrow.classify_fields( - self._schema, - set(self._metadata.schema.record_fields), - selected_fields=None) + self._schema, record_fields, selected_fields=None) # Data file writers. self._index_writer_info: Optional[_IndexWriterInfo] = None diff --git a/python/src/space/core/schema/arrow.py b/python/src/space/core/schema/arrow.py index 898c240..8f7961b 100644 --- a/python/src/space/core/schema/arrow.py +++ b/python/src/space/core/schema/arrow.py @@ -20,9 +20,9 @@ import pyarrow as pa from substrait.type_pb2 import NamedStruct, Type -from space.core.utils.constants import UTF_8 from space.core.schema import constants from space.core.schema.types import TfFeatures +from space.core.utils.constants import UTF_8 _PARQUET_FIELD_ID_KEY = b"PARQUET:field_id" @@ -49,11 +49,13 @@ def next(self) -> str: return name -def arrow_schema(fields: NamedStruct) -> pa.Schema: +def arrow_schema(fields: NamedStruct, record_fields: Set[str], + physical: bool) -> pa.Schema: """Return Arrow schema from Substrait fields. Args: fields: schema fields in the Substrait format. + record_fields: a set of record field names. physical: if true, return the physical schema. Physical schema matches with the underlying index (Parquet) file schema. Record fields are stored by their references, e.g., row position in ArrayRecord file. @@ -61,19 +63,28 @@ def arrow_schema(fields: NamedStruct) -> pa.Schema: return pa.schema( _arrow_fields( _NamesVisitor(fields.names), # type: ignore[arg-type] - fields.struct.types)) # type: ignore[arg-type] + fields.struct.types, # type: ignore[arg-type] + record_fields, + physical)) -def _arrow_fields(names_visitor: _NamesVisitor, - types: List[Type]) -> List[pa.Field]: +def _arrow_fields(names_visitor: _NamesVisitor, types: List[Type], + record_fields: Set[str], physical: bool) -> List[pa.Field]: fields: List[pa.Field] = [] for type_ in types: name = names_visitor.next() - arrow_field = pa.field(name, - _arrow_type(type_, names_visitor), - metadata=field_metadata(_substrait_field_id(type_))) - fields.append(arrow_field) + + if physical and name in record_fields: + arrow_type: pa.DataType = pa.struct( + record_address_types()) # type: ignore[arg-type] + else: + arrow_type = _arrow_type(type_, names_visitor) + + fields.append( + pa.field(name, + arrow_type, + metadata=field_metadata(_substrait_field_id(type_)))) return fields diff --git a/python/src/space/core/schema/types/tf_features.py b/python/src/space/core/schema/types/tf_features.py index 1e3dcbd..c183000 100644 --- a/python/src/space/core/schema/types/tf_features.py +++ b/python/src/space/core/schema/types/tf_features.py @@ -15,7 +15,7 @@ """Define a custom Arrow type for Tensorflow Dataset Features.""" from __future__ import annotations -from typing import Any +from typing import Any, Union import json import pyarrow as pa @@ -47,9 +47,12 @@ def __arrow_ext_serialize__(self) -> bytes: def __arrow_ext_deserialize__( cls, storage_type: pa.DataType, # pylint: disable=unused-argument - serialized: bytes) -> TfFeatures: - return TfFeatures( - f.FeaturesDict.from_json(json.loads(serialized.decode(UTF_8)))) + serialized: Union[bytes, str] + ) -> TfFeatures: + if isinstance(serialized, bytes): + serialized = serialized.decode(UTF_8) + + return TfFeatures(f.FeaturesDict.from_json(json.loads(serialized))) def serialize(self, value: Any) -> bytes: """Serialize value using the provided features_dict.""" diff --git a/python/tests/core/manifests/test_record.py b/python/tests/core/manifests/test_record.py index b254915..a77be0f 100644 --- a/python/tests/core/manifests/test_record.py +++ b/python/tests/core/manifests/test_record.py @@ -21,7 +21,6 @@ class TestRecordManifestWriter: def test_write(self, tmp_path): - data_dir = tmp_path / "dataset" / "data" metadata_dir = tmp_path / "dataset" / "metadata" metadata_dir.mkdir(parents=True) diff --git a/python/tests/core/ops/test_append.py b/python/tests/core/ops/test_append.py index bd4b67b..8e4ab95 100644 --- a/python/tests/core/ops/test_append.py +++ b/python/tests/core/ops/test_append.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List +import numpy as np import pyarrow as pa import pyarrow.parquet as pq +from tensorflow_datasets import features # type: ignore[import-untyped] from space.core.ops import LocalAppendOp import space.core.proto.metadata_pb2 as meta +from space.core.schema.types import TfFeatures from space.core.storage import Storage class TestLocalAppendOp: + # TODO: to add tests using Arrow table input. + def test_write_pydict_all_types(self, tmp_path): location = tmp_path / "dataset" schema = pa.schema([ @@ -75,6 +81,82 @@ def test_write_pydict_all_types(self, tmp_path): assert patch.storage_statistics_update == meta.StorageStatistics( num_rows=5, index_compressed_bytes=114, index_uncompressed_bytes=126) + def test_write_pydict_with_record_fields(self, tmp_path): + tf_features_images = features.FeaturesDict( + {"images": features.Image(shape=(None, None, 3), dtype=np.uint8)}) + tf_features_objects = features.FeaturesDict({ + "objects": + features.Sequence({ + "bbox": features.BBoxFeature(), + "id": np.int64 + }), + }) + + location = tmp_path / "dataset" + schema = pa.schema([ + pa.field("int64", pa.int64()), + pa.field("string", pa.string()), + pa.field("images", TfFeatures(tf_features_images)), + pa.field("objects", TfFeatures(tf_features_objects)) + ]) + storage = Storage.create(location=str(location), + schema=schema, + primary_keys=["int64"], + record_fields=["images", "objects"]) + + op = LocalAppendOp(str(location), storage.metadata) + + op.write({ + "int64": [1, 2, 3], + "string": ["a", "b", "c"], + "images": [b"images0", b"images1", b"images2"], + "objects": [b"objects0", b"objects1", b"objects2"] + }) + op.write({ + "int64": [0, 10], + "string": ["A", "z"], + "images": [b"images3", b"images4"], + "objects": [b"objects3", b"objects4"] + }) + + patch = op.finish() + assert patch is not None + + # Validate index manifest files. + index_manifest = self._read_manifests( + storage, list(patch.added_index_manifest_files)) + assert index_manifest == { + "_FILE": index_manifest["_FILE"], + "_INDEX_COMPRESSED_BYTES": [114], + "_INDEX_UNCOMPRESSED_BYTES": [126], + "_NUM_ROWS": [5], + "_STATS_f0": [{ + "_MAX": 10, + "_MIN": 0 + }] + } + + # Validate record manifest files. + record_manifest = self._read_manifests( + storage, list(patch.added_record_manifest_files)) + assert record_manifest == { + "_FILE": record_manifest["_FILE"], + "_FIELD_ID": [2, 3], + "_NUM_ROWS": [5, 5], + "_UNCOMPRESSED_BYTES": [55, 60] + } + + # Data file exists. + self._check_file_exists(location, index_manifest["_FILE"]) + self._check_file_exists(location, record_manifest["_FILE"]) + + # Validate statistics. + assert patch.storage_statistics_update == meta.StorageStatistics( + num_rows=5, + index_compressed_bytes=114, + index_uncompressed_bytes=126, + record_uncompressed_bytes=115) + def test_empty_op_return_none(self, tmp_path): location = tmp_path / "dataset" schema = pa.schema([pa.field("int64", pa.int64())]) @@ -85,3 +167,15 @@ def test_empty_op_return_none(self, tmp_path): op = LocalAppendOp(str(location), storage.metadata) assert op.finish() is None + + def _read_manifests(self, storage: Storage, + file_paths: List[str]) -> pa.Table: + manifests = [] + for f in file_paths: + manifests.append(pq.read_table(storage.full_path(f))) + + return pa.concat_tables(manifests).to_pydict() + + def _check_file_exists(self, location, file_paths: List[str]): + for f in file_paths: + assert (location / f).exists() diff --git a/python/tests/core/schema/conftest.py b/python/tests/core/schema/conftest.py index 4205696..58d4f5f 100644 --- a/python/tests/core/schema/conftest.py +++ b/python/tests/core/schema/conftest.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest +import json +import numpy as np import pyarrow as pa +import pytest +from tensorflow_datasets import features # type: ignore[import-untyped] from substrait.type_pb2 import NamedStruct, Type from space.core.schema.arrow import field_metadata +from space.core.schema.types import TfFeatures @pytest.fixture @@ -92,3 +96,31 @@ def sample_arrow_schema(): ]), metadata=field_metadata(260)) ]) + + +@pytest.fixture +def tf_features(): + return features.FeaturesDict( + {"images": features.Image(shape=(None, None, 3), dtype=np.uint8)}) + + +@pytest.fixture +def tf_features_substrait_fields(tf_features): # pylint: disable=redefined-outer-name + return NamedStruct( + names=["int64", "features"], + struct=Type.Struct(types=[ + Type(i64=Type.I64(type_variation_reference=0)), + Type(user_defined=Type.UserDefined(type_parameters=[ + Type.Parameter(string="TF_FEATURES"), + Type.Parameter(string=json.dumps(tf_features.to_json())) + ], + type_variation_reference=1)) + ])) + + +@pytest.fixture +def tf_features_arrow_schema(tf_features): # pylint: disable=redefined-outer-name + return pa.schema([ + pa.field("int64", pa.int64(), metadata=field_metadata(0)), + pa.field("features", TfFeatures(tf_features), metadata=field_metadata(1)) + ]) diff --git a/python/tests/core/schema/test_arrow.py b/python/tests/core/schema/test_arrow.py index 420c220..9d7c409 100644 --- a/python/tests/core/schema/test_arrow.py +++ b/python/tests/core/schema/test_arrow.py @@ -15,6 +15,7 @@ import pyarrow as pa from space.core.schema import arrow +from space.core.schema.arrow import field_metadata def test_field_metadata(): @@ -27,8 +28,33 @@ def test_field_id(): b"123"})) == 123 -def test_arrow_schema(sample_substrait_fields, sample_arrow_schema): - assert sample_arrow_schema == arrow.arrow_schema(sample_substrait_fields) +def test_arrow_schema_logical_without_records(sample_substrait_fields, + sample_arrow_schema): + assert arrow.arrow_schema(sample_substrait_fields, [], + False) == sample_arrow_schema + + +def test_arrow_schema_logical_with_records(tf_features_substrait_fields, + tf_features_arrow_schema): + assert arrow.arrow_schema(tf_features_substrait_fields, [], + False) == tf_features_arrow_schema + + +def test_arrow_schema_physical_without_records(sample_substrait_fields, + sample_arrow_schema): + assert arrow.arrow_schema(sample_substrait_fields, [], + True) == sample_arrow_schema + + +def test_arrow_schema_physical_with_records(tf_features_substrait_fields): + arrow_schema = pa.schema([ + pa.field("int64", pa.int64(), metadata=field_metadata(0)), + pa.field("features", + pa.struct([("_FILE", pa.string()), ("_ROW_ID", pa.int32())]), + metadata=field_metadata(1)) + ]) + assert arrow.arrow_schema(tf_features_substrait_fields, ["features"], + True) == arrow_schema def test_field_name_to_id_dict(sample_arrow_schema): diff --git a/python/tests/core/schema/types/test_tf_features.py b/python/tests/core/schema/types/test_tf_features.py index 1f988d6..5d0d019 100644 --- a/python/tests/core/schema/types/test_tf_features.py +++ b/python/tests/core/schema/types/test_tf_features.py @@ -53,10 +53,16 @@ def test_arrow_ext_serialize_deserialize(self, tf_features, sample_objects): "type"] == "tensorflow_datasets.core.features.features_dict.FeaturesDict" # pylint: disable=line-too-long assert "sequence" in features_dict["content"]["features"]["objects"] + # Bytes input. tf_features = TfFeatures.__arrow_ext_deserialize__(storage_type=None, serialized=serialized) assert len(tf_features.serialize(sample_objects)) > 0 + # String input. + tf_features = TfFeatures.__arrow_ext_deserialize__( + storage_type=None, serialized=serialized.decode(UTF_8)) + assert len(tf_features.serialize(sample_objects)) > 0 + def test_serialize_deserialize(self, tf_features, sample_objects): value_bytes = tf_features.serialize(sample_objects) assert len(value_bytes) > 0