Skip to content

Commit 80769fd

Browse files
committed
Add parent message to threaded replies and reactions
1 parent 6251d20 commit 80769fd

File tree

1 file changed

+83
-13
lines changed

1 file changed

+83
-13
lines changed

bridge/matrix/matrix.go

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/42wim/matterbridge/bridge/helper"
1616
"github.com/42wim/matterircd/bridge"
1717
"github.com/davecgh/go-spew/spew"
18+
lru "github.com/hashicorp/golang-lru"
1819
prefixed "github.com/matterbridge/logrus-prefixed-formatter"
1920
"github.com/sirupsen/logrus"
2021
"github.com/spf13/viper"
@@ -32,6 +33,9 @@ type Matrix struct {
3233
channels map[id.RoomID]*Channel
3334
users map[id.UserID]*User
3435
sync.RWMutex
36+
37+
msgParentCache *lru.Cache
38+
msgLastSentCache *lru.Cache
3539
}
3640

3741
var logger *logrus.Entry
@@ -45,6 +49,8 @@ func New(v *viper.Viper, cred bridge.Credentials, eventChan chan *bridge.Event,
4549
dmChannels: make(map[id.RoomID][]id.UserID),
4650
users: make(map[id.UserID]*User),
4751
}
52+
m.msgParentCache, _ = lru.New(100)
53+
m.msgLastSentCache, _ = lru.New(10)
4854

4955
ourlog := logrus.New()
5056
ourlog.SetFormatter(&prefixed.TextFormatter{
@@ -285,23 +291,30 @@ func (m *Matrix) handleMessageEvent(source mautrix.EventSource, ev *event.Event)
285291
}
286292

287293
var text string
288-
var parentID string
294+
var parentID id.EventID
289295

290296
switch {
291297
case ev.Type.String() == "m.text" || ev.Type.String() == "m.room.message":
292298
msgEventContent, _ := ev.Content.Parsed.(*event.MessageEventContent)
293299
text = msgEventContent.Body
294300
if msgEventContent.RelatesTo != nil {
295-
parentID = msgEventContent.RelatesTo.EventID.String()
301+
parentID = msgEventContent.RelatesTo.EventID
296302
}
297303
default:
298304
logger.Warnf("handleMessageEvent unsupported event type %s", ev.Type.String())
299305
}
300306

307+
if !m.v.GetBool("matrix.hidereplies") && parentID.String() != "" {
308+
message, err := m.addParentMsg(ev.RoomID, parentID, text, m.v.GetInt("matrix.shortenrepliesto"), "@", m.v.GetBool("matrix.unicode"))
309+
if err != nil {
310+
logger.Errorf("Unable to get parent post for %#v", ev) //nolint:govet
311+
}
312+
text = message
313+
}
314+
301315
m.RLock()
302316
_, ok := m.dmChannels[ev.RoomID]
303317
m.RUnlock()
304-
305318
if ok {
306319
event := &bridge.Event{ //nolint:gocritic
307320
Type: "direct_message",
@@ -313,7 +326,7 @@ func (m *Matrix) handleMessageEvent(source mautrix.EventSource, ev *event.Event)
313326
// Files: m.getFilesFromData(data),
314327
MessageID: string(ev.ID),
315328
// Event: rmsg.Event,
316-
ParentID: parentID,
329+
ParentID: parentID.String(),
317330
},
318331
}
319332

@@ -331,7 +344,7 @@ func (m *Matrix) handleMessageEvent(source mautrix.EventSource, ev *event.Event)
331344
// Files: m.getFilesFromData(data),
332345
MessageID: string(ev.ID),
333346
// Event: rmsg.Event,
334-
ParentID: parentID,
347+
ParentID: parentID.String(),
335348
},
336349
}
337350

@@ -348,22 +361,30 @@ func (m *Matrix) handleReactionEvent(source mautrix.EventSource, ev *event.Event
348361
return
349362
}
350363

364+
var text string
351365
var reaction string
352-
var parentID string
366+
var parentID id.EventID
353367

354368
switch {
355369
case ev.Type.String() == "m.reaction":
356370
reactionEventContent, _ := ev.Content.Parsed.(*event.ReactionEventContent)
357371
reaction = reactionEventContent.RelatesTo.Key
358-
parentID = reactionEventContent.RelatesTo.EventID.String()
372+
parentID = reactionEventContent.RelatesTo.EventID
359373
default:
360374
logger.Warnf("handleEvent unsupported event type %s", ev.Type.String())
361375
}
362376

377+
if !m.v.GetBool("matrix.hidereplies") {
378+
message, err := m.addParentMsg(ev.RoomID, parentID, text, m.v.GetInt("matrix.shortenrepliesto"), "@", m.v.GetBool("matrix.unicode"))
379+
if err != nil {
380+
logger.Errorf("Unable to get parent post for %#v", ev) //nolint:govet
381+
}
382+
text = message
383+
}
384+
363385
m.RLock()
364386
_, ok := m.dmChannels[ev.RoomID]
365387
m.RUnlock()
366-
367388
channelType := ""
368389
if ok {
369390
channelType = "D"
@@ -377,8 +398,8 @@ func (m *Matrix) handleReactionEvent(source mautrix.EventSource, ev *event.Event
377398
Sender: ghost,
378399
Reaction: reaction,
379400
ChannelType: channelType,
380-
Message: "",
381-
ParentID: parentID,
401+
Message: text,
402+
ParentID: parentID.String(),
382403
},
383404
}
384405

