diff --git a/rocketmq-broker/src/broker_runtime.rs b/rocketmq-broker/src/broker_runtime.rs index a7be606a1..381cd2850 100644 --- a/rocketmq-broker/src/broker_runtime.rs +++ b/rocketmq-broker/src/broker_runtime.rs @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -use std::cell::SyncUnsafeCell; +use std::any::Any; use std::collections::HashMap; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; @@ -28,6 +28,7 @@ use rocketmq_common::common::config_manager::ConfigManager; use rocketmq_common::common::constant::PermName; use rocketmq_common::common::server::config::ServerConfig; use rocketmq_common::common::statistics::state_getter::StateGetter; +use rocketmq_common::ArcCellWrapper; use rocketmq_common::TimeUtils::get_current_millis; use rocketmq_common::UtilAll::compute_next_morning_time_millis; use rocketmq_remoting::protocol::body::topic_info_wrapper::topic_config_wrapper::TopicConfigAndMappingSerializeWrapper; @@ -65,6 +66,7 @@ use crate::processor::client_manage_processor::ClientManageProcessor; use crate::processor::consumer_manage_processor::ConsumerManageProcessor; use crate::processor::default_pull_message_result_handler::DefaultPullMessageResultHandler; use crate::processor::pull_message_processor::PullMessageProcessor; +use crate::processor::pull_message_result_handler::PullMessageResultHandler; use crate::processor::send_message_processor::SendMessageProcessor; use crate::processor::BrokerRequestProcessor; use crate::schedule::schedule_message_service::ScheduleMessageService; @@ -374,8 +376,8 @@ impl BrokerRuntime { self.broker_config.clone(), self.message_store.as_ref().unwrap(), ); - let pull_message_result_handler = - Arc::new(SyncUnsafeCell::new(DefaultPullMessageResultHandler::new( + let mut pull_message_result_handler = + ArcCellWrapper::new(Box::new(DefaultPullMessageResultHandler::new( Arc::new(self.topic_config_manager.clone()), Arc::new(self.consumer_offset_manager.clone()), self.consumer_manager.clone(), @@ -383,7 +385,7 @@ impl BrokerRuntime { self.broker_stats_manager.clone(), self.broker_config.clone(), Arc::new(Default::default()), - ))); + )) as Box); let message_store = Arc::new(self.message_store.as_ref().unwrap().clone()); let pull_message_processor = PullMessageProcessor::new( pull_message_result_handler.clone(), @@ -414,11 +416,13 @@ impl BrokerRuntime { self.broker_config.clone(), )); - unsafe { - (*pull_message_result_handler.get()).set_pull_request_hold_service(Some(Arc::new( + let pull_message_result_handler = pull_message_result_handler.as_mut() as &mut dyn Any; + pull_message_result_handler + .downcast_mut::() + .unwrap() + .set_pull_request_hold_service(Some(Arc::new( self.pull_request_hold_service.clone().unwrap(), ))); - } self.message_store .as_mut() diff --git a/rocketmq-broker/src/offset/manager/consumer_offset_manager.rs b/rocketmq-broker/src/offset/manager/consumer_offset_manager.rs index 73b29eb72..aaeecc2f0 100644 --- a/rocketmq-broker/src/offset/manager/consumer_offset_manager.rs +++ b/rocketmq-broker/src/offset/manager/consumer_offset_manager.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use std::cell::SyncUnsafeCell; use std::collections::HashMap; use std::collections::HashSet; use std::fmt; @@ -27,6 +26,7 @@ use std::sync::Arc; use rocketmq_common::common::broker::broker_config::BrokerConfig; use rocketmq_common::common::config_manager::ConfigManager; use rocketmq_common::utils::serde_json_utils::SerdeJsonUtils; +use rocketmq_common::ArcCellWrapper; use rocketmq_remoting::protocol::DataVersion; use rocketmq_remoting::protocol::RemotingSerializable; use rocketmq_store::log_file::MessageStore; @@ -60,7 +60,7 @@ impl ConsumerOffsetManager { ConsumerOffsetManager { broker_config, consumer_offset_wrapper: ConsumerOffsetWrapper { - data_version: Arc::new(SyncUnsafeCell::new(DataVersion::default())), + data_version: ArcCellWrapper::new(DataVersion::default()), offset_table: Arc::new(parking_lot::RwLock::new(HashMap::new())), reset_offset_table: Arc::new(parking_lot::RwLock::new(HashMap::new())), pull_offset_table: Arc::new(parking_lot::RwLock::new(HashMap::new())), @@ -143,7 +143,9 @@ impl ConsumerOffsetManager { } else { 0 }; - unsafe { &mut *self.consumer_offset_wrapper.data_version.get() } + self.consumer_offset_wrapper + .data_version + .mut_from_ref() .next_version_with(state_machine_version); } } @@ -212,8 +214,8 @@ impl ConfigManager for ConsumerOffsetManager { .offset_table .write() .extend(wrapper.offset_table.read().clone()); - let data_version = unsafe { &mut *self.consumer_offset_wrapper.data_version.get() }; - *data_version = unsafe { &*wrapper.data_version.get() }.clone(); + let data_version = self.consumer_offset_wrapper.data_version.mut_from_ref(); + *data_version = wrapper.data_version.as_ref().clone(); } } } @@ -253,9 +255,9 @@ impl ConsumerOffsetManager { } } -#[derive(Debug, Default, Clone)] +#[derive(Default, Clone)] struct ConsumerOffsetWrapper { - data_version: Arc>, + data_version: ArcCellWrapper, offset_table: Arc>>>, reset_offset_table: Arc>>>, pull_offset_table: @@ -287,7 +289,7 @@ impl ConsumerOffsetWrapper { impl Serialize for ConsumerOffsetWrapper { fn serialize(&self, serializer: S) -> Result { let mut state = serializer.serialize_struct("ConsumerOffsetWrapper", 5)?; - state.serialize_field("dataVersion", unsafe { &*self.data_version.get() })?; + state.serialize_field("dataVersion", self.data_version.as_ref())?; state.serialize_field("offsetTable", &*self.offset_table.read())?; state.serialize_field("resetOffsetTable", &*self.reset_offset_table.read())?; state.serialize_field("pullOffsetTable", &*self.pull_offset_table.read())?; @@ -390,7 +392,7 @@ impl<'de> Deserialize<'de> for ConsumerOffsetWrapper { let pull_offset_table = pull_offset_table.unwrap_or_default(); Ok(ConsumerOffsetWrapper { - data_version: Arc::new(SyncUnsafeCell::new(data_version)), + data_version: ArcCellWrapper::new(data_version), offset_table: Arc::new(parking_lot::RwLock::new(offset_table)), reset_offset_table: Arc::new(parking_lot::RwLock::new(reset_offset_table)), pull_offset_table: Arc::new(parking_lot::RwLock::new(pull_offset_table)), diff --git a/rocketmq-broker/src/processor/admin_broker_processor/topic_request_handler.rs b/rocketmq-broker/src/processor/admin_broker_processor/topic_request_handler.rs index 6639b8b66..caffb37a3 100644 --- a/rocketmq-broker/src/processor/admin_broker_processor/topic_request_handler.rs +++ b/rocketmq-broker/src/processor/admin_broker_processor/topic_request_handler.rs @@ -163,7 +163,11 @@ impl TopicRequestHandler { .broker_runtime_inner() .register_increment_broker_data( vec![topic_config], - self.inner.topic_config_manager.data_version().clone(), + self.inner + .topic_config_manager + .data_version() + .as_ref() + .clone(), ) .await; } @@ -261,7 +265,11 @@ impl TopicRequestHandler { .broker_runtime_inner() .register_increment_broker_data( request_body.topic_config_list, - self.inner.topic_config_manager.data_version().clone(), + self.inner + .topic_config_manager + .data_version() + .as_ref() + .clone(), ) .await; } @@ -364,7 +372,13 @@ impl TopicRequestHandler { .lock() .clone(), ), - data_version: Some(self.inner.topic_config_manager.data_version().clone()), + data_version: Some( + self.inner + .topic_config_manager + .data_version() + .as_ref() + .clone(), + ), ..Default::default() }; let content = topic_config_and_mapping_serialize_wrapper.to_json(); diff --git a/rocketmq-broker/src/processor/pull_message_processor.rs b/rocketmq-broker/src/processor/pull_message_processor.rs index 7e4fe45f2..4ceebbd97 100644 --- a/rocketmq-broker/src/processor/pull_message_processor.rs +++ b/rocketmq-broker/src/processor/pull_message_processor.rs @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -use std::cell::SyncUnsafeCell; use std::sync::Arc; use rocketmq_common::common::broker::broker_config::BrokerConfig; @@ -22,6 +21,7 @@ use rocketmq_common::common::constant::PermName; use rocketmq_common::common::filter::expression_type::ExpressionType; use rocketmq_common::common::sys_flag::pull_sys_flag::PullSysFlag; use rocketmq_common::common::FAQUrl; +use rocketmq_common::ArcCellWrapper; use rocketmq_common::TimeUtils::get_current_millis; use rocketmq_remoting::code::request_code::RequestCode; use rocketmq_remoting::code::response_code::RemotingSysResponseCode; @@ -68,7 +68,7 @@ use crate::topic::manager::topic_queue_mapping_manager::TopicQueueMappingManager #[derive(Clone)] pub struct PullMessageProcessor { - pull_message_result_handler: Arc>, + pull_message_result_handler: ArcCellWrapper>, broker_config: Arc, subscription_group_manager: Arc>, topic_config_manager: Arc, @@ -84,7 +84,7 @@ pub struct PullMessageProcessor { impl PullMessageProcessor { pub fn new( - pull_message_result_handler: Arc>, + pull_message_result_handler: ArcCellWrapper>, broker_config: Arc, subscription_group_manager: Arc>, topic_config_manager: Arc, @@ -758,7 +758,7 @@ where } }; if let Some(get_message_result) = get_message_result { - return self.pull_message_result_handler().handle( + return self.pull_message_result_handler.handle( get_message_result, request, request_header, @@ -776,10 +776,6 @@ where None } - fn pull_message_result_handler(&self) -> &dyn PullMessageResultHandler { - unsafe { &*self.pull_message_result_handler.get() } - } - fn query_broadcast_pull_init_offset( &mut self, topic: &str, @@ -848,9 +844,8 @@ where let command = response.set_opaque(opaque).mark_response_type(); match ctx.upgrade() { None => {} - Some(ctx) => { - let ctx_ref = unsafe { &mut *ctx.get() }; - ctx_ref.write(command).await; + Some(mut ctx) => { + ctx.write(command).await; } } } diff --git a/rocketmq-broker/src/topic/manager/topic_config_manager.rs b/rocketmq-broker/src/topic/manager/topic_config_manager.rs index d4547217e..97c196546 100644 --- a/rocketmq-broker/src/topic/manager/topic_config_manager.rs +++ b/rocketmq-broker/src/topic/manager/topic_config_manager.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use std::cell::SyncUnsafeCell; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; @@ -29,6 +28,7 @@ use rocketmq_common::common::constant::PermName; use rocketmq_common::common::mix_all; use rocketmq_common::common::topic::TopicValidator; use rocketmq_common::utils::serde_json_utils::SerdeJsonUtils; +use rocketmq_common::ArcCellWrapper; use rocketmq_common::TopicAttributes::ALL; use rocketmq_remoting::protocol::body::topic_info_wrapper::topic_config_wrapper::TopicConfigAndMappingSerializeWrapper; use rocketmq_remoting::protocol::body::topic_info_wrapper::TopicConfigSerializeWrapper; @@ -45,7 +45,7 @@ use crate::broker_runtime::BrokerRuntimeInner; pub(crate) struct TopicConfigManager { topic_config_table: Arc>>, - data_version: Arc>, + data_version: ArcCellWrapper, broker_config: Arc, message_store: Option, topic_config_table_lock: Arc>, @@ -74,7 +74,7 @@ impl TopicConfigManager { ) -> Self { let mut manager = Self { topic_config_table: Arc::new(parking_lot::Mutex::new(HashMap::new())), - data_version: Arc::new(SyncUnsafeCell::new(DataVersion::default())), + data_version: ArcCellWrapper::new(DataVersion::default()), broker_config, message_store: None, topic_config_table_lock: Default::default(), @@ -231,7 +231,7 @@ impl TopicConfigManager { topic_queue_mapping_info_map: HashMap, ) -> TopicConfigAndMappingSerializeWrapper { if self.broker_config.enable_split_registration { - self.data_version_mut().next_version(); + self.data_version.mut_from_ref().next_version(); } TopicConfigAndMappingSerializeWrapper { topic_config_table: Some(topic_config_table), @@ -296,7 +296,7 @@ impl TopicConfigManager { default_topic, topic_config, remote_address ); self.put_topic_config(topic_config.clone()); - self.data_version_mut().next_version_with( + self.data_version.mut_from_ref().next_version_with( self.message_store .as_ref() .unwrap() @@ -357,7 +357,7 @@ impl TopicConfigManager { config.order = is_order; self.put_topic_config(config.clone()); - self.data_version_mut().next_version_with( + self.data_version.mut_from_ref().next_version_with( self.message_store .as_ref() .unwrap() @@ -410,7 +410,8 @@ impl TopicConfigManager { } else { 0 }; - self.data_version_mut() + self.data_version + .mut_from_ref() .next_version_with(state_machine_version); self.persist(); } else { @@ -441,7 +442,7 @@ impl TopicConfigManager { } } - self.data_version_mut().next_version_with( + self.data_version.mut_from_ref().next_version_with( self.message_store .as_ref() .unwrap() @@ -508,7 +509,7 @@ impl TopicConfigManager { config.topic_sys_flag = 0; info!("create new topic {:?}", config); self.put_topic_config(config.clone()); - self.data_version_mut().next_version_with( + self.data_version.mut_from_ref().next_version_with( self.message_store .as_ref() .unwrap() @@ -530,12 +531,8 @@ impl TopicConfigManager { self.topic_config_table.lock().contains_key(topic) } - pub fn data_version(&self) -> &DataVersion { - unsafe { &*self.data_version.get() } - } - - fn data_version_mut(&self) -> &mut DataVersion { - unsafe { &mut *self.data_version.get() } + pub fn data_version(&self) -> ArcCellWrapper { + self.data_version.clone() } #[inline] @@ -551,7 +548,7 @@ impl ConfigManager for TopicConfigManager { fn encode_pretty(&self, pretty_format: bool) -> String { let topic_config_table = self.topic_config_table.lock().clone(); - let version = self.data_version().clone(); + let version = self.data_version().as_ref().clone(); match pretty_format { true => TopicConfigSerializeWrapper::new(Some(topic_config_table), Some(version)) .to_json_pretty(), @@ -569,7 +566,7 @@ impl ConfigManager for TopicConfigManager { let wrapper = SerdeJsonUtils::from_json_str::(json_string) .expect("Decode TopicConfigSerializeWrapper from json failed"); if let Some(value) = wrapper.data_version() { - self.data_version_mut().assign_new_one(value); + self.data_version.mut_from_ref().assign_new_one(value); } if let Some(map) = wrapper.topic_config_table() { for (key, value) in map { diff --git a/rocketmq-common/src/lib.rs b/rocketmq-common/src/lib.rs index d791ac6b4..10632d404 100644 --- a/rocketmq-common/src/lib.rs +++ b/rocketmq-common/src/lib.rs @@ -19,10 +19,12 @@ #![allow(unused_imports)] #![feature(sync_unsafe_cell)] +use std::borrow::Borrow; use std::cell::SyncUnsafeCell; use std::ops::Deref; use std::ops::DerefMut; use std::sync::Arc; +use std::sync::Weak; pub use crate::common::attribute::topic_attributes as TopicAttributes; pub use crate::common::message::message_accessor as MessageAccessor; @@ -46,10 +48,44 @@ pub mod log; mod thread_pool; pub mod utils; +pub struct WeakCellWrapper { + inner: Weak>, +} + +impl Clone for WeakCellWrapper { + fn clone(&self) -> Self { + WeakCellWrapper { + inner: self.inner.clone(), + } + } +} + +impl WeakCellWrapper { + pub fn upgrade(&self) -> Option> { + self.inner + .upgrade() + .map(|value| ArcCellWrapper { inner: value }) + } +} + +#[derive(Default)] pub struct ArcCellWrapper { inner: Arc>, } +impl ArcCellWrapper { + #[allow(clippy::mut_from_ref)] + pub fn mut_from_ref(&self) -> &mut T { + unsafe { &mut *self.inner.get() } + } + + pub fn downgrade(this: &Self) -> WeakCellWrapper { + WeakCellWrapper { + inner: Arc::downgrade(&this.inner), + } + } +} + impl ArcCellWrapper { #[inline] pub fn new(value: T) -> Self { @@ -92,6 +128,51 @@ impl DerefMut for ArcCellWrapper { } } +pub struct SyncUnsafeCellWrapper { + inner: SyncUnsafeCell, +} + +impl SyncUnsafeCellWrapper { + #[inline] + pub fn new(value: T) -> Self { + Self { + inner: SyncUnsafeCell::new(value), + } + } +} + +impl SyncUnsafeCellWrapper { + #[allow(clippy::mut_from_ref)] + pub fn mut_from_ref(&self) -> &mut T { + unsafe { &mut *self.inner.get() } + } +} + +impl AsRef for SyncUnsafeCellWrapper { + fn as_ref(&self) -> &T { + unsafe { &*self.inner.get() } + } +} + +impl AsMut for SyncUnsafeCellWrapper { + fn as_mut(&mut self) -> &mut T { + &mut *self.inner.get_mut() + } +} + +impl Deref for SyncUnsafeCellWrapper { + type Target = T; + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl DerefMut for SyncUnsafeCellWrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut() + } +} + #[cfg(test)] mod arc_cell_wrapper_tests { use std::sync::Arc; diff --git a/rocketmq-remoting/src/protocol/command_custom_header.rs b/rocketmq-remoting/src/protocol/command_custom_header.rs index a9cb0f7d9..7e6fa9529 100644 --- a/rocketmq-remoting/src/protocol/command_custom_header.rs +++ b/rocketmq-remoting/src/protocol/command_custom_header.rs @@ -14,11 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +use std::any::Any; use std::collections::HashMap; use crate::rocketmq_serializable::RocketMQSerializable; -pub trait CommandCustomHeader { +pub trait CommandCustomHeader: Any { /// Checks the fields of the implementing type. /// /// Returns a `Result` indicating whether the fields are valid or not. diff --git a/rocketmq-remoting/src/protocol/remoting_command.rs b/rocketmq-remoting/src/protocol/remoting_command.rs index 1067cb022..041782d29 100644 --- a/rocketmq-remoting/src/protocol/remoting_command.rs +++ b/rocketmq-remoting/src/protocol/remoting_command.rs @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -use std::cell::SyncUnsafeCell; use std::collections::HashMap; use std::fmt; use std::sync::atomic::AtomicI32; @@ -28,6 +27,7 @@ use bytes::Bytes; use bytes::BytesMut; use lazy_static::lazy_static; use rocketmq_common::common::mq_version::RocketMqVersion; +use rocketmq_common::ArcCellWrapper; use serde::Deserialize; use serde::Serialize; use tracing::error; @@ -93,7 +93,7 @@ pub struct RemotingCommand { suspended: bool, #[serde(skip)] command_custom_header: - Option>>, + Option>>, #[serde(rename = "serializeTypeCurrentRPC")] serialize_type: SerializeType, } @@ -216,30 +216,14 @@ impl RemotingCommand { where T: CommandCustomHeader + Sync + Send + 'static, { - self.command_custom_header = Some(Arc::new(SyncUnsafeCell::new(command_custom_header))); - /*if let Some(cch) = &self.command_custom_header { - let option = cch.to_map(); - - match &mut self.ext_fields { - None => { - self.ext_fields = option; - } - Some(ext) => { - if let Some(val) = option { - for (key, value) in &val { - ext.insert(key.clone(), value.clone()); - } - } - } - } - }*/ + self.command_custom_header = Some(ArcCellWrapper::new(Box::new(command_custom_header))); self } pub fn set_command_custom_header_origin( mut self, command_custom_header: Option< - Arc>, + ArcCellWrapper>, >, ) -> Self { self.command_custom_header = command_custom_header; @@ -250,23 +234,7 @@ impl RemotingCommand { where T: CommandCustomHeader + Sync + Send + 'static, { - self.command_custom_header = Some(Arc::new(SyncUnsafeCell::new(command_custom_header))); - /*if let Some(cch) = &self.command_custom_header { - let option = cch.to_map(); - - match &mut self.ext_fields { - None => { - self.ext_fields = option; - } - Some(ext) => { - if let Some(val) = option { - for (key, value) in &val { - ext.insert(key.clone(), value.clone()); - } - } - } - } - }*/ + self.command_custom_header = Some(ArcCellWrapper::new(Box::new(command_custom_header))); } pub fn set_code(mut self, code: impl Into) -> Self { @@ -365,9 +333,7 @@ impl RemotingCommand { pub fn header_encode(&self) -> Option { self.command_custom_header.as_ref().and_then(|header| { - let header_ptr = header.get(); - let header_ref = unsafe { &*(header_ptr as *const dyn CommandCustomHeader) }; - header_ref + header .to_map() .as_ref() .map(|val| Bytes::from(serde_json::to_vec(val).unwrap())) @@ -376,9 +342,7 @@ impl RemotingCommand { pub fn make_custom_header_to_net(&mut self) { if let Some(header) = &self.command_custom_header { - let header_ptr = header.get(); - let header_ref = unsafe { &*(header_ptr as *const dyn CommandCustomHeader) }; - let option = header_ref.to_map(); + let option = header.to_map(); match &mut self.ext_fields { None => { @@ -443,22 +407,6 @@ impl RemotingCommand { .copy_from_slice(&serialize_type.to_be_bytes()); } } - - /*let header_length = header.as_ref().map_or(0, |h| h.len()) as i32; - let body_length = self.body.as_ref().map_or(0, |b| b.len()) as i32; - let total_length = 4 + header_length + body_length; - - dst.reserve((total_length + 4) as usize); - dst.put_i32(total_length); - let serialize_type = - RemotingCommand::mark_serialize_type(header_length, item.get_serialize_type()); - dst.put_i32(serialize_type); - - if let Some(header_inner) = header { - dst.put(header_inner); - } - - let st = serde_json::to_string(self).unwrap();*/ } pub fn get_body(&self) -> Option<&Bytes> { @@ -588,50 +536,40 @@ impl RemotingCommand { pub fn read_custom_header(&mut self) -> Option<&T> where - T: CommandCustomHeader + Sync + Send, + T: CommandCustomHeader + Sync + Send + 'static, { match self.command_custom_header.as_ref() { None => None, Some(value) => { - let value = value.get(); - let value = value as *const dyn CommandCustomHeader as *const T; - unsafe { Some(&*value) } + let ptr_raw = std::ptr::from_ref(&**value.as_ref()) as *const T; + unsafe { Some(&*ptr_raw) } } } } pub fn read_custom_header_mut(&mut self) -> Option<&mut T> where - T: CommandCustomHeader + Sync + Send, + T: CommandCustomHeader + Sync + Send + 'static, { - match self.command_custom_header.as_ref() { + match self.command_custom_header.as_mut() { None => None, Some(value) => { - let value = value.get(); - let value = value as *const dyn CommandCustomHeader as *mut T; - unsafe { Some(&mut *value) } + let ptr_raw = std::ptr::from_mut(&mut **value.as_mut()) as *mut T; + unsafe { Some(&mut *ptr_raw) } } } } pub fn command_custom_header_ref(&self) -> Option<&dyn CommandCustomHeader> { match self.command_custom_header.as_ref() { None => None, - Some(value) => { - let value = value.get(); - let value = value as *const dyn CommandCustomHeader; - unsafe { Some(&*value) } - } + Some(value) => Some(value.as_ref().as_ref()), } } pub fn command_custom_header_mut(&mut self) -> Option<&mut dyn CommandCustomHeader> { - match self.command_custom_header.as_ref() { + match self.command_custom_header.as_mut() { None => None, - Some(value) => { - let value = value.get(); - let value = value as *mut dyn CommandCustomHeader; - unsafe { Some(&mut *value) } - } + Some(value) => Some(value.as_mut().as_mut()), } } } diff --git a/rocketmq-remoting/src/rpc/rpc_client_impl.rs b/rocketmq-remoting/src/rpc/rpc_client_impl.rs index a6a1580ad..97c99176c 100644 --- a/rocketmq-remoting/src/rpc/rpc_client_impl.rs +++ b/rocketmq-remoting/src/rpc/rpc_client_impl.rs @@ -102,7 +102,8 @@ impl RpcClientImpl { .body() .as_ref() .map(|value| Box::new(value.clone()) as Box); - let rpc_response = RpcResponse::new(response.code(), response_header, body); + let rpc_response = + RpcResponse::new(response.code(), Box::new(response_header), body); Ok(rpc_response) } _ => Ok(RpcResponse::new_exception(Some(RpcException( @@ -138,7 +139,8 @@ impl RpcClientImpl { .body() .as_ref() .map(|value| Box::new(value.clone()) as Box); - let rpc_response = RpcResponse::new(response.code(), response_header, body); + let rpc_response = + RpcResponse::new(response.code(), Box::new(response_header), body); Ok(rpc_response) } _ => Ok(RpcResponse::new_exception(Some(RpcException( @@ -173,7 +175,8 @@ impl RpcClientImpl { .body() .as_ref() .map(|value| Box::new(value.clone()) as Box); - let rpc_response = RpcResponse::new(response.code(), response_header, body); + let rpc_response = + RpcResponse::new(response.code(), Box::new(response_header), body); Ok(rpc_response) } _ => Ok(RpcResponse::new_exception(Some(RpcException( @@ -208,7 +211,8 @@ impl RpcClientImpl { .body() .as_ref() .map(|value| Box::new(value.clone()) as Box); - let rpc_response = RpcResponse::new(response.code(), response_header, body); + let rpc_response = + RpcResponse::new(response.code(), Box::new(response_header), body); Ok(rpc_response) } _ => Ok(RpcResponse::new_exception(Some(RpcException( @@ -243,7 +247,8 @@ impl RpcClientImpl { .body() .as_ref() .map(|value| Box::new(value.clone()) as Box); - let rpc_response = RpcResponse::new(response.code(), response_header, body); + let rpc_response = + RpcResponse::new(response.code(), Box::new(response_header), body); Ok(rpc_response) } _ => Ok(RpcResponse::new_exception(Some(RpcException( @@ -278,7 +283,8 @@ impl RpcClientImpl { .body() .as_ref() .map(|value| Box::new(value.clone()) as Box); - let rpc_response = RpcResponse::new(response.code(), response_header, body); + let rpc_response = + RpcResponse::new(response.code(), Box::new(response_header), body); Ok(rpc_response) } ResponseCode::QueryNotFound => { @@ -317,7 +323,8 @@ impl RpcClientImpl { .body() .as_ref() .map(|value| Box::new(value.clone()) as Box); - let rpc_response = RpcResponse::new(response.code(), response_header, body); + let rpc_response = + RpcResponse::new(response.code(), Box::new(response_header), body); Ok(rpc_response) } _ => Ok(RpcResponse::new_exception(Some(RpcException( diff --git a/rocketmq-remoting/src/rpc/rpc_response.rs b/rocketmq-remoting/src/rpc/rpc_response.rs index 52c57954c..d8d002a7b 100644 --- a/rocketmq-remoting/src/rpc/rpc_response.rs +++ b/rocketmq-remoting/src/rpc/rpc_response.rs @@ -15,8 +15,8 @@ * limitations under the License. */ use std::any::Any; -use std::cell::SyncUnsafeCell; -use std::sync::Arc; + +use rocketmq_common::ArcCellWrapper; use crate::error::RpcException; use crate::protocol::command_custom_header::CommandCustomHeader; @@ -24,7 +24,7 @@ use crate::protocol::command_custom_header::CommandCustomHeader; #[derive(Default)] pub struct RpcResponse { pub code: i32, - pub header: Option>>, + pub header: Option>>, pub body: Option>, pub exception: Option, } @@ -32,28 +32,26 @@ pub struct RpcResponse { impl RpcResponse { pub fn get_header(&self) -> Option<&T> where - T: CommandCustomHeader + Send + Sync + 'static, + T: CommandCustomHeader + Any + Send + Sync + 'static, { match self.header.as_ref() { None => None, Some(value) => { - let value = value.get(); - let value = value as *const dyn CommandCustomHeader as *const T; - unsafe { Some(&*value) } + let ptr_raw = std::ptr::from_ref(&**value.as_ref()) as *const T; + unsafe { Some(&*ptr_raw) } } } } pub fn get_header_mut(&self) -> Option<&mut T> where - T: CommandCustomHeader + Send + Sync + 'static, + T: CommandCustomHeader + Any + Send + Sync + 'static, { match self.header.as_ref() { None => None, Some(value) => { - let value = value.get(); - let value = value as *const dyn CommandCustomHeader as *mut T; - unsafe { Some(&mut *value) } + let ptr_raw = std::ptr::from_mut(&mut **value.mut_from_ref()) as *mut T; + unsafe { Some(&mut *ptr_raw) } } } } @@ -69,12 +67,12 @@ impl RpcResponse { pub fn new( code: i32, - header: impl CommandCustomHeader + Send + Sync + 'static, + header: Box, body: Option>, ) -> Self { Self { code, - header: Some(Arc::new(SyncUnsafeCell::new(header))), + header: Some(ArcCellWrapper::new(header)), body, exception: None, } diff --git a/rocketmq-remoting/src/runtime/server.rs b/rocketmq-remoting/src/runtime/server.rs index fe784302b..4bd5ae242 100644 --- a/rocketmq-remoting/src/runtime/server.rs +++ b/rocketmq-remoting/src/runtime/server.rs @@ -15,15 +15,15 @@ * limitations under the License. */ -use std::cell::SyncUnsafeCell; use std::future::Future; use std::net::SocketAddr; use std::sync::Arc; -use std::sync::Weak; use std::time::Duration; use futures::SinkExt; use rocketmq_common::common::server::config::ServerConfig; +use rocketmq_common::ArcCellWrapper; +use rocketmq_common::WeakCellWrapper; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio::sync::broadcast; @@ -49,12 +49,11 @@ type Tx = mpsc::UnboundedSender; /// Shorthand for the receive half of the message channel. type Rx = mpsc::UnboundedReceiver; -pub type ConnectionHandlerContext = Weak>; +pub type ConnectionHandlerContext = WeakCellWrapper; pub struct ConnectionHandler { request_processor: RP, - //connection: Connection, - connection_handler_context: Arc>, + connection_handler_context: ArcCellWrapper, channel: Channel, shutdown: Shutdown, _shutdown_complete: mpsc::Sender<()>, @@ -77,9 +76,8 @@ impl Drop for ConnectionHandler { impl ConnectionHandler { async fn handle(&mut self) -> anyhow::Result<()> { while !self.shutdown.is_shutdown { - let connection_handler_context = unsafe { &mut *self.connection_handler_context.get() }; let frame = tokio::select! { - res = connection_handler_context.connection.framed.next() => res, + res = self.connection_handler_context.connection.framed.next() => res, _ = self.shutdown.recv() =>{ //If a shutdown signal is received, return from `handle`. return Ok(()); @@ -102,7 +100,7 @@ impl ConnectionHandler { //let ctx = ConnectionHandlerContext::new(&self.connection); let opaque = cmd.opaque(); let channel = self.channel.clone(); - let ctx = Arc::downgrade(&self.connection_handler_context); + let ctx = ArcCellWrapper::downgrade(&self.connection_handler_context); let response = tokio::select! { result = self.request_processor.process_request(channel,ctx,cmd) => result, }; @@ -111,7 +109,7 @@ impl ConnectionHandler { } let response = response.unwrap(); tokio::select! { - result =connection_handler_context.connection.framed.send(response.set_opaque(opaque)) => match result{ + result =self.connection_handler_context.connection.framed.send(response.set_opaque(opaque)) => match result{ Ok(_) =>{}, Err(err) => { error!("send response failed: {}", err); @@ -170,11 +168,9 @@ impl ConnectionListener { let mut handler = ConnectionHandler { request_processor: self.request_processor.clone(), //connection: Connection::new(socket, remote_addr), - connection_handler_context: Arc::new(SyncUnsafeCell::new( - ConnectionHandlerContextWrapper { - connection: Connection::new(socket), - }, - )), + connection_handler_context: ArcCellWrapper::new(ConnectionHandlerContextWrapper { + connection: Connection::new(socket), + }), channel, shutdown: Shutdown::new(self.notify_shutdown.subscribe()), _shutdown_complete: self.shutdown_complete_tx.clone(), diff --git a/rocketmq-store/src/log_file/mapped_file/default_impl.rs b/rocketmq-store/src/log_file/mapped_file/default_impl.rs index 7a6724ca4..5b7c2c0c0 100644 --- a/rocketmq-store/src/log_file/mapped_file/default_impl.rs +++ b/rocketmq-store/src/log_file/mapped_file/default_impl.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use std::cell::SyncUnsafeCell; use std::fs::File; use std::fs::OpenOptions; use std::io::Write; @@ -32,6 +31,7 @@ use bytes::BytesMut; use memmap2::MmapMut; use rocketmq_common::common::message::message_batch::MessageExtBatch; use rocketmq_common::common::message::message_single::MessageExtBrokerInner; +use rocketmq_common::SyncUnsafeCellWrapper; use rocketmq_common::UtilAll::ensure_dir_ok; use tracing::debug; use tracing::error; @@ -56,8 +56,7 @@ static TOTAL_MAPPED_FILES: AtomicI32 = AtomicI32::new(0); pub struct DefaultMappedFile { reference_resource: ReferenceResource, file: File, - // file_channel: FileChannel, - mmapped_file: SyncUnsafeCell, + mmapped_file: SyncUnsafeCellWrapper, transient_store_pool: Option, file_name: String, file_from_offset: u64, @@ -69,7 +68,6 @@ pub struct DefaultMappedFile { store_timestamp: AtomicI64, first_create_in_queue: bool, last_flush_time: u64, - // mapped_byte_buffer_wait_to_clean: Option, swap_map_time: u64, mapped_byte_buffer_access_count_since_last_swap: AtomicI64, start_timestamp: u64, @@ -111,7 +109,7 @@ impl DefaultMappedFile { first_shutdown_timestamp: AtomicI64::new(0), }, file, - mmapped_file: SyncUnsafeCell::new(mmap), + mmapped_file: SyncUnsafeCellWrapper::new(mmap), file_name, file_from_offset, mapped_byte_buffer: None, @@ -195,7 +193,7 @@ impl DefaultMappedFile { start_timestamp: 0, transient_store_pool: Some(transient_store_pool), stop_timestamp: 0, - mmapped_file: SyncUnsafeCell::new(mmap), + mmapped_file: SyncUnsafeCellWrapper::new(mmap), } } } @@ -578,29 +576,19 @@ impl MappedFile for DefaultMappedFile { todo!() } - /* fn init( - &mut self, - file_name: &str, - file_size: usize, - transient_store_pool: &TransientStorePool, - ) -> std::io::Result<()> { - todo!() - }*/ - fn is_loaded(&self, position: i64, size: usize) -> bool { true } } #[allow(unused_variables)] -#[allow(clippy::mut_from_ref)] impl DefaultMappedFile { pub fn get_mapped_file_mut(&self) -> &mut MmapMut { - unsafe { &mut *self.mmapped_file.get() } + self.mmapped_file.mut_from_ref() } pub fn get_mapped_file(&self) -> &MmapMut { - unsafe { &*self.mmapped_file.get() } + self.mmapped_file.as_ref() } fn is_able_to_flush(&self, flush_least_pages: i32) -> bool { diff --git a/rocketmq-store/src/queue.rs b/rocketmq-store/src/queue.rs index 7ec76c232..c1d55371a 100644 --- a/rocketmq-store/src/queue.rs +++ b/rocketmq-store/src/queue.rs @@ -35,7 +35,6 @@ pub mod local_file_consume_queue_store; mod queue_offset_operator; pub mod single_consume_queue; -//pub type ArcConsumeQueue = Arc>>; pub type ArcConsumeQueue = ArcCellWrapper>; pub type ConsumeQueueTable = parking_lot::Mutex>>;