Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ impl<W: Write + Send> DeflateEncoder<W> {
let chunk = chunks[0];
let compressor = &mut self.compressors[0];
let output = &mut self.output_buffers[0];
let bound = Compressor::deflate_compress_bound(chunk.len());
let mut bound = Compressor::deflate_compress_bound(chunk.len());
if !final_block {
bound += 5;
}
if output.len() < bound {
output
.try_reserve(bound - output.len())
Expand Down Expand Up @@ -100,7 +103,10 @@ impl<W: Write + Send> DeflateEncoder<W> {
.zip(self.output_buffers.par_iter_mut())
.enumerate()
.map(|(i, ((&chunk, compressor), output))| {
let bound = Compressor::deflate_compress_bound(chunk.len());
let mut bound = Compressor::deflate_compress_bound(chunk.len());
if !(final_block && i == num_chunks - 1) {
bound += 5;
}
if output.len() < bound {
output
.try_reserve(bound - output.len())
Expand Down Expand Up @@ -149,7 +155,10 @@ impl<W: Write + Send> DeflateEncoder<W> {

let compressor = &mut self.compressors[0];
let output = &mut self.output_buffers[0];
let bound = Compressor::deflate_compress_bound(self.buffer.len());
let mut bound = Compressor::deflate_compress_bound(self.buffer.len());
if !final_block {
bound += 5;
}
if output.len() < bound {
output
.try_reserve(bound - output.len())
Expand Down
61 changes: 61 additions & 0 deletions tests/stream_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
use libdeflate::stream::{DeflateDecoder, DeflateEncoder};
use std::io::{Cursor, Read, Write};
use std::sync::{Arc, Mutex};

#[derive(Clone)]
struct FlushTrackingWriter {
data: Arc<Mutex<Vec<u8>>>,
flush_count: Arc<Mutex<usize>>,
}

impl Write for FlushTrackingWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.data.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}

fn flush(&mut self) -> std::io::Result<()> {
*self.flush_count.lock().unwrap() += 1;
Ok(())
}
}

struct ErrorFlushWriter;

impl Write for ErrorFlushWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
Ok(buf.len())
}

fn flush(&mut self) -> std::io::Result<()> {
Err(std::io::Error::new(std::io::ErrorKind::Other, "flush error"))
}
}

#[test]
fn test_stream_round_trip() {
Expand Down Expand Up @@ -43,3 +74,33 @@ fn test_stream_small_chunks() {

assert_eq!(data, decompressed);
}

#[test]
fn test_encoder_flush() {
let data = Arc::new(Mutex::new(Vec::new()));
let flush_count = Arc::new(Mutex::new(0));
let writer = FlushTrackingWriter {
data: data.clone(),
flush_count: flush_count.clone(),
};

let mut encoder = DeflateEncoder::new(writer, 6);
encoder.write_all(b"Hello World").unwrap();
encoder.flush().unwrap();

// Verify data was written (compressed)
assert!(!data.lock().unwrap().is_empty());

// Verify flush was called on the underlying writer
assert_eq!(*flush_count.lock().unwrap(), 1);
}

#[test]
fn test_encoder_flush_error() {
let writer = ErrorFlushWriter;
let mut encoder = DeflateEncoder::new(writer, 6);
encoder.write_all(b"Hello World").unwrap();

// flush() should fail because the underlying writer returns an error
assert!(encoder.flush().is_err());
}