Skip to content

Commit 4b91105

Browse files
sungwyFokko
andauthored
Support Time Travel in InspectTable.entries (#599)
* time travel in entries table * undo * Update pyiceberg/table/__init__.py Co-authored-by: Fokko Driesprong <[email protected]> * adopt review feedback * docs --------- Co-authored-by: Fokko Driesprong <[email protected]>
1 parent 35d4648 commit 4b91105

File tree

3 files changed

+162
-134
lines changed

3 files changed

+162
-134
lines changed

mkdocs/docs/api.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,18 @@ table.append(df)
342342

343343
To explore the table metadata, tables can be inspected.
344344

345+
<!-- prettier-ignore-start -->
346+
347+
!!! tip "Time Travel"
348+
To inspect a tables's metadata with the time travel feature, call the inspect table method with the `snapshot_id` argument.
349+
Time travel is supported on all metadata tables except `snapshots` and `refs`.
350+
351+
```python
352+
table.inspect.entries(snapshot_id=805611270568163028)
353+
```
354+
355+
<!-- prettier-ignore-end -->
356+
345357
### Snapshots
346358

347359
Inspect the snapshots of the table:

pyiceberg/table/__init__.py

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3253,6 +3253,18 @@ def __init__(self, tbl: Table) -> None:
32533253
except ModuleNotFoundError as e:
32543254
raise ModuleNotFoundError("For metadata operations PyArrow needs to be installed") from e
32553255

3256+
def _get_snapshot(self, snapshot_id: Optional[int] = None) -> Snapshot:
3257+
if snapshot_id is not None:
3258+
if snapshot := self.tbl.metadata.snapshot_by_id(snapshot_id):
3259+
return snapshot
3260+
else:
3261+
raise ValueError(f"Cannot find snapshot with ID {snapshot_id}")
3262+
3263+
if snapshot := self.tbl.metadata.current_snapshot():
3264+
return snapshot
3265+
else:
3266+
raise ValueError("Cannot get a snapshot as the table does not have any.")
3267+
32563268
def snapshots(self) -> "pa.Table":
32573269
import pyarrow as pa
32583270

@@ -3287,7 +3299,7 @@ def snapshots(self) -> "pa.Table":
32873299
schema=snapshots_schema,
32883300
)
32893301

3290-
def entries(self) -> "pa.Table":
3302+
def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table":
32913303
import pyarrow as pa
32923304

32933305
from pyiceberg.io.pyarrow import schema_to_pyarrow
@@ -3346,64 +3358,64 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
33463358
])
33473359

