Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 100 additions & 26 deletions writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ public final class VortexWriter implements Closeable {
private static final int LAYOUT_DICT = 3;
private static final int LAYOUT_ZONED = 4;

// Stat ordinals in the Rust `Stat` enum (see ZonedStatsSchema). Emitted: MAX, MIN, NULL_COUNT.
// Stat ordinals in the Rust `Stat` enum (see ZonedStatsSchema). Emitted: MAX, MIN, SUM, NULL_COUNT.
private static final int STAT_MAX = 3;
private static final int STAT_MIN = 4;
private static final int STAT_SUM = 5;
private static final int STAT_NULL_COUNT = 6;

// Columns with global cardinality below this threshold are dict-encoded across all chunks.
Expand Down Expand Up @@ -122,6 +123,7 @@ public final class VortexWriter implements Closeable {
// Stats (ScalarValue bytes) of the most recently written segment, captured for ChunkRef.
private byte[] lastStatsMin;
private byte[] lastStatsMax;
private byte[] lastStatsSum;
// Null count of the most recently written segment's input data (0 for dense arrays).
private long lastNullCount;

Expand Down Expand Up @@ -469,7 +471,7 @@ public void writeChunk(Map<String, Object> columns) throws IOException {
} else {
long rowCount = arrayLength(data);
int segIdx = writeSegment(colDtype, data);
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax, lastNullCount));
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax, lastStatsSum, lastNullCount));
}
}
firstChunkSeen = true;
Expand Down Expand Up @@ -572,6 +574,7 @@ private int writeSegment(DType dtype, Object data, EncodingEncoder encodingOverr
segs.add(new SegRef(offset, bytesWritten - offset));
lastStatsMin = result.statsMin();
lastStatsMax = result.statsMax();
lastStatsSum = columnSum(dtype, data);
lastNullCount = segNullCount;
return segIdx;
}
Expand Down Expand Up @@ -701,42 +704,49 @@ private void flushZoneMaps() throws IOException {
if (chunks.isEmpty()) {
continue;
}
DType minMaxDtype = zoneMinMaxDtype(columnDtype(colName));
DType colDtype = columnDtype(colName);
DType minMaxDtype = zoneMinMaxDtype(colDtype);
boolean hasMinMax = minMaxDtype != null && chunks.stream().allMatch(ChunkRef::hasStats);
DType sumDtype = zoneSumDtype(colDtype);
long[] nullCounts = new long[chunks.size()];
for (int i = 0; i < chunks.size(); i++) {
nullCounts[i] = chunks.get(i).nullCount();
}
emitZoneMap(colName, hasMinMax ? minMaxDtype : null,
chunks.stream().map(ChunkRef::statsMin).toList(),
chunks.stream().map(ChunkRef::statsMax).toList(),
sumDtype, chunks.stream().map(ChunkRef::statsSum).toList(),
nullCounts);
}
// Dict-encoded columns (one zone per code chunk). MIN/MAX come from each chunk's logical
// Dict-encoded columns (one zone per code chunk). MIN/MAX/SUM come from each chunk's logical
// values (computed at dict-build time); NULL_COUNT always. Matches Rust, whose zone-map
// stats are computed on the logical column dtype, independent of the dict encoding.
for (Map.Entry<String, DictColRef> e : dictColRefs.entrySet()) {
DictColRef ref = e.getValue();
DType minMaxDtype = zoneMinMaxDtype(columnDtype(e.getKey()));
DType colDtype = columnDtype(e.getKey());
DType minMaxDtype = zoneMinMaxDtype(colDtype);
boolean hasMinMax = minMaxDtype != null
&& ref.chunkStatsMin().stream().allMatch(java.util.Objects::nonNull)
&& ref.chunkStatsMax().stream().allMatch(java.util.Objects::nonNull);
long[] nullCounts = ref.chunkNullCounts().stream().mapToLong(Long::longValue).toArray();
emitZoneMap(e.getKey(), hasMinMax ? minMaxDtype : null,
ref.chunkStatsMin(), ref.chunkStatsMax(), nullCounts);
ref.chunkStatsMin(), ref.chunkStatsMax(),
zoneSumDtype(colDtype), ref.chunkStatsSum(), nullCounts);
}
}

private DType columnDtype(String colName) {
return schema.fieldTypes().get(schema.fieldNames().indexOf(colName));
}

