From 94396162b222826fc0bb2ad9e9fb9a58bafe4a4a Mon Sep 17 00:00:00 2001
From: Carl Lerche <me@carllerche.com>
Date: Tue, 28 Feb 2017 18:33:34 -0800
Subject: [PATCH] Implement `chain` combinator for `Buf`

---
 src/buf/buf.rs      |  26 ++++-
 src/buf/chain.rs    | 241 ++++++++++++++++++++++++++++++++++++++++++++
 src/buf/mod.rs      |   2 +
 tests/test_chain.rs |  93 +++++++++++++++++
 4 files changed, 361 insertions(+), 1 deletion(-)
 create mode 100644 src/buf/chain.rs
 create mode 100644 tests/test_chain.rs

diff --git a/src/buf/buf.rs b/src/buf/buf.rs
index 0732aff..e85b321 100644
--- a/src/buf/buf.rs
+++ b/src/buf/buf.rs
@@ -1,4 +1,4 @@
-use super::{Take, Reader, Iter, FromBuf};
+use super::{IntoBuf, Take, Reader, Iter, FromBuf, Chain};
 use byteorder::ByteOrder;
 use iovec::IoVec;
 
@@ -522,6 +522,30 @@ pub trait Buf {
         super::take::new(self, limit)
     }
 
+    /// Creates an adaptor which will chain this buffer with another.
+    ///
+    /// The returned `Buf` instance will first consume all bytes from `self`.
+    /// Afterwards the output is equivalent to the output of next.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use bytes::{Bytes, Buf, IntoBuf};
+    /// use bytes::buf::Chain;
+    ///
+    /// let buf = Bytes::from(&b"hello "[..]).into_buf()
+    ///             .chain(Bytes::from(&b"world"[..]));
+    ///
+    /// let full: Bytes = buf.collect();
+    /// assert_eq!(full[..], b"hello world"[..]);
+    /// ```
+    fn chain<U>(self, next: U) -> Chain<Self, U::Buf>
+        where U: IntoBuf,
+              Self: Sized,
+    {
+        Chain::new(self, next.into_buf())
+    }
+
     /// Creates a "by reference" adaptor for this instance of `Buf`.
     ///
     /// The returned adaptor also implements `Buf` and will simply borrow `self`.
