Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 126 additions & 73 deletions writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -701,69 +701,71 @@ private void flushZoneMaps() throws IOException {
if (chunks.isEmpty()) {
continue;
}
int nZones = chunks.size();
boolean[] allValid = new boolean[nZones];
java.util.Arrays.fill(allValid, true);

// NULL_COUNT is computable for every column type; MIN/MAX only for fixed-width
// primitives whose chunks all carry stats. Field/bit order follows
// ZonedStatsSchema: MAX(3), MIN(4), NULL_COUNT(6); each stat field is nullable.
DType colDtype = schema.fieldTypes().get(schema.fieldNames().indexOf(colName));
boolean hasMinMax = colDtype instanceof DType.Primitive
&& chunks.stream().allMatch(ChunkRef::hasStats);

List<String> names = new java.util.ArrayList<>();
List<DType> types = new java.util.ArrayList<>();
List<Object> fields = new java.util.ArrayList<>();
if (hasMinMax) {
PType ptype = ((DType.Primitive) colDtype).ptype();
DType nullablePrim = new DType.Primitive(ptype, true);
boolean[] notTruncated = new boolean[nZones];
names.add("max");
types.add(nullablePrim);
fields.add(new NullableData(statColumn(ptype, chunks, true), allValid.clone()));
names.add("max_is_truncated");
types.add(new DType.Bool(false));
fields.add(notTruncated);
names.add("min");
types.add(nullablePrim);
fields.add(new NullableData(statColumn(ptype, chunks, false), allValid.clone()));
names.add("min_is_truncated");
types.add(new DType.Bool(false));
fields.add(notTruncated.clone());
}
long[] nullCounts = new long[nZones];
for (int i = 0; i < nZones; i++) {
DType minMaxDtype = zoneMinMaxDtype(columnDtype(colName));
boolean hasMinMax = minMaxDtype != null && chunks.stream().allMatch(ChunkRef::hasStats);
long[] nullCounts = new long[chunks.size()];
for (int i = 0; i < chunks.size(); i++) {
nullCounts[i] = chunks.get(i).nullCount();
}
names.add("null_count");
types.add(new DType.Primitive(PType.U64, true));
fields.add(new NullableData(nullCounts, allValid.clone()));

DType.Struct statsDtype = new DType.Struct(List.copyOf(names), List.copyOf(types), false);
int zonesSegIdx = writeSegment(statsDtype, new StructData(fields), new StructEncodingEncoder());
zoneMaps.put(colName, new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize(), hasMinMax));
emitZoneMap(colName, hasMinMax ? minMaxDtype : null,
chunks.stream().map(ChunkRef::statsMin).toList(),
chunks.stream().map(ChunkRef::statsMax).toList(),
nullCounts);
}
// Dict-encoded columns live in a separate path (one zone per code chunk); they carry
// NULL_COUNT only for now (no dict-level MIN/MAX yet). Matches Rust, which zone-maps dict
// columns (vortex.stats wrapping vortex.dict).
// Dict-encoded columns (one zone per code chunk). MIN/MAX come from each chunk's logical
// values (computed at dict-build time); NULL_COUNT always. Matches Rust, whose zone-map
// stats are computed on the logical column dtype, independent of the dict encoding.
for (Map.Entry<String, DictColRef> e : dictColRefs.entrySet()) {
// A dict column always has at least one code chunk, so null counts are non-empty.
long[] nullCounts = e.getValue().chunkNullCounts().stream().mapToLong(Long::longValue).toArray();
writeNullCountZoneMap(e.getKey(), nullCounts);
DictColRef ref = e.getValue();
DType minMaxDtype = zoneMinMaxDtype(columnDtype(e.getKey()));
boolean hasMinMax = minMaxDtype != null
&& ref.chunkStatsMin().stream().allMatch(java.util.Objects::nonNull)
&& ref.chunkStatsMax().stream().allMatch(java.util.Objects::nonNull);
long[] nullCounts = ref.chunkNullCounts().stream().mapToLong(Long::longValue).toArray();
emitZoneMap(e.getKey(), hasMinMax ? minMaxDtype : null,
ref.chunkStatsMin(), ref.chunkStatsMax(), nullCounts);
}
}

