diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index d4414c7c52..f14ed123a3 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -43,6 +43,7 @@ from enum import Enum from functools import lru_cache, singledispatch from typing import ( + IO, TYPE_CHECKING, Any, Generic, @@ -122,6 +123,7 @@ OutputStream, ) from pyiceberg.io.fileformat import DataFileStatistics as DataFileStatistics +from pyiceberg.io.fileformat import FileFormatFactory, FileFormatModel, FileFormatWriter from pyiceberg.manifest import ( DataFile, DataFileContent, @@ -1884,6 +1886,7 @@ def _to_requested_schema( include_field_ids: bool = False, projected_missing_fields: dict[int, Any] = EMPTY_DICT, allow_timestamp_tz_mismatch: bool = False, + file_format: FileFormat = FileFormat.PARQUET, ) -> pa.RecordBatch: # We could reuse some of these visitors struct_array = visit_with_partner( @@ -1895,6 +1898,7 @@ def _to_requested_schema( include_field_ids, projected_missing_fields=projected_missing_fields, allow_timestamp_tz_mismatch=allow_timestamp_tz_mismatch, + file_format=file_format, ), ArrowAccessor(file_schema), ) @@ -1907,6 +1911,7 @@ class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, pa.Array | None] _downcast_ns_timestamp_to_us: bool _projected_missing_fields: dict[int, Any] _allow_timestamp_tz_mismatch: bool + _file_format: FileFormat def __init__( self, @@ -1915,6 +1920,7 @@ def __init__( include_field_ids: bool = False, projected_missing_fields: dict[int, Any] = EMPTY_DICT, allow_timestamp_tz_mismatch: bool = False, + file_format: FileFormat = FileFormat.PARQUET, ) -> None: self._file_schema = file_schema self._include_field_ids = include_field_ids @@ -1923,6 +1929,7 @@ def __init__( # When True, allows projecting timestamptz (UTC) to timestamp (no tz). # Allowed for reading (aligns with Spark); disallowed for writing to enforce Iceberg spec's strict typing. self._allow_timestamp_tz_mismatch = allow_timestamp_tz_mismatch + self._file_format = file_format def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: file_field = self._file_schema.find_field(field.field_id) @@ -1981,9 +1988,12 @@ def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Fi if field.doc: metadata[PYARROW_FIELD_DOC_KEY] = field.doc if self._include_field_ids: - # For projection visitor, we don't know the file format, so default to Parquet - # This is used for schema conversion during reads, not writes - metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id) + if self._file_format == FileFormat.ORC: + metadata[ORC_FIELD_ID_KEY] = str(field.field_id) + else: + metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id) + if self._file_format == FileFormat.ORC: + metadata[ORC_FIELD_REQUIRED_KEY] = str(field.required).lower() return pa.field( name=field.name, @@ -2602,21 +2612,87 @@ def data_file_statistics_from_parquet_metadata( ) +class ParquetFormatWriter(FileFormatWriter): + """Writes Arrow tables to a Parquet file.""" + + def __init__(self, output_file: OutputFile, file_schema: Schema, properties: Properties) -> None: + self._output_file = output_file + self._file_schema = file_schema + self._properties = properties + self._writer: pq.ParquetWriter | None = None + self._fos: OutputStream | None = None + self._parquet_writer_kwargs = _get_parquet_writer_kwargs(properties) + self._row_group_size = property_as_int( + properties=properties, + property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT, + default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT, + ) + + def write(self, table: pa.Table) -> None: + if self._writer is None: + self._fos = self._output_file.create(overwrite=True) + self._writer = pq.ParquetWriter( + cast(IO[Any], self._fos), + schema=table.schema, + store_decimal_as_integer=True, + **self._parquet_writer_kwargs, + ) + self._writer.write(table, row_group_size=self._row_group_size) + + def close(self) -> DataFileStatistics: + if self._result is not None: + return self._result + try: + if self._writer is None: + raise ValueError("Cannot close a writer that was never written to") + self._writer.close() + self._result = data_file_statistics_from_parquet_metadata( + parquet_metadata=self._writer.writer.metadata, + stats_columns=compute_statistics_plan(self._file_schema, self._properties), + parquet_column_mapping=parquet_path_to_id_mapping(self._file_schema), + ) + return self._result + finally: + if self._fos is not None: + self._fos.close() + + +class ParquetFormatModel(FileFormatModel): + """Format model for Apache Parquet.""" + + @property + def format(self) -> FileFormat: + return FileFormat.PARQUET + + def file_extension(self) -> str: + return "parquet" + + def create_writer( + self, + output_file: OutputFile, + file_schema: Schema, + properties: Properties, + ) -> ParquetFormatWriter: + return ParquetFormatWriter(output_file, file_schema, properties) + + +FileFormatFactory.register(ParquetFormatModel()) + + def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties - parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties) - row_group_size = property_as_int( - properties=table_metadata.properties, - property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT, - default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT, + file_format = FileFormat( + table_metadata.properties.get( + TableProperties.WRITE_FILE_FORMAT, + TableProperties.WRITE_FILE_FORMAT_DEFAULT, + ) ) + format_model = FileFormatFactory.get(file_format) location_provider = load_location_provider(table_location=table_metadata.location, table_properties=table_metadata.properties) - def write_parquet(task: WriteTask) -> DataFile: + def write_data_file(task: WriteTask) -> DataFile: table_schema = table_metadata.schema() - # if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly - # otherwise use the original schema if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema: file_schema = sanitized_schema else: @@ -2630,29 +2706,25 @@ def write_parquet(task: WriteTask) -> DataFile: batch=batch, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us, include_field_ids=True, + file_format=file_format, ) for batch in task.record_batches ] arrow_table = pa.Table.from_batches(batches) file_path = location_provider.new_data_location( - data_file_name=task.generate_data_file_filename("parquet"), + data_file_name=task.generate_data_file_filename(format_model.file_extension()), partition_key=task.partition_key, ) fo = io.new_output(file_path) - with fo.create(overwrite=True) as fos: - with pq.ParquetWriter( - fos, schema=arrow_table.schema, store_decimal_as_integer=True, **parquet_writer_kwargs - ) as writer: - writer.write(arrow_table, row_group_size=row_group_size) - statistics = data_file_statistics_from_parquet_metadata( - parquet_metadata=writer.writer.metadata, - stats_columns=compute_statistics_plan(file_schema, table_metadata.properties), - parquet_column_mapping=parquet_path_to_id_mapping(file_schema), - ) - data_file = DataFile.from_args( + writer = format_model.create_writer(fo, file_schema, table_metadata.properties) + with writer: + writer.write(arrow_table) + statistics = writer.result() + + return DataFile.from_args( content=DataFileContent.DATA, file_path=file_path, - file_format=FileFormat.PARQUET, + file_format=file_format, partition=task.partition_key.partition if task.partition_key else Record(), file_size_in_bytes=len(fo), # After this has been fixed: @@ -2666,10 +2738,8 @@ def write_parquet(task: WriteTask) -> DataFile: **statistics.to_serialized_dict(), ) - return data_file - executor = ExecutorFactory.get_or_create() - data_files = executor.map(write_parquet, tasks) + data_files = executor.map(write_data_file, tasks) return iter(data_files) diff --git a/tests/io/test_format_writers.py b/tests/io/test_format_writers.py new file mode 100644 index 0000000000..d6ce77b978 --- /dev/null +++ b/tests/io/test_format_writers.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Parametrized format writer tests, modeled after Java's BaseFormatModelTests.""" + +from pathlib import Path + +import pyarrow as pa +import pyarrow.dataset as ds +import pytest + +from pyiceberg.io.fileformat import FileFormatFactory, FileFormatModel +from pyiceberg.io.pyarrow import PyArrowFileIO +from pyiceberg.manifest import FileFormat +from pyiceberg.schema import Schema +from pyiceberg.types import LongType, NestedField + + +@pytest.fixture(params=FileFormatFactory.available_formats(), ids=lambda f: f.name.lower()) +def format_model(request: pytest.FixtureRequest) -> FileFormatModel: + return FileFormatFactory.get(request.param) + + +@pytest.fixture +def simple_table() -> pa.Table: + return pa.table( + { + "foo": ["a", "b", "c"], + "bar": pa.array([1, 2, 3], type=pa.int32()), + "baz": [True, False, True], + } + ) + + +def test_parquet_registered() -> None: + """ParquetFormatModel is registered in the factory.""" + model = FileFormatFactory.get(FileFormat.PARQUET) + assert model.format == FileFormat.PARQUET + assert model.file_extension() == "parquet" + + +def test_round_trip(format_model: FileFormatModel, table_schema_simple: Schema, simple_table: pa.Table, tmp_path: Path) -> None: + """Write a table and read it back, to verify equality and record count.""" + file_path = str(tmp_path / f"test.{format_model.file_extension()}") + writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {}) + writer.write(simple_table) + statistics = writer.close() + + result = ds.dataset(file_path).to_table() + assert result.equals(simple_table) + assert statistics.record_count == 3 + + +def test_statistics_record_count(format_model: FileFormatModel, table_schema_simple: Schema, tmp_path: Path) -> None: + """close() returns DataFileStatistics with correct record count.""" + table = pa.table( + { + "foo": ["a", "b", "c", "d", "e"], + "bar": pa.array([10, 20, 30, 40, 50], type=pa.int32()), + "baz": [True] * 5, + } + ) + file_path = str(tmp_path / f"test.{format_model.file_extension()}") + writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {}) + writer.write(table) + assert writer.close().record_count == 5 + + +def test_null_handling(format_model: FileFormatModel, table_schema_simple: Schema, tmp_path: Path) -> None: + """Nullable columns produce correct null_value_counts in statistics.""" + table = pa.table( + { + "foo": ["a", None, "c"], # field_id=1, optional + "bar": pa.array([1, 2, 3], type=pa.int32()), # field_id=2, required + "baz": [True, False, True], # field_id=3, optional + } + ) + file_path = str(tmp_path / f"test.{format_model.file_extension()}") + writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {}) + writer.write(table) + stats = writer.close() + assert stats.record_count == 3 + assert stats.null_value_counts.get(1) == 1 + + +def test_context_manager_caches_result( + format_model: FileFormatModel, table_schema_simple: Schema, simple_table: pa.Table, tmp_path: Path +) -> None: + """writer.result() returns cached statistics after context manager exit.""" + file_path = str(tmp_path / f"test.{format_model.file_extension()}") + writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {}) + with writer: + writer.write(simple_table) + assert writer.result().record_count == 3 + + +def test_close_is_idempotent( + format_model: FileFormatModel, table_schema_simple: Schema, simple_table: pa.Table, tmp_path: Path +) -> None: + """Calling close() twice returns the same cached statistics object.""" + file_path = str(tmp_path / f"test.{format_model.file_extension()}") + writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {}) + writer.write(simple_table) + stats1 = writer.close() + stats2 = writer.close() + assert stats1 is stats2 + + +def test_close_without_write_raises(format_model: FileFormatModel, table_schema_simple: Schema, tmp_path: Path) -> None: + """Closing a writer that was never written to raises ValueError.""" + file_path = str(tmp_path / f"test.{format_model.file_extension()}") + writer = format_model.create_writer(PyArrowFileIO().new_output(file_path), table_schema_simple, {}) + with pytest.raises(ValueError, match="Cannot close a writer that was never written to"): + writer.close() + + +def test_construct_field_uses_orc_field_id_key() -> None: + """ArrowProjectionVisitor uses ORC field ID and required keys when file_format is ORC.""" + from pyiceberg.io.pyarrow import ( + ORC_FIELD_ID_KEY, + ORC_FIELD_REQUIRED_KEY, + PYARROW_PARQUET_FIELD_ID_KEY, + ArrowProjectionVisitor, + ) + + schema = Schema(NestedField(field_id=1, name="x", field_type=LongType(), required=True)) + + visitor = ArrowProjectionVisitor(schema, include_field_ids=True, file_format=FileFormat.ORC) + field = visitor._construct_field(schema.find_field(1), pa.int64()) + assert field.metadata is not None + assert ORC_FIELD_ID_KEY in field.metadata + assert ORC_FIELD_REQUIRED_KEY in field.metadata + assert field.metadata[ORC_FIELD_REQUIRED_KEY] == b"true" + assert PYARROW_PARQUET_FIELD_ID_KEY not in field.metadata + + visitor_pq = ArrowProjectionVisitor(schema, include_field_ids=True, file_format=FileFormat.PARQUET) + field_pq = visitor_pq._construct_field(schema.find_field(1), pa.int64()) + assert field_pq.metadata is not None + assert PYARROW_PARQUET_FIELD_ID_KEY in field_pq.metadata + assert ORC_FIELD_ID_KEY not in field_pq.metadata + assert ORC_FIELD_REQUIRED_KEY not in field_pq.metadata diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 2170741bdd..bc879df542 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -28,6 +28,7 @@ import pyarrow import pyarrow as pa +import pyarrow.dataset as ds import pyarrow.orc as orc import pyarrow.parquet as pq import pytest @@ -86,7 +87,7 @@ from pyiceberg.manifest import DataFile, DataFileContent, FileFormat from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema, make_compatible_name, visit -from pyiceberg.table import FileScanTask, TableProperties +from pyiceberg.table import FileScanTask, TableProperties, WriteTask from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.table.name_mapping import create_mapping_from_schema from pyiceberg.transforms import HourTransform, IdentityTransform @@ -2930,6 +2931,61 @@ def test_write_file_rejects_timestamptz_to_timestamp(tmp_path: Path) -> None: list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task]))) +def _simple_write_task_and_metadata( + tmp_path: Path, table_schema_simple: Schema, properties: dict[str, str] +) -> tuple[TableMetadataV2, WriteTask]: + """Build a TableMetadataV2 and a 3-row WriteTask matching table_schema_simple.""" + arrow_data = pa.table( + { + "foo": ["a", "b", "c"], + "bar": pa.array([1, 2, 3], type=pa.int32()), + "baz": [True, False, True], + } + ) + table_metadata = TableMetadataV2( + location=f"file://{tmp_path}", + last_column_id=3, + current_schema_id=1, + format_version=2, + schemas=[table_schema_simple], + partition_specs=[PartitionSpec()], + properties=properties, + ) + task = WriteTask( + write_uuid=uuid.uuid4(), + task_id=0, + record_batches=arrow_data.to_batches(), + schema=table_schema_simple, + ) + return table_metadata, task + + +def test_write_file_dispatches_on_write_format_default(tmp_path: Path, table_schema_simple: Schema) -> None: + """write_file() reads write.format.default and raises if the format is not registered.""" + table_metadata, task = _simple_write_task_and_metadata(tmp_path, table_schema_simple, {"write.format.default": "orc"}) + + with pytest.raises(ValueError, match="No writer registered for"): + list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task]))) + + +def test_write_file_parquet_round_trip(tmp_path: Path, table_schema_simple: Schema) -> None: + """write_file() with write.format.default=parquet writes a readable Parquet file with correct DataFile metadata.""" + table_metadata, task = _simple_write_task_and_metadata(tmp_path, table_schema_simple, {"write.format.default": "parquet"}) + + data_files = list(write_file(io=PyArrowFileIO(), table_metadata=table_metadata, tasks=iter([task]))) + + assert len(data_files) == 1 + data_file = data_files[0] + assert data_file.file_format == FileFormat.PARQUET + assert data_file.record_count == 3 + assert data_file.file_path.endswith(".parquet") + assert data_file.file_size_in_bytes > 0 + + result = ds.dataset(data_file.file_path.replace("file://", "")).to_table() + assert result.num_rows == 3 + assert result.column_names == ["foo", "bar", "baz"] + + def test__to_requested_schema_timestamps( arrow_table_schema_with_all_timestamp_precisions: pa.Schema, arrow_table_with_all_timestamp_precisions: pa.Table,