diff --git a/rocketmq-broker/src/out_api/broker_outer_api.rs b/rocketmq-broker/src/out_api/broker_outer_api.rs index c226ee35..aab94edb 100644 --- a/rocketmq-broker/src/out_api/broker_outer_api.rs +++ b/rocketmq-broker/src/out_api/broker_outer_api.rs @@ -36,7 +36,7 @@ use rocketmq_remoting::{ pub struct BrokerOuterAPI { remoting_client: RocketmqDefaultClient, name_server_address: Option, - broker_outer_executor: TokioExecutorService, + broker_outer_executor: Option, } impl BrokerOuterAPI { @@ -66,13 +66,14 @@ impl BrokerOuterAPI { } impl BrokerOuterAPI { - pub fn update_name_server_address_list(&mut self, addrs: String) { + pub async fn update_name_server_address_list(&mut self, addrs: String) { let addr_vec = addrs .split("';'") .map(|s| s.to_string()) .collect::>(); self.remoting_client .update_name_server_address_list(addr_vec) + .await } pub async fn register_broker_all( diff --git a/rocketmq-remoting/src/clients.rs b/rocketmq-remoting/src/clients.rs index de5f0958..ca7a4e87 100644 --- a/rocketmq-remoting/src/clients.rs +++ b/rocketmq-remoting/src/clients.rs @@ -81,7 +81,7 @@ impl RemoteClient { #[allow(async_fn_in_trait)] pub trait RemotingClient: RemotingService { - fn update_name_server_address_list(&mut self, addrs: Vec); + async fn update_name_server_address_list(&mut self, addrs: Vec); fn get_name_server_address_list(&self) -> Vec; diff --git a/rocketmq-remoting/src/clients/rocketmq_default_impl.rs b/rocketmq-remoting/src/clients/rocketmq_default_impl.rs index 2b3e2a68..ea659b9a 100644 --- a/rocketmq-remoting/src/clients/rocketmq_default_impl.rs +++ b/rocketmq-remoting/src/clients/rocketmq_default_impl.rs @@ -33,11 +33,10 @@ pub struct RocketmqDefaultClient { service_bridge: ServiceBridge, tokio_client_config: TokioClientConfig, //cache connection - connection_tables: HashMap, - connection_tables_lock: std::sync::RwLock<()>, - lock: tokio::sync::RwLock<()>, - namesrv_addr_list: Arc>>, - namesrv_addr_choosed: Arc>>, + connection_tables: + tokio::sync::Mutex>>>, + namesrv_addr_list: Arc>>, + namesrv_addr_choosed: Arc>>, } impl RocketmqDefaultClient { @@ -46,8 +45,6 @@ impl RocketmqDefaultClient { service_bridge: ServiceBridge::new(), tokio_client_config, connection_tables: Default::default(), - connection_tables_lock: Default::default(), - lock: Default::default(), namesrv_addr_list: Arc::new(Default::default()), namesrv_addr_choosed: Arc::new(Default::default()), } @@ -55,33 +52,16 @@ impl RocketmqDefaultClient { } impl RocketmqDefaultClient { - async fn get_and_create_client(&mut self, addr: String) -> &Client { - let lc = self.lock.write().await; - if self.connection_tables.contains_key(&addr) { - return self.connection_tables.get(&addr).unwrap(); + async fn get_and_create_client(&mut self, addr: String) -> Arc> { + let mut mutex_guard = self.connection_tables.lock().await; + if mutex_guard.contains_key(&addr) { + return mutex_guard.get(&addr).unwrap().clone(); } let addr_inner = addr.clone(); let client = Client::connect(addr_inner).await.unwrap(); - - self.connection_tables.insert(addr.clone(), client); - drop(lc); - self.connection_tables.get(&addr).unwrap() - } - - async fn get_and_create_client_mut(&mut self, addr: String) -> Option<&mut Client> { - let lc = self.lock.write().await; - - if self.connection_tables.contains_key(&addr) { - return self.connection_tables.get_mut(&addr); - } - - let addr_inner = addr.clone(); - let client = Client::connect(addr_inner).await.unwrap(); - - self.connection_tables.insert(addr.clone(), client); - drop(lc); - self.connection_tables.get_mut(&addr) + mutex_guard.insert(addr.clone(), Arc::new(tokio::sync::Mutex::new(client))); + mutex_guard.get(&addr).unwrap().clone() } } @@ -106,8 +86,8 @@ impl RemotingService for RocketmqDefaultClient { #[allow(unused_variables)] impl RemotingClient for RocketmqDefaultClient { - fn update_name_server_address_list(&mut self, addrs: Vec) { - let mut old = self.namesrv_addr_list.lock().unwrap(); + async fn update_name_server_address_list(&mut self, addrs: Vec) { + let mut old = self.namesrv_addr_list.lock().await; let mut update = false; if !addrs.is_empty() { @@ -134,19 +114,18 @@ impl RemotingClient for RocketmqDefaultClient { old.clone_from(&addrs); // should close the channel if choosed addr is not exist. - if let Some(namesrv_addr) = self.namesrv_addr_choosed.lock().unwrap().as_ref() { + if let Some(namesrv_addr) = self.namesrv_addr_choosed.lock().await.as_ref() { if !addrs.contains(namesrv_addr) { - let write_guard = self.connection_tables_lock.write().unwrap(); let mut remove_vec = Vec::new(); - for (addr, client) in self.connection_tables.iter() { + let mut result = self.connection_tables.lock().await; + for (addr, client) in result.iter() { if addr.contains(namesrv_addr) { remove_vec.push(addr.clone()); } } for addr in &remove_vec { - self.connection_tables.remove(addr); + result.remove(addr); } - drop(write_guard); } } } @@ -167,12 +146,9 @@ impl RemotingClient for RocketmqDefaultClient { request: RemotingCommand, timeout_millis: u64, ) -> RemotingCommand { - let client = self - .get_and_create_client_mut(addr.clone()) - .await - .take() - .unwrap(); - ServiceBridge::invoke_sync(client, request, timeout_millis) + let client = self.get_and_create_client(addr.clone()).await; + let client_ref = &mut *client.lock().await; + ServiceBridge::invoke_sync(client_ref, request, timeout_millis) .await .unwrap() } @@ -184,12 +160,9 @@ impl RemotingClient for RocketmqDefaultClient { timeout_millis: u64, invoke_callback: impl InvokeCallback, ) -> Result<(), Box> { - let client = self - .get_and_create_client_mut(addr.clone()) - .await - .take() - .unwrap(); - ServiceBridge::invoke_async(client, request, timeout_millis, invoke_callback).await; + let client = self.get_and_create_client(addr.clone()).await; + let client_ref = &mut *client.lock().await; + ServiceBridge::invoke_async(client_ref, request, timeout_millis, invoke_callback).await; Ok(()) }