/// Emits a NULL_COUNT-only `vortex.stats` zone-map (one zone per chunk) for `colName`.
private void writeNullCountZoneMap(String colName, long[] nullCounts) throws IOException {
private DType columnDtype(String colName) {
return schema.fieldTypes().get(schema.fieldNames().indexOf(colName));
}

/// Writes one `vortex.stats` zone-map for `colName`: one zone per chunk, with NULL_COUNT always
/// and MAX/MIN (plus always-false `_is_truncated` flags) when `minMaxDtype` is non-null.
/// `minBytes`/`maxBytes` hold each zone's serialised min/max scalar — read only when
/// `minMaxDtype` is set. Field/bit order follows ZonedStatsSchema: MAX(3), MIN(4), NULL_COUNT(6).
private void emitZoneMap(String colName, DType minMaxDtype,
List<byte[]> minBytes, List<byte[]> maxBytes, long[] nullCounts) throws IOException {
int nZones = nullCounts.length;
boolean[] allValid = new boolean[nZones];
java.util.Arrays.fill(allValid, true);
DType.Struct statsDtype = new DType.Struct(
List.of("null_count"), List.of(new DType.Primitive(PType.U64, true)), false);
StructData sd = new StructData(List.of(new NullableData(nullCounts, allValid)));
int zonesSegIdx = writeSegment(statsDtype, sd, new StructEncodingEncoder());
zoneMaps.put(colName, new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize(), false));

List<String> names = new java.util.ArrayList<>();
List<DType> types = new java.util.ArrayList<>();
List<Object> fields = new java.util.ArrayList<>();
if (minMaxDtype != null) {
boolean[] notTruncated = new boolean[nZones];
names.add("max");
types.add(minMaxDtype);
fields.add(new NullableData(zoneStatValues(minMaxDtype, maxBytes), allValid.clone()));
names.add("max_is_truncated");
types.add(new DType.Bool(false));
fields.add(notTruncated);
names.add("min");
types.add(minMaxDtype);
fields.add(new NullableData(zoneStatValues(minMaxDtype, minBytes), allValid.clone()));
names.add("min_is_truncated");
types.add(new DType.Bool(false));
fields.add(notTruncated.clone());
}
names.add("null_count");
types.add(new DType.Primitive(PType.U64, true));
fields.add(new NullableData(nullCounts, allValid.clone()));

DType.Struct statsDtype = new DType.Struct(List.copyOf(names), List.copyOf(types), false);
int zonesSegIdx = writeSegment(statsDtype, new StructData(fields), new StructEncodingEncoder());
zoneMaps.put(colName, new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize(), minMaxDtype != null));
}

/// Wraps a column's data layout in a `vortex.stats` (zoned) layout when a zone-map was
Expand Down Expand Up @@ -804,73 +806,107 @@ private static long countNulls(boolean[] validity) {
return nulls;
}