/// Writes one `vortex.stats` zone-map for `colName`: one zone per chunk, with NULL_COUNT always
/// and MAX/MIN (plus always-false `_is_truncated` flags) when `minMaxDtype` is non-null.
/// `minBytes`/`maxBytes` hold each zone's serialised min/max scalar — read only when
/// `minMaxDtype` is set. Field/bit order follows ZonedStatsSchema: MAX(3), MIN(4), NULL_COUNT(6).
private void emitZoneMap(String colName, DType minMaxDtype,
List<byte[]> minBytes, List<byte[]> maxBytes, long[] nullCounts) throws IOException {
/// Writes one `vortex.stats` zone-map for `colName`: one zone per chunk, with NULL_COUNT always,
/// MAX/MIN (plus always-false `_is_truncated` flags) when `minMaxDtype` is non-null, and SUM when
/// `sumDtype` is non-null. `minBytes`/`maxBytes`/`sumBytes` hold each zone's serialised scalar —
/// read only when the matching dtype is set; a `null` `sumBytes` entry marks an overflowed zone
/// (recorded as a null sum). Field/bit order follows ZonedStatsSchema: MAX(3), MIN(4), SUM(5),
/// NULL_COUNT(6).
private void emitZoneMap(String colName, DType minMaxDtype, List<byte[]> minBytes, List<byte[]> maxBytes,
DType sumDtype, List<byte[]> sumBytes, long[] nullCounts) throws IOException {
int nZones = nullCounts.length;
boolean[] allValid = new boolean[nZones];
java.util.Arrays.fill(allValid, true);
Expand All @@ -759,13 +769,21 @@ private void emitZoneMap(String colName, DType minMaxDtype,
types.add(new DType.Bool(false));
fields.add(notTruncated.clone());
}
if (sumDtype != null) {
boolean[] sumValid = new boolean[nZones];
Object sumArr = sumColumn(sumDtype, sumBytes, sumValid);
names.add("sum");
types.add(sumDtype);
fields.add(new NullableData(sumArr, sumValid));
}
names.add("null_count");
types.add(new DType.Primitive(PType.U64, true));
fields.add(new NullableData(nullCounts, allValid.clone()));

DType.Struct statsDtype = new DType.Struct(List.copyOf(names), List.copyOf(types), false);
int zonesSegIdx = writeSegment(statsDtype, new StructData(fields), new StructEncodingEncoder());
zoneMaps.put(colName, new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize(), minMaxDtype != null));
zoneMaps.put(colName,
new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize(), minMaxDtype != null, sumDtype != null));
}

/// Wraps a column's data layout in a `vortex.stats` (zoned) layout when a zone-map was
Expand All @@ -778,20 +796,23 @@ private int wrapZoneMap(FlatBufferBuilder fbb, String colName, int dataLayout, l
int zonesSegV = Layout.createSegmentsVector(fbb, new long[]{zm.zonesSegIdx()});
int zonesFlat = Layout.createLayout(fbb, LAYOUT_FLAT, zm.nZones(), 0, 0, zonesSegV);
int childV = Layout.createChildrenVector(fbb, new int[]{dataLayout, zonesFlat});
int metaV = Layout.createMetadataVector(fbb, zonedMetadataBytes(zm.zoneLen(), zm.hasMinMax()));
int metaV = Layout.createMetadataVector(fbb, zonedMetadataBytes(zm.zoneLen(), zm.hasMinMax(), zm.hasSum()));
return Layout.createLayout(fbb, LAYOUT_ZONED, colRows, metaV, childV, 0);
}

/// `vortex.stats` metadata: `u32` zone length (LE) + a 1-byte stat bitset (LSB-first) with the
/// NULL_COUNT bit always set and the MAX/MIN bits set when present, matching
/// NULL_COUNT bit always set and the MAX/MIN and SUM bits set when present, matching
/// [io.github.dfa1.vortex.inspect] `ZonedStatsSchema`.
private static byte[] zonedMetadataBytes(long zoneLen, boolean hasMinMax) {
private static byte[] zonedMetadataBytes(long zoneLen, boolean hasMinMax, boolean hasSum) {
byte[] meta = new byte[5];
ByteBuffer.wrap(meta).order(ByteOrder.LITTLE_ENDIAN).putInt((int) zoneLen);
int bits = 1 << STAT_NULL_COUNT;
if (hasMinMax) {
bits |= (1 << STAT_MAX) | (1 << STAT_MIN);
}
if (hasSum) {
bits |= (1 << STAT_SUM);
}
meta[4] = (byte) bits;
return meta;
}
Expand Down Expand Up @@ -820,6 +841,32 @@ private static DType zoneMinMaxDtype(DType dtype) {
};
}

