Skip to content

Commit a1823b0

Browse files
authored
Merge pull request #8801 from element-hq/feature/bca/stop_double_reporting_utds
Analyics: stop double reporting posthog utds
2 parents 5ccc486 + 0a284bb commit a1823b0

File tree

4 files changed

+234
-10
lines changed

4 files changed

+234
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright (c) 2024 New Vector Ltd
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package im.vector.app.features
18+
19+
import androidx.test.ext.junit.runners.AndroidJUnit4
20+
import androidx.test.platform.app.InstrumentationRegistry
21+
import im.vector.app.InstrumentedTest
22+
import im.vector.app.features.analytics.ReportedDecryptionFailurePersistence
23+
import kotlinx.coroutines.test.runTest
24+
import org.amshove.kluent.shouldBeEqualTo
25+
import org.junit.Test
26+
import org.junit.runner.RunWith
27+
28+
@RunWith(AndroidJUnit4::class)
29+
class ReportedDecryptionFailurePersistenceTest : InstrumentedTest {
30+
31+
private val context = InstrumentationRegistry.getInstrumentation().targetContext
32+
33+
@Test
34+
fun shouldPersistReportedUtds() = runTest {
35+
val persistence = ReportedDecryptionFailurePersistence(context)
36+
persistence.load()
37+
38+
val eventIds = listOf("$0000", "$0001", "$0002", "$0003")
39+
eventIds.forEach {
40+
persistence.markAsReported(it)
41+
}
42+
43+
eventIds.forEach {
44+
persistence.hasBeenReported(it) shouldBeEqualTo true
45+
}
46+
47+
persistence.hasBeenReported("$0004") shouldBeEqualTo false
48+
49+
persistence.persist()
50+
51+
// Load a new one
52+
val persistence2 = ReportedDecryptionFailurePersistence(context)
53+
persistence2.load()
54+
55+
eventIds.forEach {
56+
persistence2.hasBeenReported(it) shouldBeEqualTo true
57+
}
58+
}
59+
60+
@Test
61+
fun testSaturation() = runTest {
62+
val persistence = ReportedDecryptionFailurePersistence(context)
63+
64+
for (i in 1..6000) {
65+
persistence.markAsReported("000$i")
66+
}
67+
68+
// This should have saturated the bloom filter, making the rate of false positives too high.
69+
// A new bloom filter should have been created to avoid that and the recent reported events should still be in the new filter.
70+
for (i in 5800..6000) {
71+
persistence.hasBeenReported("000$i") shouldBeEqualTo true
72+
}
73+
74+
// Old ones should not be there though
75+
for (i in 1..1000) {
76+
persistence.hasBeenReported("000$i") shouldBeEqualTo false
77+
}
78+
}
79+
}

vector/src/main/java/im/vector/app/features/analytics/DecryptionFailureTracker.kt

