From 82e7582248db07d9d729da6655784d72ba92b996 Mon Sep 17 00:00:00 2001 From: mxsm Date: Fri, 31 May 2024 07:35:08 +0000 Subject: [PATCH] =?UTF-8?q?[ISSUE=20#409]=E2=9A=A1=EF=B8=8FOptimize=20Remo?= =?UTF-8?q?tingCommand=20decode=20and=20encode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/codec/remoting_command_codec.rs | 166 +++++++++++++----- .../src/protocol/remoting_command.rs | 7 + 2 files changed, 131 insertions(+), 42 deletions(-) diff --git a/rocketmq-remoting/src/codec/remoting_command_codec.rs b/rocketmq-remoting/src/codec/remoting_command_codec.rs index 07f16422..23418bd9 100644 --- a/rocketmq-remoting/src/codec/remoting_command_codec.rs +++ b/rocketmq-remoting/src/codec/remoting_command_codec.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use tokio_util::codec::{Decoder, Encoder}; use crate::{ @@ -23,6 +23,26 @@ use crate::{ protocol::remoting_command::RemotingCommand, }; +/// Encodes a `RemotingCommand` into a `BytesMut` buffer. +/// +/// This method takes a `RemotingCommand` and a mutable reference to a `BytesMut` buffer as +/// parameters. It first encodes the header of the `RemotingCommand` and calculates the lengths of +/// the header and body. It then reserves the necessary space in the `BytesMut` buffer and writes +/// the total length, serialize type, header, and body to the buffer. +/// +/// # Arguments +/// +/// * `item` - A `RemotingCommand` that is to be encoded. +/// * `dst` - A mutable reference to a `BytesMut` buffer where the encoded command will be written. +/// +/// # Returns +/// +/// * `Result<(), Self::Error>` - Returns `Ok(())` if the encoding is successful, otherwise returns +/// an `Err` with a `RemotingError`. +/// +/// # Errors +/// +/// This function will return an error if the encoding process fails. #[derive(Debug, Clone)] pub struct RemotingCommandCodec; @@ -42,6 +62,32 @@ impl Decoder for RemotingCommandCodec { type Item = RemotingCommand; type Error = RemotingError; + /// Decodes a `RemotingCommand` from a `BytesMut` buffer. + /// + /// This method takes a mutable reference to a `BytesMut` buffer as a parameter. + /// It first checks if there are at least 4 bytes in the buffer, if not, it returns `Ok(None)`. + /// Then it reads the total size of the incoming data as a big-endian i32 from the first 4 + /// bytes. If the available data is less than the total size, it returns `Ok(None)`. + /// It then splits the `BytesMut` buffer to get the command data including the total size and + /// discards the first i32 (total size). It reads the header length as a big-endian i32 and + /// checks if the header length is greater than the total size minus 4. If it is, it returns + /// an error. It then splits the buffer again to get the header data and deserializes it + /// into a `RemotingCommand`. If the total size minus 4 is greater than the header length, + /// it sets the body of the `RemotingCommand`. + /// + /// # Arguments + /// + /// * `src` - A mutable reference to a `BytesMut` buffer from which the `RemotingCommand` will + /// be decoded. + /// + /// # Returns + /// + /// * `Result, Self::Error>` - Returns `Ok(Some(cmd))` if the decoding is + /// successful, otherwise returns an `Err` with a `RemotingError`. + /// + /// # Errors + /// + /// This function will return an error if the decoding process fails. fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { let read_to = src.len(); if read_to < 4 { @@ -58,55 +104,76 @@ impl Decoder for RemotingCommandCodec { // Split the BytesMut to get the command data including the total size. let mut cmd_data = src.split_to(total_size + 4); // Discard the first i32 (total size). - let _ = cmd_data.get_i32(); - + cmd_data.advance(4); + if cmd_data.remaining() < 4 { + return Ok(None); + } // Read the header length as a big-endian i32. let header_length = cmd_data.get_i32() as usize; + if header_length > total_size - 4 { + return Err(RemotingCommandDecoderError(format!( + "Header length {} is greater than total size {}", + header_length, total_size + ))); + } // Assume the header is of i32 type and directly get it from the data. let header_data = cmd_data.split_to(header_length); - let cmd = serde_json::from_slice::(&header_data).map_err(|error| { + let mut cmd = serde_json::from_slice::(&header_data).map_err(|error| { // Handle deserialization error gracefully RemotingCommandDecoderError(format!("Deserialization error: {}", error)) })?; + if total_size - 4 > header_length { + cmd.set_body_mut_ref(Some( + cmd_data.split_to(total_size - 4 - header_length).freeze(), + )); + } - let body_length = total_size - 4 - header_length; - Ok(Some(if body_length > 0 { - let body_data = cmd_data.split_to(body_length).to_vec(); - cmd.set_body(Some(Bytes::from(body_data))) - } else { - cmd - })) + Ok(Some(cmd)) } } impl Encoder for RemotingCommandCodec { type Error = RemotingError; + /// Encodes a `RemotingCommand` into a `BytesMut` buffer. + /// + /// This method takes a `RemotingCommand` and a mutable reference to a `BytesMut` buffer as + /// parameters. It first encodes the header of the `RemotingCommand` and calculates the + /// lengths of the header and body. It then reserves the necessary space in the `BytesMut` + /// buffer and writes the total length, serialize type, header, and body to the buffer. + /// + /// # Arguments + /// + /// * `item` - A `RemotingCommand` that is to be encoded. + /// * `dst` - A mutable reference to a `BytesMut` buffer where the encoded command will be + /// written. + /// + /// # Returns + /// + /// * `Result<(), Self::Error>` - Returns `Ok(())` if the encoding is successful, otherwise + /// returns an `Err` with a `RemotingError`. + /// + /// # Errors + /// + /// This function will return an error if the encoding process fails. fn encode(&mut self, item: RemotingCommand, dst: &mut BytesMut) -> Result<(), Self::Error> { - let mut total_length = 4i32; let header = item.fast_header_encode(); - let mut header_length = 0i32; - if let Some(header) = &header { - header_length = header.len() as i32; - total_length += header_length; - } - let body = item.get_body(); - if let Some(body) = &body { - total_length += body.len() as i32; - } + let header_length = header.as_ref().map_or(0, |h| h.len()) as i32; + let body_length = item.get_body().map_or(0, |b| b.len()) as i32; + let total_length = 4 + header_length + body_length; - dst.reserve(total_length as usize); - // total length: 8 + header length + body length + dst.reserve((total_length + 4) as usize); dst.put_i32(total_length); let serialize_type = RemotingCommand::mark_serialize_type(header_length, item.get_serialize_type()); dst.put_i32(serialize_type); + if let Some(header_inner) = header { dst.put(header_inner); } - if let Some(body_inner) = body { + if let Some(body_inner) = item.get_body() { dst.put(body_inner); } Ok(()) @@ -115,36 +182,51 @@ impl Encoder for RemotingCommandCodec { #[cfg(test)] mod tests { + use bytes::Bytes; + use super::*; use crate::protocol::{header::client_request_header::GetRouteInfoRequestHeader, LanguageCode}; - #[test] - fn test_encode() { + #[tokio::test] + async fn decode_handles_insufficient_data() { + let mut decoder = RemotingCommandCodec::new(); + let mut src = BytesMut::from(&[0, 0, 0, 1][..]); + assert!(matches!(decoder.decode(&mut src), Ok(None))); + } + + #[tokio::test] + async fn decode_handles_sufficient_data() { + let mut decoder = RemotingCommandCodec::new(); + let mut src = BytesMut::from(&[0, 0, 0, 1, 0, 0, 0, 0][..]); + assert!(matches!(decoder.decode(&mut src), Ok(None))); + } + + #[tokio::test] + async fn encode_handles_empty_body() { + let mut encoder = RemotingCommandCodec::new(); let mut dst = BytesMut::new(); let command = RemotingCommand::create_remoting_command(1) .set_code(1) .set_language(LanguageCode::JAVA) .set_opaque(1) .set_flag(1) - .set_body(Some(Bytes::from("body"))) .set_command_custom_header(GetRouteInfoRequestHeader::new("1111", Some(true))) .set_remark(Some("remark".to_string())); - println!("{}", serde_json::to_string(&command).unwrap()); - let mut encoder = RemotingCommandCodec::new(); - let _ = encoder.encode(command, &mut dst); - - let _expected_length = 8 + "header".len() as i32 + "body".len() as i32; - let result = encoder.decode(&mut dst); - println!("{:?}", result.unwrap().unwrap().get_serialize_type()); + assert!(encoder.encode(command, &mut dst).is_ok()); } - #[test] - fn tsts() { - let mut bytes1 = bytes::BytesMut::from("122222"); - let _bytes2 = bytes1.split_to(1); - println!("{}", bytes1.len()); - bytes1.reserve(1); - let _bytes2 = bytes1.split_to(1); - println!("{}", bytes1.len()); + #[tokio::test] + async fn encode_handles_non_empty_body() { + let mut encoder = RemotingCommandCodec::new(); + let mut dst = BytesMut::new(); + let command = RemotingCommand::create_remoting_command(1) + .set_code(1) + .set_language(LanguageCode::JAVA) + .set_opaque(1) + .set_flag(1) + .set_body(Some(Bytes::from("body"))) + .set_command_custom_header(GetRouteInfoRequestHeader::new("1111", Some(true))) + .set_remark(Some("remark".to_string())); + assert!(encoder.encode(command, &mut dst).is_ok()); } } diff --git a/rocketmq-remoting/src/protocol/remoting_command.rs b/rocketmq-remoting/src/protocol/remoting_command.rs index 86330f9c..53560b65 100644 --- a/rocketmq-remoting/src/protocol/remoting_command.rs +++ b/rocketmq-remoting/src/protocol/remoting_command.rs @@ -238,6 +238,13 @@ impl RemotingCommand { } self } + + pub fn set_body_mut_ref(&mut self, body: Option>) { + if let Some(value) = body { + self.body = Some(value.into()); + } + } + pub fn set_suspended(mut self, suspended: bool) -> Self { self.suspended = suspended; self