@@ -67,8 +67,8 @@ use crate::{
67
67
requests:: { IncomingResponse , OutgoingRequest , UploadSigningKeysRequest } ,
68
68
session_manager:: { GroupSessionManager , SessionManager } ,
69
69
store:: {
70
- Changes , DeviceChanges , DynCryptoStore , IdentityChanges , IntoCryptoStore , MemoryStore ,
71
- Result as StoreResult , SecretImportError , Store ,
70
+ locks :: LockStoreError , Changes , DeviceChanges , DynCryptoStore , IdentityChanges ,
71
+ IntoCryptoStore , MemoryStore , Result as StoreResult , SecretImportError , Store ,
72
72
} ,
73
73
types:: {
74
74
events:: {
@@ -129,6 +129,18 @@ pub struct OlmMachineInner {
129
129
/// A state machine that handles creating room key backups.
130
130
#[ cfg( feature = "backups_v1" ) ]
131
131
backup_machine : BackupMachine ,
132
+ /// Latest "generation" of data known by the crypto store.
133
+ ///
134
+ /// This is a counter that only increments, set in the database (and can
135
+ /// wrap). It's incremented whenever some process acquires a lock for the
136
+ /// first time. *This assumes the crypto store lock is being held, to
137
+ /// avoid data races on writing to this value in the store*.
138
+ ///
139
+ /// The current process will maintain this value in local memory and in the
140
+ /// DB over time. Observing a different value than the one read in
141
+ /// memory, when reading from the store indicates that somebody else has
142
+ /// written into the database under our feet.
143
+ pub ( crate ) crypto_store_generation : Arc < Mutex < Option < u64 > > > ,
132
144
}
133
145
134
146
#[ cfg( not( tarpaulin_include) ) ]
@@ -142,6 +154,8 @@ impl std::fmt::Debug for OlmMachine {
142
154
}
143
155
144
156
impl OlmMachine {
157
+ const CURRENT_GENERATION_STORE_KEY : & str = "generation-counter" ;
158
+
145
159
/// Create a new memory based OlmMachine.
146
160
///
147
161
/// The created machine will keep the encryption keys only in memory and
@@ -212,6 +226,7 @@ impl OlmMachine {
212
226
identity_manager,
213
227
#[ cfg( feature = "backups_v1" ) ]
214
228
backup_machine,
229
+ crypto_store_generation : Arc :: new ( Mutex :: new ( None ) ) ,
215
230
} ) ;
216
231
217
232
Self { inner }
@@ -1728,6 +1743,106 @@ impl OlmMachine {
1728
1743
pub fn backup_machine ( & self ) -> & BackupMachine {
1729
1744
& self . inner . backup_machine
1730
1745
}
1746
+
1747
+ /// Syncs the database and in-memory generation counter.
1748
+ ///
1749
+ /// This requires that the crypto store lock has been acquired already.
1750
+ pub async fn initialize_crypto_store_generation ( & self ) -> StoreResult < ( ) > {
1751
+ // Avoid reentrant initialization by taking the lock for the entire's function
1752
+ // scope.
1753
+ let mut gen_guard = self . inner . crypto_store_generation . lock ( ) . await ;
1754
+
1755
+ let prev_generation =
1756
+ self . inner . store . get_custom_value ( Self :: CURRENT_GENERATION_STORE_KEY ) . await ?;
1757
+
1758
+ let gen = match prev_generation {
1759
+ Some ( val) => {
1760
+ // There was a value in the store. We need to signal that we're a different
1761
+ // process, so we don't just reuse the value but increment it.
1762
+ u64:: from_le_bytes (
1763
+ val. try_into ( ) . map_err ( |_| LockStoreError :: InvalidGenerationFormat ) ?,
1764
+ )
1765
+ . wrapping_add ( 1 )
1766
+ }
1767
+ None => 0 ,
1768
+ } ;
1769
+
1770
+ self . inner
1771
+ . store
1772
+ . set_custom_value ( Self :: CURRENT_GENERATION_STORE_KEY , gen. to_le_bytes ( ) . to_vec ( ) )
1773
+ . await ?;
1774
+
1775
+ * gen_guard = Some ( gen) ;
1776
+
1777
+ Ok ( ( ) )
1778
+ }
1779
+
1780
+ /// If needs be, update the local and on-disk crypto store generation.
1781
+ ///
1782
+ /// Returns true whether another user has modified the internal generation
1783
+ /// counter, and as such we've incremented and updated it in the
1784
+ /// database.
1785
+ ///
1786
+ /// ## Requirements
1787
+ ///
1788
+ /// - This assumes that `initialize_crypto_store_generation` has been called
1789
+ /// beforehand.
1790
+ /// - This requires that the crypto store lock has been acquired.
1791
+ pub async fn maintain_crypto_store_generation ( & self ) -> StoreResult < bool > {
1792
+ let mut gen_guard = self . inner . crypto_store_generation . lock ( ) . await ;
1793
+
1794
+ // The database value must be there:
1795
+ // - either we could initialize beforehand, thus write into the database,
1796
+ // - or we couldn't, and then another process was holding onto the database's
1797
+ // lock, thus
1798
+ // has written a generation counter in there.
1799
+ let actual_gen = self
1800
+ . inner
1801
+ . store
1802
+ . get_custom_value ( Self :: CURRENT_GENERATION_STORE_KEY )
1803
+ . await ?
1804
+ . ok_or ( LockStoreError :: MissingGeneration ) ?;
1805
+
1806
+ let actual_gen = u64:: from_le_bytes (
1807
+ actual_gen. try_into ( ) . map_err ( |_| LockStoreError :: InvalidGenerationFormat ) ?,
1808
+ ) ;
1809
+
1810
+ let expected_gen = match gen_guard. as_ref ( ) {
1811
+ Some ( expected_gen) => {
1812
+ if actual_gen == * expected_gen {
1813
+ return Ok ( false ) ;
1814
+ }
1815
+ // Increment the biggest, and store it everywhere.
1816
+ actual_gen. max ( * expected_gen) . wrapping_add ( 1 )
1817
+ }
1818
+ None => {
1819
+ // Some other process hold onto the lock when initializing, so we must reload.
1820
+ // Increment database value, and store it everywhere.
1821
+ actual_gen. wrapping_add ( 1 )
1822
+ }
1823
+ } ;
1824
+
1825
+ tracing:: debug!(
1826
+ "Crypto store generation mismatch: previously known was {:?}, actual is {:?}, next is {}" ,
1827
+ * gen_guard,
1828
+ actual_gen,
1829
+ expected_gen
1830
+ ) ;
1831
+
1832
+ // Update known value.
1833
+ * gen_guard = Some ( expected_gen) ;
1834
+
1835
+ // Update value in database.
1836
+ self . inner
1837
+ . store
1838
+ . set_custom_value (
1839
+ Self :: CURRENT_GENERATION_STORE_KEY ,
1840
+ expected_gen. to_le_bytes ( ) . to_vec ( ) ,
1841
+ )
1842
+ . await ?;
1843
+
1844
+ Ok ( true )
1845
+ }
1731
1846
}
1732
1847
1733
1848
#[ cfg( any( feature = "testing" , test) ) ]
0 commit comments