diff --git a/src/stream.rs b/src/stream.rs index e1d2fb4..e8e33b8 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -62,7 +62,10 @@ impl DeflateEncoder { let chunk = chunks[0]; let compressor = &mut self.compressors[0]; let output = &mut self.output_buffers[0]; - let bound = Compressor::deflate_compress_bound(chunk.len()); + let mut bound = Compressor::deflate_compress_bound(chunk.len()); + if !final_block { + bound += 5; + } if output.len() < bound { output .try_reserve(bound - output.len()) @@ -100,7 +103,10 @@ impl DeflateEncoder { .zip(self.output_buffers.par_iter_mut()) .enumerate() .map(|(i, ((&chunk, compressor), output))| { - let bound = Compressor::deflate_compress_bound(chunk.len()); + let mut bound = Compressor::deflate_compress_bound(chunk.len()); + if !(final_block && i == num_chunks - 1) { + bound += 5; + } if output.len() < bound { output .try_reserve(bound - output.len()) @@ -149,7 +155,10 @@ impl DeflateEncoder { let compressor = &mut self.compressors[0]; let output = &mut self.output_buffers[0]; - let bound = Compressor::deflate_compress_bound(self.buffer.len()); + let mut bound = Compressor::deflate_compress_bound(self.buffer.len()); + if !final_block { + bound += 5; + } if output.len() < bound { output .try_reserve(bound - output.len()) diff --git a/tests/stream_test.rs b/tests/stream_test.rs index 1a9ca16..5f5bc92 100644 --- a/tests/stream_test.rs +++ b/tests/stream_test.rs @@ -1,5 +1,36 @@ use libdeflate::stream::{DeflateDecoder, DeflateEncoder}; use std::io::{Cursor, Read, Write}; +use std::sync::{Arc, Mutex}; + +#[derive(Clone)] +struct FlushTrackingWriter { + data: Arc>>, + flush_count: Arc>, +} + +impl Write for FlushTrackingWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.data.lock().unwrap().extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + *self.flush_count.lock().unwrap() += 1; + Ok(()) + } +} + +struct ErrorFlushWriter; + +impl Write for ErrorFlushWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Err(std::io::Error::new(std::io::ErrorKind::Other, "flush error")) + } +} #[test] fn test_stream_round_trip() { @@ -43,3 +74,33 @@ fn test_stream_small_chunks() { assert_eq!(data, decompressed); } + +#[test] +fn test_encoder_flush() { + let data = Arc::new(Mutex::new(Vec::new())); + let flush_count = Arc::new(Mutex::new(0)); + let writer = FlushTrackingWriter { + data: data.clone(), + flush_count: flush_count.clone(), + }; + + let mut encoder = DeflateEncoder::new(writer, 6); + encoder.write_all(b"Hello World").unwrap(); + encoder.flush().unwrap(); + + // Verify data was written (compressed) + assert!(!data.lock().unwrap().is_empty()); + + // Verify flush was called on the underlying writer + assert_eq!(*flush_count.lock().unwrap(), 1); +} + +#[test] +fn test_encoder_flush_error() { + let writer = ErrorFlushWriter; + let mut encoder = DeflateEncoder::new(writer, 6); + encoder.write_all(b"Hello World").unwrap(); + + // flush() should fail because the underlying writer returns an error + assert!(encoder.flush().is_err()); +}