/// Builds the per-zone min (or max) values array in the storage shape the primitive encoder
/// expects, decoding each chunk's serialised [ScalarValue] stat.
private static Object statColumn(PType ptype, List<ChunkRef> chunks, boolean max) throws IOException {
int n = chunks.size();
/// The (nullable) dtype a zone-map stores per-zone min/max in for `dtype`, or `null` when the
/// column has no recordable min/max. Primitives store the primitive; extension columns unwrap
/// to their storage primitive (`ExtEncoding` propagates the storage min/max scalars unchanged);
/// Utf8 stores the full string value. Matches [ZonedStatsSchema#statDtype]. Binary is excluded:
/// `vortex.varbin` records its min/max as string scalars, not `bytes`.
private static DType zoneMinMaxDtype(DType dtype) {
return switch (dtype) {
case DType.Primitive p -> p.withNullable(true);
case DType.Extension ext when ext.storageDType() instanceof DType.Primitive p -> p.withNullable(true);
case DType.Utf8 u -> u.withNullable(true);
default -> null;
};
}

/// Builds the per-zone min (or max) values array for the resolved min/max `dtype`, decoding each
/// zone's serialised [ScalarValue] stat into the array shape its encoder expects.
private static Object zoneStatValues(DType minMaxDtype, List<byte[]> statBytes) throws IOException {
return switch (minMaxDtype) {
case DType.Primitive p -> statColumn(p.ptype(), statBytes);
case DType.Utf8 ignored -> statStringColumn(statBytes);
default -> throw new IllegalStateException("no zone stat values for " + minMaxDtype);
};
}

/// Builds the per-zone string array by decoding each zone's serialised string [ScalarValue]
/// stat. Used for Utf8 columns whose `vortex.varbin` encoder records full string min/max scalars.
private static String[] statStringColumn(List<byte[]> statBytes) throws IOException {
String[] out = new String[statBytes.size()];
for (int i = 0; i < out.length; i++) {
out[i] = decodeScalar(statBytes.get(i)).string_value();
}
return out;
}

/// Builds the per-zone values array in the storage shape the primitive encoder expects, decoding
/// each zone's serialised [ScalarValue] stat.
private static Object statColumn(PType ptype, List<byte[]> statBytes) throws IOException {
int n = statBytes.size();
return switch (ptype) {
case I8, U8 -> {
byte[] a = new byte[n];
for (int i = 0; i < n; i++) {
a[i] = (byte) scalarLong(chunks.get(i), max);
a[i] = (byte) scalarLong(statBytes.get(i));
}
yield a;
}
case I16, U16 -> {
short[] a = new short[n];
for (int i = 0; i < n; i++) {
a[i] = (short) scalarLong(chunks.get(i), max);
a[i] = (short) scalarLong(statBytes.get(i));
}
yield a;
}
case I32, U32 -> {
int[] a = new int[n];
for (int i = 0; i < n; i++) {
a[i] = (int) scalarLong(chunks.get(i), max);
a[i] = (int) scalarLong(statBytes.get(i));
}
yield a;
}
case I64, U64 -> {
long[] a = new long[n];
for (int i = 0; i < n; i++) {
a[i] = scalarLong(chunks.get(i), max);
a[i] = scalarLong(statBytes.get(i));
}
yield a;
}
case F32 -> {
float[] a = new float[n];
for (int i = 0; i < n; i++) {
a[i] = (float) scalarDouble(chunks.get(i), max);
a[i] = (float) scalarDouble(statBytes.get(i));
}
yield a;
}
case F64 -> {
double[] a = new double[n];
for (int i = 0; i < n; i++) {
a[i] = scalarDouble(chunks.get(i), max);
a[i] = scalarDouble(statBytes.get(i));
}
yield a;
}
case F16 -> {
// F16 min/max are serialised as f32 scalars; re-pack to float16 storage.
short[] a = new short[n];
for (int i = 0; i < n; i++) {
a[i] = Float.floatToFloat16((float) scalarDouble(chunks.get(i), max));
a[i] = Float.floatToFloat16((float) scalarDouble(statBytes.get(i)));
}
yield a;
}
};
}

private static long scalarLong(ChunkRef cr, boolean max) throws IOException {
private static long scalarLong(byte[] bytes) throws IOException {
// Integer columns serialise min/max as int64 (signed) or uint64 (unsigned).
ScalarValue sv = decodeScalar(max ? cr.statsMax() : cr.statsMin());
ScalarValue sv = decodeScalar(bytes);
return sv.int64_value() != null ? sv.int64_value() : sv.uint64_value();
}

private static double scalarDouble(ChunkRef cr, boolean max) throws IOException {
private static double scalarDouble(byte[] bytes) throws IOException {
// Float columns serialise min/max as f64 (F64) or f32 (F32).
ScalarValue sv = decodeScalar(max ? cr.statsMax() : cr.statsMin());
ScalarValue sv = decodeScalar(bytes);
return sv.f64_value() != null ? sv.f64_value() : sv.f32_value();
}

Expand Down Expand Up @@ -1061,16 +1097,25 @@ private void writeGlobalDictColumn(String colName, DType.Primitive dtype, List<O
List<Integer> codesSegIdxes = new ArrayList<>();
List<Long> chunkRowCounts = new ArrayList<>();
List<Long> chunkNullCounts = new ArrayList<>();
List<byte[]> chunkStatsMin = new ArrayList<>();
List<byte[]> chunkStatsMax = new ArrayList<>();
for (Object chunk : chunks) {
int len = primitiveArrayLen(chunk, ptype);
Object codesArr = buildCodesArray(chunk, ptype, valueMap, codePType, len);
codesSegIdxes.add(writeSegment(codesDtype, codesArr));
chunkRowCounts.add((long) len);
chunkNullCounts.add(chunk instanceof NullableData nd ? countNulls(nd.validity()) : 0L);
// Per-zone min/max over the chunk's logical values (matches the flat primitive path:
// computed on nd.values(), placeholders included). Lets the dict zone-map prune like a
// plain primitive column.
Object values = chunk instanceof NullableData nd ? nd.values() : chunk;
byte[][] mm = PrimitiveEncodingEncoder.minMaxStats(ptype, values);
chunkStatsMin.add(mm != null ? mm[0] : null);
chunkStatsMax.add(mm != null ? mm[1] : null);
}

dictColRefs.put(colName,
new DictColRef(valuesSegIdx, dictSize, codesSegIdxes, chunkRowCounts, chunkNullCounts));
dictColRefs.put(colName, new DictColRef(valuesSegIdx, dictSize, codesSegIdxes,
chunkRowCounts, chunkNullCounts, chunkStatsMin, chunkStatsMax));
}

private void writeGlobalDictUtf8Column(String colName, DType.Utf8 dtype, List<Object> chunks)
Expand Down Expand Up @@ -1107,15 +1152,22 @@ private void writeGlobalDictUtf8Column(String colName, DType.Utf8 dtype, List<Ob
List<Integer> codesSegIdxes = new ArrayList<>();
List<Long> chunkRowCounts = new ArrayList<>();
List<Long> chunkNullCounts = new ArrayList<>();
List<byte[]> chunkStatsMin = new ArrayList<>();
List<byte[]> chunkStatsMax = new ArrayList<>();
for (Object chunk : chunks) {
String[] strs = (String[]) chunk;
Object codesArr = buildUtf8CodesArray(strs, valueMap, codePType);
codesSegIdxes.add(writeSegment(codesDtype, codesArr));
chunkRowCounts.add((long) strs.length);
chunkNullCounts.add(0L); // global-dict Utf8 columns are dense (non-nullable)
// Per-zone string min/max over the chunk's values (matches the flat varbin path), so the
// dict zone-map prunes like a plain Utf8 column.
byte[][] mm = VarBinEncodingEncoder.minMaxStats(strs);
chunkStatsMin.add(mm != null ? mm[0] : null);
chunkStatsMax.add(mm != null ? mm[1] : null);
}
dictColRefs.put(colName,
new DictColRef(valuesSegIdx, dictSize, codesSegIdxes, chunkRowCounts, chunkNullCounts));
dictColRefs.put(colName, new DictColRef(valuesSegIdx, dictSize, codesSegIdxes,
chunkRowCounts, chunkNullCounts, chunkStatsMin, chunkStatsMax));
}

private static Object buildUtf8CodesArray(String[] strs, Map<String, Integer> valueMap, PType codePType) {
Expand Down Expand Up @@ -1320,7 +1372,8 @@ private record ZoneMapRef(int zonesSegIdx, long nZones, long zoneLen, boolean ha
}

private record DictColRef(int valuesSegIdx, long valuesLen, List<Integer> codesSegIdxes,
List<Long> chunkRowCounts, List<Long> chunkNullCounts) {
List<Long> chunkRowCounts, List<Long> chunkNullCounts,
List<byte[]> chunkStatsMin, List<byte[]> chunkStatsMax) {
long totalRows() {
return chunkRowCounts.stream().mapToLong(Long::longValue).sum();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) {
MemorySegment seg = encodePrimitive(ptype, data, ctx.arena());
byte[] min = null;
byte[] max = null;
byte[][] stats = computeStats(ptype, data);
byte[][] stats = minMaxStats(ptype, data);
if (stats != null) {
min = stats[0];
max = stats[1];
Expand Down Expand Up @@ -86,7 +86,15 @@ private static MemorySegment encodePrimitive(PType ptype, Object data, Arena are
};
}

private static byte[][] computeStats(PType ptype, Object data) {
/// Computes the serialised min/max [io.github.dfa1.vortex.proto.ScalarValue] pair for a raw
/// primitive array, in the same signed/unsigned/float shape the per-segment stats use. Returns
/// `null` for an empty array. Shared so the dictionary zone-map path computes per-chunk min/max
/// identically to the flat path.
///
/// @param ptype the primitive type of `data`
/// @param data the raw primitive array (e.g. `long[]`, `int[]`, `String`-free)
/// @return a two-element `{min, max}` array of encoded scalars, or `null` if `data` is empty
public static byte[][] minMaxStats(PType ptype, Object data) {
return switch (ptype) {
case I8 -> {
byte[] arr = (byte[]) data;
Expand Down
Loading