diff --git a/src/buf/chain.rs b/src/buf/chain.rs
new file mode 100644
index 0000000..6b7cd4b
--- /dev/null
+++ b/src/buf/chain.rs
@@ -0,0 +1,241 @@
+use {Buf, BufMut};
+use iovec::IoVec;
+
+/// A `Chain` sequences two buffers.
+///
+/// `Chain` is an adapter that links two underlying buffers and provides a
+/// continous view across both buffers. It is able to sequence either immutable
+/// buffers ([`Buf`] values) or mutable buffers ([`BufMut`] values).
+///
+/// This struct is generally created by calling [`Buf::chain`]. Please see that
+/// function's documentation for more detail.
+///
+/// # Examples
+///
+/// ```
+/// use bytes::{Bytes, Buf, IntoBuf};
+/// use bytes::buf::Chain;
+///
+/// let buf = Bytes::from(&b"hello "[..]).into_buf()
+///             .chain(Bytes::from(&b"world"[..]));
+///
+/// let full: Bytes = buf.collect();
+/// assert_eq!(full[..], b"hello world"[..]);
+/// ```
+///
+/// [`Buf::chain`]: trait.Buf.html#method.chain
+/// [`Buf`]: trait.Buf.html
+/// [`BufMut`]: trait.BufMut.html
+pub struct Chain<T, U> {
+    a: T,
+    b: U,
+}
+
+impl<T, U> Chain<T, U> {
+    /// Creates a new `Chain` sequencing the provided values.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use bytes::BytesMut;
+    /// use bytes::buf::Chain;
+    ///
+    /// let buf = Chain::new(
+    ///     BytesMut::with_capacity(1024),
+    ///     BytesMut::with_capacity(1024));
+    ///
+    /// // Use the chained buffer
+    /// ```
+    pub fn new(a: T, b: U) -> Chain<T, U> {
+        Chain {
+            a: a,
+            b: b,
+        }
+    }
+
+    /// Gets a reference to the first underlying `Buf`.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use bytes::{Bytes, Buf, IntoBuf};
+    ///
+    /// let buf = Bytes::from(&b"hello"[..]).into_buf()
+    ///             .chain(Bytes::from(&b"world"[..]));
+    ///
+    /// assert_eq!(buf.first_ref().get_ref()[..], b"hello"[..]);
+    /// ```
+    pub fn first_ref(&self) -> &T {
+        &self.a
+    }
+
+    /// Gets a mutable reference to the first underlying `Buf`.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use bytes::{Bytes, Buf, IntoBuf};
+    ///
+    /// let mut buf = Bytes::from(&b"hello "[..]).into_buf()
+    ///                 .chain(Bytes::from(&b"world"[..]));
+    ///
+    /// buf.first_mut().set_position(1);
+    ///
+    /// let full: Bytes = buf.collect();
+    /// assert_eq!(full[..], b"ello world"[..]);
+    /// ```
+    pub fn first_mut(&mut self) -> &mut T {
+        &mut self.a
+    }
+
+    /// Gets a reference to the last underlying `Buf`.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use bytes::{Bytes, Buf, IntoBuf};
+    ///
+    /// let buf = Bytes::from(&b"hello"[..]).into_buf()
+    ///             .chain(Bytes::from(&b"world"[..]));
+    ///
+    /// assert_eq!(buf.last_ref().get_ref()[..], b"world"[..]);
+    /// ```
+    pub fn last_ref(&self) -> &U {
+        &self.b
+    }
+
+    /// Gets a mutable reference to the last underlying `Buf`.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use bytes::{Bytes, Buf, IntoBuf};
+    ///
+    /// let mut buf = Bytes::from(&b"hello "[..]).into_buf()
+    ///                 .chain(Bytes::from(&b"world"[..]));
+    ///
+    /// buf.last_mut().set_position(1);
+    ///
+    /// let full: Bytes = buf.collect();
+    /// assert_eq!(full[..], b"hello orld"[..]);
+    /// ```
+    pub fn last_mut(&mut self) -> &mut U {
+        &mut self.b
+    }
+
+    /// Consumes this `Chain`, returning the underlying values.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use bytes::{Bytes, Buf, IntoBuf};
+    ///
+    /// let buf = Bytes::from(&b"hello"[..]).into_buf()
+    ///             .chain(Bytes::from(&b"world"[..]));
+    ///
+    /// let (first, last) = buf.into_inner();
+    /// assert_eq!(first.get_ref()[..], b"hello"[..]);
+    /// assert_eq!(last.get_ref()[..], b"world"[..]);
+    /// ```
+    pub fn into_inner(self) -> (T, U) {
+        (self.a, self.b)
+    }
+}
+
+impl<T, U> Buf for Chain<T, U>
+    where T: Buf,
+          U: Buf,
+{
+    fn remaining(&self) -> usize {
+        self.a.remaining() + self.b.remaining()
+    }
+
+    fn bytes(&self) -> &[u8] {
+        if self.a.has_remaining() {
+            self.a.bytes()
+        } else {
+            self.b.bytes()
+        }
+    }
+
+    fn advance(&mut self, mut cnt: usize) {
+        let a_rem = self.a.remaining();
+
+        if a_rem != 0 {
+            if a_rem >= cnt {
+                self.a.advance(cnt);
+                return;
+            }
+
+            // Consume what is left of a
+            self.a.advance(a_rem);
+
+            cnt -= a_rem;
+        }
+
+        self.b.advance(cnt);
+    }
+
+    fn bytes_vec<'a>(&'a self, dst: &mut [&'a IoVec]) -> usize {
+        if self.a.has_remaining() {
+            let mut n = self.a.bytes_vec(dst);
+
+            if n < dst.len() {
+                n += self.b.bytes_vec(&mut dst[n..]);
+            }
+
+            n
+        } else {
+            self.b.bytes_vec(dst)
+        }
+    }
+}
+
+impl<T, U> BufMut for Chain<T, U>
+    where T: BufMut,
+          U: BufMut,
+{
+    fn remaining_mut(&self) -> usize {
+        self.a.remaining_mut() + self.b.remaining_mut()
+    }
+
+    unsafe fn bytes_mut(&mut self) -> &mut [u8] {
+        if self.a.has_remaining_mut() {
+            self.a.bytes_mut()
+        } else {
+            self.b.bytes_mut()
+        }
+    }
+
+    unsafe fn advance_mut(&mut self, mut cnt: usize) {
+        let a_rem = self.a.remaining_mut();
+
+        if a_rem != 0 {
+            if a_rem >= cnt {
+                self.a.advance_mut(cnt);
+                return;
+            }
+
+            // Consume what is left of a
+            self.a.advance_mut(a_rem);
+
+            cnt -= a_rem;
+        }
+
+        self.b.advance_mut(cnt);
+    }
+
+    unsafe fn bytes_vec_mut<'a>(&'a mut self, dst: &mut [&'a mut IoVec]) -> usize {
+        if self.a.has_remaining_mut() {
+            let mut n = self.a.bytes_vec_mut(dst);
+
+            if n < dst.len() {
+                n += self.b.bytes_vec_mut(&mut dst[n..]);
+            }
+
+            n
+        } else {
+            self.b.bytes_vec_mut(dst)
+        }
+    }
+}
diff --git a/src/buf/mod.rs b/src/buf/mod.rs
index f8cf37a..7d7ce22 100644
--- a/src/buf/mod.rs
+++ b/src/buf/mod.rs
@@ -19,6 +19,7 @@
 mod buf;
 mod buf_mut;
 mod from_buf;
