diff options
Diffstat (limited to 'bridge/matrix/helpers.go')
-rw-r--r-- | bridge/matrix/helpers.go | 367 |
1 files changed, 267 insertions, 100 deletions
diff --git a/bridge/matrix/helpers.go b/bridge/matrix/helpers.go index 5a91f748..93031e1d 100644 --- a/bridge/matrix/helpers.go +++ b/bridge/matrix/helpers.go @@ -1,16 +1,21 @@ package bmatrix import ( - "encoding/json" "errors" "fmt" "html" - "strings" + "sort" + "sync" "time" - matrix "github.com/matterbridge/gomatrix" + matrix "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) +// arbitrary limit to determine when to cleanup nickname cache entries +const MaxNumberOfUsersInCache = 50_000 + func newMatrixUsername(username string) *matrixUsername { mUsername := new(matrixUsername) @@ -28,11 +33,11 @@ func newMatrixUsername(username string) *matrixUsername { } // getRoomID retrieves a matching room ID from the channel name. -func (b *Bmatrix) getRoomID(channel string) string { +func (b *Bmatrix) getRoomID(channelName string) id.RoomID { b.RLock() defer b.RUnlock() - for ID, name := range b.RoomMap { - if name == channel { + for ID, channel := range b.RoomMap { + if channelName == channel.name { return ID } } @@ -40,31 +45,59 @@ func (b *Bmatrix) getRoomID(channel string) string { return "" } -// interface2Struct marshals and immediately unmarshals an interface. -// Useful for converting map[string]interface{} to a struct. -func interface2Struct(in interface{}, out interface{}) error { - jsonObj, err := json.Marshal(in) - if err != nil { - return err //nolint:wrapcheck - } +type NicknameCacheEntry struct { + displayName string + lastUpdated time.Time + conflictWithOtherUsername bool +} - return json.Unmarshal(jsonObj, out) +type NicknameUserEntry struct { + globalEntry *NicknameCacheEntry + perChannel map[id.RoomID]NicknameCacheEntry } -// getDisplayName retrieves the displayName for mxid, querying the homeserver if the mxid is not in the cache. -func (b *Bmatrix) getDisplayName(mxid string) string { - if b.GetBool("UseUserName") { - return mxid[1:] +type NicknameCache struct { + users map[id.UserID]NicknameUserEntry + sync.RWMutex +} + +func NewNicknameCache() *NicknameCache { + return &NicknameCache{ + users: make(map[id.UserID]NicknameUserEntry), + RWMutex: sync.RWMutex{}, } +} - b.RLock() - if val, present := b.NicknameMap[mxid]; present { - b.RUnlock() +// note: cache is not locked here +func (c *NicknameCache) retrieveDisplaynameFromCache(channelID id.RoomID, mxid id.UserID) string { + var cachedEntry *NicknameCacheEntry = nil + + c.RLock() + if user, userPresent := c.users[mxid]; userPresent { + // try first the name of the user in the room, then globally + if roomCachedEntry, roomPresent := user.perChannel[channelID]; roomPresent { + cachedEntry = &roomCachedEntry + } else if user.globalEntry != nil { + cachedEntry = user.globalEntry + } + } + c.RUnlock() - return val.displayName + if cachedEntry == nil { + return "" + } + + if cachedEntry.conflictWithOtherUsername { + // TODO: the current behavior is that only users with clashing usernames and *that have + // spoken since the bridge started* will get their mxids shown, and this doesn't + // feel right + return fmt.Sprintf("%s (%s)", cachedEntry.displayName, mxid) } - b.RUnlock() + return cachedEntry.displayName +} + +func (b *Bmatrix) retrieveGlobalDisplayname(mxid id.UserID) string { displayName, err := b.mc.GetDisplayName(mxid) var httpError *matrix.HTTPError if errors.As(err, &httpError) { @@ -72,127 +105,198 @@ func (b *Bmatrix) getDisplayName(mxid string) string { } if err != nil { - return b.cacheDisplayName(mxid, mxid[1:]) + return string(mxid)[1:] } - return b.cacheDisplayName(mxid, displayName.DisplayName) + return displayName.DisplayName } -// cacheDisplayName stores the mapping between a mxid and a display name, to be reused later without performing a query to the homserver. -// Note that old entries are cleaned when this function is called. -func (b *Bmatrix) cacheDisplayName(mxid string, displayName string) string { - now := time.Now() +// getDisplayName retrieves the displayName for mxid, querying the homeserver if the mxid is not in the cache. +func (b *Bmatrix) getDisplayName(channelID id.RoomID, mxid id.UserID) string { + if b.GetBool("UseUserName") { + return string(mxid)[1:] + } - // scan to delete old entries, to stop memory usage from becoming too high with old entries. - // In addition, we also detect if another user have the same username, and if so, we append their mxids to their usernames to differentiate them. - toDelete := []string{} - conflict := false + displayname := b.NicknameCache.retrieveDisplaynameFromCache(channelID, mxid) + if displayname != "" { + return displayname + } - b.Lock() - for mxid, v := range b.NicknameMap { - // to prevent username reuse across matrix servers - or even on the same server, append - // the mxid to the username when there is a conflict - if v.displayName == displayName { - conflict = true - // TODO: it would be nice to be able to rename previous messages from this user. - // The current behavior is that only users with clashing usernames and *that have spoken since the bridge last started* will get their mxids shown, and I don't know if that's the expected behavior. - v.displayName = fmt.Sprintf("%s (%s)", displayName, mxid) - b.NicknameMap[mxid] = v + // retrieve the global display name + return b.cacheDisplayName("", mxid, b.retrieveGlobalDisplayname(mxid)) +} + +// scan to delete old entries, to stop memory usage from becoming high with obsolete entries. +// note: assume the cache is already write-locked +// TODO: should we update the timestamp when the entry is used? +func (c *NicknameCache) clearObsoleteEntries(mxid id.UserID) { + // we have a "off-by-one" to account for when the user being added to the + // cache already have obsolete cache entries, as we want to keep it because + // we will be refreshing it in a minute + if len(c.users) <= MaxNumberOfUsersInCache+1 { + return + } + + usersLastTimestamp := make(map[id.UserID]int64, len(c.users)) + // compute the last updated timestamp entry for each user + for mxidIter, NicknameCacheIter := range c.users { + userLastTimestamp := time.Unix(0, 0) + for _, userInChannelCacheEntry := range NicknameCacheIter.perChannel { + if userInChannelCacheEntry.lastUpdated.After(userLastTimestamp) { + userLastTimestamp = userInChannelCacheEntry.lastUpdated + } } - if now.Sub(v.lastUpdated) > 10*time.Minute { - toDelete = append(toDelete, mxid) + if NicknameCacheIter.globalEntry != nil { + if NicknameCacheIter.globalEntry.lastUpdated.After(userLastTimestamp) { + userLastTimestamp = NicknameCacheIter.globalEntry.lastUpdated + } } - } - if conflict { - displayName = fmt.Sprintf("%s (%s)", displayName, mxid) + usersLastTimestamp[mxidIter] = userLastTimestamp.UnixNano() } - for _, v := range toDelete { - delete(b.NicknameMap, v) + // get the limit timestamp before which we must clear entries as obsolete + sortedTimestamps := make([]int64, 0, len(usersLastTimestamp)) + for _, value := range usersLastTimestamp { + sortedTimestamps = append(sortedTimestamps, value) } - - b.NicknameMap[mxid] = NicknameCacheEntry{ - displayName: displayName, - lastUpdated: now, + sort.Slice(sortedTimestamps, func(i, j int) bool { return sortedTimestamps[i] < sortedTimestamps[j] }) + limitTimestamp := sortedTimestamps[len(sortedTimestamps)-MaxNumberOfUsersInCache] + + // delete entries older than the limit + for mxidIter, timestamp := range usersLastTimestamp { + // do not clear the user that we are adding to the cache + if timestamp <= limitTimestamp && mxidIter != mxid { + delete(c.users, mxidIter) + } } - b.Unlock() - - return displayName } -// handleError converts errors into httpError. -//nolint:exhaustivestruct -func handleError(err error) *httpError { - var mErr matrix.HTTPError - if !errors.As(err, &mErr) { - return &httpError{ - Err: "not a HTTPError", +// to prevent username reuse across matrix rooms - or even inside the same room, if a user uses multiple servers - +// identify users with naming conflicts +func (c *NicknameCache) detectConflict(mxid id.UserID, displayName string) bool { + conflict := false + + for mxidIter, NicknameCacheIter := range c.users { + // skip conflict detection against ourselves, obviously + if mxidIter == mxid { + continue } - } - var httpErr httpError + for channelID, userInChannelCacheEntry := range NicknameCacheIter.perChannel { + if userInChannelCacheEntry.displayName == displayName { + userInChannelCacheEntry.conflictWithOtherUsername = true + c.users[mxidIter].perChannel[channelID] = userInChannelCacheEntry + conflict = true + } + } - if err := json.Unmarshal(mErr.Contents, &httpErr); err != nil { - return &httpError{ - Err: "unmarshal failed", + if NicknameCacheIter.globalEntry != nil && NicknameCacheIter.globalEntry.displayName == displayName { + c.users[mxidIter].globalEntry.conflictWithOtherUsername = true + conflict = true } } - return &httpErr + return conflict } -func (b *Bmatrix) containsAttachment(content map[string]interface{}) bool { - // Skip empty messages - if content["msgtype"] == nil { - return false +// cacheDisplayName stores the mapping between a mxid and a display name, to be reused +// later without performing a query to the homeserver. +// Note that old entries are cleaned when this function is called. +func (b *Bmatrix) cacheDisplayName(channelID id.RoomID, mxid id.UserID, displayName string) string { + now := time.Now() + + cache := b.NicknameCache + + cache.Lock() + defer cache.Unlock() + + conflict := cache.detectConflict(mxid, displayName) + + cache.clearObsoleteEntries(mxid) + + var newEntry NicknameUserEntry + if user, userPresent := cache.users[mxid]; userPresent { + newEntry = user + } else { + newEntry = NicknameUserEntry{ + globalEntry: nil, + perChannel: make(map[id.RoomID]NicknameCacheEntry), + } } - // Only allow image,video or file msgtypes - if !(content["msgtype"].(string) == "m.image" || - content["msgtype"].(string) == "m.video" || - content["msgtype"].(string) == "m.file") { - return false + cacheEntry := NicknameCacheEntry{ + displayName: displayName, + lastUpdated: now, + conflictWithOtherUsername: conflict, } - return true + // this is a local (room-specific) display name, let's cache it as such + if channelID == "" { + newEntry.globalEntry = &cacheEntry + } else { + globalDisplayName := b.retrieveGlobalDisplayname(mxid) + // updating the global display name or resetting the room name to the global name + if globalDisplayName == displayName { + delete(newEntry.perChannel, channelID) + newEntry.globalEntry = &cacheEntry + } else { + newEntry.perChannel[channelID] = cacheEntry + } + } + + cache.users[mxid] = newEntry + + return displayName } -// getAvatarURL returns the avatar URL of the specified sender. -func (b *Bmatrix) getAvatarURL(sender string) string { - urlPath := b.mc.BuildURL("profile", sender, "avatar_url") +func (b *Bmatrix) removeDisplayNameFromCache(mxid id.UserID) { + cache := b.NicknameCache - s := struct { - AvatarURL string `json:"avatar_url"` - }{} + cache.Lock() + defer cache.Unlock() - err := b.mc.MakeRequest("GET", urlPath, nil, &s) - if err != nil { - b.Log.Errorf("getAvatarURL failed: %s", err) + delete(cache.users, mxid) +} +// getAvatarURL returns the avatar URL of the specified sender. +func (b *Bmatrix) getAvatarURL(sender id.UserID) string { + url, err := b.mc.GetAvatarURL(sender) + if err != nil { + b.Log.Errorf("Couldn't retrieve the URL of the avatar for MXID %s", sender) return "" } - url := strings.ReplaceAll(s.AvatarURL, "mxc://", b.GetString("Server")+"/_matrix/media/r0/thumbnail/") - if url != "" { - url += "?width=37&height=37&method=crop" - } - - return url + return url.String() } // handleRatelimit handles the ratelimit errors and return if we're ratelimited and the amount of time to sleep func (b *Bmatrix) handleRatelimit(err error) (time.Duration, bool) { - httpErr := handleError(err) - if httpErr.Errcode != "M_LIMIT_EXCEEDED" { + var mErr matrix.HTTPError + if !errors.As(err, &mErr) { + b.Log.Errorf("Received a non-HTTPError, don't know what to make of it:\n%#v", err) + return 0, false + } + + if mErr.RespError.ErrCode != "M_LIMIT_EXCEEDED" { return 0, false } - b.Log.Debugf("ratelimited: %s", httpErr.Err) - b.Log.Infof("getting ratelimited by matrix, sleeping approx %d seconds before retrying", httpErr.RetryAfterMs/1000) + b.Log.Debugf("ratelimited: %s", mErr.RespError.Err) + + // fallback to a one-second delay + retryDelayMs := 1000 + + if retryDelayString, present := mErr.RespError.ExtraData["retry_after_ms"]; present { + if retryDelayInt, correct := retryDelayString.(int); correct && retryDelayInt > retryDelayMs { + retryDelayMs = retryDelayInt + } + } + + b.Log.Infof("getting ratelimited by matrix, sleeping approx %d seconds before retrying", retryDelayMs/1000) - return time.Duration(httpErr.RetryAfterMs) * time.Millisecond, true + return time.Duration(retryDelayMs) * time.Millisecond, true } // retry function will check if we're ratelimited and retries again when backoff time expired @@ -213,3 +317,66 @@ func (b *Bmatrix) retry(f func() error) error { } } } + +type SendMessageEventWrapper struct { + inner *matrix.Client +} + +//nolint: wrapcheck +func (w SendMessageEventWrapper) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*matrix.RespSendEvent, error) { + return w.inner.SendMessageEvent(roomID, eventType, contentJSON) +} + +//nolint: wrapcheck +func (b *Bmatrix) sendMessageEventWithRetries(channel id.RoomID, message event.MessageEventContent, username string) (string, error) { + var ( + resp *matrix.RespSendEvent + client interface { + SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (resp *matrix.RespSendEvent, err error) + } + err error + ) + + b.RLock() + appservice := b.RoomMap[channel].appService + b.RUnlock() + + client = SendMessageEventWrapper{inner: b.mc} + + // only try to send messages through the app Service *once* we have received + // events through it (otherwise we don't really know if the appservice works) + // Additionally, even if we're receiving messages in that room via the appService listener, + // let's check that the appservice "covers" that room + if appservice && b.appService.namespaces.containsRoom(channel) && len(b.appService.namespaces.prefixes) > 0 { + b.Log.Debugf("Sending with appService") + // we take the first prefix + bridgeUserID := fmt.Sprintf("@%s%s:%s", b.appService.namespaces.prefixes[0], id.EncodeUserLocalpart(username), b.appService.appService.HomeserverDomain) + intent := b.appService.appService.Intent(id.UserID(bridgeUserID)) + // if we can't change the display name it's not great but not the end of the world either, ignore it + // TODO: do not perform this action on every message, with an in-memory cache or something + _ = intent.SetDisplayName(username) + client = intent + } else { + applyUsernametoMessage(&message, username) + } + + err = b.retry(func() error { + resp, err = client.SendMessageEvent(channel, event.EventMessage, message) + + return err + }) + if err != nil { + return "", err + } + + return string(resp.EventID), err +} + +func applyUsernametoMessage(newMsg *event.MessageEventContent, username string) { + matrixUsername := newMatrixUsername(username) + + newMsg.Body = matrixUsername.plain + newMsg.Body + if newMsg.FormattedBody != "" { + newMsg.FormattedBody = matrixUsername.formatted + newMsg.FormattedBody + } +} |