summaryrefslogtreecommitdiffstats
path: root/bridge/matrix/helpers.go
diff options
context:
space:
mode:
Diffstat (limited to 'bridge/matrix/helpers.go')
-rw-r--r--bridge/matrix/helpers.go367
1 files changed, 267 insertions, 100 deletions
diff --git a/bridge/matrix/helpers.go b/bridge/matrix/helpers.go
index 5a91f748..93031e1d 100644
--- a/bridge/matrix/helpers.go
+++ b/bridge/matrix/helpers.go
@@ -1,16 +1,21 @@
package bmatrix
import (
- "encoding/json"
"errors"
"fmt"
"html"
- "strings"
+ "sort"
+ "sync"
"time"
- matrix "github.com/matterbridge/gomatrix"
+ matrix "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
)
+// arbitrary limit to determine when to cleanup nickname cache entries
+const MaxNumberOfUsersInCache = 50_000
+
func newMatrixUsername(username string) *matrixUsername {
mUsername := new(matrixUsername)
@@ -28,11 +33,11 @@ func newMatrixUsername(username string) *matrixUsername {
}
// getRoomID retrieves a matching room ID from the channel name.
-func (b *Bmatrix) getRoomID(channel string) string {
+func (b *Bmatrix) getRoomID(channelName string) id.RoomID {
b.RLock()
defer b.RUnlock()
- for ID, name := range b.RoomMap {
- if name == channel {
+ for ID, channel := range b.RoomMap {
+ if channelName == channel.name {
return ID
}
}
@@ -40,31 +45,59 @@ func (b *Bmatrix) getRoomID(channel string) string {
return ""
}
-// interface2Struct marshals and immediately unmarshals an interface.
-// Useful for converting map[string]interface{} to a struct.
-func interface2Struct(in interface{}, out interface{}) error {
- jsonObj, err := json.Marshal(in)
- if err != nil {
- return err //nolint:wrapcheck
- }
+type NicknameCacheEntry struct {
+ displayName string
+ lastUpdated time.Time
+ conflictWithOtherUsername bool
+}
- return json.Unmarshal(jsonObj, out)
+type NicknameUserEntry struct {
+ globalEntry *NicknameCacheEntry
+ perChannel map[id.RoomID]NicknameCacheEntry
}
-// getDisplayName retrieves the displayName for mxid, querying the homeserver if the mxid is not in the cache.
-func (b *Bmatrix) getDisplayName(mxid string) string {
- if b.GetBool("UseUserName") {
- return mxid[1:]
+type NicknameCache struct {
+ users map[id.UserID]NicknameUserEntry
+ sync.RWMutex
+}
+
+func NewNicknameCache() *NicknameCache {
+ return &NicknameCache{
+ users: make(map[id.UserID]NicknameUserEntry),
+ RWMutex: sync.RWMutex{},
}
+}
- b.RLock()
- if val, present := b.NicknameMap[mxid]; present {
- b.RUnlock()
+// note: cache is not locked here
+func (c *NicknameCache) retrieveDisplaynameFromCache(channelID id.RoomID, mxid id.UserID) string {
+ var cachedEntry *NicknameCacheEntry = nil
+
+ c.RLock()
+ if user, userPresent := c.users[mxid]; userPresent {
+ // try first the name of the user in the room, then globally
+ if roomCachedEntry, roomPresent := user.perChannel[channelID]; roomPresent {
+ cachedEntry = &roomCachedEntry
+ } else if user.globalEntry != nil {
+ cachedEntry = user.globalEntry
+ }
+ }
+ c.RUnlock()
- return val.displayName
+ if cachedEntry == nil {
+ return ""
+ }
+
+ if cachedEntry.conflictWithOtherUsername {
+ // TODO: the current behavior is that only users with clashing usernames and *that have
+ // spoken since the bridge started* will get their mxids shown, and this doesn't
+ // feel right
+ return fmt.Sprintf("%s (%s)", cachedEntry.displayName, mxid)
}
- b.RUnlock()
+ return cachedEntry.displayName
+}
+
+func (b *Bmatrix) retrieveGlobalDisplayname(mxid id.UserID) string {
displayName, err := b.mc.GetDisplayName(mxid)
var httpError *matrix.HTTPError
if errors.As(err, &httpError) {
@@ -72,127 +105,198 @@ func (b *Bmatrix) getDisplayName(mxid string) string {
}
if err != nil {
- return b.cacheDisplayName(mxid, mxid[1:])
+ return string(mxid)[1:]
}
- return b.cacheDisplayName(mxid, displayName.DisplayName)
+ return displayName.DisplayName
}
-// cacheDisplayName stores the mapping between a mxid and a display name, to be reused later without performing a query to the homserver.
-// Note that old entries are cleaned when this function is called.
-func (b *Bmatrix) cacheDisplayName(mxid string, displayName string) string {
- now := time.Now()
+// getDisplayName retrieves the displayName for mxid, querying the homeserver if the mxid is not in the cache.
+func (b *Bmatrix) getDisplayName(channelID id.RoomID, mxid id.UserID) string {
+ if b.GetBool("UseUserName") {
+ return string(mxid)[1:]
+ }
- // scan to delete old entries, to stop memory usage from becoming too high with old entries.
- // In addition, we also detect if another user have the same username, and if so, we append their mxids to their usernames to differentiate them.
- toDelete := []string{}
- conflict := false
+ displayname := b.NicknameCache.retrieveDisplaynameFromCache(channelID, mxid)
+ if displayname != "" {
+ return displayname
+ }
- b.Lock()
- for mxid, v := range b.NicknameMap {
- // to prevent username reuse across matrix servers - or even on the same server, append
- // the mxid to the username when there is a conflict
- if v.displayName == displayName {
- conflict = true
- // TODO: it would be nice to be able to rename previous messages from this user.
- // The current behavior is that only users with clashing usernames and *that have spoken since the bridge last started* will get their mxids shown, and I don't know if that's the expected behavior.
- v.displayName = fmt.Sprintf("%s (%s)", displayName, mxid)
- b.NicknameMap[mxid] = v
+ // retrieve the global display name
+ return b.cacheDisplayName("", mxid, b.retrieveGlobalDisplayname(mxid))
+}
+
+// scan to delete old entries, to stop memory usage from becoming high with obsolete entries.
+// note: assume the cache is already write-locked
+// TODO: should we update the timestamp when the entry is used?
+func (c *NicknameCache) clearObsoleteEntries(mxid id.UserID) {
+ // we have a "off-by-one" to account for when the user being added to the
+ // cache already have obsolete cache entries, as we want to keep it because
+ // we will be refreshing it in a minute
+ if len(c.users) <= MaxNumberOfUsersInCache+1 {
+ return
+ }
+
+ usersLastTimestamp := make(map[id.UserID]int64, len(c.users))
+ // compute the last updated timestamp entry for each user
+ for mxidIter, NicknameCacheIter := range c.users {
+ userLastTimestamp := time.Unix(0, 0)
+ for _, userInChannelCacheEntry := range NicknameCacheIter.perChannel {
+ if userInChannelCacheEntry.lastUpdated.After(userLastTimestamp) {
+ userLastTimestamp = userInChannelCacheEntry.lastUpdated
+ }
}
- if now.Sub(v.lastUpdated) > 10*time.Minute {
- toDelete = append(toDelete, mxid)
+ if NicknameCacheIter.globalEntry != nil {
+ if NicknameCacheIter.globalEntry.lastUpdated.After(userLastTimestamp) {
+ userLastTimestamp = NicknameCacheIter.globalEntry.lastUpdated
+ }
}
- }
- if conflict {
- displayName = fmt.Sprintf("%s (%s)", displayName, mxid)
+ usersLastTimestamp[mxidIter] = userLastTimestamp.UnixNano()
}
- for _, v := range toDelete {
- delete(b.NicknameMap, v)
+ // get the limit timestamp before which we must clear entries as obsolete
+ sortedTimestamps := make([]int64, 0, len(usersLastTimestamp))
+ for _, value := range usersLastTimestamp {
+ sortedTimestamps = append(sortedTimestamps, value)
}
-
- b.NicknameMap[mxid] = NicknameCacheEntry{
- displayName: displayName,
- lastUpdated: now,
+ sort.Slice(sortedTimestamps, func(i, j int) bool { return sortedTimestamps[i] < sortedTimestamps[j] })
+ limitTimestamp := sortedTimestamps[len(sortedTimestamps)-MaxNumberOfUsersInCache]
+
+ // delete entries older than the limit
+ for mxidIter, timestamp := range usersLastTimestamp {
+ // do not clear the user that we are adding to the cache
+ if timestamp <= limitTimestamp && mxidIter != mxid {
+ delete(c.users, mxidIter)
+ }
}
- b.Unlock()
-
- return displayName
}
-// handleError converts errors into httpError.
-//nolint:exhaustivestruct
-func handleError(err error) *httpError {
- var mErr matrix.HTTPError
- if !errors.As(err, &mErr) {
- return &httpError{
- Err: "not a HTTPError",
+// to prevent username reuse across matrix rooms - or even inside the same room, if a user uses multiple servers -
+// identify users with naming conflicts
+func (c *NicknameCache) detectConflict(mxid id.UserID, displayName string) bool {
+ conflict := false
+
+ for mxidIter, NicknameCacheIter := range c.users {
+ // skip conflict detection against ourselves, obviously
+ if mxidIter == mxid {
+ continue
}
- }
- var httpErr httpError
+ for channelID, userInChannelCacheEntry := range NicknameCacheIter.perChannel {
+ if userInChannelCacheEntry.displayName == displayName {
+ userInChannelCacheEntry.conflictWithOtherUsername = true
+ c.users[mxidIter].perChannel[channelID] = userInChannelCacheEntry
+ conflict = true
+ }
+ }
- if err := json.Unmarshal(mErr.Contents, &httpErr); err != nil {
- return &httpError{
- Err: "unmarshal failed",
+ if NicknameCacheIter.globalEntry != nil && NicknameCacheIter.globalEntry.displayName == displayName {
+ c.users[mxidIter].globalEntry.conflictWithOtherUsername = true
+ conflict = true
}
}
- return &httpErr
+ return conflict
}
-func (b *Bmatrix) containsAttachment(content map[string]interface{}) bool {
- // Skip empty messages
- if content["msgtype"] == nil {
- return false
+// cacheDisplayName stores the mapping between a mxid and a display name, to be reused
+// later without performing a query to the homeserver.
+// Note that old entries are cleaned when this function is called.
+func (b *Bmatrix) cacheDisplayName(channelID id.RoomID, mxid id.UserID, displayName string) string {
+ now := time.Now()
+
+ cache := b.NicknameCache
+
+ cache.Lock()
+ defer cache.Unlock()
+
+ conflict := cache.detectConflict(mxid, displayName)
+
+ cache.clearObsoleteEntries(mxid)
+
+ var newEntry NicknameUserEntry
+ if user, userPresent := cache.users[mxid]; userPresent {
+ newEntry = user
+ } else {
+ newEntry = NicknameUserEntry{
+ globalEntry: nil,
+ perChannel: make(map[id.RoomID]NicknameCacheEntry),
+ }
}
- // Only allow image,video or file msgtypes
- if !(content["msgtype"].(string) == "m.image" ||
- content["msgtype"].(string) == "m.video" ||
- content["msgtype"].(string) == "m.file") {
- return false
+ cacheEntry := NicknameCacheEntry{
+ displayName: displayName,
+ lastUpdated: now,
+ conflictWithOtherUsername: conflict,
}
- return true
+ // this is a local (room-specific) display name, let's cache it as such
+ if channelID == "" {
+ newEntry.globalEntry = &cacheEntry
+ } else {
+ globalDisplayName := b.retrieveGlobalDisplayname(mxid)
+ // updating the global display name or resetting the room name to the global name
+ if globalDisplayName == displayName {
+ delete(newEntry.perChannel, channelID)
+ newEntry.globalEntry = &cacheEntry
+ } else {
+ newEntry.perChannel[channelID] = cacheEntry
+ }
+ }
+
+ cache.users[mxid] = newEntry
+
+ return displayName
}
-// getAvatarURL returns the avatar URL of the specified sender.
-func (b *Bmatrix) getAvatarURL(sender string) string {
- urlPath := b.mc.BuildURL("profile", sender, "avatar_url")
+func (b *Bmatrix) removeDisplayNameFromCache(mxid id.UserID) {
+ cache := b.NicknameCache
- s := struct {
- AvatarURL string `json:"avatar_url"`
- }{}
+ cache.Lock()
+ defer cache.Unlock()
- err := b.mc.MakeRequest("GET", urlPath, nil, &s)
- if err != nil {
- b.Log.Errorf("getAvatarURL failed: %s", err)
+ delete(cache.users, mxid)
+}
+// getAvatarURL returns the avatar URL of the specified sender.
+func (b *Bmatrix) getAvatarURL(sender id.UserID) string {
+ url, err := b.mc.GetAvatarURL(sender)
+ if err != nil {
+ b.Log.Errorf("Couldn't retrieve the URL of the avatar for MXID %s", sender)
return ""
}
- url := strings.ReplaceAll(s.AvatarURL, "mxc://", b.GetString("Server")+"/_matrix/media/r0/thumbnail/")
- if url != "" {
- url += "?width=37&height=37&method=crop"
- }
-
- return url
+ return url.String()
}
// handleRatelimit handles the ratelimit errors and return if we're ratelimited and the amount of time to sleep
func (b *Bmatrix) handleRatelimit(err error) (time.Duration, bool) {
- httpErr := handleError(err)
- if httpErr.Errcode != "M_LIMIT_EXCEEDED" {
+ var mErr matrix.HTTPError
+ if !errors.As(err, &mErr) {
+ b.Log.Errorf("Received a non-HTTPError, don't know what to make of it:\n%#v", err)
+ return 0, false
+ }
+
+ if mErr.RespError.ErrCode != "M_LIMIT_EXCEEDED" {
return 0, false
}
- b.Log.Debugf("ratelimited: %s", httpErr.Err)
- b.Log.Infof("getting ratelimited by matrix, sleeping approx %d seconds before retrying", httpErr.RetryAfterMs/1000)
+ b.Log.Debugf("ratelimited: %s", mErr.RespError.Err)
+
+ // fallback to a one-second delay
+ retryDelayMs := 1000
+
+ if retryDelayString, present := mErr.RespError.ExtraData["retry_after_ms"]; present {
+ if retryDelayInt, correct := retryDelayString.(int); correct && retryDelayInt > retryDelayMs {
+ retryDelayMs = retryDelayInt
+ }
+ }
+
+ b.Log.Infof("getting ratelimited by matrix, sleeping approx %d seconds before retrying", retryDelayMs/1000)
- return time.Duration(httpErr.RetryAfterMs) * time.Millisecond, true
+ return time.Duration(retryDelayMs) * time.Millisecond, true
}
// retry function will check if we're ratelimited and retries again when backoff time expired
@@ -213,3 +317,66 @@ func (b *Bmatrix) retry(f func() error) error {
}
}
}
+
+type SendMessageEventWrapper struct {
+ inner *matrix.Client
+}
+
+//nolint: wrapcheck
+func (w SendMessageEventWrapper) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*matrix.RespSendEvent, error) {
+ return w.inner.SendMessageEvent(roomID, eventType, contentJSON)
+}
+
+//nolint: wrapcheck
+func (b *Bmatrix) sendMessageEventWithRetries(channel id.RoomID, message event.MessageEventContent, username string) (string, error) {
+ var (
+ resp *matrix.RespSendEvent
+ client interface {
+ SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (resp *matrix.RespSendEvent, err error)
+ }
+ err error
+ )
+
+ b.RLock()
+ appservice := b.RoomMap[channel].appService
+ b.RUnlock()
+
+ client = SendMessageEventWrapper{inner: b.mc}
+
+ // only try to send messages through the app Service *once* we have received
+ // events through it (otherwise we don't really know if the appservice works)
+ // Additionally, even if we're receiving messages in that room via the appService listener,
+ // let's check that the appservice "covers" that room
+ if appservice && b.appService.namespaces.containsRoom(channel) && len(b.appService.namespaces.prefixes) > 0 {
+ b.Log.Debugf("Sending with appService")
+ // we take the first prefix
+ bridgeUserID := fmt.Sprintf("@%s%s:%s", b.appService.namespaces.prefixes[0], id.EncodeUserLocalpart(username), b.appService.appService.HomeserverDomain)
+ intent := b.appService.appService.Intent(id.UserID(bridgeUserID))
+ // if we can't change the display name it's not great but not the end of the world either, ignore it
+ // TODO: do not perform this action on every message, with an in-memory cache or something
+ _ = intent.SetDisplayName(username)
+ client = intent
+ } else {
+ applyUsernametoMessage(&message, username)
+ }
+
+ err = b.retry(func() error {
+ resp, err = client.SendMessageEvent(channel, event.EventMessage, message)
+
+ return err
+ })
+ if err != nil {
+ return "", err
+ }
+
+ return string(resp.EventID), err
+}
+
+func applyUsernametoMessage(newMsg *event.MessageEventContent, username string) {
+ matrixUsername := newMatrixUsername(username)
+
+ newMsg.Body = matrixUsername.plain + newMsg.Body
+ if newMsg.FormattedBody != "" {
+ newMsg.FormattedBody = matrixUsername.formatted + newMsg.FormattedBody
+ }
+}