diff --git a/src/buf/ring.rs b/src/buf/ring.rs index ca1df7f9e476fb4a86cd3c27177366f0b9c8f1f5..57ec32dabebc14b84cb5ca0410a5af48e042dcee 100644 --- a/src/buf/ring.rs +++ b/src/buf/ring.rs @@ -1,6 +1,11 @@ use {alloc, Buf, MutBuf}; use std::{cmp, fmt, io, ptr}; +enum Mark { + NoMark, + At { pos: usize, len: usize }, +} + /// Buf backed by a continous chunk of memory. Maintains a read cursor and a /// write cursor. When reads and writes reach the end of the allocated buffer, /// wraps around to the start. @@ -9,7 +14,7 @@ pub struct RingBuf { cap: usize, // Capacity of the buffer pos: usize, // Offset of read cursor len: usize, // Number of bytes to read - mark: Option<usize>, // Marked read position + mark: Mark, // Marked read position } // TODO: There are most likely many optimizations that can be made @@ -22,7 +27,7 @@ impl RingBuf { cap: 0, pos: 0, len: 0, - mark: None, + mark: Mark::NoMark, } } @@ -36,7 +41,7 @@ impl RingBuf { cap: capacity, pos: 0, len: 0, - mark: None, + mark: Mark::NoMark, } } @@ -58,7 +63,7 @@ impl RingBuf { /// buffer multiple times. The mark will be cleared if it is overwritten /// during a write. pub fn mark(&mut self) { - self.mark = Some(self.pos); + self.mark = Mark::At { pos: self.pos, len: self.len }; } /// Resets the read position to the previously marked position. @@ -70,9 +75,14 @@ impl RingBuf { /// /// This method will panic if no mark has been set, pub fn reset(&mut self){ - let mark = self.mark.take().expect("no mark set"); - self.len = (self.len + self.pos + self.cap - mark) % self.cap; - self.pos = mark; + match self.mark { + Mark::NoMark => panic!("no mark set"), + Mark::At {pos, len} => { + self.pos = pos; + self.len = len; + self.mark = Mark::NoMark; + } + } } fn read_remaining(&self) -> usize { @@ -97,9 +107,16 @@ impl RingBuf { fn advance_writer(&mut self, mut cnt: usize) { cnt = cmp::min(cnt, self.write_remaining()); self.len += cnt; - if let Some(mark) = self.mark { - if (self.pos + self.len) % self.cap > mark { - self.mark = None; + + // Adjust the mark to account for bytes written. + if let Mark::At { ref mut len, .. } = self.mark { + *len += cnt; + } + + // Clear the mark if we've written past it. + if let Mark::At { len, .. } = self.mark { + if len > self.cap { + self.mark = Mark::NoMark; } } } diff --git a/test/test_ring.rs b/test/test_ring.rs index d2c1ba957d5f98e43f69c5f35ba276f88eb97a74..78bf4a0fe8817bc607864130e1808b4cb8f371a5 100644 --- a/test/test_ring.rs +++ b/test/test_ring.rs @@ -89,3 +89,34 @@ fn test_wrap_reset() { buf.write(&[1, 2, 3, 4]).unwrap(); buf.reset(); } + +#[test] +// Test that writes across a mark/reset are preserved. +fn test_mark_write() { + use std::io::{Read, Write}; + + let mut buf = RingBuf::new(8); + buf.write(&[1, 2, 3, 4, 5, 6, 7]).unwrap(); + buf.mark(); + buf.write(&[8]).unwrap(); + buf.reset(); + + let mut buf2 = [0; 8]; + buf.read(&mut buf2).unwrap(); + assert_eq!(buf2, [1, 2, 3, 4, 5, 6, 7, 8]); +} + +#[test] +// Test that "RingBuf::reset" does not reset the length of a +// full buffer to zero. +fn test_reset_full() { + use bytes::traits::MutBuf; + use std::io::Write; + + let mut buf = RingBuf::new(8); + buf.write(&[1, 2, 3, 4, 5, 6, 7, 8]).unwrap(); + assert_eq!(MutBuf::remaining(&buf), 0); + buf.mark(); + buf.reset(); + assert_eq!(MutBuf::remaining(&buf), 0); +}