+16-10
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ private const val MAX_WAIT_MILLIS = 60_000
6363
class DecryptionFailureTracker @Inject constructor(
6464
private val analyticsTracker: AnalyticsTracker,
6565
private val sessionDataSource: ActiveSessionDataSource,
66+
private val decryptionFailurePersistence: ReportedDecryptionFailurePersistence,
6667
private val clock: Clock
6768
) : Session.Listener, LiveEventListener {
6869

@@ -76,9 +77,6 @@ class DecryptionFailureTracker @Inject constructor(
7677
// Only accessed on a `post` call, ensuring sequential access
7778
private val trackedEventsMap = mutableMapOf<String, DecryptionFailure>()
7879

79-
// List of eventId that have been reported, to avoid double reporting
80-
private val alreadyReported = mutableListOf<String>()
81-
8280
// Mutex to ensure sequential access to internal state
8381
private val mutex = Mutex()
8482

@@ -98,10 +96,16 @@ class DecryptionFailureTracker @Inject constructor(
9896
this.scope = scope
9997
}
10098
observeActiveSession()
99+
post {
100+
decryptionFailurePersistence.load()
101+
}
101102
}
102103

103104
fun stop() {
104105
Timber.v("Stop DecryptionFailureTracker")
106+
post {
107+
decryptionFailurePersistence.persist()
108+
}
105109
activeSessionSourceDisposable.cancel(CancellationException("Closing DecryptionFailureTracker"))
106110

107111
activeSession?.removeListener(this)
@@ -123,6 +127,7 @@ class DecryptionFailureTracker @Inject constructor(
123127
delay(CHECK_INTERVAL)
124128
post {
125129
checkFailures()
130+
decryptionFailurePersistence.persist()
126131
currentTicker = null
127132
if (trackedEventsMap.isNotEmpty()) {
128133
// Reschedule
@@ -136,15 +141,15 @@ class DecryptionFailureTracker @Inject constructor(
136141
.distinctUntilChanged()
137142
.onEach {
138143
Timber.v("Active session changed ${it.getOrNull()?.myUserId}")
139-
it.orNull()?.let { session ->
144+
it.getOrNull()?.let { session ->
140145
post {
141146
onSessionActive(session)
142147
}
143148
}
144149
}.launchIn(scope)
145150
}
146151

147-
private fun onSessionActive(session: Session) {
152+
private suspend fun onSessionActive(session: Session) {
148153
Timber.v("onSessionActive ${session.myUserId} previous: ${activeSession?.myUserId}")
149154
val sessionId = session.sessionId
150155
if (sessionId == activeSession?.sessionId) {
@@ -201,7 +206,8 @@ class DecryptionFailureTracker @Inject constructor(
201206
// already tracked
202207
return
203208
}
204-
if (alreadyReported.contains(eventId)) {
209+
if (decryptionFailurePersistence.hasBeenReported(eventId)) {
210+
Timber.v("Event $eventId already reported")
205211
// already reported
206212
return
207213
}
@@ -236,7 +242,7 @@ class DecryptionFailureTracker @Inject constructor(
236242
}
237243
}
238244

239-
private fun handleEventDecrypted(eventId: String) {
245+
private suspend fun handleEventDecrypted(eventId: String) {
240246
Timber.v("Handle event decrypted $eventId time: ${clock.epochMillis()}")
241247
// Only consider if it was tracked as a failure
242248
val trackedFailure = trackedEventsMap[eventId] ?: return
@@ -269,7 +275,7 @@ class DecryptionFailureTracker @Inject constructor(
269275
}
270276

271277
// This will mutate the trackedEventsMap, so don't call it while iterating on it.
272-
private fun reportFailure(decryptionFailure: DecryptionFailure) {
278+
private suspend fun reportFailure(decryptionFailure: DecryptionFailure) {
273279
Timber.v("Report failure for event ${decryptionFailure.failedEventId}")
274280
val error = decryptionFailure.toAnalyticsEvent()
275281

@@ -278,10 +284,10 @@ class DecryptionFailureTracker @Inject constructor(
278284
// now remove from tracked
279285
trackedEventsMap.remove(decryptionFailure.failedEventId)
280286
// mark as already reported
281-
alreadyReported.add(decryptionFailure.failedEventId)
287+
decryptionFailurePersistence.markAsReported(decryptionFailure.failedEventId)
282288
}
283289

284-
private fun checkFailures() {
290+
private suspend fun checkFailures() {
285291
val now = clock.epochMillis()
286292
Timber.v("Check failures now $now")
287293
// report the definitely failed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright (c) 2024 New Vector Ltd
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package im.vector.app.features.analytics
18+
19+
import android.content.Context
20+
import android.util.LruCache
21+
import com.google.common.hash.BloomFilter
22+
import com.google.common.hash.Funnels
23+
import kotlinx.coroutines.Dispatchers
24+
import kotlinx.coroutines.withContext
25+
import timber.log.Timber
26+
import java.io.File
27+
import java.io.FileOutputStream
28+
import javax.inject.Inject
29+
30+
private const val REPORTED_UTD_FILE_NAME = "im.vector.analytics.reported_utd"
31+
private const val EXPECTED_INSERTIONS = 5000
32+
33+
/**
34+
* This class is used to keep track of the reported decryption failures to avoid double reporting.
35+
* It uses a bloom filter to limit the memory/disk usage.
36+
*/
37+
class ReportedDecryptionFailurePersistence @Inject constructor(
38+
private val context: Context,
39+
) {
40+
41+
// Keep a cache of recent reported failures in memory.
42+
// They will be persisted to the a new bloom filter if the previous one is getting saturated.
43+
// Should be around 30KB max in memory.
44+
// Also allows to have 0% false positive rate for recent failures.
45+
private val inMemoryReportedFailures: LruCache<String, Unit> = LruCache(300)
46+
47+
// Thread-safe and lock-free.
48+
// The expected insertions is 5000, and expected false positive probability of 3% when close to max capability.
49+
// The persisted size is expected to be around 5KB (100 times less than if it was raw strings).
50+
private var bloomFilter: BloomFilter<String> = BloomFilter.create<String>(Funnels.stringFunnel(Charsets.UTF_8), EXPECTED_INSERTIONS)
51+
52+
/**
53+
* Mark an event as reported.
54+
* @param eventId the event id to mark as reported.
55+
*/
56+
suspend fun markAsReported(eventId: String) {
57+
// Add to in memory cache.
58+
inMemoryReportedFailures.put(eventId, Unit)
59+
bloomFilter.put(eventId)
60+
61+
// check if the filter is getting saturated? and then replace
62+
if (bloomFilter.approximateElementCount() > EXPECTED_INSERTIONS - 500) {
63+
// The filter is getting saturated, and the false positive rate is increasing.
64+
// It's time to replace the filter with a new one. And move the in-memory cache to the new filter.
65+
bloomFilter = BloomFilter.create<String>(Funnels.stringFunnel(Charsets.UTF_8), EXPECTED_INSERTIONS)
66+
inMemoryReportedFailures.snapshot().keys.forEach {
67+
bloomFilter.put(it)
68+
}
69+
persist()
70+
}
71+
Timber.v("## Bloom filter stats: expectedFpp: ${bloomFilter.expectedFpp()}, size: ${bloomFilter.approximateElementCount()}")
72+
}
73+
74+
/**
75+
* Check if an event has been reported.
76+
* @param eventId the event id to check.
77+
* @return true if the event has been reported.
78+
*/
79+
fun hasBeenReported(eventId: String): Boolean {
80+
// First check in memory cache.
81+
if (inMemoryReportedFailures.get(eventId) != null) {
82+
return true
83+
}
84+
return bloomFilter.mightContain(eventId)
85+
}
86+
87+
/**
88+
* Load the reported failures from disk.
89+
*/
90+
suspend fun load() {
91+
withContext(Dispatchers.IO) {
92+
try {
93+
val file = File(context.applicationContext.cacheDir, REPORTED_UTD_FILE_NAME)
94+
if (file.exists()) {
95+
file.inputStream().use {
96+
bloomFilter = BloomFilter.readFrom(it, Funnels.stringFunnel(Charsets.UTF_8))
97+
}
98+
}
99+
} catch (e: Throwable) {
100+
Timber.e(e, "## Failed to load reported failures")
101+
}
102+
}
103+
}
104+
105+
/**
106+
* Persist the reported failures to disk.
107+
*/
108+
suspend fun persist() {
109+
withContext(Dispatchers.IO) {
110+
try {
111+
val file = File(context.applicationContext.cacheDir, REPORTED_UTD_FILE_NAME)
112+
if (!file.exists()) file.createNewFile()
113+
FileOutputStream(file).buffered().use {
114+
bloomFilter.writeTo(it)
115+
}
116+
Timber.v("## Successfully saved reported failures, size: ${file.length()}")
117+
} catch (e: Throwable) {
118+
Timber.e(e, "## Failed to save reported failures")
119+
}
120+
}
121+
}
122+
}

vector/src/test/java/im/vector/app/features/analytics/DecryptionFailureTrackerTest.kt

+17
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import im.vector.app.test.fakes.FakeAnalyticsTracker
2323
import im.vector.app.test.fakes.FakeClock
2424
import im.vector.app.test.fakes.FakeSession
2525
import im.vector.app.test.shared.createTimberTestRule
26+
import io.mockk.coEvery
2627
import io.mockk.every
2728
import io.mockk.just
2829
import io.mockk.mockk
@@ -60,9 +61,24 @@ class DecryptionFailureTrackerTest {
6061

6162
private val fakeClock = FakeClock()
6263

64+
val reportedEvents = mutableSetOf<String>()
65+
66+
private val fakePersistence = mockk<ReportedDecryptionFailurePersistence> {
67+
68+
coEvery { load() } just runs
69+
coEvery { persist() } just runs
70+
coEvery { markAsReported(any()) } coAnswers {
71+
reportedEvents.add(firstArg())
72+
}
73+
every { hasBeenReported(any()) } answers {
74+
reportedEvents.contains(firstArg())
75+
}
76+
}
77+
6378
private val decryptionFailureTracker = DecryptionFailureTracker(
6479
fakeAnalyticsTracker,
6580
fakeActiveSessionDataSource.instance,
81+
fakePersistence,
6682
fakeClock
6783
)
6884

@@ -101,6 +117,7 @@ class DecryptionFailureTrackerTest {
101117

102118
@Before
103119
fun setupTest() {
120+
reportedEvents.clear()
104121
fakeMxOrgTestSession.fakeCryptoService.fakeCrossSigningService.givenIsCrossSigningVerifiedReturns(false)
105122
}
106123

0 commit comments

Comments
 (0)