@@ -498,6 +519,7 @@ func (m *Matrix) MsgChannelThread(channelID, parentID, text string) (string, err
498519

499520
logger.Trace("msgchannelthread: error,resp ", err, resp)
500521

522+
m.msgLastSentCache.Add(resp.EventID.String(), fmt.Sprintf("%s: %s", id.RoomAlias(channelID), text))
501523
return resp.EventID.String(), nil
502524
}
503525

@@ -592,7 +614,8 @@ func (m *Matrix) GetChannelUsers(channelID string) ([]*bridge.UserInfo, error) {
592614
return nil, err
593615
}
594616

595-
logger.Tracef("GetChannelUsers %s %d", channelID, len(resp.Joined))
617+
logger.Debugf("GetChannelUsers %s %d", channelID, len(resp.Joined))
618+
logger.Tracef("GetChannelUsers %s", spew.Sdump(resp.Joined))
596619

597620
for user := range resp.Joined {
598621
users = append(users, m.createUser(user))
@@ -606,6 +629,7 @@ func (m *Matrix) GetUsers() []*bridge.UserInfo {
606629

607630
logger.Trace("GetUsers ", m.users)
608631
logger.Trace("GetUsers ", spew.Sdump(m.users))
632+
logger.Debugf("GetUsers %d", len(m.users))
609633

610634
m.RLock()
611635
for userID := range m.users {
@@ -625,6 +649,8 @@ func (m *Matrix) GetChannels() []*bridge.ChannelInfo {
625649
m.RLock()
626650
defer m.RUnlock()
627651

652+
logger.Tracef("GetChannels %s", spew.Sdump(m.channels))
653+
628654
for roomID, channel := range m.channels {
629655
channel.RLock()
630656

@@ -644,6 +670,8 @@ func (m *Matrix) GetChannels() []*bridge.ChannelInfo {
644670
Private: false,
645671
})
646672

673+
logger.Debugf("GetChannels %s (%s)", channel.Alias.String(), roomID.String())
674+
647675
channel.RUnlock()
648676
}
649677

@@ -761,6 +789,39 @@ func isValidNick(s string) bool {
761789
return true
762790
}
763791

792+
func (m *Matrix) addParentMsg(roomID id.RoomID, parentID id.EventID, msg string, newLen int, uncounted string, unicode bool) (string, error) {
793+
var replyMessage string
794+
795+
// Search and use cached reply if it exists.
796+
// None found, so we'll need to create one and save it for future uses.
797+
if v, ok := m.msgParentCache.Get(parentID); !ok {
798+
resp, err := m.mc.GetEvent(roomID, parentID)
799+
// Retry once on failure.
800+
if err != nil {
801+
resp, err = m.mc.GetEvent(roomID, parentID)
802+
}
803+
if err != nil {
804+
return msg, err
805+
}
806+
807+
body := ""
808+
if val, ok := resp.Content.Raw["body"]; ok {
809+
body = val.(string)
810+
}
811+
812+
parentUser := m.GetUser(resp.Sender.String())
813+
parentMessage := maybeShorten(body, newLen, uncounted, unicode)
814+
replyMessage = fmt.Sprintf(" (re @%s: %s)", parentUser.Nick, parentMessage)
815+
logger.Debugf("Created reply for parent post %s:%s", parentID.String(), replyMessage)
816+
817+
m.msgParentCache.Add(parentID, replyMessage)
818+
} else if replyMessage, ok = v.(string); ok {
819+
logger.Debugf("Found saved reply for parent post %s, using:%s", parentID, replyMessage)
820+
}
821+
822+
return strings.TrimRight(msg, "\n") + replyMessage, nil
823+
}
824+
764825
// maybeShorten returns a prefix of msg that is approximately newLen
765826
// characters long, followed by "...". Words that start with uncounted
766827
// are included in the result but are not reckoned against newLen.
@@ -840,7 +901,7 @@ func (m *Matrix) SearchUsers(query string) ([]*bridge.UserInfo, error) {
840901
return brusers, nil
841902
}
842903

843-
func (m *Matrix) GetPostThread(channelID string) interface{} {
904+
func (m *Matrix) GetPostThread(postID string) interface{} {
844905
return nil
845906
}
846907

@@ -873,5 +934,14 @@ func (m *Matrix) RemoveReaction(msgID, emoji string) error {
873934
}
874935

875936
func (m *Matrix) GetLastSentMsgs() []string {
876-
return []string{}
937+
data := make([]string, 0)
938+
939+
for _, k := range m.msgLastSentCache.Keys() {
940+
if v, ok := m.msgLastSentCache.Get(k); ok {
941+
msg, _ := v.(string)
942+
data = append(data, fmt.Sprintf("[@@%s] %s", k, msg))
943+
}
944+
}
945+
946+
return data
877947
}

0 commit comments

Comments
 (0)