33483360
entries = []
3349-
if snapshot := self.tbl.metadata.current_snapshot():
3350-
for manifest in snapshot.manifests(self.tbl.io):
3351-
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
3352-
column_sizes = entry.data_file.column_sizes or {}
3353-
value_counts = entry.data_file.value_counts or {}
3354-
null_value_counts = entry.data_file.null_value_counts or {}
3355-
nan_value_counts = entry.data_file.nan_value_counts or {}
3356-
lower_bounds = entry.data_file.lower_bounds or {}
3357-
upper_bounds = entry.data_file.upper_bounds or {}
3358-
readable_metrics = {
3359-
schema.find_column_name(field.field_id): {
3360-
"column_size": column_sizes.get(field.field_id),
3361-
"value_count": value_counts.get(field.field_id),
3362-
"null_value_count": null_value_counts.get(field.field_id),
3363-
"nan_value_count": nan_value_counts.get(field.field_id),
3364-
# Makes them readable
3365-
"lower_bound": from_bytes(field.field_type, lower_bound)
3366-
if (lower_bound := lower_bounds.get(field.field_id))
3367-
else None,
3368-
"upper_bound": from_bytes(field.field_type, upper_bound)
3369-
if (upper_bound := upper_bounds.get(field.field_id))
3370-
else None,
3371-
}
3372-
for field in self.tbl.metadata.schema().fields
3373-
}
3374-
3375-
partition = entry.data_file.partition
3376-
partition_record_dict = {
3377-
field.name: partition[pos]
3378-
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
3361+
snapshot = self._get_snapshot(snapshot_id)
3362+
for manifest in snapshot.manifests(self.tbl.io):
3363+
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
3364+
column_sizes = entry.data_file.column_sizes or {}
3365+
value_counts = entry.data_file.value_counts or {}
3366+
null_value_counts = entry.data_file.null_value_counts or {}
3367+
nan_value_counts = entry.data_file.nan_value_counts or {}
3368+
lower_bounds = entry.data_file.lower_bounds or {}
3369+
upper_bounds = entry.data_file.upper_bounds or {}
3370+
readable_metrics = {
3371+
schema.find_column_name(field.field_id): {
3372+
"column_size": column_sizes.get(field.field_id),
3373+
"value_count": value_counts.get(field.field_id),
3374+
"null_value_count": null_value_counts.get(field.field_id),
3375+
"nan_value_count": nan_value_counts.get(field.field_id),
3376+
# Makes them readable
3377+
"lower_bound": from_bytes(field.field_type, lower_bound)
3378+
if (lower_bound := lower_bounds.get(field.field_id))
3379+
else None,
3380+
"upper_bound": from_bytes(field.field_type, upper_bound)
3381+
if (upper_bound := upper_bounds.get(field.field_id))
3382+
else None,
33793383
}
3380-
3381-
entries.append({
3382-
'status': entry.status.value,
3383-
'snapshot_id': entry.snapshot_id,
3384-
'sequence_number': entry.data_sequence_number,
3385-
'file_sequence_number': entry.file_sequence_number,
3386-
'data_file': {
3387-
"content": entry.data_file.content,
3388-
"file_path": entry.data_file.file_path,
3389-
"file_format": entry.data_file.file_format,
3390-
"partition": partition_record_dict,
3391-
"record_count": entry.data_file.record_count,
3392-
"file_size_in_bytes": entry.data_file.file_size_in_bytes,
3393-
"column_sizes": dict(entry.data_file.column_sizes),
3394-
"value_counts": dict(entry.data_file.value_counts),
3395-
"null_value_counts": dict(entry.data_file.null_value_counts),
3396-
"nan_value_counts": entry.data_file.nan_value_counts,
3397-
"lower_bounds": entry.data_file.lower_bounds,
3398-
"upper_bounds": entry.data_file.upper_bounds,
3399-
"key_metadata": entry.data_file.key_metadata,
3400-
"split_offsets": entry.data_file.split_offsets,
3401-
"equality_ids": entry.data_file.equality_ids,
3402-
"sort_order_id": entry.data_file.sort_order_id,
3403-
"spec_id": entry.data_file.spec_id,
3404-
},
3405-
'readable_metrics': readable_metrics,
3406-
})
3384+
for field in self.tbl.metadata.schema().fields
3385+
}
3386+
3387+
partition = entry.data_file.partition
3388+
partition_record_dict = {
3389+
field.name: partition[pos]
3390+
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
3391+
}
3392+
3393+
entries.append({
3394+
'status': entry.status.value,
3395+
'snapshot_id': entry.snapshot_id,
3396+
'sequence_number': entry.data_sequence_number,
3397+
'file_sequence_number': entry.file_sequence_number,
3398+
'data_file': {
3399+
"content": entry.data_file.content,
3400+
"file_path": entry.data_file.file_path,
3401+
"file_format": entry.data_file.file_format,
3402+
"partition": partition_record_dict,
3403+
"record_count": entry.data_file.record_count,
3404+
"file_size_in_bytes": entry.data_file.file_size_in_bytes,
3405+
"column_sizes": dict(entry.data_file.column_sizes),
3406+
"value_counts": dict(entry.data_file.value_counts),
3407+
"null_value_counts": dict(entry.data_file.null_value_counts),
3408+
"nan_value_counts": entry.data_file.nan_value_counts,
3409+
"lower_bounds": entry.data_file.lower_bounds,
3410+
"upper_bounds": entry.data_file.upper_bounds,
3411+
"key_metadata": entry.data_file.key_metadata,
3412+
"split_offsets": entry.data_file.split_offsets,
3413+
"equality_ids": entry.data_file.equality_ids,
3414+
"sort_order_id": entry.data_file.sort_order_id,
3415+
"spec_id": entry.data_file.spec_id,
3416+
},
3417+
'readable_metrics': readable_metrics,
3418+
})
34073419

34083420
return pa.Table.from_pylist(
34093421
entries,

tests/integration/test_inspect_table.py

Lines changed: 80 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pyarrow as pa
2323
import pytest
2424
import pytz
25-
from pyspark.sql import SparkSession
25+
from pyspark.sql import DataFrame, SparkSession
2626

2727
from pyiceberg.catalog import Catalog
2828
from pyiceberg.exceptions import NoSuchTableError
@@ -148,81 +148,85 @@ def test_inspect_entries(
148148
# Write some data
149149
tbl.append(arrow_table_with_null)
150150

151-
df = tbl.inspect.entries()
152-
153-
assert df.column_names == [
154-
'status',
155-
'snapshot_id',
156-
'sequence_number',
157-
'file_sequence_number',
158-
'data_file',
159-
'readable_metrics',
160-
]
161-
162-
# Make sure that they are filled properly
163-
for int_column in ['status', 'snapshot_id', 'sequence_number', 'file_sequence_number']:
164-
for value in df[int_column]:
165-
assert isinstance(value.as_py(), int)
166-
167-
for snapshot_id in df['snapshot_id']:
168-
assert isinstance(snapshot_id.as_py(), int)
169-
170-
lhs = df.to_pandas()
171-
rhs = spark.table(f"{identifier}.entries").toPandas()
172-
for column in df.column_names:
173-
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
174-
if column == 'data_file':
175-
right = right.asDict(recursive=True)
176-
for df_column in left.keys():
177-
if df_column == 'partition':
178-
# Spark leaves out the partition if the table is unpartitioned
179-
continue
180-
181-
df_lhs = left[df_column]
182-
df_rhs = right[df_column]
183-
if isinstance(df_rhs, dict):
184-
# Arrow turns dicts into lists of tuple
185-
df_lhs = dict(df_lhs)
186-
187-
assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
188-
elif column == 'readable_metrics':
189-
right = right.asDict(recursive=True)
190-
191-
assert list(left.keys()) == [
192-
'bool',
193-
'string',
194-
'string_long',
195-
'int',
196-
'long',
197-
'float',
198-
'double',
199-
'timestamp',
200-
'timestamptz',
201-
'date',
202-
'binary',
203-
'fixed',
204-
]
205-
206-
assert left.keys() == right.keys()
207-
208-
for rm_column in left.keys():
209-
rm_lhs = left[rm_column]
210-
rm_rhs = right[rm_column]
211-
212-
assert rm_lhs['column_size'] == rm_rhs['column_size']
213-
assert rm_lhs['value_count'] == rm_rhs['value_count']
214-
assert rm_lhs['null_value_count'] == rm_rhs['null_value_count']
215-
assert rm_lhs['nan_value_count'] == rm_rhs['nan_value_count']
216-
217-
if rm_column == 'timestamptz':
218-
# PySpark does not correctly set the timstamptz
219-
rm_rhs['lower_bound'] = rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
220-
rm_rhs['upper_bound'] = rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)
221-
222-
assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
223-
assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
224-
else:
225-
assert left == right, f"Difference in column {column}: {left} != {right}"
151+
def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
152+
assert df.column_names == [
153+
'status',
154+
'snapshot_id',
155+
'sequence_number',
156+
'file_sequence_number',
157+
'data_file',
158+
'readable_metrics',
159+
]
160+
161+
# Make sure that they are filled properly
162+
for int_column in ['status', 'snapshot_id', 'sequence_number', 'file_sequence_number']:
163+
for value in df[int_column]:
164+
assert isinstance(value.as_py(), int)
165+
166+
for snapshot_id in df['snapshot_id']:
167+
assert isinstance(snapshot_id.as_py(), int)
168+
169+
lhs = df.to_pandas()
170+
rhs = spark_df.toPandas()
171+
for column in df.column_names:
172+
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
173+
if column == 'data_file':
174+
right = right.asDict(recursive=True)
175+
for df_column in left.keys():
176+
if df_column == 'partition':
177+
# Spark leaves out the partition if the table is unpartitioned
178+
continue
179+
180+
df_lhs = left[df_column]
181+
df_rhs = right[df_column]
182+
if isinstance(df_rhs, dict):
183+
# Arrow turns dicts into lists of tuple
184+
df_lhs = dict(df_lhs)
185+
186+
assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
187+
elif column == 'readable_metrics':
188+
right = right.asDict(recursive=True)
189+
190+
assert list(left.keys()) == [
191+
'bool',
192+
'string',
193+
'string_long',
194+
'int',
195+
'long',
196+
'float',
197+
'double',
198+
'timestamp',
199+
'timestamptz',
200+
'date',
201+
'binary',
202+
'fixed',
203+
]
204+
205+
assert left.keys() == right.keys()
206+
207+
for rm_column in left.keys():
208+
rm_lhs = left[rm_column]
209+
rm_rhs = right[rm_column]
210+
211+
assert rm_lhs['column_size'] == rm_rhs['column_size']
212+
assert rm_lhs['value_count'] == rm_rhs['value_count']
213+
assert rm_lhs['null_value_count'] == rm_rhs['null_value_count']
214+
assert rm_lhs['nan_value_count'] == rm_rhs['nan_value_count']
215+
216+
if rm_column == 'timestamptz':
217+
# PySpark does not correctly set the timstamptz
218+
rm_rhs['lower_bound'] = rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
219+
rm_rhs['upper_bound'] = rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)
220+
221+
assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
222+
assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
223+
else:
224+
assert left == right, f"Difference in column {column}: {left} != {right}"
225+
226+
for snapshot in tbl.metadata.snapshots:
227+
df = tbl.inspect.entries(snapshot_id=snapshot.snapshot_id)
228+
spark_df = spark.sql(f"SELECT * FROM {identifier}.entries VERSION AS OF {snapshot.snapshot_id}")
229+
check_pyiceberg_df_equals_spark_df(df, spark_df)
226230

227231

228232
@pytest.mark.integration

0 commit comments

Comments
 (0)