diff --git a/server/jetstream_cluster.go b/server/jetstream_cluster.go index da7147b1e2f..8e250fbff7c 100644 --- a/server/jetstream_cluster.go +++ b/server/jetstream_cluster.go @@ -3098,8 +3098,10 @@ func (js *jetStream) applyStreamEntries(mset *stream, ce *CommittedEntry, isReco if subject == _EMPTY_ && ts == 0 && len(msg) == 0 && len(hdr) == 0 { // Skip and update our lseq. last := mset.store.SkipMsg() + mset.mu.Lock() mset.setLastSeq(last) mset.clearAllPreAcks(last) + mset.mu.Unlock() continue } @@ -8805,6 +8807,8 @@ func (mset *stream) processCatchupMsg(msg []byte) (uint64, error) { return 0, err } + mset.mu.Lock() + defer mset.mu.Unlock() // Update our lseq. mset.setLastSeq(seq) @@ -8812,11 +8816,9 @@ func (mset *stream) processCatchupMsg(msg []byte) (uint64, error) { if len(hdr) > 0 { if msgId := getMsgId(hdr); msgId != _EMPTY_ { if !ddloaded { - mset.mu.Lock() mset.rebuildDedupe() - mset.mu.Unlock() } - mset.storeMsgId(&ddentry{msgId, seq, ts}) + mset.storeMsgIdLocked(&ddentry{msgId, seq, ts}) } } diff --git a/server/jetstream_cluster_1_test.go b/server/jetstream_cluster_1_test.go index 8253d109e72..2230d3853bc 100644 --- a/server/jetstream_cluster_1_test.go +++ b/server/jetstream_cluster_1_test.go @@ -6999,6 +6999,99 @@ func TestJetStreamClusterStreamUpscalePeersAfterDownscale(t *testing.T) { checkPeerSet() } +func TestJetStreamClusterClearAllPreAcksOnRemoveMsg(t *testing.T) { + c := createJetStreamClusterExplicit(t, "R3S", 3) + defer c.shutdown() + + nc, js := jsClientConnect(t, c.randomServer()) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo"}, + Replicas: 3, + Retention: nats.WorkQueuePolicy, + }) + require_NoError(t, err) + + _, err = js.AddConsumer("TEST", &nats.ConsumerConfig{ + Durable: "CONSUMER", + AckPolicy: nats.AckExplicitPolicy, + }) + require_NoError(t, err) + + for i := 0; i < 3; i++ { + _, err = js.Publish("foo", nil) + require_NoError(t, err) + } + + // Wait for all servers to converge on the same state. + checkFor(t, 5*time.Second, 500*time.Millisecond, func() error { + return checkState(t, c, globalAccountName, "TEST") + }) + + // Register pre-acks on all servers. + // Normally this can't happen as the stream leader will have the message that's acked available, just for testing. + for _, s := range c.servers { + acc, err := s.lookupAccount(globalAccountName) + require_NoError(t, err) + mset, err := acc.lookupStream("TEST") + require_NoError(t, err) + o := mset.lookupConsumer("CONSUMER") + require_NotNil(t, o) + + // Register pre-acks for the 3 messages. + mset.registerPreAckLock(o, 1) + mset.registerPreAckLock(o, 2) + mset.registerPreAckLock(o, 3) + } + + // Check there's an expected amount of pre-acks, and there are no pre-acks for the given sequence. + checkPreAcks := func(seq uint64, expected int) { + t.Helper() + checkFor(t, 5*time.Second, time.Second, func() error { + for _, s := range c.servers { + acc, err := s.lookupAccount(globalAccountName) + if err != nil { + return err + } + mset, err := acc.lookupStream("TEST") + if err != nil { + return err + } + mset.mu.RLock() + numPreAcks := len(mset.preAcks) + numSeqPreAcks := len(mset.preAcks[seq]) + mset.mu.RUnlock() + if numPreAcks != expected { + return fmt.Errorf("expected %d pre-acks, got %d", expected, numPreAcks) + } + if seq > 0 && numSeqPreAcks != 0 { + return fmt.Errorf("expected 0 pre-acks for seq %d, got %d", seq, numSeqPreAcks) + } + } + return nil + }) + } + // Check all pre-acks were registered. + checkPreAcks(0, 3) + + // Deleting the message should clear the pre-ack. + err = js.DeleteMsg("TEST", 1) + require_NoError(t, err) + checkPreAcks(1, 2) + + // Erasing the message should clear the pre-ack. + err = js.SecureDeleteMsg("TEST", 2) + require_NoError(t, err) + checkPreAcks(2, 1) + + // Purging should clear all pre-acks below the purged floor. + err = js.PurgeStream("TEST", &nats.StreamPurgeRequest{Sequence: 4}) + require_NoError(t, err) + checkPreAcks(3, 0) +} + // // DO NOT ADD NEW TESTS IN THIS FILE (unless to balance test times) // Add at the end of jetstream_cluster__test.go, with being the highest value. diff --git a/server/norace_test.go b/server/norace_test.go index 69914b9dc0f..ad68deb3321 100644 --- a/server/norace_test.go +++ b/server/norace_test.go @@ -7738,32 +7738,47 @@ func TestNoRaceJetStreamClusterUnbalancedInterestMultipleConsumers(t *testing.T) // make sure we do not remove prematurely. msgs, err := sub.Fetch(100, nats.MaxWait(time.Second)) require_NoError(t, err) - require_True(t, len(msgs) == 100) + require_Len(t, len(msgs), 100) for _, m := range msgs { m.AckSync() } ci, err := js.ConsumerInfo("EVENTS", "D") require_NoError(t, err) - require_True(t, ci.NumPending == uint64(numToSend-100)) - require_True(t, ci.NumAckPending == 0) - require_True(t, ci.Delivered.Stream == 100) - require_True(t, ci.AckFloor.Stream == 100) + require_Equal(t, ci.NumPending, uint64(numToSend-100)) + require_Equal(t, ci.NumAckPending, 0) + require_Equal(t, ci.Delivered.Stream, 100) + require_Equal(t, ci.AckFloor.Stream, 100) // Check stream state on all servers. - for _, s := range c.servers { - mset, err := s.GlobalAccount().lookupStream("EVENTS") - require_NoError(t, err) - state := mset.state() - require_True(t, state.Msgs == 900) - require_True(t, state.FirstSeq == 101) - require_True(t, state.LastSeq == 1000) - require_True(t, state.Consumers == 2) - } + // Since acks result in messages to be removed through proposals, + // it could take some time to be reflected in the stream state. + checkFor(t, 5*time.Second, 500*time.Millisecond, func() error { + for _, s := range c.servers { + mset, err := s.GlobalAccount().lookupStream("EVENTS") + if err != nil { + return err + } + state := mset.state() + if state.Msgs != 900 { + return fmt.Errorf("expected state.Msgs=900, got %d", state.Msgs) + } + if state.FirstSeq != 101 { + return fmt.Errorf("expected state.FirstSeq=101, got %d", state.FirstSeq) + } + if state.LastSeq != 1000 { + return fmt.Errorf("expected state.LastSeq=1000, got %d", state.LastSeq) + } + if state.Consumers != 2 { + return fmt.Errorf("expected state.Consumers=2, got %d", state.Consumers) + } + } + return nil + }) msgs, err = sub.Fetch(900, nats.MaxWait(time.Second)) require_NoError(t, err) - require_True(t, len(msgs) == 900) + require_Len(t, len(msgs), 900) for _, m := range msgs { m.AckSync() } @@ -7776,15 +7791,15 @@ func TestNoRaceJetStreamClusterUnbalancedInterestMultipleConsumers(t *testing.T) mset, err := s.GlobalAccount().lookupStream("EVENTS") require_NoError(t, err) state := mset.state() - require_True(t, state.Msgs == 0) - require_True(t, state.FirstSeq == 1001) - require_True(t, state.LastSeq == 1000) - require_True(t, state.Consumers == 2) + require_Equal(t, state.Msgs, 0) + require_Equal(t, state.FirstSeq, 1001) + require_Equal(t, state.LastSeq, 1000) + require_Equal(t, state.Consumers, 2) // Now check preAcks mset.mu.RLock() numPreAcks := len(mset.preAcks) mset.mu.RUnlock() - require_True(t, numPreAcks == 0) + require_Len(t, numPreAcks, 0) } } @@ -7872,27 +7887,27 @@ func TestNoRaceJetStreamClusterUnbalancedInterestMultipleFilteredConsumers(t *te ci, err := js.ConsumerInfo("EVENTS", "D") require_NoError(t, err) - require_True(t, ci.NumPending == 0) - require_True(t, ci.NumAckPending == 0) - require_True(t, ci.Delivered.Consumer == 500) - require_True(t, ci.Delivered.Stream == 1000) - require_True(t, ci.AckFloor.Consumer == 500) - require_True(t, ci.AckFloor.Stream == 1000) + require_Equal(t, ci.NumPending, 0) + require_Equal(t, ci.NumAckPending, 0) + require_Equal(t, ci.Delivered.Consumer, 500) + require_Equal(t, ci.Delivered.Stream, 1000) + require_Equal(t, ci.AckFloor.Consumer, 500) + require_Equal(t, ci.AckFloor.Stream, 1000) // Check final stream state on all servers. for _, s := range c.servers { mset, err := s.GlobalAccount().lookupStream("EVENTS") require_NoError(t, err) state := mset.state() - require_True(t, state.Msgs == 0) - require_True(t, state.FirstSeq == 1001) - require_True(t, state.LastSeq == 1000) - require_True(t, state.Consumers == 2) + require_Equal(t, state.Msgs, 0) + require_Equal(t, state.FirstSeq, 1001) + require_Equal(t, state.LastSeq, 1000) + require_Equal(t, state.Consumers, 2) // Now check preAcks mset.mu.RLock() numPreAcks := len(mset.preAcks) mset.mu.RUnlock() - require_True(t, numPreAcks == 0) + require_Len(t, numPreAcks, 0) } } diff --git a/server/stream.go b/server/stream.go index 85c85868c63..ce23941eb05 100644 --- a/server/stream.go +++ b/server/stream.go @@ -1126,10 +1126,10 @@ func (mset *stream) lastSeq() uint64 { return mset.lseq } +// Set last seq. +// Write lock should be held. func (mset *stream) setLastSeq(lseq uint64) { - mset.mu.Lock() mset.lseq = lseq - mset.mu.Unlock() } func (mset *stream) sendCreateAdvisory() { @@ -2188,11 +2188,16 @@ func (mset *stream) purge(preq *JSApiStreamPurgeRequest) (purged uint64, err err store.FastState(&state) fseq, lseq := state.FirstSeq, state.LastSeq + mset.mu.Lock() // Check if our last has moved past what our original last sequence was, if so reset. if lseq > mlseq { mset.setLastSeq(lseq) } + // Clear any pending acks below first seq. + mset.clearAllPreAcksBelowFloor(fseq) + mset.mu.Unlock() + // Purge consumers. // Check for filtered purge. if preq != nil && preq.Subject != _EMPTY_ { @@ -2239,7 +2244,14 @@ func (mset *stream) deleteMsg(seq uint64) (bool, error) { if mset.closed.Load() { return false, errStreamClosed } - return mset.store.RemoveMsg(seq) + removed, err := mset.store.RemoveMsg(seq) + if err != nil { + return removed, err + } + mset.mu.Lock() + mset.clearAllPreAcks(seq) + mset.mu.Unlock() + return removed, err } // EraseMsg will securely remove a message and rewrite the data with random data. @@ -2247,7 +2259,14 @@ func (mset *stream) eraseMsg(seq uint64) (bool, error) { if mset.closed.Load() { return false, errStreamClosed } - return mset.store.EraseMsg(seq) + removed, err := mset.store.EraseMsg(seq) + if err != nil { + return removed, err + } + mset.mu.Lock() + mset.clearAllPreAcks(seq) + mset.mu.Unlock() + return removed, err } // Are we a mirror? @@ -4138,15 +4157,8 @@ func (mset *stream) purgeMsgIds() { } } -// storeMsgId will store the message id for duplicate detection. -func (mset *stream) storeMsgId(dde *ddentry) { - mset.mu.Lock() - defer mset.mu.Unlock() - mset.storeMsgIdLocked(dde) -} - // storeMsgIdLocked will store the message id for duplicate detection. -// Lock should he held. +// Lock should be held. func (mset *stream) storeMsgIdLocked(dde *ddentry) { if mset.ddmap == nil { mset.ddmap = make(map[string]*ddentry)