+mod chain;
 mod into_buf;
 mod iter;
 mod reader;
@@ -29,6 +30,7 @@ mod writer;
 pub use self::buf::Buf;
 pub use self::buf_mut::BufMut;
 pub use self::from_buf::FromBuf;
+pub use self::chain::Chain;
 pub use self::into_buf::IntoBuf;
 pub use self::iter::Iter;
 pub use self::reader::Reader;
diff --git a/tests/test_chain.rs b/tests/test_chain.rs
new file mode 100644
index 0000000..08181bd
--- /dev/null
+++ b/tests/test_chain.rs
@@ -0,0 +1,93 @@
+extern crate bytes;
+extern crate iovec;
+
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use bytes::buf::Chain;
+use iovec::IoVec;
+use std::io::Cursor;
+
+#[test]
+fn collect_two_bufs() {
+    let a = Cursor::new(Bytes::from(&b"hello"[..]));
+    let b = Cursor::new(Bytes::from(&b"world"[..]));
+
+    let res: Vec<u8> = a.chain(b).collect();
+    assert_eq!(res, &b"helloworld"[..]);
+}
+
+#[test]
+fn writing_chained() {
+    let mut a = BytesMut::with_capacity(64);
+    let mut b = BytesMut::with_capacity(64);
+
+    {
+        let mut buf = Chain::new(&mut a, &mut b);
+
+        for i in 0..128 {
+            buf.put(i as u8);
+        }
+    }
+
+    assert_eq!(64, a.len());
+    assert_eq!(64, b.len());
+
+    for i in 0..64 {
+        let expect = i as u8;
+        assert_eq!(expect, a[i]);
+        assert_eq!(expect + 64, b[i]);
+    }
+}
+
+#[test]
+fn iterating_two_bufs() {
+    let a = Cursor::new(Bytes::from(&b"hello"[..]));
+    let b = Cursor::new(Bytes::from(&b"world"[..]));
+
+    let res: Vec<u8> = a.chain(b).iter().collect();
+    assert_eq!(res, &b"helloworld"[..]);
+}
+
+#[test]
+fn vectored_read() {
+    let a = Cursor::new(Bytes::from(&b"hello"[..]));
+    let b = Cursor::new(Bytes::from(&b"world"[..]));
+
+    let mut buf = a.chain(b);
+    let mut iovecs: [&IoVec; 4] = Default::default();
+
+    assert_eq!(2, buf.bytes_vec(&mut iovecs));
+    assert_eq!(iovecs[0][..], b"hello"[..]);
+    assert_eq!(iovecs[1][..], b"world"[..]);
+    assert!(iovecs[2].is_empty());
+    assert!(iovecs[3].is_empty());
+
+    buf.advance(2);
+
+    iovecs = Default::default();
+
+    assert_eq!(2, buf.bytes_vec(&mut iovecs));
+    assert_eq!(iovecs[0][..], b"llo"[..]);
+    assert_eq!(iovecs[1][..], b"world"[..]);
+    assert!(iovecs[2].is_empty());
+    assert!(iovecs[3].is_empty());
+
+    buf.advance(3);
+
+    iovecs = Default::default();
+
+    assert_eq!(1, buf.bytes_vec(&mut iovecs));
+    assert_eq!(iovecs[0][..], b"world"[..]);
+    assert!(iovecs[1].is_empty());
+    assert!(iovecs[2].is_empty());
+    assert!(iovecs[3].is_empty());
+
+    buf.advance(3);
+
+    iovecs = Default::default();
+
+    assert_eq!(1, buf.bytes_vec(&mut iovecs));
+    assert_eq!(iovecs[0][..], b"ld"[..]);
+    assert!(iovecs[1].is_empty());
+    assert!(iovecs[2].is_empty());
+    assert!(iovecs[3].is_empty());
+}
-- 
GitLab