diff --git a/components/brave_vpn/browser/brave_vpn_service_unittest.cc b/components/brave_vpn/browser/brave_vpn_service_unittest.cc index 3bcdb26dd0a2..36cc8fd9d32e 100644 --- a/components/brave_vpn/browser/brave_vpn_service_unittest.cc +++ b/components/brave_vpn/browser/brave_vpn_service_unittest.cc @@ -506,7 +506,7 @@ class BraveVPNServiceTest : public testing::Test { PurchasedState state) { observer->ResetStates(); SetPurchasedState(env, state); - base::RunLoop().RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_TRUE(observer->GetPurchasedState().has_value()); EXPECT_EQ(observer->GetPurchasedState().value(), state); } @@ -799,7 +799,7 @@ TEST_F(BraveVPNServiceTest, LoadPurchasedStateForAnotherEnvFailed) { EXPECT_FALSE(observer.GetPurchasedState().has_value()); LoadPurchasedState(development); - base::RunLoop().RunUntilIdle(); + task_environment_.RunUntilIdle(); // Successfully set purchased state for dev env. EXPECT_TRUE(observer.GetPurchasedState().has_value()); EXPECT_EQ(observer.GetPurchasedState().value(), PurchasedState::PURCHASED); @@ -813,7 +813,7 @@ TEST_F(BraveVPNServiceTest, LoadPurchasedStateForAnotherEnvFailed) { EXPECT_FALSE(observer.GetPurchasedState().has_value()); // no order found for staging. LoadPurchasedState(staging); - base::RunLoop().RunUntilIdle(); + task_environment_.RunUntilIdle(); // The purchased state was not changed from dev env. EXPECT_FALSE(observer.GetPurchasedState().has_value()); EXPECT_EQ(GetCurrentEnvironment(), skus::GetDefaultEnvironment()); @@ -826,7 +826,7 @@ TEST_F(BraveVPNServiceTest, LoadPurchasedStateForAnotherEnvFailed) { // No region data for staging. SetInterceptorResponse(""); LoadPurchasedState(staging); - base::RunLoop().RunUntilIdle(); + task_environment_.RunUntilIdle(); // The purchased state was not changed from dev env. EXPECT_FALSE(observer.GetPurchasedState().has_value()); EXPECT_EQ(GetCurrentEnvironment(), skus::GetDefaultEnvironment()); @@ -983,7 +983,7 @@ TEST_F(BraveVPNServiceTest, LoadPurchasedStateNotifications) { EXPECT_TRUE(observer.GetPurchasedState().has_value()); EXPECT_EQ(PurchasedState::LOADING, observer.GetPurchasedState().value()); } - base::RunLoop().RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_TRUE(observer.GetPurchasedState().has_value()); EXPECT_EQ(PurchasedState::NOT_PURCHASED, GetPurchasedInfoSync()); // Observer called when state will be changed. @@ -1000,7 +1000,7 @@ TEST_F(BraveVPNServiceTest, LoadPurchasedStateForAnotherEnv) { EXPECT_EQ(PurchasedState::NOT_PURCHASED, GetPurchasedInfoSync()); EXPECT_EQ(GetCurrentEnvironment(), skus::GetDefaultEnvironment()); LoadPurchasedState(development); - base::RunLoop().RunUntilIdle(); + task_environment_.RunUntilIdle(); // Successfully set purchased state for dev env. EXPECT_TRUE(observer.GetPurchasedState().has_value()); EXPECT_EQ(observer.GetPurchasedState().value(), PurchasedState::PURCHASED); @@ -1011,7 +1011,7 @@ TEST_F(BraveVPNServiceTest, LoadPurchasedStateForAnotherEnv) { EXPECT_EQ(skus::GetEnvironmentForDomain(staging), skus::kEnvStaging); EXPECT_EQ(GetCurrentEnvironment(), skus::GetDefaultEnvironment()); LoadPurchasedState(staging); - base::RunLoop().RunUntilIdle(); + task_environment_.RunUntilIdle(); // Successfully changed purchased state for dev env. EXPECT_TRUE(observer.GetPurchasedState().has_value()); EXPECT_EQ(observer.GetPurchasedState().value(), PurchasedState::PURCHASED); diff --git a/components/skus/browser/rs/cxx/src/httpclient.rs b/components/skus/browser/rs/cxx/src/httpclient.rs index 48829484cfb7..52922ab723f0 100644 --- a/components/skus/browser/rs/cxx/src/httpclient.rs +++ b/components/skus/browser/rs/cxx/src/httpclient.rs @@ -85,7 +85,9 @@ impl From> for Result>, InternalErr ) })?; - response.headers_mut().ok_or(InternalError::BorrowFailed)?.insert(key, value); + if let Some(headers) = response.headers_mut() { + headers.insert(key, value); + } } response @@ -105,9 +107,17 @@ impl NativeClient { ) -> Result>, InternalError> { let (tx, rx) = oneshot::channel(); let context = Box::new(HttpRoundtripContext { tx, client: self.clone() }); + let ctx = self + .inner + .lock().await + .ctx + .clone(); let fetcher = ffi::shim_executeRequest( - &self.ctx.try_borrow().map_err(|_| InternalError::BorrowFailed)?.ctx, + &*ctx + .try_borrow() + .map_err(|_| InternalError::BorrowFailed)? + , &req, |context, resp| { let _ = context.tx.send(resp.into()); diff --git a/components/skus/browser/rs/cxx/src/lib.rs b/components/skus/browser/rs/cxx/src/lib.rs index 346d67df9766..e3957e890550 100644 --- a/components/skus/browser/rs/cxx/src/lib.rs +++ b/components/skus/browser/rs/cxx/src/lib.rs @@ -11,29 +11,39 @@ mod storage; use std::cell::RefCell; use std::fmt; use std::rc::Rc; +use std::thread; use cxx::{type_id, ExternType, UniquePtr}; use futures::executor::{LocalPool, LocalSpawner}; use futures::task::LocalSpawnExt; +use futures::lock::Mutex; use tracing::debug; pub use skus; use crate::httpclient::{HttpRoundtripContext, WakeupContext}; +use crate::storage::{StorageGetContext, StoragePurgeContext, StorageSetContext}; use errors::result_to_string; -pub struct NativeClientContext { +pub struct NativeClientExecutor { + is_shutdown: bool, + pool: Option, + spawner: LocalSpawner, + thread_id: thread::ThreadId, +} + +#[derive(Clone)] +pub struct NativeClientInner { environment: skus::Environment, - ctx: UniquePtr, + executor: Rc>, + ctx: Rc>>, } #[derive(Clone)] pub struct NativeClient { - is_shutdown: Rc>, - pool: Rc>, - spawner: LocalSpawner, - ctx: Rc>, + executor: Rc>, + inner: Rc>, } impl fmt::Debug for NativeClient { @@ -44,11 +54,44 @@ impl fmt::Debug for NativeClient { impl NativeClient { fn try_run_until_stalled(&self) { - if *self.is_shutdown.borrow() { + let executor = self.executor.clone(); + if let Ok(mut executor) = executor.try_borrow_mut() { + executor.try_run_until_stalled() + }; + } + + fn get_spawner(&self) -> LocalSpawner { + self.executor.borrow().spawner.clone() + } +} + +impl NativeClientExecutor { + fn new() -> Self { + let pool = LocalPool::new(); + let spawner = pool.spawner(); + Self { + is_shutdown: false, + pool: Some(pool), + spawner, + thread_id: thread::current().id(), + } + } + + fn shutdown(&mut self) { + // drop any existing futures + drop(self.pool.take()); + // ensure lingering callbacks passed to c++ are short circuited + self.is_shutdown = true; + } + + fn try_run_until_stalled(&mut self) { + assert!(thread::current().id() == self.thread_id, "sdk called on a different thread!"); + let _ = thread::current().id() == self.thread_id; + if self.is_shutdown { debug!("sdk is shutdown, exiting"); return; } - if let Ok(mut pool) = self.pool.try_borrow_mut() { + if let Some(pool) = &mut self.pool { pool.run_until_stalled(); } } @@ -128,6 +171,9 @@ mod ffi { extern "Rust" { type HttpRoundtripContext; type WakeupContext; + type StoragePurgeContext; + type StorageSetContext; + type StorageGetContext; type CppSDK; fn initialize_sdk(ctx: UniquePtr, env: String) -> Box; @@ -195,9 +241,26 @@ mod ffi { ctx: Box, ); - fn shim_purge(ctx: Pin<&mut SkusContext>); - fn shim_set(ctx: Pin<&mut SkusContext>, key: &str, value: &str); - fn shim_get(ctx: Pin<&mut SkusContext>, key: &str) -> String; + fn shim_purge( + ctx: Pin<&mut SkusContext>, + done: fn(Box, bool), + st_ctx: Box, + ); + + fn shim_set( + ctx: Pin<&mut SkusContext>, + key: &str, + value: &str, + done: fn(Box, bool), + st_ctx: Box, + ); + + fn shim_get( + ctx: Pin<&mut SkusContext>, + key: &str, + done: fn(Box, String, bool), + st_ctx: Box, + ); type RefreshOrderCallbackState; type RefreshOrderCallback = crate::RefreshOrderCallback; @@ -230,14 +293,15 @@ fn initialize_sdk(ctx: UniquePtr, env: String) -> Box let env = env.parse::().unwrap_or(skus::Environment::Local); - let pool = LocalPool::new(); - let spawner = pool.spawner(); + let executor = Rc::new(RefCell::new(NativeClientExecutor::new())); let sdk = skus::sdk::SDK::new( NativeClient { - is_shutdown: Rc::new(RefCell::new(false)), - pool: Rc::new(RefCell::new(pool)), - spawner: spawner.clone(), - ctx: Rc::new(RefCell::new(NativeClientContext { environment: env.clone(), ctx })), + executor: executor.clone(), + inner: Rc::new(Mutex::new(NativeClientInner { + environment: env.clone(), + executor, + ctx: Rc::new(RefCell::new(ctx)), + })), }, env, None, @@ -246,23 +310,21 @@ fn initialize_sdk(ctx: UniquePtr, env: String) -> Box let sdk = Rc::new(sdk); { let sdk = sdk.clone(); + let spawner = sdk.client.get_spawner(); let init = async move { sdk.initialize().await }; if spawner.spawn_local(init).is_err() { debug!("pool is shutdown"); } } - sdk.client.pool.borrow_mut().run_until_stalled(); + sdk.client.try_run_until_stalled(); Box::new(CppSDK { sdk }) } impl CppSDK { fn shutdown(&self) { - // drop any existing futures - drop(self.sdk.client.pool.take()); - // ensure lingering callbacks passed to c++ are short circuited - *self.sdk.client.is_shutdown.borrow_mut() = true; + self.sdk.client.executor.borrow_mut().shutdown(); } fn refresh_order( @@ -271,7 +333,7 @@ impl CppSDK { callback_state: UniquePtr, order_id: String, ) { - let spawner = self.sdk.client.spawner.clone(); + let spawner = self.sdk.client.get_spawner(); if spawner .spawn_local(refresh_order_task(self.sdk.clone(), callback, callback_state, order_id)) .is_err() @@ -288,7 +350,7 @@ impl CppSDK { callback_state: UniquePtr, order_id: String, ) { - let spawner = self.sdk.client.spawner.clone(); + let spawner = self.sdk.client.get_spawner(); if spawner .spawn_local(fetch_order_credentials_task( self.sdk.clone(), @@ -311,7 +373,7 @@ impl CppSDK { domain: String, path: String, ) { - let spawner = self.sdk.client.spawner.clone(); + let spawner = self.sdk.client.get_spawner(); if spawner .spawn_local(prepare_credentials_presentation_task( self.sdk.clone(), @@ -334,7 +396,7 @@ impl CppSDK { callback_state: UniquePtr, domain: String, ) { - let spawner = self.sdk.client.spawner.clone(); + let spawner = self.sdk.client.get_spawner(); if spawner .spawn_local(credential_summary_task( self.sdk.clone(), @@ -357,7 +419,7 @@ impl CppSDK { order_id: String, receipt: String, ) { - let spawner = self.sdk.client.spawner.clone(); + let spawner = self.sdk.client.get_spawner(); if spawner .spawn_local(submit_receipt_task( self.sdk.clone(), @@ -379,7 +441,7 @@ impl CppSDK { callback_state: UniquePtr, receipt: String, ) { - let spawner = self.sdk.client.spawner.clone(); + let spawner = self.sdk.client.get_spawner(); if spawner .spawn_local(create_order_from_receipt_task( self.sdk.clone(), @@ -563,11 +625,7 @@ async fn create_order_from_receipt_task( callback_state: UniquePtr, receipt: String, ) { - match sdk - .create_order_from_receipt(&receipt) - .await - .map_err(|e| e.into()) - { + match sdk.create_order_from_receipt(&receipt).await.map_err(|e| e.into()) { Ok(order_id) => callback.0(callback_state.into_raw(), ffi::SkusResult::Ok, &order_id), Err(e) => callback.0(callback_state.into_raw(), e, ""), } diff --git a/components/skus/browser/rs/cxx/src/shim.h b/components/skus/browser/rs/cxx/src/shim.h index b4f5760497a0..cae428bd4abd 100644 --- a/components/skus/browser/rs/cxx/src/shim.h +++ b/components/skus/browser/rs/cxx/src/shim.h @@ -24,6 +24,9 @@ struct HttpRequest; struct HttpResponse; struct HttpRoundtripContext; struct WakeupContext; +struct StoragePurgeContext; +struct StorageSetContext; +struct StorageGetContext; class FetchOrderCredentialsCallbackState { public: @@ -82,9 +85,23 @@ class SkusContext { public: virtual ~SkusContext() = default; virtual std::unique_ptr CreateFetcher() const = 0; - virtual std::string GetValueFromStore(std::string key) const = 0; - virtual void PurgeStore() const = 0; - virtual void UpdateStoreValue(std::string key, std::string value) const = 0; + virtual void GetValueFromStore( + const std::string& key, + rust::cxxbridge1::Fn, + rust::String value, + bool success)> done, + rust::cxxbridge1::Box st_ctx) const = 0; + virtual void PurgeStore( + rust::cxxbridge1::Fn< + void(rust::cxxbridge1::Box, bool success)> + done, + rust::cxxbridge1::Box st_ctx) const = 0; + virtual void UpdateStoreValue( + const std::string& key, + const std::string& value, + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx) const = 0; }; using RefreshOrderCallback = void (*)(RefreshOrderCallbackState* callback_state, @@ -114,12 +131,25 @@ void shim_logMessage(rust::cxxbridge1::Str file, TracingLevel level, rust::cxxbridge1::Str message); -void shim_purge(skus::SkusContext& ctx); // NOLINT -void shim_set(skus::SkusContext& ctx, // NOLINT - rust::cxxbridge1::Str key, - rust::cxxbridge1::Str value); -::rust::String shim_get(skus::SkusContext& ctx, // NOLINT - rust::cxxbridge1::Str key); +void shim_purge( + skus::SkusContext& ctx, // NOLINT + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx); +void shim_set( + skus::SkusContext& ctx, // NOLINT + rust::cxxbridge1::Str key, + rust::cxxbridge1::Str value, + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx); +void shim_get( + skus::SkusContext& ctx, // NOLINT + rust::cxxbridge1::Str key, + rust::cxxbridge1::Fn, + rust::String value, + bool success)> done, + rust::cxxbridge1::Box st_ctx); void shim_scheduleWakeup( ::std::uint64_t delay_ms, diff --git a/components/skus/browser/rs/cxx/src/storage.rs b/components/skus/browser/rs/cxx/src/storage.rs index 75e8ced8684b..2254aa96a646 100644 --- a/components/skus/browser/rs/cxx/src/storage.rs +++ b/components/skus/browser/rs/cxx/src/storage.rs @@ -1,31 +1,128 @@ -use std::cell::RefMut; +// Copyright (c) 2021 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. -use crate::{ffi, NativeClient, NativeClientContext}; -use skus::{errors, Environment, KVClient, KVStore}; +use async_trait::async_trait; +use futures::channel::oneshot; +use futures::lock::MutexGuard; +use tracing::debug; +use crate::{ffi, NativeClient, NativeClientInner}; +use skus::{errors::InternalError, Environment, KVClient, KVStore}; + +pub struct StoragePurgeContext { + tx: oneshot::Sender>, + client: NativeClientInner, +} + +pub struct StorageSetContext { + tx: oneshot::Sender>, + client: NativeClientInner, +} + +pub struct StorageGetContext { + tx: oneshot::Sender, InternalError>>, + client: NativeClientInner, +} + +#[async_trait(?Send)] impl KVClient for NativeClient { - type Store = NativeClientContext; + type Store = NativeClientInner; + type StoreRef<'a> = MutexGuard<'a, NativeClientInner>; - #[allow(clippy::needless_lifetimes)] - fn get_store<'a>(&'a self) -> Result, errors::InternalError> { - self.ctx.try_borrow_mut().or(Err(errors::InternalError::BorrowFailed)) + async fn get_store<'a>(&'a self) -> Result, InternalError> { + Ok(self.inner.lock().await) } } -impl KVStore for NativeClientContext { +#[async_trait(?Send)] +impl KVStore for NativeClientInner { fn env(&self) -> &Environment { &self.environment } - fn purge(&mut self) -> Result<(), errors::InternalError> { - ffi::shim_purge(self.ctx.pin_mut()); - Ok(()) + async fn purge(&mut self) -> Result<(), InternalError> { + let (tx, rx) = oneshot::channel(); + let context = Box::new(StoragePurgeContext { tx, client: self.clone() }); + + ffi::shim_purge( + self.ctx.try_borrow_mut().map_err(|_| InternalError::BorrowFailed)?.pin_mut(), + |context, success| { + let _ = + context.tx.send(success.then_some(()).ok_or( + InternalError::StorageWriteFailed("prefs write failed".to_string()), + )); + + if let Ok(mut executor) = context.client.executor.try_borrow_mut() { + executor.try_run_until_stalled() + } + }, + context, + ); + match rx.await { + Ok(ret) => ret, + Err(_) => { + debug!("purge response channel was cancelled"); + Err(InternalError::FutureCancelled) + } + } } - fn set(&mut self, key: &str, value: &str) -> Result<(), errors::InternalError> { - ffi::shim_set(self.ctx.pin_mut(), key, value); - Ok(()) + + async fn set(&mut self, key: &str, value: &str) -> Result<(), InternalError> { + let (tx, rx) = oneshot::channel(); + let context = Box::new(StorageSetContext { tx, client: self.clone() }); + + ffi::shim_set( + self.ctx.try_borrow_mut().map_err(|_| InternalError::BorrowFailed)?.pin_mut(), + key, + value, + |context, success| { + let _ = + context.tx.send(success.then_some(()).ok_or( + InternalError::StorageWriteFailed("prefs write failed".to_string()), + )); + + if let Ok(mut executor) = context.client.executor.try_borrow_mut() { + executor.try_run_until_stalled() + } + }, + context, + ); + match rx.await { + Ok(ret) => ret, + Err(_) => { + debug!("purge response channel was cancelled"); + Err(InternalError::FutureCancelled) + } + } } - fn get(&mut self, key: &str) -> Result, errors::InternalError> { - let ret = ffi::shim_get(self.ctx.pin_mut(), key); - Ok(if !ret.is_empty() { Some(ret) } else { None }) + + async fn get(&mut self, key: &str) -> Result, InternalError> { + let (tx, rx) = oneshot::channel(); + let context = Box::new(StorageGetContext { tx, client: self.clone() }); + + ffi::shim_get( + self.ctx.try_borrow_mut().map_err(|_| InternalError::BorrowFailed)?.pin_mut(), + key, + |context, resp, success| { + let _ = context.tx.send( + success + .then_some(if !resp.is_empty() { Some(resp) } else { None }) + .ok_or(InternalError::StorageReadFailed("prefs read failed".to_string())), + ); + + if let Ok(mut executor) = context.client.executor.try_borrow_mut() { + executor.try_run_until_stalled() + } + }, + context, + ); + match rx.await { + Ok(ret) => ret, + Err(_) => { + debug!("purge response channel was cancelled"); + Err(InternalError::FutureCancelled) + } + } } } diff --git a/components/skus/browser/rs/lib/src/storage/kv.rs b/components/skus/browser/rs/lib/src/storage/kv.rs index 3d303ade2be2..e31506529359 100644 --- a/components/skus/browser/rs/lib/src/storage/kv.rs +++ b/components/skus/browser/rs/lib/src/storage/kv.rs @@ -1,4 +1,9 @@ -use std::cell::RefMut; +// Copyright (c) 2021 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::ops::DerefMut; use std::collections::HashMap; use std::fmt::Debug; @@ -35,52 +40,53 @@ struct KVState { pub credentials: Option, } -trait KVStoreHelpers { - fn get_state(&mut self) -> Result; - fn set_state(&mut self, state: &KVState) -> Result<(), InternalError>; +#[async_trait(?Send)] +trait KVStoreHelpers { + async fn get_state(&mut self) -> Result; + async fn set_state(&mut self, state: &KVState) -> Result<(), InternalError>; } +#[async_trait(?Send)] pub trait KVClient { - type Store; + type Store: KVStore; + type StoreRef<'a>: DerefMut where Self: 'a; - #[allow(clippy::needless_lifetimes)] - fn get_store<'a>(&'a self) -> Result, InternalError> - where - Self::Store: KVStore; + async fn get_store<'a>(&'a self) -> Result, InternalError>; } +#[async_trait(?Send)] pub trait KVStore: Sized { fn env(&self) -> &Environment; - fn purge(&mut self) -> Result<(), InternalError>; - fn set(&mut self, key: &str, value: &str) -> Result<(), InternalError>; - fn get(&mut self, key: &str) -> Result, InternalError>; + async fn purge(&mut self) -> Result<(), InternalError>; + async fn set(&mut self, key: &str, value: &str) -> Result<(), InternalError>; + async fn get(&mut self, key: &str) -> Result, InternalError>; } fn key_from_environment(env: &Environment) -> String { format!("skus:{}", env) } +#[async_trait(?Send)] impl KVStoreHelpers for C where C: KVStore, { - fn get_state(&mut self) -> Result { + async fn get_state(&mut self) -> Result { let key = key_from_environment(self.env()); - if let Ok(Some(state)) = self.get("rewards:local") { + if let Ok(Some(state)) = self.get("rewards:local").await { // Perform a one time migration, clearing any old values - self.purge()?; + self.purge().await?; // and setting a new key with the prior value - self.set(&key, &state)?; + self.set(&key, &state).await?; } - let state = self.get(&key)?.unwrap_or_else(|| "{}".to_string()); + let state = self.get(&key).await?.unwrap_or_else(|| "{}".to_string()); Ok(serde_json::from_str(&state)?) } - fn set_state(&mut self, state: &KVState) -> Result<(), InternalError> { + async fn set_state(&mut self, state: &KVState) -> Result<(), InternalError> { let key = key_from_environment(self.env()); event!(Level::DEBUG, "set state"); - event!(Level::TRACE, state = %format!("{:#?}", state), "set state",); - self.set(&key, &serde_json::to_string(state)?) + self.set(&key, &serde_json::to_string(state)?).await } } @@ -88,40 +94,39 @@ where impl StorageClient for C where C: KVClient + Debug, - ::Store: KVStore, { #[instrument] async fn clear(&self) -> Result<(), InternalError> { - let mut store = self.get_store()?; - store.purge() + let mut store = self.get_store().await?; + store.purge().await } #[instrument] async fn insert_wallet(&self, wallet: &Wallet) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if state.wallet.is_none() { state.wallet = Some(wallet.clone()); } - store.set_state(&state) + store.set_state(&state).await } #[instrument] async fn replace_promotions(&self, promotions: &[Promotion]) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; state.promotions = Some(promotions.to_vec()); - store.set_state(&state) + store.set_state(&state).await } #[instrument] async fn get_orders(&self) -> Result>, InternalError> { - let mut store = self.get_store()?; - let state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let state: KVState = store.get_state().await?; let orders = state.orders.map(|os| os.into_values().collect()); event!(Level::DEBUG, orders = ?orders, "got orders"); Ok(orders) @@ -129,8 +134,8 @@ where #[instrument] async fn get_order(&self, order_id: &str) -> Result, InternalError> { - let mut store = self.get_store()?; - let state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let state: KVState = store.get_state().await?; let order = state.orders.and_then(|mut orders| orders.remove(order_id)); event!(Level::DEBUG, order = ?order, "got order"); Ok(order) @@ -139,8 +144,8 @@ where #[instrument] async fn has_credentials(&self, order_id: &str) -> Result { let mut result: bool = false; - let mut store = self.get_store()?; - let state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let state: KVState = store.get_state().await?; event!(Level::DEBUG, has_credentials = ?!state.credentials.is_none(), "does order have credentials"); if let Some(credentials) = state.credentials { @@ -160,8 +165,8 @@ where #[instrument] async fn upsert_order(&self, order: &Order) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if state.orders.is_none() { state.orders = Some(HashMap::new()); @@ -171,14 +176,14 @@ where orders.insert(order.id.clone(), order.clone()); } - store.set_state(&state) + store.set_state(&state).await } #[instrument] #[cfg(feature = "e2e_test")] async fn delete_n_item_creds(&self, item_id: &str, n: usize) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if let Some(mut credentials) = state.credentials { // remove old creds @@ -216,20 +221,20 @@ where } state.credentials = Some(credentials); } - store.set_state(&state) + store.set_state(&state).await } #[instrument] async fn delete_item_creds(&self, item_id: &str) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if let Some(mut credentials) = state.credentials { credentials.items.remove(item_id); state.credentials = Some(credentials); } - store.set_state(&state) + store.set_state(&state).await } #[instrument] @@ -237,8 +242,8 @@ where &self, item_id: &str, ) -> Result, InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if state.credentials.is_none() { state.credentials = Some(CredentialsState::default()); @@ -269,8 +274,8 @@ where &self, item_id: &str, ) -> Result, InternalError> { - let mut store = self.get_store()?; - let state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let state: KVState = store.get_state().await?; let credentials = state.credentials.and_then(|mut credentials| { if let Some(Credentials::SingleUse(credentials)) = credentials.items.remove(item_id) { Some(credentials) @@ -289,8 +294,8 @@ where request_id: &str, creds: Vec, ) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if state.credentials.is_none() { state.credentials = Some(CredentialsState::default()); @@ -328,7 +333,7 @@ where } } - store.set_state(&state) + store.set_state(&state).await } #[instrument] @@ -337,8 +342,8 @@ where item_id: &str, request_id: &str, ) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if state.credentials.is_none() { state.credentials = Some(CredentialsState::default()); @@ -358,7 +363,7 @@ where } } - store.set_state(&state) + store.set_state(&state).await } #[instrument] @@ -367,8 +372,8 @@ where item_id: &str, creds: Vec, ) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if state.credentials.is_none() { state.credentials = Some(CredentialsState { items: HashMap::new() }); @@ -388,7 +393,7 @@ where } } - store.set_state(&state) + store.set_state(&state).await } #[instrument] @@ -402,8 +407,8 @@ where ) -> Result<(), InternalError> { // each time this function is run, we are going to append a credential - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if let Some(credentials) = state.credentials.as_mut() { if let Some(item_credentials) = credentials.items.get_mut(item_id) { @@ -447,7 +452,7 @@ where )); } } - return store.set_state(&state); + return store.set_state(&state).await; } } Err(InternalError::StorageWriteFailed("Item credentials were not initiated".to_string())) @@ -460,8 +465,8 @@ where issuer_id: &str, unblinded_creds: Vec, ) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if let Some(credentials) = state.credentials.as_mut() { if let Some(item_credentials) = credentials.items.get_mut(item_id) { @@ -484,7 +489,7 @@ where )); } } - return store.set_state(&state); + return store.set_state(&state).await; } } Err(InternalError::StorageWriteFailed("Item credentials were not initiated".to_string())) @@ -496,8 +501,8 @@ where item_id: &str, index: usize, ) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if let Some(credentials) = state.credentials.as_mut() { if let Some(item_credentials) = credentials.items.get_mut(item_id) { @@ -514,7 +519,7 @@ where tlv2_cred.unblinded_creds.as_mut() { unblinded_creds[index].spent = true; - return store.set_state(&state); + return store.set_state(&state).await; } } } @@ -537,8 +542,8 @@ where item_id: &str, index: usize, ) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if let Some(credentials) = state.credentials.as_mut() { if let Some(item_credentials) = credentials.items.get_mut(item_id) { @@ -547,7 +552,7 @@ where if let Some(unblinded_creds) = item_credentials.unblinded_creds.as_mut() { unblinded_creds[index].spent = true; - return store.set_state(&state); + return store.set_state(&state).await; } } _ => { @@ -566,8 +571,8 @@ where &self, item_id: &str, ) -> Result, InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if state.credentials.is_none() { state.credentials = Some(CredentialsState { items: HashMap::new() }); @@ -599,8 +604,8 @@ where item_id: &str, creds: Vec, ) -> Result<(), InternalError> { - let mut store = self.get_store()?; - let mut state: KVState = store.get_state()?; + let mut store = self.get_store().await?; + let mut state: KVState = store.get_state().await?; if state.credentials.is_none() { state.credentials = Some(CredentialsState { items: HashMap::new() }); @@ -617,6 +622,6 @@ where credentials.items.insert(item_id.to_string(), creds); } - store.set_state(&state) + store.set_state(&state).await } } diff --git a/components/skus/browser/skus_context_impl.cc b/components/skus/browser/skus_context_impl.cc index ebd51fb44076..53ea38a199d2 100644 --- a/components/skus/browser/skus_context_impl.cc +++ b/components/skus/browser/skus_context_impl.cc @@ -10,11 +10,8 @@ #include "base/logging.h" #include "base/task/sequenced_task_runner.h" -#include "brave/components/skus/browser/pref_names.h" #include "brave/components/skus/browser/rs/cxx/src/lib.rs.h" #include "brave/components/skus/browser/skus_url_loader_impl.h" -#include "components/prefs/pref_service.h" -#include "components/prefs/scoped_user_pref_update.h" namespace { @@ -91,20 +88,35 @@ void shim_logMessage(rust::cxxbridge1::Str file, } } -void shim_purge(skus::SkusContext& ctx) { // NOLINT - ctx.PurgeStore(); +void shim_purge( + skus::SkusContext& ctx, // NOLINT + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx) { + ctx.PurgeStore(std::move(done), std::move(st_ctx)); } -void shim_set(skus::SkusContext& ctx, // NOLINT - rust::cxxbridge1::Str key, - rust::cxxbridge1::Str value) { +void shim_set( + skus::SkusContext& ctx, // NOLINT + rust::cxxbridge1::Str key, + rust::cxxbridge1::Str value, + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx) { ctx.UpdateStoreValue(static_cast(key), - static_cast(value)); + static_cast(value), std::move(done), + std::move(st_ctx)); } -::rust::String shim_get(skus::SkusContext& ctx, // NOLINT - rust::cxxbridge1::Str key) { - return ::rust::String(ctx.GetValueFromStore(static_cast(key))); +void shim_get( + skus::SkusContext& ctx, // NOLINT + rust::cxxbridge1::Str key, + rust::cxxbridge1::Fn, + rust::String value, + bool success)> done, + rust::cxxbridge1::Box st_ctx) { + ctx.GetValueFromStore(static_cast(key), std::move(done), + std::move(st_ctx)); } void shim_scheduleWakeup( @@ -112,8 +124,6 @@ void shim_scheduleWakeup( rust::cxxbridge1::Fn)> done, rust::cxxbridge1::Box ctx) { int buffer_ms = 10; - VLOG(1) << "shim_scheduleWakeup " << (delay_ms + buffer_ms) << " (" - << delay_ms << "ms plus " << buffer_ms << "ms buffer)"; base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask( FROM_HERE, base::BindOnce(&OnScheduleWakeup, std::move(done), std::move(ctx)), @@ -132,37 +142,57 @@ std::unique_ptr shim_executeRequest( } SkusContextImpl::SkusContextImpl( - PrefService* prefs, - scoped_refptr url_loader_factory) - : prefs_(*prefs), url_loader_factory_(url_loader_factory) {} - + std::unique_ptr + pending_url_loader_factory, + scoped_refptr ui_task_runner, + base::WeakPtr skus_service) + : pending_url_loader_factory_(std::move(pending_url_loader_factory)), + ui_task_runner_(ui_task_runner), + skus_service_(skus_service) {} SkusContextImpl::~SkusContextImpl() = default; std::unique_ptr SkusContextImpl::CreateFetcher() const { - return std::make_unique(url_loader_factory_); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + auto url_loader_factory = network::SharedURLLoaderFactory::Create( + std::move(pending_url_loader_factory_)); + pending_url_loader_factory_ = url_loader_factory->Clone(); + return std::make_unique(url_loader_factory); } -std::string SkusContextImpl::GetValueFromStore(std::string key) const { - VLOG(1) << "shim_get: `" << key << "`"; - const auto& state = prefs_->GetDict(prefs::kSkusState); - const base::Value* value = state.Find(key); - if (value) { - return value->GetString(); - } - return ""; +void SkusContextImpl::GetValueFromStore( + const std::string& key, + rust::cxxbridge1::Fn, + rust::String value, + bool success)> done, + rust::cxxbridge1::Box st_ctx) const { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + ui_task_runner_->PostTask( + FROM_HERE, + base::BindOnce(&SkusServiceImpl::GetValueFromStore, skus_service_, key, + std::move(done), std::move(st_ctx))); } -void SkusContextImpl::PurgeStore() const { - VLOG(1) << "shim_purge"; - ScopedDictPrefUpdate state(&*prefs_, prefs::kSkusState); - state->clear(); +void SkusContextImpl::PurgeStore( + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx) const { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + ui_task_runner_->PostTask( + FROM_HERE, base::BindOnce(&SkusServiceImpl::PurgeStore, skus_service_, + std::move(done), std::move(st_ctx))); } -void SkusContextImpl::UpdateStoreValue(std::string key, - std::string value) const { - VLOG(1) << "shim_set: `" << key << "` = `" << value << "`"; - ScopedDictPrefUpdate state(&*prefs_, prefs::kSkusState); - state->Set(key, value); +void SkusContextImpl::UpdateStoreValue( + const std::string& key, + const std::string& value, + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx) const { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + ui_task_runner_->PostTask( + FROM_HERE, + base::BindOnce(&SkusServiceImpl::UpdateStoreValue, skus_service_, key, + value, std::move(done), std::move(st_ctx))); } } // namespace skus diff --git a/components/skus/browser/skus_context_impl.h b/components/skus/browser/skus_context_impl.h index 6450d43283cf..2954b5a4a1a5 100644 --- a/components/skus/browser/skus_context_impl.h +++ b/components/skus/browser/skus_context_impl.h @@ -12,12 +12,8 @@ #include "base/memory/raw_ref.h" #include "base/memory/scoped_refptr.h" #include "brave/components/skus/browser/rs/cxx/src/shim.h" - -class PrefService; - -namespace network { -class SharedURLLoaderFactory; -} // namespace network +#include "brave/components/skus/browser/skus_service_impl.h" +#include "services/network/public/cpp/shared_url_loader_factory.h" namespace skus { class SkusUrlLoader; @@ -38,21 +34,38 @@ class SkusContextImpl : public SkusContext { SkusContextImpl& operator=(const SkusContextImpl&) = delete; explicit SkusContextImpl( - PrefService* prefs, - scoped_refptr url_loader_factory); + std::unique_ptr + pending_url_loader_factory, + scoped_refptr ui_task_runner, + base::WeakPtr); ~SkusContextImpl() override; std::unique_ptr CreateFetcher() const override; - std::string GetValueFromStore(std::string key) const override; - void PurgeStore() const override; - void UpdateStoreValue(std::string key, std::string value) const override; + void GetValueFromStore( + const std::string& key, + rust::cxxbridge1::Fn, + rust::String value, + bool success)> done, + rust::cxxbridge1::Box st_ctx) const override; + void PurgeStore( + rust::cxxbridge1::Fn< + void(rust::cxxbridge1::Box, bool success)> + done, + rust::cxxbridge1::Box st_ctx) const override; + void UpdateStoreValue( + const std::string& key, + const std::string& value, + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx) const override; private: - // used to store the credential - const raw_ref prefs_; - + SEQUENCE_CHECKER(sequence_checker_); // used for making requests to SKU server - scoped_refptr url_loader_factory_; + mutable std::unique_ptr + pending_url_loader_factory_ GUARDED_BY_CONTEXT(sequence_checker_); + scoped_refptr ui_task_runner_; + base::WeakPtr skus_service_; }; } // namespace skus diff --git a/components/skus/browser/skus_service_impl.cc b/components/skus/browser/skus_service_impl.cc index 9e0915b44f52..896051213256 100644 --- a/components/skus/browser/skus_service_impl.cc +++ b/components/skus/browser/skus_service_impl.cc @@ -8,12 +8,12 @@ #include #include -#include "base/json/json_reader.h" #include "brave/components/skus/browser/pref_names.h" #include "brave/components/skus/browser/rs/cxx/src/lib.rs.h" #include "brave/components/skus/browser/skus_context_impl.h" #include "brave/components/skus/browser/skus_utils.h" #include "components/prefs/pref_service.h" +#include "components/prefs/scoped_user_pref_update.h" #include "services/network/public/cpp/shared_url_loader_factory.h" namespace { @@ -21,9 +21,8 @@ namespace { void OnRefreshOrder(skus::RefreshOrderCallbackState* callback_state, skus::SkusResult result, rust::cxxbridge1::Str order) { - std::string order_str = static_cast(order); if (callback_state->cb) { - std::move(callback_state->cb).Run(order_str); + std::move(callback_state->cb).Run(static_cast(order)); } delete callback_state; } @@ -87,11 +86,27 @@ namespace skus { SkusServiceImpl::SkusServiceImpl( PrefService* prefs, scoped_refptr url_loader_factory) - : prefs_(prefs), url_loader_factory_(url_loader_factory) {} + : prefs_(prefs), url_loader_factory_(url_loader_factory) { + sdk_task_runner_ = base::ThreadPool::CreateSingleThreadTaskRunner( + {base::TaskPriority::USER_BLOCKING}); + ui_task_runner_ = base::SequencedTaskRunner::GetCurrentDefault(); +} SkusServiceImpl::~SkusServiceImpl() = default; -void SkusServiceImpl::Shutdown() {} +void SkusServiceImpl::Shutdown() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + // Disconnect remotes. + receivers_.ClearWithReason(0, "Shutting down"); + + for (auto it = sdks_.begin(); it != sdks_.end();) { + // CppSDK must be destroyed on the sdk task runner. + sdk_task_runner_->PostTask( + FROM_HERE, base::BindOnce([](::rust::Box sdk) {}, + std::move(sdks_.extract(it++).mapped()))); + } +} mojo::PendingRemote SkusServiceImpl::MakeRemote() { mojo::PendingRemote remote; @@ -107,69 +122,110 @@ void SkusServiceImpl::RefreshOrder( const std::string& domain, const std::string& order_id, mojom::SkusService::RefreshOrderCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); std::unique_ptr cbs( new skus::RefreshOrderCallbackState); - cbs->cb = std::move(callback); - GetOrCreateSDK(domain)->refresh_order(OnRefreshOrder, std::move(cbs), - order_id); + + cbs->cb = base::BindOnce( + [](mojom::SkusService::RefreshOrderCallback cb, + scoped_refptr ui_task_runner, + const std::string& result) { + ui_task_runner->PostTask(FROM_HERE, + base::BindOnce(std::move(cb), result)); + }, + std::move(callback), ui_task_runner_); + + PostTaskWithSDK(domain, + base::BindOnce( + [](std::unique_ptr cbs, + const std::string& order_id, skus::CppSDK* sdk) { + sdk->refresh_order(OnRefreshOrder, std::move(cbs), + order_id); + }, + std::move(cbs), order_id)); } void SkusServiceImpl::FetchOrderCredentials( const std::string& domain, const std::string& order_id, mojom::SkusService::FetchOrderCredentialsCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); std::unique_ptr cbs( new skus::FetchOrderCredentialsCallbackState); - cbs->cb = std::move(callback); - GetOrCreateSDK(domain)->fetch_order_credentials(OnFetchOrderCredentials, - std::move(cbs), order_id); + cbs->cb = base::BindOnce( + [](mojom::SkusService::FetchOrderCredentialsCallback cb, + scoped_refptr ui_task_runner, + const std::string& result) { + ui_task_runner->PostTask(FROM_HERE, + base::BindOnce(std::move(cb), result)); + }, + std::move(callback), ui_task_runner_); + + PostTaskWithSDK( + domain, + base::BindOnce( + [](std::unique_ptr cbs, + const std::string& order_id, skus::CppSDK* sdk) { + sdk->fetch_order_credentials(OnFetchOrderCredentials, + std::move(cbs), order_id); + }, + std::move(cbs), order_id)); } void SkusServiceImpl::PrepareCredentialsPresentation( const std::string& domain, const std::string& path, mojom::SkusService::PrepareCredentialsPresentationCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); std::unique_ptr cbs( new skus::PrepareCredentialsPresentationCallbackState); - cbs->cb = std::move(callback); - GetOrCreateSDK(domain)->prepare_credentials_presentation( - OnPrepareCredentialsPresentation, std::move(cbs), domain, path); -} -::rust::Box& SkusServiceImpl::GetOrCreateSDK( - const std::string& domain) { - auto env = GetEnvironmentForDomain(domain); - if (sdk_.count(env)) { - return sdk_.at(env); - } + cbs->cb = base::BindOnce( + [](mojom::SkusService::PrepareCredentialsPresentationCallback cb, + scoped_refptr ui_task_runner, + const std::string& result) { + ui_task_runner->PostTask(FROM_HERE, + base::BindOnce(std::move(cb), result)); + }, + std::move(callback), ui_task_runner_); - auto sdk = initialize_sdk( - std::make_unique(prefs_, url_loader_factory_), - env); - sdk_.insert_or_assign(env, std::move(sdk)); - return sdk_.at(env); + PostTaskWithSDK( + domain, + base::BindOnce( + [](std::unique_ptr + cbs, + const std::string& domain, const std::string& path, + skus::CppSDK* sdk) { + sdk->prepare_credentials_presentation( + OnPrepareCredentialsPresentation, std::move(cbs), domain, path); + }, + std::move(cbs), domain, path)); } void SkusServiceImpl::CredentialSummary( const std::string& domain, mojom::SkusService::CredentialSummaryCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); std::unique_ptr cbs( new skus::CredentialSummaryCallbackState); - cbs->cb = - base::BindOnce(&SkusServiceImpl::OnCredentialSummary, - weak_factory_.GetWeakPtr(), domain, std::move(callback)); - GetOrCreateSDK(domain)->credential_summary(::OnCredentialSummary, - std::move(cbs), domain); -} + cbs->cb = base::BindOnce( + [](mojom::SkusService::CredentialSummaryCallback cb, + scoped_refptr ui_task_runner, + const std::string& result) { + ui_task_runner->PostTask(FROM_HERE, + base::BindOnce(std::move(cb), result)); + }, + std::move(callback), ui_task_runner_); -void SkusServiceImpl::OnCredentialSummary( - const std::string& domain, - mojom::SkusService::CredentialSummaryCallback callback, - const std::string& summary_string) { - if (callback) { - std::move(callback).Run(summary_string); - } + PostTaskWithSDK( + domain, base::BindOnce( + [](std::unique_ptr cbs, + const std::string& domain, skus::CppSDK* sdk) { + sdk->credential_summary(OnCredentialSummary, std::move(cbs), + domain); + }, + std::move(cbs), domain)); } void SkusServiceImpl::SubmitReceipt( @@ -177,22 +233,164 @@ void SkusServiceImpl::SubmitReceipt( const std::string& order_id, const std::string& receipt, skus::mojom::SkusService::SubmitReceiptCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); std::unique_ptr cbs( new skus::SubmitReceiptCallbackState); - cbs->cb = std::move(callback); - GetOrCreateSDK(domain)->submit_receipt(OnSubmitReceipt, std::move(cbs), - order_id, receipt); + + cbs->cb = base::BindOnce( + [](skus::mojom::SkusService::SubmitReceiptCallback cb, + scoped_refptr ui_task_runner, + const std::string& result) { + ui_task_runner->PostTask(FROM_HERE, + base::BindOnce(std::move(cb), result)); + }, + std::move(callback), ui_task_runner_); + + PostTaskWithSDK( + domain, base::BindOnce( + [](std::unique_ptr cbs, + const std::string& order_id, const std::string& receipt, + skus::CppSDK* sdk) { + sdk->submit_receipt(OnSubmitReceipt, std::move(cbs), + order_id, receipt); + }, + std::move(cbs), order_id, receipt)); } void SkusServiceImpl::CreateOrderFromReceipt( const std::string& domain, const std::string& receipt, skus::mojom::SkusService::CreateOrderFromReceiptCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); std::unique_ptr cbs( new skus::CreateOrderFromReceiptCallbackState); - cbs->cb = std::move(callback); - GetOrCreateSDK(domain)->create_order_from_receipt(::OnCreateOrderFromReceipt, - std::move(cbs), receipt); + + cbs->cb = base::BindOnce( + [](skus::mojom::SkusService::CreateOrderFromReceiptCallback cb, + scoped_refptr ui_task_runner, + const std::string& result) { + ui_task_runner->PostTask(FROM_HERE, + base::BindOnce(std::move(cb), result)); + }, + std::move(callback), ui_task_runner_); + + PostTaskWithSDK( + domain, + base::BindOnce( + [](std::unique_ptr cbs, + const std::string& receipt, skus::CppSDK* sdk) { + sdk->create_order_from_receipt(OnCreateOrderFromReceipt, + std::move(cbs), receipt); + }, + std::move(cbs), receipt)); +} + +void SkusServiceImpl::PurgeStore( + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + ScopedDictPrefUpdate state(&*prefs_, prefs::kSkusState); + state->clear(); + + sdk_task_runner_->PostTask( + FROM_HERE, + base::BindOnce( + [](rust::cxxbridge1::Fn, bool)> done, + rust::cxxbridge1::Box ctx) { + done(std::move(ctx), true); + }, + std::move(done), std::move(st_ctx))); +} + +void SkusServiceImpl::GetValueFromStore( + const std::string& key, + rust::cxxbridge1::Fn, + rust::String, + bool)> done, + rust::cxxbridge1::Box ctx) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + const auto& state = prefs_->GetDict(prefs::kSkusState); + const base::Value* value = state.Find(key); + std::string result; + if (value) { + result = value->GetString(); + } + + sdk_task_runner_->PostTask( + FROM_HERE, base::BindOnce( + [](rust::cxxbridge1::Fn, + rust::String, bool)> done, + rust::cxxbridge1::Box ctx, + std::string value) { + done(std::move(ctx), ::rust::String(value), true); + }, + std::move(done), std::move(ctx), result)); +} + +void SkusServiceImpl::UpdateStoreValue( + const std::string& key, + const std::string& value, + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + ScopedDictPrefUpdate state(&*prefs_, prefs::kSkusState); + state->Set(key, value); + sdk_task_runner_->PostTask( + FROM_HERE, + base::BindOnce( + [](rust::cxxbridge1::Fn, bool)> done, + rust::cxxbridge1::Box ctx) { + done(std::move(ctx), true); + }, + std::move(done), std::move(st_ctx))); +} + +void SkusServiceImpl::PostTaskWithSDK( + const std::string& domain, + base::OnceCallback cb) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + auto env = GetEnvironmentForDomain(domain); + if (sdks_.count(env)) { + sdk_task_runner_->PostTask( + FROM_HERE, base::BindOnce(std::move(cb), &*(sdks_.at(env)))); + } else { + sdk_task_runner_->PostTaskAndReplyWithResult( + FROM_HERE, + base::BindOnce( + [](const std::string& env, + base::WeakPtr skus_service, + std::unique_ptr + pending_url_loader_factory, + scoped_refptr ui_task_runner) { + auto sdk = + initialize_sdk(std::make_unique( + std::move(pending_url_loader_factory), + ui_task_runner, skus_service), + env); + return sdk; + }, + env, weak_factory_.GetWeakPtr(), url_loader_factory_->Clone(), + ui_task_runner_), + base::BindOnce(&SkusServiceImpl::OnSDKInitialized, + weak_factory_.GetWeakPtr(), env, std::move(cb))); + } +} + +void SkusServiceImpl::OnSDKInitialized( + const std::string& env, + base::OnceCallback cb, + ::rust::Box sdk) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (!sdks_.count(env)) { + sdks_.insert_or_assign(env, std::move(sdk)); + } + sdk_task_runner_->PostTask(FROM_HERE, + base::BindOnce(std::move(cb), &*(sdks_.at(env)))); } } // namespace skus diff --git a/components/skus/browser/skus_service_impl.h b/components/skus/browser/skus_service_impl.h index 19a8d4c619dc..1fb86d46364b 100644 --- a/components/skus/browser/skus_service_impl.h +++ b/components/skus/browser/skus_service_impl.h @@ -11,6 +11,9 @@ #include #include "base/memory/weak_ptr.h" +#include "base/task/single_thread_task_runner.h" +#include "base/task/thread_pool.h" +#include "base/threading/sequence_bound.h" #include "brave/components/skus/browser/rs/cxx/src/shim.h" #include "brave/components/skus/common/skus_sdk.mojom.h" #include "components/keyed_service/core/keyed_service.h" @@ -99,22 +102,39 @@ class SkusServiceImpl : public KeyedService, public mojom::SkusService { const std::string& receipt, skus::mojom::SkusService::CreateOrderFromReceiptCallback callback) override; - - ::rust::Box& GetOrCreateSDK(const std::string& domain); + void PurgeStore(rust::cxxbridge1::Fn< + void(rust::cxxbridge1::Box, + bool success)> done, + rust::cxxbridge1::Box st_ctx); + void GetValueFromStore( + const std::string& key, + rust::cxxbridge1::Fn, + rust::String, + bool)> done, + rust::cxxbridge1::Box ctx); + void UpdateStoreValue( + const std::string& key, + const std::string& value, + rust::cxxbridge1::Fn, + bool success)> done, + rust::cxxbridge1::Box st_ctx); private: - void OnCredentialSummary( - const std::string& domain, - mojom::SkusService::CredentialSummaryCallback callback, - const std::string& summary_string); - - void OnCreateOrderFromReceipt( - mojom::SkusService::CredentialSummaryCallback callback, - const std::string& order_id_string); - - raw_ptr prefs_; - scoped_refptr url_loader_factory_; - std::unordered_map> sdk_; + void PostTaskWithSDK(const std::string& domain, + base::OnceCallback cb); + + void OnSDKInitialized(const std::string& env, + base::OnceCallback cb, + ::rust::Box cpp_sdk); + + SEQUENCE_CHECKER(sequence_checker_); + raw_ptr prefs_ GUARDED_BY_CONTEXT(sequence_checker_); + scoped_refptr url_loader_factory_ + GUARDED_BY_CONTEXT(sequence_checker_); + scoped_refptr sdk_task_runner_; + scoped_refptr ui_task_runner_; + std::unordered_map> sdks_ + GUARDED_BY_CONTEXT(sequence_checker_); mojo::ReceiverSet receivers_; base::WeakPtrFactory weak_factory_{this}; }; diff --git a/components/skus/browser/skus_service_unittest.cc b/components/skus/browser/skus_service_unittest.cc index 08f410115dca..b4bd4ee436ee 100644 --- a/components/skus/browser/skus_service_unittest.cc +++ b/components/skus/browser/skus_service_unittest.cc @@ -9,7 +9,6 @@ #include "base/json/json_reader.h" #include "base/json/json_writer.h" -#include "base/run_loop.h" #include "base/strings/string_util.h" #include "base/test/bind.h" #include "base/test/task_environment.h" @@ -254,7 +253,7 @@ class SkusServiceTestUnitTest : public testing::Test { callback_called = true; result = summary; })); - base::RunLoop().RunUntilIdle(); + task_environment_.RunUntilIdle(); EXPECT_TRUE(callback_called); return result; }