summaryrefslogtreecommitdiffstats
path: root/vendor/maunium.net/go/mautrix/appservice
diff options
context:
space:
mode:
authormsglm <msglm@techchud.xyz>2023-10-27 07:08:25 -0500
committermsglm <msglm@techchud.xyz>2023-10-27 07:08:25 -0500
commit032a7e0c1188d3507b8d9a9571f2446a43cf775b (patch)
tree2bd38c01bc7761a6195e426082ce7191ebc765a1 /vendor/maunium.net/go/mautrix/appservice
parent56e7bd01ca09ad52b0c4f48f146a20a4f1b78696 (diff)
downloadmatterbridge-msglm-032a7e0c1188d3507b8d9a9571f2446a43cf775b.tar.gz
matterbridge-msglm-032a7e0c1188d3507b8d9a9571f2446a43cf775b.tar.bz2
matterbridge-msglm-032a7e0c1188d3507b8d9a9571f2446a43cf775b.zip
apply https://github.com/42wim/matterbridge/pull/1864v1.26.0+0.1.0
Diffstat (limited to 'vendor/maunium.net/go/mautrix/appservice')
-rw-r--r--vendor/maunium.net/go/mautrix/appservice/appservice.go350
-rw-r--r--vendor/maunium.net/go/mautrix/appservice/eventprocessor.go175
-rw-r--r--vendor/maunium.net/go/mautrix/appservice/http.go348
-rw-r--r--vendor/maunium.net/go/mautrix/appservice/intent.go419
-rw-r--r--vendor/maunium.net/go/mautrix/appservice/protocol.go152
-rw-r--r--vendor/maunium.net/go/mautrix/appservice/registration.go100
-rw-r--r--vendor/maunium.net/go/mautrix/appservice/txnid.go43
-rw-r--r--vendor/maunium.net/go/mautrix/appservice/websocket.go408
8 files changed, 1995 insertions, 0 deletions
diff --git a/vendor/maunium.net/go/mautrix/appservice/appservice.go b/vendor/maunium.net/go/mautrix/appservice/appservice.go
new file mode 100644
index 00000000..099e4b27
--- /dev/null
+++ b/vendor/maunium.net/go/mautrix/appservice/appservice.go
@@ -0,0 +1,350 @@
+// Copyright (c) 2023 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package appservice
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "net/http/cookiejar"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/gorilla/mux"
+ "github.com/gorilla/websocket"
+ "github.com/rs/zerolog"
+ "golang.org/x/net/publicsuffix"
+ "gopkg.in/yaml.v3"
+ "maunium.net/go/maulogger/v2/maulogadapt"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+// EventChannelSize is the size for the Events channel in Appservice instances.
+var EventChannelSize = 64
+var OTKChannelSize = 4
+
+// Create a blank appservice instance.
+func Create() *AppService {
+ jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
+ as := &AppService{
+ Log: zerolog.Nop(),
+ clients: make(map[id.UserID]*mautrix.Client),
+ intents: make(map[id.UserID]*IntentAPI),
+ HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
+ StateStore: mautrix.NewMemoryStateStore().(StateStore),
+ Router: mux.NewRouter(),
+ UserAgent: mautrix.DefaultUserAgent,
+ txnIDC: NewTransactionIDCache(128),
+ Live: true,
+ Ready: false,
+ ProcessID: getDefaultProcessID(),
+
+ Events: make(chan *event.Event, EventChannelSize),
+ ToDeviceEvents: make(chan *event.Event, EventChannelSize),
+ OTKCounts: make(chan *mautrix.OTKCount, OTKChannelSize),
+ DeviceLists: make(chan *mautrix.DeviceLists, EventChannelSize),
+ QueryHandler: &QueryHandlerStub{},
+ }
+
+ as.Router.HandleFunc("/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
+ as.Router.HandleFunc("/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
+ as.Router.HandleFunc("/users/{userID}", as.GetUser).Methods(http.MethodGet)
+ as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
+ as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
+ as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
+ as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost)
+ as.Router.HandleFunc("/_matrix/app/unstable/fi.mau.msc2659/ping", as.PostPing).Methods(http.MethodPost)
+ as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
+ as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
+
+ return as
+}
+
+// QueryHandler handles room alias and user ID queries from the homeserver.
+type QueryHandler interface {
+ QueryAlias(alias string) bool
+ QueryUser(userID id.UserID) bool
+}
+
+type QueryHandlerStub struct{}
+
+func (qh *QueryHandlerStub) QueryAlias(alias string) bool {
+ return false
+}
+
+func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool {
+ return false
+}
+
+type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
+
+type StateStore interface {
+ mautrix.StateStore
+
+ IsRegistered(userID id.UserID) bool
+ MarkRegistered(userID id.UserID)
+
+ GetPowerLevel(roomID id.RoomID, userID id.UserID) int
+ GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int
+ HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool
+}
+
+// AppService is the main config for all appservices.
+// It also serves as the appservice instance struct.
+type AppService struct {
+ HomeserverDomain string
+ hsURLForClient *url.URL
+ Host HostConfig
+
+ Registration *Registration
+ Log zerolog.Logger
+
+ txnIDC *TransactionIDCache
+
+ Events chan *event.Event
+ ToDeviceEvents chan *event.Event
+ DeviceLists chan *mautrix.DeviceLists
+ OTKCounts chan *mautrix.OTKCount
+ QueryHandler QueryHandler
+ StateStore StateStore
+
+ Router *mux.Router
+ UserAgent string
+ server *http.Server
+ HTTPClient *http.Client
+ botClient *mautrix.Client
+ botIntent *IntentAPI
+
+ DefaultHTTPRetries int
+
+ Live bool
+ Ready bool
+
+ clients map[id.UserID]*mautrix.Client
+ clientsLock sync.RWMutex
+ intents map[id.UserID]*IntentAPI
+ intentsLock sync.RWMutex
+
+ ws *websocket.Conn
+ wsWriteLock sync.Mutex
+ StopWebsocket func(error)
+ websocketHandlers map[string]WebsocketHandler
+ websocketHandlersLock sync.RWMutex
+ websocketRequests map[int]chan<- *WebsocketCommand
+ websocketRequestsLock sync.RWMutex
+ websocketRequestID int32
+ // ProcessID is an identifier sent to the websocket proxy for debugging connections
+ ProcessID string
+
+ DoublePuppetValue string
+ GetProfile func(userID id.UserID, roomID id.RoomID) *event.MemberEventContent
+}
+
+const DoublePuppetKey = "fi.mau.double_puppet_source"
+
+func getDefaultProcessID() string {
+ pid := syscall.Getpid()
+ uid := syscall.Getuid()
+ hostname, _ := os.Hostname()
+ return fmt.Sprintf("%s-%d-%d", hostname, uid, pid)
+}
+
+func (as *AppService) PrepareWebsocket() {
+ as.websocketHandlersLock.Lock()
+ defer as.websocketHandlersLock.Unlock()
+ if as.websocketHandlers == nil {
+ as.websocketHandlers = make(map[string]WebsocketHandler, 32)
+ as.websocketRequests = make(map[int]chan<- *WebsocketCommand)
+ }
+}
+
+// HostConfig contains info about how to host the appservice.
+type HostConfig struct {
+ Hostname string `yaml:"hostname"`
+ Port uint16 `yaml:"port"`
+ TLSKey string `yaml:"tls_key,omitempty"`
+ TLSCert string `yaml:"tls_cert,omitempty"`
+}
+
+// Address gets the whole address of the Appservice.
+func (hc *HostConfig) Address() string {
+ return fmt.Sprintf("%s:%d", hc.Hostname, hc.Port)
+}
+
+func (hc *HostConfig) IsUnixSocket() bool {
+ return strings.HasPrefix(hc.Hostname, "/")
+}
+
+func (hc *HostConfig) IsConfigured() bool {
+ return hc.IsUnixSocket() || hc.Port != 0
+}
+
+// Save saves this config into a file at the given path.
+func (as *AppService) Save(path string) error {
+ data, err := yaml.Marshal(as)
+ if err != nil {
+ return err
+ }
+ return os.WriteFile(path, data, 0644)
+}
+
+// YAML returns the config in YAML format.
+func (as *AppService) YAML() (string, error) {
+ data, err := yaml.Marshal(as)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func (as *AppService) BotMXID() id.UserID {
+ return id.NewUserID(as.Registration.SenderLocalpart, as.HomeserverDomain)
+}
+
+func (as *AppService) makeIntent(userID id.UserID) *IntentAPI {
+ as.intentsLock.Lock()
+ defer as.intentsLock.Unlock()
+
+ intent, ok := as.intents[userID]
+ if ok {
+ return intent
+ }
+
+ localpart, homeserver, err := userID.Parse()
+ if err != nil || len(localpart) == 0 || homeserver != as.HomeserverDomain {
+ if err != nil {
+ as.Log.Error().Err(err).
+ Str("user_id", userID.String()).
+ Msg("Failed to parse user ID")
+ } else if len(localpart) == 0 {
+ as.Log.Error().Err(err).
+ Str("user_id", userID.String()).
+ Msg("Failed to make intent: localpart is empty")
+ } else if homeserver != as.HomeserverDomain {
+ as.Log.Error().Err(err).
+ Str("user_id", userID.String()).
+ Str("expected_homeserver", as.HomeserverDomain).
+ Msg("Failed to make intent: homeserver doesn't match")
+ }
+ return nil
+ }
+ intent = as.NewIntentAPI(localpart)
+ as.intents[userID] = intent
+ return intent
+}
+
+func (as *AppService) Intent(userID id.UserID) *IntentAPI {
+ as.intentsLock.RLock()
+ intent, ok := as.intents[userID]
+ as.intentsLock.RUnlock()
+ if !ok {
+ return as.makeIntent(userID)
+ }
+ return intent
+}
+
+func (as *AppService) BotIntent() *IntentAPI {
+ if as.botIntent == nil {
+ as.botIntent = as.makeIntent(as.BotMXID())
+ }
+ return as.botIntent
+}
+
+func (as *AppService) SetHomeserverURL(homeserverURL string) error {
+ parsedURL, err := url.Parse(homeserverURL)
+ if err != nil {
+ return err
+ }
+
+ as.hsURLForClient = parsedURL
+ if as.hsURLForClient.Scheme == "unix" {
+ as.hsURLForClient.Scheme = "http"
+ as.hsURLForClient.Host = "unix"
+ as.hsURLForClient.Path = ""
+ } else if as.hsURLForClient.Scheme == "" {
+ as.hsURLForClient.Scheme = "https"
+ }
+ as.hsURLForClient.RawPath = parsedURL.EscapedPath()
+
+ jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
+ as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar}
+ if parsedURL.Scheme == "unix" {
+ as.HTTPClient.Transport = &http.Transport{
+ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
+ return net.Dial("unix", parsedURL.Path)
+ },
+ }
+ }
+ return nil
+}
+
+func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client {
+ client := &mautrix.Client{
+ HomeserverURL: as.hsURLForClient,
+ UserID: userID,
+ SetAppServiceUserID: true,
+ AccessToken: as.Registration.AppToken,
+ UserAgent: as.UserAgent,
+ StateStore: as.StateStore,
+ Log: as.Log.With().Str("as_user_id", userID.String()).Logger(),
+ Client: as.HTTPClient,
+ DefaultHTTPRetries: as.DefaultHTTPRetries,
+ }
+ client.Logger = maulogadapt.ZeroAsMau(&client.Log)
+ return client
+}
+
+func (as *AppService) NewExternalMautrixClient(userID id.UserID, token string, homeserverURL string) (*mautrix.Client, error) {
+ client := as.NewMautrixClient(userID)
+ client.AccessToken = token
+ if homeserverURL != "" {
+ client.Client = &http.Client{Timeout: 180 * time.Second}
+ var err error
+ client.HomeserverURL, err = mautrix.ParseAndNormalizeBaseURL(homeserverURL)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return client, nil
+}
+
+func (as *AppService) makeClient(userID id.UserID) *mautrix.Client {
+ as.clientsLock.Lock()
+ defer as.clientsLock.Unlock()
+
+ client, ok := as.clients[userID]
+ if !ok {
+ client = as.NewMautrixClient(userID)
+ as.clients[userID] = client
+ }
+ return client
+}
+
+func (as *AppService) Client(userID id.UserID) *mautrix.Client {
+ as.clientsLock.RLock()
+ client, ok := as.clients[userID]
+ as.clientsLock.RUnlock()
+ if !ok {
+ return as.makeClient(userID)
+ }
+ return client
+}
+
+func (as *AppService) BotClient() *mautrix.Client {
+ if as.botClient == nil {
+ as.botClient = as.makeClient(as.BotMXID())
+ }
+ return as.botClient
+}
diff --git a/vendor/maunium.net/go/mautrix/appservice/eventprocessor.go b/vendor/maunium.net/go/mautrix/appservice/eventprocessor.go
new file mode 100644
index 00000000..437d8536
--- /dev/null
+++ b/vendor/maunium.net/go/mautrix/appservice/eventprocessor.go
@@ -0,0 +1,175 @@
+// Copyright (c) 2023 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package appservice
+
+import (
+ "encoding/json"
+ "runtime/debug"
+
+ "github.com/rs/zerolog"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/event"
+)
+
+type ExecMode uint8
+
+const (
+ AsyncHandlers ExecMode = iota
+ AsyncLoop
+ Sync
+)
+
+type EventHandler = func(evt *event.Event)
+type OTKHandler = func(otk *mautrix.OTKCount)
+type DeviceListHandler = func(lists *mautrix.DeviceLists, since string)
+
+type EventProcessor struct {
+ ExecMode ExecMode
+
+ as *AppService
+ stop chan struct{}
+ handlers map[event.Type][]EventHandler
+
+ otkHandlers []OTKHandler
+ deviceListHandlers []DeviceListHandler
+}
+
+func NewEventProcessor(as *AppService) *EventProcessor {
+ return &EventProcessor{
+ ExecMode: AsyncHandlers,
+ as: as,
+ stop: make(chan struct{}, 1),
+ handlers: make(map[event.Type][]EventHandler),
+
+ otkHandlers: make([]OTKHandler, 0),
+ deviceListHandlers: make([]DeviceListHandler, 0),
+ }
+}
+
+func (ep *EventProcessor) On(evtType event.Type, handler EventHandler) {
+ handlers, ok := ep.handlers[evtType]
+ if !ok {
+ handlers = []EventHandler{handler}
+ } else {
+ handlers = append(handlers, handler)
+ }
+ ep.handlers[evtType] = handlers
+}
+
+func (ep *EventProcessor) PrependHandler(evtType event.Type, handler EventHandler) {
+ handlers, ok := ep.handlers[evtType]
+ if !ok {
+ handlers = []EventHandler{handler}
+ } else {
+ handlers = append([]EventHandler{handler}, handlers...)
+ }
+ ep.handlers[evtType] = handlers
+}
+
+func (ep *EventProcessor) OnOTK(handler OTKHandler) {
+ ep.otkHandlers = append(ep.otkHandlers, handler)
+}
+
+func (ep *EventProcessor) OnDeviceList(handler DeviceListHandler) {
+ ep.deviceListHandlers = append(ep.deviceListHandlers, handler)
+}
+
+func (ep *EventProcessor) recoverFunc(data interface{}) {
+ if err := recover(); err != nil {
+ d, _ := json.Marshal(data)
+ ep.as.Log.Error().
+ Str(zerolog.ErrorStackFieldName, string(debug.Stack())).
+ Interface(zerolog.ErrorFieldName, err).
+ Str("event_content", string(d)).
+ Msg("Panic in Matrix event handler")
+ }
+}
+
+func (ep *EventProcessor) callHandler(handler EventHandler, evt *event.Event) {
+ defer ep.recoverFunc(evt)
+ handler(evt)
+}
+
+func (ep *EventProcessor) callOTKHandler(handler OTKHandler, otk *mautrix.OTKCount) {
+ defer ep.recoverFunc(otk)
+ handler(otk)
+}
+
+func (ep *EventProcessor) callDeviceListHandler(handler DeviceListHandler, dl *mautrix.DeviceLists) {
+ defer ep.recoverFunc(dl)
+ handler(dl, "")
+}
+
+func (ep *EventProcessor) DispatchOTK(otk *mautrix.OTKCount) {
+ for _, handler := range ep.otkHandlers {
+ go ep.callOTKHandler(handler, otk)
+ }
+}
+
+func (ep *EventProcessor) DispatchDeviceList(dl *mautrix.DeviceLists) {
+ for _, handler := range ep.deviceListHandlers {
+ go ep.callDeviceListHandler(handler, dl)
+ }
+}
+
+func (ep *EventProcessor) Dispatch(evt *event.Event) {
+ handlers, ok := ep.handlers[evt.Type]
+ if !ok {
+ return
+ }
+ switch ep.ExecMode {
+ case AsyncHandlers:
+ for _, handler := range handlers {
+ go ep.callHandler(handler, evt)
+ }
+ case AsyncLoop:
+ go func() {
+ for _, handler := range handlers {
+ ep.callHandler(handler, evt)
+ }
+ }()
+ case Sync:
+ for _, handler := range handlers {
+ ep.callHandler(handler, evt)
+ }
+ }
+}
+func (ep *EventProcessor) startEvents() {
+ for {
+ select {
+ case evt := <-ep.as.Events:
+ ep.Dispatch(evt)
+ case <-ep.stop:
+ return
+ }
+ }
+}
+
+func (ep *EventProcessor) startEncryption() {
+ for {
+ select {
+ case evt := <-ep.as.ToDeviceEvents:
+ ep.Dispatch(evt)
+ case otk := <-ep.as.OTKCounts:
+ ep.DispatchOTK(otk)
+ case dl := <-ep.as.DeviceLists:
+ ep.DispatchDeviceList(dl)
+ case <-ep.stop:
+ return
+ }
+ }
+}
+
+func (ep *EventProcessor) Start() {
+ go ep.startEvents()
+ go ep.startEncryption()
+}
+
+func (ep *EventProcessor) Stop() {
+ close(ep.stop)
+}
diff --git a/vendor/maunium.net/go/mautrix/appservice/http.go b/vendor/maunium.net/go/mautrix/appservice/http.go
new file mode 100644
index 00000000..06ac7788
--- /dev/null
+++ b/vendor/maunium.net/go/mautrix/appservice/http.go
@@ -0,0 +1,348 @@
+// Copyright (c) 2023 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package appservice
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/gorilla/mux"
+ "github.com/rs/zerolog"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+// Start starts the HTTP server that listens for calls from the Matrix homeserver.
+func (as *AppService) Start() {
+ as.server = &http.Server{
+ Handler: as.Router,
+ }
+ var err error
+ if as.Host.IsUnixSocket() {
+ err = as.listenUnix()
+ } else {
+ as.server.Addr = as.Host.Address()
+ err = as.listenTCP()
+ }
+ if err != nil && !errors.Is(err, http.ErrServerClosed) {
+ as.Log.Error().Err(err).Msg("Error in HTTP listener")
+ } else {
+ as.Log.Debug().Msg("HTTP listener stopped")
+ }
+}
+
+func (as *AppService) listenUnix() error {
+ socket := as.Host.Hostname
+ _ = syscall.Unlink(socket)
+ defer func() {
+ _ = syscall.Unlink(socket)
+ }()
+ listener, err := net.Listen("unix", socket)
+ if err != nil {
+ return err
+ }
+ as.Log.Info().Str("socket", socket).Msg("Starting unix socket HTTP listener")
+ return as.server.Serve(listener)
+}
+
+func (as *AppService) listenTCP() error {
+ if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 {
+ as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener")
+ return as.server.ListenAndServe()
+ } else {
+ as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener with TLS")
+ return as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey)
+ }
+}
+
+func (as *AppService) Stop() {
+ if as.server == nil {
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ _ = as.server.Shutdown(ctx)
+ as.server = nil
+}
+
+// CheckServerToken checks if the given request originated from the Matrix homeserver.
+func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) {
+ authHeader := r.Header.Get("Authorization")
+ if len(authHeader) > 0 && strings.HasPrefix(authHeader, "Bearer ") {
+ isValid = authHeader[len("Bearer "):] == as.Registration.ServerToken
+ } else {
+ queryToken := r.URL.Query().Get("access_token")
+ if len(queryToken) > 0 {
+ isValid = queryToken == as.Registration.ServerToken
+ } else {
+ Error{
+ ErrorCode: ErrUnknownToken,
+ HTTPStatus: http.StatusForbidden,
+ Message: "Missing access token",
+ }.Write(w)
+ return
+ }
+ }
+ if !isValid {
+ Error{
+ ErrorCode: ErrUnknownToken,
+ HTTPStatus: http.StatusForbidden,
+ Message: "Incorrect access token",
+ }.Write(w)
+ }
+ return
+}
+
+// PutTransaction handles a /transactions PUT call from the homeserver.
+func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
+ if !as.CheckServerToken(w, r) {
+ return
+ }
+
+ vars := mux.Vars(r)
+ txnID := vars["txnID"]
+ if len(txnID) == 0 {
+ Error{
+ ErrorCode: ErrNoTransactionID,
+ HTTPStatus: http.StatusBadRequest,
+ Message: "Missing transaction ID",
+ }.Write(w)
+ return
+ }
+ defer r.Body.Close()
+ body, err := io.ReadAll(r.Body)
+ if err != nil || len(body) == 0 {
+ Error{
+ ErrorCode: ErrNotJSON,
+ HTTPStatus: http.StatusBadRequest,
+ Message: "Missing request body",
+ }.Write(w)
+ return
+ }
+ log := as.Log.With().Str("transaction_id", txnID).Logger()
+ ctx := context.Background()
+ ctx = log.WithContext(ctx)
+ if as.txnIDC.IsProcessed(txnID) {
+ // Duplicate transaction ID: no-op
+ WriteBlankOK(w)
+ log.Debug().Msg("Ignoring duplicate transaction")
+ return
+ }
+
+ var txn Transaction
+ err = json.Unmarshal(body, &txn)
+ if err != nil {
+ log.Error().Err(err).Msg("Failed to parse transaction content")
+ Error{
+ ErrorCode: ErrBadJSON,
+ HTTPStatus: http.StatusBadRequest,
+ Message: "Failed to parse body JSON",
+ }.Write(w)
+ } else {
+ as.handleTransaction(ctx, txnID, &txn)
+ WriteBlankOK(w)
+ }
+}
+
+func (as *AppService) handleTransaction(ctx context.Context, id string, txn *Transaction) {
+ log := zerolog.Ctx(ctx)
+ log.Debug().Object("content", txn).Msg("Starting handling of transaction")
+ if as.Registration.EphemeralEvents {
+ if txn.EphemeralEvents != nil {
+ as.handleEvents(ctx, txn.EphemeralEvents, event.EphemeralEventType)
+ } else if txn.MSC2409EphemeralEvents != nil {
+ as.handleEvents(ctx, txn.MSC2409EphemeralEvents, event.EphemeralEventType)
+ }
+ if txn.ToDeviceEvents != nil {
+ as.handleEvents(ctx, txn.ToDeviceEvents, event.ToDeviceEventType)
+ } else if txn.MSC2409ToDeviceEvents != nil {
+ as.handleEvents(ctx, txn.MSC2409ToDeviceEvents, event.ToDeviceEventType)
+ }
+ }
+ as.handleEvents(ctx, txn.Events, event.UnknownEventType)
+ if txn.DeviceLists != nil {
+ as.handleDeviceLists(ctx, txn.DeviceLists)
+ } else if txn.MSC3202DeviceLists != nil {
+ as.handleDeviceLists(ctx, txn.MSC3202DeviceLists)
+ }
+ if txn.DeviceOTKCount != nil {
+ as.handleOTKCounts(ctx, txn.DeviceOTKCount)
+ } else if txn.MSC3202DeviceOTKCount != nil {
+ as.handleOTKCounts(ctx, txn.MSC3202DeviceOTKCount)
+ }
+ as.txnIDC.MarkProcessed(id)
+ log.Debug().Msg("Finished dispatching events from transaction")
+}
+
+func (as *AppService) handleOTKCounts(ctx context.Context, otks OTKCountMap) {
+ for userID, devices := range otks {
+ for deviceID, otkCounts := range devices {
+ otkCounts.UserID = userID
+ otkCounts.DeviceID = deviceID
+ select {
+ case as.OTKCounts <- &otkCounts:
+ default:
+ zerolog.Ctx(ctx).Warn().
+ Str("user_id", userID.String()).
+ Msg("Dropped OTK count update for user because channel is full")
+ }
+ }
+ }
+}
+
+func (as *AppService) handleDeviceLists(ctx context.Context, dl *mautrix.DeviceLists) {
+ select {
+ case as.DeviceLists <- dl:
+ default:
+ zerolog.Ctx(ctx).Warn().Msg("Dropped device list update because channel is full")
+ }
+}
+
+func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, defaultTypeClass event.TypeClass) {
+ log := zerolog.Ctx(ctx)
+ for _, evt := range evts {
+ evt.Mautrix.ReceivedAt = time.Now()
+ if defaultTypeClass != event.UnknownEventType {
+ evt.Type.Class = defaultTypeClass
+ } else if evt.StateKey != nil {
+ evt.Type.Class = event.StateEventType
+ } else {
+ evt.Type.Class = event.MessageEventType
+ }
+ err := evt.Content.ParseRaw(evt.Type)
+ if errors.Is(err, event.ErrUnsupportedContentType) {
+ log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event")
+ } else if err != nil {
+ log.Warn().Err(err).
+ Str("event_id", evt.ID.String()).
+ Str("event_type", evt.Type.Type).
+ Str("event_type_class", evt.Type.Class.Name()).
+ Msg("Failed to parse content of event")
+ }
+
+ if evt.Type.IsState() {
+ // TODO remove this check after making sure the log doesn't happen
+ historical, ok := evt.Content.Raw["org.matrix.msc2716.historical"].(bool)
+ if ok && historical {
+ log.Warn().
+ Str("event_id", evt.ID.String()).
+ Str("event_type", evt.Type.Type).
+ Str("state_key", evt.GetStateKey()).
+ Msg("Received historical state event")
+ } else {
+ mautrix.UpdateStateStore(as.StateStore, evt)
+ }
+ }
+ var ch chan *event.Event
+ if evt.Type.Class == event.ToDeviceEventType {
+ ch = as.ToDeviceEvents
+ } else {
+ ch = as.Events
+ }
+ select {
+ case ch <- evt:
+ default:
+ log.Warn().
+ Str("event_id", evt.ID.String()).
+ Str("event_type", evt.Type.Type).
+ Str("event_type_class", evt.Type.Class.Name()).
+ Msg("Event channel is full")
+ ch <- evt
+ }
+ }
+}
+
+// GetRoom handles a /rooms GET call from the homeserver.
+func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
+ if !as.CheckServerToken(w, r) {
+ return
+ }
+
+ vars := mux.Vars(r)
+ roomAlias := vars["roomAlias"]
+ ok := as.QueryHandler.QueryAlias(roomAlias)
+ if ok {
+ WriteBlankOK(w)
+ } else {
+ Error{
+ ErrorCode: ErrUnknown,
+ HTTPStatus: http.StatusNotFound,
+ }.Write(w)
+ }
+}
+
+// GetUser handles a /users GET call from the homeserver.
+func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
+ if !as.CheckServerToken(w, r) {
+ return
+ }
+
+ vars := mux.Vars(r)
+ userID := id.UserID(vars["userID"])
+ ok := as.QueryHandler.QueryUser(userID)
+ if ok {
+ WriteBlankOK(w)
+ } else {
+ Error{
+ ErrorCode: ErrUnknown,
+ HTTPStatus: http.StatusNotFound,
+ }.Write(w)
+ }
+}
+
+func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) {
+ if !as.CheckServerToken(w, r) {
+ return
+ }
+ body, err := io.ReadAll(r.Body)
+ if err != nil || len(body) == 0 || !json.Valid(body) {
+ Error{
+ ErrorCode: ErrNotJSON,
+ HTTPStatus: http.StatusBadRequest,
+ Message: "Missing request body",
+ }.Write(w)
+ return
+ }
+
+ var txn mautrix.ReqAppservicePing
+ _ = json.Unmarshal(body, &txn)
+ as.Log.Debug().Str("txn_id", txn.TxnID).Msg("Received ping from homeserver")
+
+ w.Header().Add("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("{}"))
+}
+
+func (as *AppService) GetLive(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("Content-Type", "application/json")
+ if as.Live {
+ w.WriteHeader(http.StatusOK)
+ } else {
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ w.Write([]byte("{}"))
+}
+
+func (as *AppService) GetReady(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("Content-Type", "application/json")
+ if as.Ready {
+ w.WriteHeader(http.StatusOK)
+ } else {
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ w.Write([]byte("{}"))
+}
diff --git a/vendor/maunium.net/go/mautrix/appservice/intent.go b/vendor/maunium.net/go/mautrix/appservice/intent.go
new file mode 100644
index 00000000..af6fea37
--- /dev/null
+++ b/vendor/maunium.net/go/mautrix/appservice/intent.go
@@ -0,0 +1,419 @@
+// Copyright (c) 2020 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package appservice
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type IntentAPI struct {
+ *mautrix.Client
+ bot *mautrix.Client
+ as *AppService
+ Localpart string
+ UserID id.UserID
+
+ IsCustomPuppet bool
+}
+
+func (as *AppService) NewIntentAPI(localpart string) *IntentAPI {
+ userID := id.NewUserID(localpart, as.HomeserverDomain)
+ bot := as.BotClient()
+ if userID == bot.UserID {
+ bot = nil
+ }
+ return &IntentAPI{
+ Client: as.Client(userID),
+ bot: bot,
+ as: as,
+ Localpart: localpart,
+ UserID: userID,
+
+ IsCustomPuppet: false,
+ }
+}
+
+func (intent *IntentAPI) Register() error {
+ _, _, err := intent.Client.Register(&mautrix.ReqRegister{
+ Username: intent.Localpart,
+ Type: mautrix.AuthTypeAppservice,
+ InhibitLogin: true,
+ })
+ return err
+}
+
+func (intent *IntentAPI) EnsureRegistered() error {
+ if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) {
+ return nil
+ }
+
+ err := intent.Register()
+ if err != nil && !errors.Is(err, mautrix.MUserInUse) {
+ return fmt.Errorf("failed to ensure registered: %w", err)
+ }
+ intent.as.StateStore.MarkRegistered(intent.UserID)
+ return nil
+}
+
+type EnsureJoinedParams struct {
+ IgnoreCache bool
+ BotOverride *mautrix.Client
+}
+
+func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedParams) error {
+ var params EnsureJoinedParams
+ if len(extra) > 1 {
+ panic("invalid number of extra parameters")
+ } else if len(extra) == 1 {
+ params = extra[0]
+ }
+ if intent.as.StateStore.IsInRoom(roomID, intent.UserID) && !params.IgnoreCache {
+ return nil
+ }
+
+ if err := intent.EnsureRegistered(); err != nil {
+ return fmt.Errorf("failed to ensure joined: %w", err)
+ }
+
+ resp, err := intent.JoinRoomByID(roomID)
+ if err != nil {
+ bot := intent.bot
+ if params.BotOverride != nil {
+ bot = params.BotOverride
+ }
+ if !errors.Is(err, mautrix.MForbidden) || bot == nil {
+ return fmt.Errorf("failed to ensure joined: %w", err)
+ }
+ _, inviteErr := bot.InviteUser(roomID, &mautrix.ReqInviteUser{
+ UserID: intent.UserID,
+ })
+ if inviteErr != nil {
+ return fmt.Errorf("failed to invite in ensure joined: %w", inviteErr)
+ }
+ resp, err = intent.JoinRoomByID(roomID)
+ if err != nil {
+ return fmt.Errorf("failed to ensure joined after invite: %w", err)
+ }
+ }
+ intent.as.StateStore.SetMembership(resp.RoomID, intent.UserID, event.MembershipJoin)
+ return nil
+}
+
+func (intent *IntentAPI) AddDoublePuppetValue(into interface{}) interface{} {
+ if !intent.IsCustomPuppet || intent.as.DoublePuppetValue == "" {
+ return into
+ }
+ switch val := into.(type) {
+ case *map[string]interface{}:
+ if *val == nil {
+ valNonPtr := make(map[string]interface{})
+ *val = valNonPtr
+ }
+ (*val)[DoublePuppetKey] = intent.as.DoublePuppetValue
+ return val
+ case map[string]interface{}:
+ val[DoublePuppetKey] = intent.as.DoublePuppetValue
+ return val
+ case *event.Content:
+ if val.Raw == nil {
+ val.Raw = make(map[string]interface{})
+ }
+ val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue
+ return val
+ case event.Content:
+ if val.Raw == nil {
+ val.Raw = make(map[string]interface{})
+ }
+ val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue
+ return val
+ default:
+ return &event.Content{
+ Raw: map[string]interface{}{
+ DoublePuppetKey: intent.as.DoublePuppetValue,
+ },
+ Parsed: val,
+ }
+ }
+}
+
+func (intent *IntentAPI) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
+ if err := intent.EnsureJoined(roomID); err != nil {
+ return nil, err
+ }
+ contentJSON = intent.AddDoublePuppetValue(contentJSON)
+ return intent.Client.SendMessageEvent(roomID, eventType, contentJSON)
+}
+
+func (intent *IntentAPI) SendMassagedMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
+ if err := intent.EnsureJoined(roomID); err != nil {
+ return nil, err
+ }
+ contentJSON = intent.AddDoublePuppetValue(contentJSON)
+ return intent.Client.SendMessageEvent(roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
+}
+
+func (intent *IntentAPI) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
+ if eventType != event.StateMember || stateKey != string(intent.UserID) {
+ if err := intent.EnsureJoined(roomID); err != nil {
+ return nil, err
+ }
+ }
+ contentJSON = intent.AddDoublePuppetValue(contentJSON)
+ return intent.Client.SendStateEvent(roomID, eventType, stateKey, contentJSON)
+}
+
+func (intent *IntentAPI) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
+ if err := intent.EnsureJoined(roomID); err != nil {
+ return nil, err
+ }
+ contentJSON = intent.AddDoublePuppetValue(contentJSON)
+ return intent.Client.SendMassagedStateEvent(roomID, eventType, stateKey, contentJSON, ts)
+}
+
+func (intent *IntentAPI) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error {
+ if err := intent.EnsureJoined(roomID); err != nil {
+ return err
+ }
+ return intent.Client.StateEvent(roomID, eventType, stateKey, outContent)
+}
+
+func (intent *IntentAPI) State(roomID id.RoomID) (mautrix.RoomStateMap, error) {
+ if err := intent.EnsureJoined(roomID); err != nil {
+ return nil, err
+ }
+ return intent.Client.State(roomID)
+}
+
+func (intent *IntentAPI) SendCustomMembershipEvent(roomID id.RoomID, target id.UserID, membership event.Membership, reason string, extraContent ...map[string]interface{}) (*mautrix.RespSendEvent, error) {
+ content := &event.MemberEventContent{
+ Membership: membership,
+ Reason: reason,
+ }
+ memberContent, ok := intent.as.StateStore.TryGetMember(roomID, target)
+ if !ok {
+ if intent.as.GetProfile != nil {
+ memberContent = intent.as.GetProfile(target, roomID)
+ ok = memberContent != nil
+ }
+ if !ok {
+ profile, err := intent.GetProfile(target)
+ if err != nil {
+ intent.Log.Debug().Err(err).
+ Str("target_user_id", target.String()).
+ Str("membership", string(membership)).
+ Msg("Failed to get profile to fill new membership event")
+ } else {
+ content.Displayname = profile.DisplayName
+ content.AvatarURL = profile.AvatarURL.CUString()
+ }
+ }
+ }
+ if ok && memberContent != nil {
+ content.Displayname = memberContent.Displayname
+ content.AvatarURL = memberContent.AvatarURL
+ }
+ var extra map[string]interface{}
+ if len(extraContent) > 0 {
+ extra = extraContent[0]
+ }
+ return intent.SendStateEvent(roomID, event.StateMember, target.String(), &event.Content{
+ Parsed: content,
+ Raw: extra,
+ })
+}
+
+func (intent *IntentAPI) JoinRoomByID(roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) {
+ if intent.IsCustomPuppet || len(extraContent) > 0 {
+ _, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipJoin, "", extraContent...)
+ return &mautrix.RespJoinRoom{}, err
+ }
+ return intent.Client.JoinRoomByID(roomID)
+}
+
+func (intent *IntentAPI) LeaveRoom(roomID id.RoomID, extra ...interface{}) (resp *mautrix.RespLeaveRoom, err error) {
+ var extraContent map[string]interface{}
+ leaveReq := &mautrix.ReqLeave{}
+ for _, item := range extra {
+ switch val := item.(type) {
+ case map[string]interface{}:
+ extraContent = val
+ case *mautrix.ReqLeave:
+ leaveReq = val
+ }
+ }
+ if intent.IsCustomPuppet || extraContent != nil {
+ _, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipLeave, leaveReq.Reason, extraContent)
+ return &mautrix.RespLeaveRoom{}, err
+ }
+ return intent.Client.LeaveRoom(roomID, leaveReq)
+}
+
+func (intent *IntentAPI) InviteUser(roomID id.RoomID, req *mautrix.ReqInviteUser, extraContent ...map[string]interface{}) (resp *mautrix.RespInviteUser, err error) {
+ if intent.IsCustomPuppet || len(extraContent) > 0 {
+ _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipInvite, req.Reason, extraContent...)
+ return &mautrix.RespInviteUser{}, err
+ }
+ return intent.Client.InviteUser(roomID, req)
+}
+
+func (intent *IntentAPI) KickUser(roomID id.RoomID, req *mautrix.ReqKickUser, extraContent ...map[string]interface{}) (resp *mautrix.RespKickUser, err error) {
+ if intent.IsCustomPuppet || len(extraContent) > 0 {
+ _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...)
+ return &mautrix.RespKickUser{}, err
+ }
+ return intent.Client.KickUser(roomID, req)
+}
+
+func (intent *IntentAPI) BanUser(roomID id.RoomID, req *mautrix.ReqBanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespBanUser, err error) {
+ if intent.IsCustomPuppet || len(extraContent) > 0 {
+ _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipBan, req.Reason, extraContent...)
+ return &mautrix.RespBanUser{}, err
+ }
+ return intent.Client.BanUser(roomID, req)
+}
+
+func (intent *IntentAPI) UnbanUser(roomID id.RoomID, req *mautrix.ReqUnbanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespUnbanUser, err error) {
+ if intent.IsCustomPuppet || len(extraContent) > 0 {
+ _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...)
+ return &mautrix.RespUnbanUser{}, err
+ }
+ return intent.Client.UnbanUser(roomID, req)
+}
+
+func (intent *IntentAPI) Member(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
+ member, ok := intent.as.StateStore.TryGetMember(roomID, userID)
+ if !ok {
+ _ = intent.StateEvent(roomID, event.StateMember, string(userID), &member)
+ }
+ return member
+}
+
+func (intent *IntentAPI) PowerLevels(roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) {
+ pl = intent.as.StateStore.GetPowerLevels(roomID)
+ if pl == nil {
+ pl = &event.PowerLevelsEventContent{}
+ err = intent.StateEvent(roomID, event.StatePowerLevels, "", pl)
+ }
+ return
+}
+
+func (intent *IntentAPI) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) (resp *mautrix.RespSendEvent, err error) {
+ return intent.SendStateEvent(roomID, event.StatePowerLevels, "", &levels)
+}
+
+func (intent *IntentAPI) SetPowerLevel(roomID id.RoomID, userID id.UserID, level int) (*mautrix.RespSendEvent, error) {
+ pl, err := intent.PowerLevels(roomID)
+ if err != nil {
+ return nil, err
+ }
+
+ if pl.GetUserLevel(userID) != level {
+ pl.SetUserLevel(userID, level)
+ return intent.SendStateEvent(roomID, event.StatePowerLevels, "", &pl)
+ }
+ return nil, nil
+}
+
+func (intent *IntentAPI) SendText(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) {
+ if err := intent.EnsureJoined(roomID); err != nil {
+ return nil, err
+ }
+ return intent.Client.SendText(roomID, text)
+}
+
+func (intent *IntentAPI) SendNotice(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) {
+ if err := intent.EnsureJoined(roomID); err != nil {
+ return nil, err
+ }
+ return intent.Client.SendNotice(roomID, text)
+}
+
+func (intent *IntentAPI) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...mautrix.ReqRedact) (*mautrix.RespSendEvent, error) {
+ if err := intent.EnsureJoined(roomID); err != nil {
+ return nil, err
+ }
+ var req mautrix.ReqRedact
+ if len(extra) > 0 {
+ req = extra[0]
+ }
+ intent.AddDoublePuppetValue(&req.Extra)
+ return intent.Client.RedactEvent(roomID, eventID, req)
+}
+
+func (intent *IntentAPI) SetRoomName(roomID id.RoomID, roomName string) (*mautrix.RespSendEvent, error) {
+ return intent.SendStateEvent(roomID, event.StateRoomName, "", map[string]interface{}{
+ "name": roomName,
+ })
+}
+
+func (intent *IntentAPI) SetRoomAvatar(roomID id.RoomID, avatarURL id.ContentURI) (*mautrix.RespSendEvent, error) {
+ return intent.SendStateEvent(roomID, event.StateRoomAvatar, "", map[string]interface{}{
+ "url": avatarURL.String(),
+ })
+}
+
+func (intent *IntentAPI) SetRoomTopic(roomID id.RoomID, topic string) (*mautrix.RespSendEvent, error) {
+ return intent.SendStateEvent(roomID, event.StateTopic, "", map[string]interface{}{
+ "topic": topic,
+ })
+}
+
+func (intent *IntentAPI) SetDisplayName(displayName string) error {
+ if err := intent.EnsureRegistered(); err != nil {
+ return err
+ }
+ resp, err := intent.Client.GetOwnDisplayName()
+ if err != nil {
+ return fmt.Errorf("failed to check current displayname: %w", err)
+ } else if resp.DisplayName == displayName {
+ // No need to update
+ return nil
+ }
+ return intent.Client.SetDisplayName(displayName)
+}
+
+func (intent *IntentAPI) SetAvatarURL(avatarURL id.ContentURI) error {
+ if err := intent.EnsureRegistered(); err != nil {
+ return err
+ }
+ resp, err := intent.Client.GetOwnAvatarURL()
+ if err != nil {
+ return fmt.Errorf("failed to check current avatar URL: %w", err)
+ } else if resp.FileID == avatarURL.FileID && resp.Homeserver == avatarURL.Homeserver {
+ // No need to update
+ return nil
+ }
+ return intent.Client.SetAvatarURL(avatarURL)
+}
+
+func (intent *IntentAPI) Whoami() (*mautrix.RespWhoami, error) {
+ if err := intent.EnsureRegistered(); err != nil {
+ return nil, err
+ }
+ return intent.Client.Whoami()
+}
+
+func (intent *IntentAPI) EnsureInvited(roomID id.RoomID, userID id.UserID) error {
+ if !intent.as.StateStore.IsInvited(roomID, userID) {
+ _, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{
+ UserID: userID,
+ })
+ if httpErr, ok := err.(mautrix.HTTPError); ok &&
+ httpErr.RespError != nil &&
+ (strings.Contains(httpErr.RespError.Err, "is already in the room") || strings.Contains(httpErr.RespError.Err, "is already joined to room")) {
+ return nil
+ }
+ return err
+ }
+ return nil
+}
diff --git a/vendor/maunium.net/go/mautrix/appservice/protocol.go b/vendor/maunium.net/go/mautrix/appservice/protocol.go
new file mode 100644
index 00000000..7a9891ef
--- /dev/null
+++ b/vendor/maunium.net/go/mautrix/appservice/protocol.go
@@ -0,0 +1,152 @@
+// Copyright (c) 2023 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package appservice
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "github.com/rs/zerolog"
+
+ "maunium.net/go/mautrix"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+)
+
+type OTKCountMap = map[id.UserID]map[id.DeviceID]mautrix.OTKCount
+type FallbackKeyMap = map[id.UserID]map[id.DeviceID][]id.KeyAlgorithm
+
+// Transaction contains a list of events.
+type Transaction struct {
+ Events []*event.Event `json:"events"`
+ EphemeralEvents []*event.Event `json:"ephemeral,omitempty"`
+ ToDeviceEvents []*event.Event `json:"to_device,omitempty"`
+
+ DeviceLists *mautrix.DeviceLists `json:"device_lists,omitempty"`
+ DeviceOTKCount OTKCountMap `json:"device_one_time_keys_count,omitempty"`
+ FallbackKeys FallbackKeyMap `json:"device_unused_fallback_key_types,omitempty"`
+
+ MSC2409EphemeralEvents []*event.Event `json:"de.sorunome.msc2409.ephemeral,omitempty"`
+ MSC2409ToDeviceEvents []*event.Event `json:"de.sorunome.msc2409.to_device,omitempty"`
+ MSC3202DeviceLists *mautrix.DeviceLists `json:"org.matrix.msc3202.device_lists,omitempty"`
+ MSC3202DeviceOTKCount OTKCountMap `json:"org.matrix.msc3202.device_one_time_keys_count,omitempty"`
+ MSC3202FallbackKeys FallbackKeyMap `json:"org.matrix.msc3202.device_unused_fallback_key_types,omitempty"`
+}
+
+func (txn *Transaction) MarshalZerologObject(ctx *zerolog.Event) {
+ ctx.Int("pdu", len(txn.Events))
+ if txn.EphemeralEvents != nil {
+ ctx.Int("edu", len(txn.EphemeralEvents))
+ } else if txn.MSC2409EphemeralEvents != nil {
+ ctx.Int("unstable_edu", len(txn.MSC2409EphemeralEvents))
+ }
+ if txn.ToDeviceEvents != nil {
+ ctx.Int("to_device", len(txn.ToDeviceEvents))
+ } else if txn.MSC2409ToDeviceEvents != nil {
+ ctx.Int("unstable_to_device", len(txn.MSC2409ToDeviceEvents))
+ }
+ if len(txn.DeviceOTKCount) > 0 {
+ ctx.Int("otk_count_users", len(txn.DeviceOTKCount))
+ } else if len(txn.MSC3202DeviceOTKCount) > 0 {
+ ctx.Int("unstable_otk_count_users", len(txn.MSC3202DeviceOTKCount))
+ }
+ if txn.DeviceLists != nil {
+ ctx.Int("device_changes", len(txn.DeviceLists.Changed))
+ } else if txn.MSC3202DeviceLists != nil {
+ ctx.Int("unstable_device_changes", len(txn.MSC3202DeviceLists.Changed))
+ }
+ if txn.FallbackKeys != nil {
+ ctx.Int("fallback_key_users", len(txn.FallbackKeys))
+ } else if txn.MSC3202FallbackKeys != nil {
+ ctx.Int("unstable_fallback_key_users", len(txn.MSC3202FallbackKeys))
+ }
+}
+
+func (txn *Transaction) ContentString() string {
+ var parts []string
+ if len(txn.Events) > 0 {
+ parts = append(parts, fmt.Sprintf("%d PDUs", len(txn.Events)))
+ }
+ if len(txn.EphemeralEvents) > 0 {
+ parts = append(parts, fmt.Sprintf("%d EDUs", len(txn.EphemeralEvents)))
+ } else if len(txn.MSC2409EphemeralEvents) > 0 {
+ parts = append(parts, fmt.Sprintf("%d EDUs (unstable)", len(txn.MSC2409EphemeralEvents)))
+ }
+ if len(txn.ToDeviceEvents) > 0 {
+ parts = append(parts, fmt.Sprintf("%d to-device events", len(txn.ToDeviceEvents)))
+ } else if len(txn.MSC2409ToDeviceEvents) > 0 {
+ parts = append(parts, fmt.Sprintf("%d to-device events (unstable)", len(txn.MSC2409ToDeviceEvents)))
+ }
+ if len(txn.DeviceOTKCount) > 0 {
+ parts = append(parts, fmt.Sprintf("OTK counts for %d users", len(txn.DeviceOTKCount)))
+ } else if len(txn.MSC3202DeviceOTKCount) > 0 {
+ parts = append(parts, fmt.Sprintf("OTK counts for %d users (unstable)", len(txn.MSC3202DeviceOTKCount)))
+ }
+ if txn.DeviceLists != nil {
+ parts = append(parts, fmt.Sprintf("%d device list changes", len(txn.DeviceLists.Changed)))
+ } else if txn.MSC3202DeviceLists != nil {
+ parts = append(parts, fmt.Sprintf("%d device list changes (unstable)", len(txn.MSC3202DeviceLists.Changed)))
+ }
+ if txn.FallbackKeys != nil {
+ parts = append(parts, fmt.Sprintf("unused fallback key counts for %d users", len(txn.FallbackKeys)))
+ } else if txn.MSC3202FallbackKeys != nil {
+ parts = append(parts, fmt.Sprintf("unused fallback key counts for %d users (unstable)", len(txn.MSC3202FallbackKeys)))
+ }
+ return strings.Join(parts, ", ")
+}
+
+// EventListener is a function that receives events.
+type EventListener func(evt *event.Event)
+
+// WriteBlankOK writes a blank OK message as a reply to a HTTP request.
+func WriteBlankOK(w http.ResponseWriter) {
+ w.Header().Add("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write([]byte("{}"))
+}
+
+// Respond responds to a HTTP request with a JSON object.
+func Respond(w http.ResponseWriter, data interface{}) error {
+ w.Header().Add("Content-Type", "application/json")
+ dataStr, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(dataStr)
+ return err
+}
+
+// Error represents a Matrix protocol error.
+type Error struct {
+ HTTPStatus int `json:"-"`
+ ErrorCode ErrorCode `json:"errcode"`
+ Message string `json:"error"`
+}
+
+func (err Error) Write(w http.ResponseWriter) {
+ w.Header().Add("Content-Type", "application/json")
+ w.WriteHeader(err.HTTPStatus)
+ _ = Respond(w, &err)
+}
+
+// ErrorCode is the machine-readable code in an Error.
+type ErrorCode string
+
+// Native ErrorCodes
+const (
+ ErrUnknownToken ErrorCode = "M_UNKNOWN_TOKEN"
+ ErrBadJSON ErrorCode = "M_BAD_JSON"
+ ErrNotJSON ErrorCode = "M_NOT_JSON"
+ ErrUnknown ErrorCode = "M_UNKNOWN"
+)
+
+// Custom ErrorCodes
+const (
+ ErrNoTransactionID ErrorCode = "NET.MAUNIUM.NO_TRANSACTION_ID"
+)
diff --git a/vendor/maunium.net/go/mautrix/appservice/registration.go b/vendor/maunium.net/go/mautrix/appservice/registration.go
new file mode 100644
index 00000000..f9c93fe4
--- /dev/null
+++ b/vendor/maunium.net/go/mautrix/appservice/registration.go
@@ -0,0 +1,100 @@
+// Copyright (c) 2022 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package appservice
+
+import (
+ "os"
+ "regexp"
+
+ "gopkg.in/yaml.v3"
+
+ "maunium.net/go/mautrix/util"
+)
+
+// Registration contains the data in a Matrix appservice registration.
+// See https://spec.matrix.org/v1.2/application-service-api/#registration
+type Registration struct {
+ ID string `yaml:"id" json:"id"`
+ URL string `yaml:"url" json:"url"`
+ AppToken string `yaml:"as_token" json:"as_token"`
+ ServerToken string `yaml:"hs_token" json:"hs_token"`
+ SenderLocalpart string `yaml:"sender_localpart" json:"sender_localpart"`
+ RateLimited *bool `yaml:"rate_limited,omitempty" json:"rate_limited,omitempty"`
+ Namespaces Namespaces `yaml:"namespaces" json:"namespaces"`
+ Protocols []string `yaml:"protocols,omitempty" json:"protocols,omitempty"`
+
+ SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty" json:"de.sorunome.msc2409.push_ephemeral,omitempty"`
+ EphemeralEvents bool `yaml:"push_ephemeral,omitempty" json:"push_ephemeral,omitempty"`
+}
+
+// CreateRegistration creates a Registration with random appservice and homeserver tokens.
+func CreateRegistration() *Registration {
+ return &Registration{
+ AppToken: util.RandomString(64),
+ ServerToken: util.RandomString(64),
+ }
+}
+
+// LoadRegistration loads a YAML file and turns it into a Registration.
+func LoadRegistration(path string) (*Registration, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ reg := &Registration{}
+ err = yaml.Unmarshal(data, reg)
+ if err != nil {
+ return nil, err
+ }
+ return reg, nil
+}
+
+// Save saves this Registration into a file at the given path.
+func (reg *Registration) Save(path string) error {
+ data, err := yaml.Marshal(reg)
+ if err != nil {
+ return err
+ }
+ return os.WriteFile(path, data, 0600)
+}
+
+// YAML returns the registration in YAML format.
+func (reg *Registration) YAML() (string, error) {
+ data, err := yaml.Marshal(reg)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+// Namespaces contains the three areas that appservices can reserve parts of.
+type Namespaces struct {
+ UserIDs NamespaceList `yaml:"users,omitempty" json:"users,omitempty"`
+ RoomAliases NamespaceList `yaml:"aliases,omitempty" json:"aliases,omitempty"`
+ RoomIDs NamespaceList `yaml:"rooms,omitempty" json:"rooms,omitempty"`
+}
+
+// Namespace is a reserved namespace in any area.
+type Namespace struct {
+ Regex string `yaml:"regex" json:"regex"`
+ Exclusive bool `yaml:"exclusive" json:"exclusive"`
+}
+
+type NamespaceList []Namespace
+
+func (nsl *NamespaceList) Register(regex *regexp.Regexp, exclusive bool) {
+ ns := Namespace{
+ Regex: regex.String(),
+ Exclusive: exclusive,
+ }
+ if nsl == nil {
+ *nsl = []Namespace{ns}
+ } else {
+ *nsl = append(*nsl, ns)
+ }
+}
diff --git a/vendor/maunium.net/go/mautrix/appservice/txnid.go b/vendor/maunium.net/go/mautrix/appservice/txnid.go
new file mode 100644
index 00000000..213703c5
--- /dev/null
+++ b/vendor/maunium.net/go/mautrix/appservice/txnid.go
@@ -0,0 +1,43 @@
+// Copyright (c) 2021 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package appservice
+
+import "sync"
+
+type TransactionIDCache struct {
+ array []string
+ arrayPtr int
+ hash map[string]struct{}
+ lock sync.RWMutex
+}
+
+func NewTransactionIDCache(size int) *TransactionIDCache {
+ return &TransactionIDCache{
+ array: make([]string, size),
+ hash: make(map[string]struct{}),
+ }
+}
+
+func (txnIDC *TransactionIDCache) IsProcessed(txnID string) bool {
+ txnIDC.lock.RLock()
+ _, exists := txnIDC.hash[txnID]
+ txnIDC.lock.RUnlock()
+ return exists
+}
+
+func (txnIDC *TransactionIDCache) MarkProcessed(txnID string) {
+ txnIDC.lock.Lock()
+ txnIDC.hash[txnID] = struct{}{}
+ if txnIDC.array[txnIDC.arrayPtr] != "" {
+ for i := 0; i < len(txnIDC.array)/8; i++ {
+ delete(txnIDC.hash, txnIDC.array[txnIDC.arrayPtr+i])
+ txnIDC.array[txnIDC.arrayPtr+i] = ""
+ }
+ }
+ txnIDC.array[txnIDC.arrayPtr] = txnID
+ txnIDC.lock.Unlock()
+}
diff --git a/vendor/maunium.net/go/mautrix/appservice/websocket.go b/vendor/maunium.net/go/mautrix/appservice/websocket.go
new file mode 100644
index 00000000..671222b8
--- /dev/null
+++ b/vendor/maunium.net/go/mautrix/appservice/websocket.go
@@ -0,0 +1,408 @@
+// Copyright (c) 2023 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package appservice
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "path/filepath"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/gorilla/websocket"
+ "github.com/rs/zerolog"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+type WebsocketRequest struct {
+ ReqID int `json:"id,omitempty"`
+ Command string `json:"command"`
+ Data interface{} `json:"data"`
+
+ Deadline time.Duration `json:"-"`
+}
+
+type WebsocketCommand struct {
+ ReqID int `json:"id,omitempty"`
+ Command string `json:"command"`
+ Data json.RawMessage `json:"data"`
+
+ Ctx context.Context `json:"-"`
+}
+
+func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest {
+ if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" {
+ return nil
+ }
+ cmd := "response"
+ if !ok {
+ cmd = "error"
+ }
+ if err, isError := data.(error); isError {
+ var errorData json.RawMessage
+ var jsonErr error
+ unwrappedErr := err
+ var prefixMessage string
+ for unwrappedErr != nil {
+ errorData, jsonErr = json.Marshal(unwrappedErr)
+ if errorData != nil && len(errorData) > 2 && jsonErr == nil {
+ prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
+ prefixMessage = strings.TrimRight(prefixMessage, ": ")
+ break
+ }
+ unwrappedErr = errors.Unwrap(unwrappedErr)
+ }
+ if errorData != nil {
+ if !gjson.GetBytes(errorData, "message").Exists() {
+ errorData, _ = sjson.SetBytes(errorData, "message", err.Error())
+ } // else: marshaled error contains a message already
+ } else {
+ errorData, _ = sjson.SetBytes(nil, "message", err.Error())
+ }
+ if len(prefixMessage) > 0 {
+ errorData, _ = sjson.SetBytes(errorData, "prefix_message", prefixMessage)
+ }
+ data = errorData
+ }
+ return &WebsocketRequest{
+ ReqID: wsc.ReqID,
+ Command: cmd,
+ Data: data,
+ }
+}
+
+type WebsocketTransaction struct {
+ Status string `json:"status"`
+ TxnID string `json:"txn_id"`
+ Transaction
+}
+
+type WebsocketTransactionResponse struct {
+ TxnID string `json:"txn_id"`
+}
+
+type WebsocketMessage struct {
+ WebsocketTransaction
+ WebsocketCommand
+}
+
+const (
+ WebsocketCloseConnReplaced = 4001
+ WebsocketCloseTxnNotAcknowledged = 4002
+)
+
+type MeowWebsocketCloseCode string
+
+const (
+ MeowServerShuttingDown MeowWebsocketCloseCode = "server_shutting_down"
+ MeowConnectionReplaced MeowWebsocketCloseCode = "conn_replaced"
+ MeowTxnNotAcknowledged MeowWebsocketCloseCode = "transactions_not_acknowledged"
+)
+
+var (
+ ErrWebsocketManualStop = errors.New("the websocket was disconnected manually")
+ ErrWebsocketOverridden = errors.New("a new call to StartWebsocket overrode the previous connection")
+ ErrWebsocketUnknownError = errors.New("an unknown error occurred")
+
+ ErrWebsocketNotConnected = errors.New("websocket not connected")
+ ErrWebsocketClosed = errors.New("websocket closed before response received")
+)
+
+func (mwcc MeowWebsocketCloseCode) String() string {
+ switch mwcc {
+ case MeowServerShuttingDown:
+ return "the server is shutting down"
+ case MeowConnectionReplaced:
+ return "the connection was replaced by another client"
+ case MeowTxnNotAcknowledged:
+ return "transactions were not acknowledged"
+ default:
+ return string(mwcc)
+ }
+}
+
+type CloseCommand struct {
+ Code int `json:"-"`
+ Command string `json:"command"`
+ Status MeowWebsocketCloseCode `json:"status"`
+}
+
+func (cc CloseCommand) Error() string {
+ return fmt.Sprintf("websocket: close %d: %s", cc.Code, cc.Status.String())
+}
+
+func parseCloseError(err error) error {
+ closeError := &websocket.CloseError{}
+ if !errors.As(err, &closeError) {
+ return err
+ }
+ var closeCommand CloseCommand
+ closeCommand.Code = closeError.Code
+ closeCommand.Command = "disconnect"
+ if len(closeError.Text) > 0 {
+ jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand)
+ if jsonErr != nil {
+ return err
+ }
+ }
+ if len(closeCommand.Status) == 0 {
+ if closeCommand.Code == WebsocketCloseConnReplaced {
+ closeCommand.Status = MeowConnectionReplaced
+ } else if closeCommand.Code == websocket.CloseServiceRestart {
+ closeCommand.Status = MeowServerShuttingDown
+ }
+ }
+ return &closeCommand
+}
+
+func (as *AppService) HasWebsocket() bool {
+ return as.ws != nil
+}
+
+func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error {
+ ws := as.ws
+ if cmd == nil {
+ return nil
+ } else if ws == nil {
+ return ErrWebsocketNotConnected
+ }
+ as.wsWriteLock.Lock()
+ defer as.wsWriteLock.Unlock()
+ if cmd.Deadline == 0 {
+ cmd.Deadline = 3 * time.Minute
+ }
+ _ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline))
+ return ws.WriteJSON(cmd)
+}
+
+func (as *AppService) clearWebsocketResponseWaiters() {
+ as.websocketRequestsLock.Lock()
+ for _, waiter := range as.websocketRequests {
+ waiter <- &WebsocketCommand{Command: "__websocket_closed"}
+ }
+ as.websocketRequests = make(map[int]chan<- *WebsocketCommand)
+ as.websocketRequestsLock.Unlock()
+}
+
+func (as *AppService) addWebsocketResponseWaiter(reqID int, waiter chan<- *WebsocketCommand) {
+ as.websocketRequestsLock.Lock()
+ as.websocketRequests[reqID] = waiter
+ as.websocketRequestsLock.Unlock()
+}
+
+func (as *AppService) removeWebsocketResponseWaiter(reqID int, waiter chan<- *WebsocketCommand) {
+ as.websocketRequestsLock.Lock()
+ existingWaiter, ok := as.websocketRequests[reqID]
+ if ok && existingWaiter == waiter {
+ delete(as.websocketRequests, reqID)
+ }
+ close(waiter)
+ as.websocketRequestsLock.Unlock()
+}
+
+type ErrorResponse struct {
+ Code string `json:"code"`
+ Message string `json:"message"`
+}
+
+func (er *ErrorResponse) Error() string {
+ return fmt.Sprintf("%s: %s", er.Code, er.Message)
+}
+
+func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error {
+ cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1))
+ respChan := make(chan *WebsocketCommand, 1)
+ as.addWebsocketResponseWaiter(cmd.ReqID, respChan)
+ defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan)
+ err := as.SendWebsocket(cmd)
+ if err != nil {
+ return err
+ }
+ select {
+ case resp := <-respChan:
+ if resp.Command == "__websocket_closed" {
+ return ErrWebsocketClosed
+ } else if resp.Command == "error" {
+ var respErr ErrorResponse
+ err = json.Unmarshal(resp.Data, &respErr)
+ if err != nil {
+ return fmt.Errorf("failed to parse error JSON: %w", err)
+ }
+ return &respErr
+ } else if response != nil {
+ err = json.Unmarshal(resp.Data, &response)
+ if err != nil {
+ return fmt.Errorf("failed to parse response JSON: %w", err)
+ }
+ return nil
+ } else {
+ return nil
+ }
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) {
+ zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command")
+ return false, fmt.Errorf("unknown request type")
+}
+
+func (as *AppService) SetWebsocketCommandHandler(cmd string, handler WebsocketHandler) {
+ as.websocketHandlersLock.Lock()
+ as.websocketHandlers[cmd] = handler
+ as.websocketHandlersLock.Unlock()
+}
+
+func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) {
+ defer stopFunc(ErrWebsocketUnknownError)
+ ctx := context.Background()
+ for {
+ var msg WebsocketMessage
+ err := ws.ReadJSON(&msg)
+ if err != nil {
+ as.Log.Debug().Err(err).Msg("Error reading from websocket")
+ stopFunc(parseCloseError(err))
+ return
+ }
+ with := as.Log.With().
+ Int("req_id", msg.ReqID).
+ Str("ws_command", msg.Command)
+ if msg.TxnID != "" {
+ with = with.Str("transaction_id", msg.TxnID)
+ }
+ log := with.Logger()
+ ctx = log.WithContext(ctx)
+ if msg.Command == "" || msg.Command == "transaction" {
+ if msg.TxnID == "" || !as.txnIDC.IsProcessed(msg.TxnID) {
+ as.handleTransaction(ctx, msg.TxnID, &msg.Transaction)
+ } else {
+ log.Debug().
+ Object("content", &msg.Transaction).
+ Msg("Ignoring duplicate transaction")
+ }
+ go func() {
+ err = as.SendWebsocket(msg.MakeResponse(true, &WebsocketTransactionResponse{TxnID: msg.TxnID}))
+ if err != nil {
+ log.Warn().Err(err).Msg("Failed to send response to websocket transaction")
+ } else {
+ log.Debug().Msg("Sent response to transaction")
+ }
+ }()
+ } else if msg.Command == "connect" {
+ log.Debug().Msg("Websocket connect confirmation received")
+ } else if msg.Command == "response" || msg.Command == "error" {
+ as.websocketRequestsLock.RLock()
+ respChan, ok := as.websocketRequests[msg.ReqID]
+ if ok {
+ select {
+ case respChan <- &msg.WebsocketCommand:
+ default:
+ log.Warn().Msg("Failed to handle response: channel didn't accept response")
+ }
+ } else {
+ log.Warn().Msg("Dropping response to unknown request ID")
+ }
+ as.websocketRequestsLock.RUnlock()
+ } else {
+ log.Debug().Msg("Received websocket command")
+ as.websocketHandlersLock.RLock()
+ handler, ok := as.websocketHandlers[msg.Command]
+ as.websocketHandlersLock.RUnlock()
+ if !ok {
+ handler = as.unknownCommandHandler
+ }
+ go func() {
+ okResp, data := handler(msg.WebsocketCommand)
+ err = as.SendWebsocket(msg.MakeResponse(okResp, data))
+ if err != nil {
+ log.Error().Err(err).Msg("Failed to send response to websocket command")
+ } else if okResp {
+ log.Debug().Msg("Sent success response to websocket command")
+ } else {
+ log.Debug().Msg("Sent error response to websocket command")
+ }
+ }()
+ }
+ }
+}
+
+func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
+ parsed, err := url.Parse(baseURL)
+ if err != nil {
+ return fmt.Errorf("failed to parse URL: %w", err)
+ }
+ parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
+ if parsed.Scheme == "http" {
+ parsed.Scheme = "ws"
+ } else if parsed.Scheme == "https" {
+ parsed.Scheme = "wss"
+ }
+ ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{
+ "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
+ "User-Agent": []string{as.BotClient().UserAgent},
+
+ "X-Mautrix-Process-ID": []string{as.ProcessID},
+ "X-Mautrix-Websocket-Version": []string{"3"},
+ })
+ if resp != nil && resp.StatusCode >= 400 {
+ var errResp Error
+ err = json.NewDecoder(resp.Body).Decode(&errResp)
+ if err != nil {
+ return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode)
+ } else {
+ return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message)
+ }
+ } else if err != nil {
+ return fmt.Errorf("failed to open websocket: %w", err)
+ }
+ if as.StopWebsocket != nil {
+ as.StopWebsocket(ErrWebsocketOverridden)
+ }
+ closeChan := make(chan error)
+ closeChanOnce := sync.Once{}
+ stopFunc := func(err error) {
+ closeChanOnce.Do(func() {
+ closeChan <- err
+ })
+ }
+ as.ws = ws
+ as.StopWebsocket = stopFunc
+ as.PrepareWebsocket()
+ as.Log.Debug().Msg("Appservice transaction websocket opened")
+
+ go as.consumeWebsocket(stopFunc, ws)
+
+ if onConnect != nil {
+ onConnect()
+ }
+
+ closeErr := <-closeChan
+
+ if as.ws == ws {
+ as.clearWebsocketResponseWaiters()
+ as.ws = nil
+ }
+
+ _ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second))
+ err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""))
+ if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
+ as.Log.Warn().Err(err).Msg("Error writing close message to websocket")
+ }
+ err = ws.Close()
+ if err != nil {
+ as.Log.Warn().Err(err).Msg("Error closing websocket")
+ }
+ return closeErr
+}