diff --git a/core/payment/src/processor.rs b/core/payment/src/processor.rs index 91051763e4..e349e52675 100644 --- a/core/payment/src/processor.rs +++ b/core/payment/src/processor.rs @@ -4,12 +4,12 @@ use crate::error::processor::{ SchedulePaymentError, ValidateAllocationError, VerifyPaymentError, }; use crate::models::order::ReadObj as DbOrder; +use actix_web::rt::Arbiter; use bigdecimal::{BigDecimal, Zero}; -use futures::lock::Mutex; +use futures::FutureExt; use metrics::counter; use std::collections::hash_map::Entry; use std::collections::HashMap; -use std::sync::Arc; use ya_client_model::payment::{Account, ActivityPayment, AgreementPayment, Payment}; use ya_core_model::driver::{ self, driver_bus_id, AccountMode, PaymentConfirmation, PaymentDetails, ValidateAllocation, @@ -266,7 +266,7 @@ impl DriverRegistry { #[derive(Clone)] pub struct PaymentProcessor { db_executor: DbExecutor, - registry: Arc>, + registry: DriverRegistry, } impl PaymentProcessor { @@ -277,24 +277,30 @@ impl PaymentProcessor { } } - pub async fn register_driver(&self, msg: RegisterDriver) -> Result<(), RegisterDriverError> { - self.registry.lock().await.register_driver(msg) + pub async fn register_driver( + &mut self, + msg: RegisterDriver, + ) -> Result<(), RegisterDriverError> { + self.registry.register_driver(msg) } - pub async fn unregister_driver(&self, msg: UnregisterDriver) { - self.registry.lock().await.unregister_driver(msg) + pub async fn unregister_driver(&mut self, msg: UnregisterDriver) { + self.registry.unregister_driver(msg) } - pub async fn register_account(&self, msg: RegisterAccount) -> Result<(), RegisterAccountError> { - self.registry.lock().await.register_account(msg) + pub async fn register_account( + &mut self, + msg: RegisterAccount, + ) -> Result<(), RegisterAccountError> { + self.registry.register_account(msg) } - pub async fn unregister_account(&self, msg: UnregisterAccount) { - self.registry.lock().await.unregister_account(msg) + pub async fn unregister_account(&mut self, msg: UnregisterAccount) { + self.registry.unregister_account(msg) } pub async fn get_accounts(&self) -> Vec { - self.registry.lock().await.get_accounts() + self.registry.get_accounts() } pub async fn get_network( @@ -302,7 +308,7 @@ impl PaymentProcessor { driver: String, network: Option, ) -> Result<(String, Network), RegisterAccountError> { - self.registry.lock().await.get_network(driver, network) + self.registry.get_network(driver, network) } pub async fn get_platform( @@ -311,10 +317,7 @@ impl PaymentProcessor { network: Option, token: Option, ) -> Result { - self.registry - .lock() - .await - .get_platform(driver, network, token) + self.registry.get_platform(driver, network, token) } pub async fn notify_payment(&self, msg: NotifyPayment) -> Result<(), NotifyPaymentError> { @@ -395,12 +398,18 @@ impl PaymentProcessor { counter!("payment.amount.sent", ya_metrics::utils::cryptocurrency_to_u64(&msg.amount), "platform" => payment_platform); let msg = SendPayment::new(payment, signature); - ya_net::from(payer_id) - .to(payee_id) - .service(BUS_ID) - .call(msg) - .await??; + // Spawning to avoid deadlock in a case that payee is the same node as payer + Arbiter::spawn( + ya_net::from(payer_id) + .to(payee_id) + .service(BUS_ID) + .call(msg) + .map(|res| match res { + Ok(Ok(_)) => (), + err => log::error!("Error sending payment message to provider: {:?}", err), + }), + ); // TODO: Implement re-sending mechanism in case SendPayment fails counter!("payment.invoices.requestor.paid", 1); @@ -409,11 +418,9 @@ impl PaymentProcessor { pub async fn schedule_payment(&self, msg: SchedulePayment) -> Result<(), SchedulePaymentError> { let amount = msg.amount.clone(); - let driver = self.registry.lock().await.driver( - &msg.payment_platform, - &msg.payer_addr, - AccountMode::SEND, - )?; + let driver = + self.registry + .driver(&msg.payment_platform, &msg.payer_addr, AccountMode::SEND)?; let order_id = driver_endpoint(&driver) .send(driver::SchedulePayment::new( amount, @@ -439,7 +446,7 @@ impl PaymentProcessor { ) -> Result<(), VerifyPaymentError> { // TODO: Split this into smaller functions let platform = payment.payment_platform.clone(); - let driver = self.registry.lock().await.driver( + let driver = self.registry.driver( &payment.payment_platform, &payment.payee_addr, AccountMode::RECV, @@ -536,11 +543,9 @@ impl PaymentProcessor { platform: String, address: String, ) -> Result { - let driver = - self.registry - .lock() - .await - .driver(&platform, &address, AccountMode::empty())?; + let driver = self + .registry + .driver(&platform, &address, AccountMode::empty())?; let amount = driver_endpoint(&driver) .send(driver::GetAccountBalance::new(address, platform)) .await??; @@ -558,11 +563,9 @@ impl PaymentProcessor { .as_dao::() .get_for_address(platform.clone(), address.clone()) .await?; - let driver = - self.registry - .lock() - .await - .driver(&platform, &address, AccountMode::empty())?; + let driver = self + .registry + .driver(&platform, &address, AccountMode::empty())?; let msg = ValidateAllocation { address, platform, diff --git a/core/payment/src/service.rs b/core/payment/src/service.rs index c61136321f..debb38eafc 100644 --- a/core/payment/src/service.rs +++ b/core/payment/src/service.rs @@ -1,6 +1,8 @@ use crate::processor::PaymentProcessor; +use futures::lock::Mutex; use futures::prelude::*; use metrics::counter; +use std::sync::Arc; use ya_core_model as core; use ya_persistence::executor::DbExecutor; use ya_service_bus::typed::ServiceBinder; @@ -8,6 +10,7 @@ use ya_service_bus::typed::ServiceBinder; pub fn bind_service(db: &DbExecutor, processor: PaymentProcessor) { log::debug!("Binding payment service to service bus"); + let processor = Arc::new(Mutex::new(processor)); local::bind_service(db, processor.clone()); public::bind_service(db, processor); @@ -22,7 +25,7 @@ mod local { use ya_core_model::payment::local::*; use ya_persistence::types::Role; - pub fn bind_service(db: &DbExecutor, processor: PaymentProcessor) { + pub fn bind_service(db: &DbExecutor, processor: Arc>) { log::debug!("Binding payment local service to service bus"); ServiceBinder::new(BUS_ID, db, processor) @@ -68,74 +71,74 @@ mod local { async fn schedule_payment( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, sender: String, msg: SchedulePayment, ) -> Result<(), GenericError> { - processor.schedule_payment(msg).await?; + processor.lock().await.schedule_payment(msg).await?; Ok(()) } async fn register_driver( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, sender: String, msg: RegisterDriver, ) -> Result<(), RegisterDriverError> { - processor.register_driver(msg).await + processor.lock().await.register_driver(msg).await } async fn unregister_driver( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, sender: String, msg: UnregisterDriver, ) -> Result<(), NoError> { - processor.unregister_driver(msg).await; + processor.lock().await.unregister_driver(msg).await; Ok(()) } async fn register_account( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, sender: String, msg: RegisterAccount, ) -> Result<(), RegisterAccountError> { - processor.register_account(msg).await + processor.lock().await.register_account(msg).await } async fn unregister_account( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, sender: String, msg: UnregisterAccount, ) -> Result<(), NoError> { - processor.unregister_account(msg).await; + processor.lock().await.unregister_account(msg).await; Ok(()) } async fn get_accounts( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, sender: String, msg: GetAccounts, ) -> Result, GenericError> { - Ok(processor.get_accounts().await) + Ok(processor.lock().await.get_accounts().await) } async fn notify_payment( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, sender: String, msg: NotifyPayment, ) -> Result<(), GenericError> { - processor.notify_payment(msg).await?; + processor.lock().await.notify_payment(msg).await?; Ok(()) } async fn get_status( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, _caller: String, msg: GetStatus, ) -> Result { @@ -148,6 +151,8 @@ mod local { } = msg; let (network, network_details) = processor + .lock() + .await .get_network(driver.clone(), network) .await .map_err(GenericError::new)?; @@ -183,9 +188,14 @@ mod local { } .map_err(GenericError::new); - let amount_fut = processor - .get_status(platform.clone(), address.clone()) - .map_err(GenericError::new); + let amount_fut = async { + processor + .lock() + .await + .get_status(platform.clone(), address.clone()) + .await + } + .map_err(GenericError::new); let (incoming, outgoing, amount, reserved) = future::try_join4(incoming_fut, outgoing_fut, amount_fut, reserved_fut).await?; @@ -203,7 +213,7 @@ mod local { async fn get_invoice_stats( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, _caller: String, msg: GetInvoiceStats, ) -> Result { @@ -255,11 +265,13 @@ mod local { async fn validate_allocation( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, sender: String, msg: ValidateAllocation, ) -> Result { Ok(processor + .lock() + .await .validate_allocation(msg.platform, msg.address, msg.amount) .await?) } @@ -277,7 +289,7 @@ mod public { use ya_core_model::payment::public::*; use ya_persistence::types::Role; - pub fn bind_service(db: &DbExecutor, processor: PaymentProcessor) { + pub fn bind_service(db: &DbExecutor, processor: Arc>) { log::debug!("Binding payment public service to service bus"); ServiceBinder::new(BUS_ID, db, processor) @@ -591,7 +603,7 @@ mod public { async fn send_payment( db: DbExecutor, - processor: PaymentProcessor, + processor: Arc>, sender_id: String, msg: SendPayment, ) -> Result { @@ -604,7 +616,12 @@ mod public { let platform = payment.payment_platform.clone(); let amount = payment.amount.clone(); let num_paid_invoices = payment.agreement_payments.len() as u64; - match processor.verify_payment(payment, signature).await { + match processor + .lock() + .await + .verify_payment(payment, signature) + .await + { Ok(_) => { counter!("payment.amount.received", ya_metrics::utils::cryptocurrency_to_u64(&amount), "platform" => platform); counter!("payment.invoices.provider.paid", num_paid_invoices);