Skip to content

Fix silent string truncation in PyGrain shared memory batching#1328

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_919292388
Open

Fix silent string truncation in PyGrain shared memory batching#1328
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_919292388

Conversation

@copybara-service
Copy link
Copy Markdown

@copybara-service copybara-service Bot commented May 22, 2026

Fix silent string truncation in PyGrain shared memory batching

The prompt truncation is caused by a silent data-loss bug in PyGrain's shared memory batch stacking implementation which is active during evaluation when num_workers > 0 (multi-process prefetching).

To improve multi-process data loading performance, Kauldron automatically enables output_to_shared_memory=True inside PyGrain's batching transformations.

When batching elements inside shm_stacking_function, the destination SharedMemoryArray's shape and dtype are inferred exclusively from the first element in the batch:

first_arg = np.asanyarray(args[0])
shape, dtype = (len(args),) + first_arg.shape, first_arg.dtype

In a batch containing variable-length string elements, the SharedMemoryArray is allocated with a fixed-width unicode dtype bounded by that first string.

When np.stack(args, out=...) copies subsequent longer strings into the allocated fixed-width shared memory array, NumPy silently truncates the strings to the array's fixed width!

The prompt truncation is caused by a silent data-loss bug in PyGrain's shared memory batch stacking implementation which is active during evaluation when `num_workers > 0` (multi-process prefetching).

To improve multi-process data loading performance, Kauldron automatically enables `output_to_shared_memory=True` inside PyGrain's batching transformations.

When batching elements inside `shm_stacking_function`, the destination `SharedMemoryArray`'s `shape` and `dtype` are inferred exclusively from the first element in the batch:
   ```python
   first_arg = np.asanyarray(args[0])
   shape, dtype = (len(args),) + first_arg.shape, first_arg.dtype
   ```

In a batch containing variable-length string elements, the `SharedMemoryArray` is allocated with a fixed-width unicode dtype bounded by that first string.

When `np.stack(args, out=...)` copies subsequent longer strings into the allocated fixed-width shared memory array, NumPy silently truncates the strings to the array's fixed width!

PiperOrigin-RevId: 919292388
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants