Skip to content

Commit ffdbb10

Browse files
Verify DeflateEncoder::flush() and fix buffer size calculation
- Added `test_encoder_flush` and `test_encoder_flush_error` to `tests/stream_test.rs`. - Implemented `FlushTrackingWriter` and `ErrorFlushWriter` helpers to verify underlying writer behavior. - Fixed a bug in `src/stream.rs` where `deflate_compress_bound` did not account for the 5-byte overhead of `FlushMode::Sync` blocks (used in non-final chunks), which caused `flush()` to fail on small inputs. - Verified that `flush()` propagates data to the underlying writer and calls its `flush()` method. Co-authored-by: 404Setup <[email protected]>
1 parent 3e5bb8a commit ffdbb10

2 files changed

Lines changed: 73 additions & 3 deletions

File tree

src/stream.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ impl<W: Write + Send> DeflateEncoder<W> {
6262
let chunk = chunks[0];
6363
let compressor = &mut self.compressors[0];
6464
let output = &mut self.output_buffers[0];
65-
let bound = Compressor::deflate_compress_bound(chunk.len());
65+
let mut bound = Compressor::deflate_compress_bound(chunk.len());
66+
if !final_block {
67+
bound += 5;
68+
}
6669
if output.len() < bound {
6770
output
6871
.try_reserve(bound - output.len())
@@ -100,7 +103,10 @@ impl<W: Write + Send> DeflateEncoder<W> {
100103
.zip(self.output_buffers.par_iter_mut())
101104
.enumerate()
102105
.map(|(i, ((&chunk, compressor), output))| {
103-
let bound = Compressor::deflate_compress_bound(chunk.len());
106+
let mut bound = Compressor::deflate_compress_bound(chunk.len());
107+
if !(final_block && i == num_chunks - 1) {
108+
bound += 5;
109+
}
104110
if output.len() < bound {
105111
output
106112
.try_reserve(bound - output.len())
@@ -149,7 +155,10 @@ impl<W: Write + Send> DeflateEncoder<W> {
149155

150156
let compressor = &mut self.compressors[0];
151157
let output = &mut self.output_buffers[0];
152-
let bound = Compressor::deflate_compress_bound(self.buffer.len());
158+
let mut bound = Compressor::deflate_compress_bound(self.buffer.len());
159+
if !final_block {
160+
bound += 5;
161+
}
153162
if output.len() < bound {
154163
output
155164
.try_reserve(bound - output.len())

tests/stream_test.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,36 @@
11
use libdeflate::stream::{DeflateDecoder, DeflateEncoder};
22
use std::io::{Cursor, Read, Write};
3+
use std::sync::{Arc, Mutex};
4+
5+
#[derive(Clone)]
6+
struct FlushTrackingWriter {
7+
data: Arc<Mutex<Vec<u8>>>,
8+
flush_count: Arc<Mutex<usize>>,
9+
}
10+
11+
impl Write for FlushTrackingWriter {
12+
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
13+
self.data.lock().unwrap().extend_from_slice(buf);
14+
Ok(buf.len())
15+
}
16+
17+
fn flush(&mut self) -> std::io::Result<()> {
18+
*self.flush_count.lock().unwrap() += 1;
19+
Ok(())
20+
}
21+
}
22+
23+
struct ErrorFlushWriter;
24+
25+
impl Write for ErrorFlushWriter {
26+
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
27+
Ok(buf.len())
28+
}
29+
30+
fn flush(&mut self) -> std::io::Result<()> {
31+
Err(std::io::Error::new(std::io::ErrorKind::Other, "flush error"))
32+
}
33+
}
334

435
#[test]
536
fn test_stream_round_trip() {
@@ -43,3 +74,33 @@ fn test_stream_small_chunks() {
4374

4475
assert_eq!(data, decompressed);
4576
}
77+
78+
#[test]
79+
fn test_encoder_flush() {
80+
let data = Arc::new(Mutex::new(Vec::new()));
81+
let flush_count = Arc::new(Mutex::new(0));
82+
let writer = FlushTrackingWriter {
83+
data: data.clone(),
84+
flush_count: flush_count.clone(),
85+
};
86+
87+
let mut encoder = DeflateEncoder::new(writer, 6);
88+
encoder.write_all(b"Hello World").unwrap();
89+
encoder.flush().unwrap();
90+
91+
// Verify data was written (compressed)
92+
assert!(!data.lock().unwrap().is_empty());
93+
94+
// Verify flush was called on the underlying writer
95+
assert_eq!(*flush_count.lock().unwrap(), 1);
96+
}
97+
98+
#[test]
99+
fn test_encoder_flush_error() {
100+
let writer = ErrorFlushWriter;
101+
let mut encoder = DeflateEncoder::new(writer, 6);
102+
encoder.write_all(b"Hello World").unwrap();
103+
104+
// flush() should fail because the underlying writer returns an error
105+
assert!(encoder.flush().is_err());
106+
}

0 commit comments

Comments
 (0)