@@ -5,6 +5,7 @@ mod test;
5
5
6
6
use std:: cmp:: max;
7
7
use std:: collections:: BTreeMap ;
8
+ use std:: collections:: Bound ;
8
9
use std:: collections:: HashMap ;
9
10
use std:: fmt:: Debug ;
10
11
use std:: io:: Cursor ;
@@ -307,18 +308,84 @@ impl MemStore {
307
308
Ok ( ( ) )
308
309
}
309
310
310
- pub async fn defensive_nonempty_range < RT , RNG : RangeBounds < RT > + Clone + Debug + Send + Iterator > (
311
+ pub async fn defensive_nonempty_range < RNG : RangeBounds < u64 > + Clone + Debug + Send > (
311
312
& self ,
312
313
range : RNG ,
313
314
) -> anyhow:: Result < ( ) > {
314
315
if !* self . defensive . read ( ) . await {
315
316
return Ok ( ( ) ) ;
316
317
}
317
- for _ in range. clone ( ) {
318
+ let start = match range. start_bound ( ) {
319
+ Bound :: Included ( i) => Some ( * i) ,
320
+ Bound :: Excluded ( i) => Some ( * i + 1 ) ,
321
+ Bound :: Unbounded => None ,
322
+ } ;
323
+
324
+ let end = match range. end_bound ( ) {
325
+ Bound :: Included ( i) => Some ( * i) ,
326
+ Bound :: Excluded ( i) => Some ( * i - 1 ) ,
327
+ Bound :: Unbounded => None ,
328
+ } ;
329
+
330
+ if start. is_none ( ) || end. is_none ( ) {
331
+ return Ok ( ( ) ) ;
332
+ }
333
+
334
+ if start > end {
335
+ return Err ( anyhow:: anyhow!( "range must be nonempty: {:?}" , range) ) ;
336
+ }
337
+
338
+ Ok ( ( ) )
339
+ }
340
+
341
+ pub async fn defensive_range_hits_logs < T : AppData , RNG : RangeBounds < u64 > + Debug + Send > (
342
+ & self ,
343
+ range : RNG ,
344
+ logs : & [ Entry < T > ] ,
345
+ ) -> anyhow:: Result < ( ) > {
346
+ if !* self . defensive . read ( ) . await {
318
347
return Ok ( ( ) ) ;
319
348
}
320
349
321
- Err ( anyhow:: anyhow!( "range must be nonempty: {:?}" , range) )
350
+ {
351
+ let want_first = match range. start_bound ( ) {
352
+ Bound :: Included ( i) => Some ( * i) ,
353
+ Bound :: Excluded ( i) => Some ( * i + 1 ) ,
354
+ Bound :: Unbounded => None ,
355
+ } ;
356
+
357
+ let first = logs. first ( ) . map ( |x| x. log_id . index ) ;
358
+
359
+ if want_first. is_some ( ) && first != want_first {
360
+ return Err ( anyhow:: anyhow!(
361
+ "{:?} want first: {:?}, but {:?}" ,
362
+ range,
363
+ want_first,
364
+ first
365
+ ) ) ;
366
+ }
367
+ }
368
+
369
+ {
370
+ let want_last = match range. end_bound ( ) {
371
+ Bound :: Included ( i) => Some ( * i) ,
372
+ Bound :: Excluded ( i) => Some ( * i - 1 ) ,
373
+ Bound :: Unbounded => None ,
374
+ } ;
375
+
376
+ let last = logs. last ( ) . map ( |x| x. log_id . index ) ;
377
+
378
+ if want_last. is_some ( ) && last != want_last {
379
+ return Err ( anyhow:: anyhow!(
380
+ "{:?} want last: {:?}, but {:?}" ,
381
+ range,
382
+ want_last,
383
+ last
384
+ ) ) ;
385
+ }
386
+ }
387
+
388
+ Ok ( ( ) )
322
389
}
323
390
324
391
pub async fn defensive_apply_log_id_gt_last < D : AppData > ( & self , entries : & [ & Entry < D > ] ) -> anyhow:: Result < ( ) > {
@@ -477,14 +544,26 @@ impl RaftStorage<ClientRequest, ClientResponse> for MemStore {
477
544
}
478
545
479
546
#[ tracing:: instrument( level = "trace" , skip( self ) ) ]
480
- async fn get_log_entries ( & self , start : u64 , stop : u64 ) -> Result < Vec < Entry < ClientRequest > > > {
481
- // Invalid request, return empty vec.
482
- if start > stop {
483
- tracing:: error!( "get_log_entries: invalid request, start({}) > stop({})" , start, stop) ;
484
- return Ok ( vec ! [ ] ) ;
485
- }
547
+ async fn get_log_entries < RNG : RangeBounds < u64 > + Clone + Debug + Send + Sync > (
548
+ & self ,
549
+ range : RNG ,
550
+ ) -> Result < Vec < Entry < ClientRequest > > > {
551
+ self . defensive_nonempty_range ( range. clone ( ) ) . await ?;
552
+
553
+ let res = {
554
+ let log = self . log . read ( ) . await ;
555
+ log. range ( range. clone ( ) ) . map ( |( _, val) | val. clone ( ) ) . collect :: < Vec < _ > > ( )
556
+ } ;
557
+
558
+ self . defensive_range_hits_logs ( range, & res) . await ?;
559
+
560
+ Ok ( res)
561
+ }
562
+
563
+ #[ tracing:: instrument( level = "trace" , skip( self ) ) ]
564
+ async fn try_get_log_entry ( & self , log_index : u64 ) -> Result < Option < Entry < ClientRequest > > > {
486
565
let log = self . log . read ( ) . await ;
487
- Ok ( log. range ( start..stop ) . map ( | ( _ , val ) | val . clone ( ) ) . collect ( ) )
566
+ Ok ( log. get ( & log_index ) . cloned ( ) )
488
567
}
489
568
490
569
#[ tracing:: instrument( level = "trace" , skip( self , range) , fields( range=?range) ) ]
0 commit comments