/// The (nullable) dtype a zone-map stores SUM in for `dtype`, or `null` when the column has no
/// recordable sum. Only plain numeric primitives are summed — signed → `i64`, unsigned → `u64`,
/// float → `f64` — matching Rust, which emits SUM for primitives and decimals but not for
/// Utf8/extension/date columns even when their storage is numeric.
private static DType zoneSumDtype(DType dtype) {
if (!(dtype instanceof DType.Primitive p)) {
return null;
}
return switch (p.ptype()) {
case U8, U16, U32, U64 -> new DType.Primitive(PType.U64, true);
case I8, I16, I32, I64 -> new DType.Primitive(PType.I64, true);
case F16, F32, F64 -> new DType.Primitive(PType.F64, true);
};
}

/// The serialised per-chunk SUM scalar for `data` of logical type `dtype`, or `null` when the
/// column is not summable (non-primitive) or the sum overflowed. Validity placeholders are zero
/// and therefore sum-neutral, so a nullable carrier sums correctly via its values.
private static byte[] columnSum(DType dtype, Object data) {
if (!(dtype instanceof DType.Primitive p)) {
return null;
}
Object values = data instanceof NullableData nd ? nd.values() : data;
return PrimitiveEncodingEncoder.sumStat(p.ptype(), values);
}

/// Builds the per-zone min (or max) values array for the resolved min/max `dtype`, decoding each
/// zone's serialised [ScalarValue] stat into the array shape its encoder expects.
private static Object zoneStatValues(DType minMaxDtype, List<byte[]> statBytes) throws IOException {
Expand All @@ -830,6 +877,28 @@ private static Object zoneStatValues(DType minMaxDtype, List<byte[]> statBytes)
};
}

/// Builds the per-zone SUM array for `sumDtype` (i64/u64 → `long[]`, f64 → `double[]`), decoding
/// each zone's serialised scalar. Zones whose sum overflowed carry a `null` entry in `sumBytes`;
/// `valid[i]` is set accordingly so the stat field reports them as null.
private static Object sumColumn(DType sumDtype, List<byte[]> sumBytes, boolean[] valid) throws IOException {
PType ptype = ((DType.Primitive) sumDtype).ptype();
int n = sumBytes.size();
if (ptype == PType.F64) {
double[] a = new double[n];
for (int i = 0; i < n; i++) {
valid[i] = sumBytes.get(i) != null;
a[i] = valid[i] ? scalarDouble(sumBytes.get(i)) : 0.0;
}
return a;
}
long[] a = new long[n];
for (int i = 0; i < n; i++) {
valid[i] = sumBytes.get(i) != null;
a[i] = valid[i] ? scalarLong(sumBytes.get(i)) : 0L;
}
return a;
}

