diff --git a/api.go b/api.go index 1348453f5..d757ea708 100644 --- a/api.go +++ b/api.go @@ -111,6 +111,12 @@ type Raft struct { // leaderCh is used to notify of leadership changes leaderCh chan bool + leaderChs []chan bool + leaderChsLock sync.Mutex + // leaderChLastMessage is the last message sent over leaderCh that has been processed. + // It is sent to new channels created by LeaderCh(). + leaderChLastMessage bool + // leaderState used only while state is leader leaderState leaderState @@ -957,14 +963,37 @@ func (r *Raft) State() RaftState { // lose it. // // Receivers can expect to receive a notification only if leadership -// transition has occured. +// transition has occured and immediately after LeaderCh() returns with the +// current state. // // If receivers aren't ready for the signal, signals may drop and only the // latest leadership transition. For example, if a receiver receives subsequent // `true` values, they may deduce that leadership was lost and regained while -// the the receiver was processing first leadership transition. +// the receiver was processing first leadership transition. func (r *Raft) LeaderCh() <-chan bool { - return r.leaderCh + ch := make(chan bool, 1) + r.leaderChsLock.Lock() + if len(r.leaderChs) == 0 { + select { + case v := <-r.leaderCh: + r.leaderChLastMessage = v + default: + } + go func() { + for v := range r.leaderCh { + r.leaderChsLock.Lock() + r.leaderChLastMessage = v + for _, c := range r.leaderChs { + overrideNotifyBool(c, v) + } + r.leaderChsLock.Unlock() + } + }() + } + r.leaderChs = append(r.leaderChs, ch) + ch <- r.leaderChLastMessage + r.leaderChsLock.Unlock() + return ch } // String returns a string representation of this Raft node. diff --git a/raft_test.go b/raft_test.go index e0086b085..3b1205b0a 100644 --- a/raft_test.go +++ b/raft_test.go @@ -263,13 +263,15 @@ func TestRaft_SingleNode(t *testing.T) { raft := c.rafts[0] // Watch leaderCh for change - select { - case v := <-raft.LeaderCh(): - if !v { - c.FailNowf("should become leader") + ch := raft.LeaderCh() + isLeader := false + for !isLeader { + select { + case v := <-ch: + isLeader = v + case <-time.After(conf.HeartbeatTimeout * 3): + c.FailNowf("timeout becoming leader") } - case <-time.After(conf.HeartbeatTimeout * 3): - c.FailNowf("timeout becoming leader") } // Should be leader @@ -1507,10 +1509,11 @@ func TestRaft_LeaderLeaseExpire(t *testing.T) { // Watch the leaderCh timeout := time.After(conf.LeaderLeaseTimeout * 2) + lch := leader.LeaderCh() LOOP: for { select { - case v := <-leader.LeaderCh(): + case v := <-lch: if !v { break LOOP }