/// Builds the per-zone string array by decoding each zone's serialised string [ScalarValue]
/// stat. Used for Utf8 columns whose `vortex.varbin` encoder records full string min/max scalars.
private static String[] statStringColumn(List<byte[]> statBytes) throws IOException {
Expand Down Expand Up @@ -1078,7 +1147,7 @@ private void writeGlobalDictColumn(String colName, DType.Primitive dtype, List<O
for (Object chunk : chunks) {
long rowCount = arrayLength(chunk);
int segIdx = writeSegment(dtype, chunk);
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax, lastNullCount));
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax, lastStatsSum, lastNullCount));
}
return;
}
Expand All @@ -1099,23 +1168,25 @@ private void writeGlobalDictColumn(String colName, DType.Primitive dtype, List<O
List<Long> chunkNullCounts = new ArrayList<>();
List<byte[]> chunkStatsMin = new ArrayList<>();
List<byte[]> chunkStatsMax = new ArrayList<>();
List<byte[]> chunkStatsSum = new ArrayList<>();
for (Object chunk : chunks) {
int len = primitiveArrayLen(chunk, ptype);
Object codesArr = buildCodesArray(chunk, ptype, valueMap, codePType, len);
codesSegIdxes.add(writeSegment(codesDtype, codesArr));
chunkRowCounts.add((long) len);
chunkNullCounts.add(chunk instanceof NullableData nd ? countNulls(nd.validity()) : 0L);
// Per-zone min/max over the chunk's logical values (matches the flat primitive path:
// computed on nd.values(), placeholders included). Lets the dict zone-map prune like a
// plain primitive column.
// Per-zone min/max + sum over the chunk's logical values (matches the flat primitive
// path: computed on nd.values(), placeholders included). Lets the dict zone-map prune
// and aggregate like a plain primitive column.
Object values = chunk instanceof NullableData nd ? nd.values() : chunk;
byte[][] mm = PrimitiveEncodingEncoder.minMaxStats(ptype, values);
chunkStatsMin.add(mm != null ? mm[0] : null);
chunkStatsMax.add(mm != null ? mm[1] : null);
chunkStatsSum.add(PrimitiveEncodingEncoder.sumStat(ptype, values));
}

dictColRefs.put(colName, new DictColRef(valuesSegIdx, dictSize, codesSegIdxes,
chunkRowCounts, chunkNullCounts, chunkStatsMin, chunkStatsMax));
chunkRowCounts, chunkNullCounts, chunkStatsMin, chunkStatsMax, chunkStatsSum));
}

private void writeGlobalDictUtf8Column(String colName, DType.Utf8 dtype, List<Object> chunks)
Expand All @@ -1134,7 +1205,7 @@ private void writeGlobalDictUtf8Column(String colName, DType.Utf8 dtype, List<Ob
for (Object chunk : chunks) {
long rowCount = arrayLength(chunk);
int segIdx = writeSegment(dtype, chunk);
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax, lastNullCount));
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax, lastStatsSum, lastNullCount));
}
return;
}
Expand Down Expand Up @@ -1166,8 +1237,10 @@ private void writeGlobalDictUtf8Column(String colName, DType.Utf8 dtype, List<Ob
chunkStatsMin.add(mm != null ? mm[0] : null);
chunkStatsMax.add(mm != null ? mm[1] : null);
}
// Utf8 columns are not summed (zoneSumDtype is null), so the sum bytes are never read.
List<byte[]> noSum = java.util.Collections.nCopies(codesSegIdxes.size(), null);
dictColRefs.put(colName, new DictColRef(valuesSegIdx, dictSize, codesSegIdxes,
chunkRowCounts, chunkNullCounts, chunkStatsMin, chunkStatsMax));
chunkRowCounts, chunkNullCounts, chunkStatsMin, chunkStatsMax, noSum));
}

private static Object buildUtf8CodesArray(String[] strs, Map<String, Integer> valueMap, PType codePType) {
Expand Down Expand Up @@ -1359,7 +1432,8 @@ private static Object buildCodesArray(Object data, PType ptype, Map<Object, Inte
private record SegRef(long offset, long len) {
}

private record ChunkRef(int segIdx, long rowCount, byte[] statsMin, byte[] statsMax, long nullCount) {
private record ChunkRef(int segIdx, long rowCount, byte[] statsMin, byte[] statsMax,
byte[] statsSum, long nullCount) {
boolean hasStats() {
return statsMin != null && statsMax != null;
}
Expand All @@ -1368,12 +1442,12 @@ boolean hasStats() {
/// Per-column zone-map: the flat segment holding the per-zone stats table, the zone
/// count (one zone per chunk), the logical rows per zone, and whether the table carries
/// MIN/MAX (else NULL_COUNT only).
private record ZoneMapRef(int zonesSegIdx, long nZones, long zoneLen, boolean hasMinMax) {
private record ZoneMapRef(int zonesSegIdx, long nZones, long zoneLen, boolean hasMinMax, boolean hasSum) {
}

private record DictColRef(int valuesSegIdx, long valuesLen, List<Integer> codesSegIdxes,
List<Long> chunkRowCounts, List<Long> chunkNullCounts,
List<byte[]> chunkStatsMin, List<byte[]> chunkStatsMax) {
List<byte[]> chunkStatsMin, List<byte[]> chunkStatsMax, List<byte[]> chunkStatsSum) {
long totalRows() {
return chunkRowCounts.stream().mapToLong(Long::longValue).sum();
}
Expand Down
Loading