diff options
Diffstat (limited to 'vendor/maunium.net/go/mautrix/appservice/http.go')
-rw-r--r-- | vendor/maunium.net/go/mautrix/appservice/http.go | 348 |
1 files changed, 348 insertions, 0 deletions
diff --git a/vendor/maunium.net/go/mautrix/appservice/http.go b/vendor/maunium.net/go/mautrix/appservice/http.go new file mode 100644 index 00000000..06ac7788 --- /dev/null +++ b/vendor/maunium.net/go/mautrix/appservice/http.go @@ -0,0 +1,348 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package appservice + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "strings" + "syscall" + "time" + + "github.com/gorilla/mux" + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// Start starts the HTTP server that listens for calls from the Matrix homeserver. +func (as *AppService) Start() { + as.server = &http.Server{ + Handler: as.Router, + } + var err error + if as.Host.IsUnixSocket() { + err = as.listenUnix() + } else { + as.server.Addr = as.Host.Address() + err = as.listenTCP() + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + as.Log.Error().Err(err).Msg("Error in HTTP listener") + } else { + as.Log.Debug().Msg("HTTP listener stopped") + } +} + +func (as *AppService) listenUnix() error { + socket := as.Host.Hostname + _ = syscall.Unlink(socket) + defer func() { + _ = syscall.Unlink(socket) + }() + listener, err := net.Listen("unix", socket) + if err != nil { + return err + } + as.Log.Info().Str("socket", socket).Msg("Starting unix socket HTTP listener") + return as.server.Serve(listener) +} + +func (as *AppService) listenTCP() error { + if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 { + as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener") + return as.server.ListenAndServe() + } else { + as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener with TLS") + return as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey) + } +} + +func (as *AppService) Stop() { + if as.server == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = as.server.Shutdown(ctx) + as.server = nil +} + +// CheckServerToken checks if the given request originated from the Matrix homeserver. +func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) { + authHeader := r.Header.Get("Authorization") + if len(authHeader) > 0 && strings.HasPrefix(authHeader, "Bearer ") { + isValid = authHeader[len("Bearer "):] == as.Registration.ServerToken + } else { + queryToken := r.URL.Query().Get("access_token") + if len(queryToken) > 0 { + isValid = queryToken == as.Registration.ServerToken + } else { + Error{ + ErrorCode: ErrUnknownToken, + HTTPStatus: http.StatusForbidden, + Message: "Missing access token", + }.Write(w) + return + } + } + if !isValid { + Error{ + ErrorCode: ErrUnknownToken, + HTTPStatus: http.StatusForbidden, + Message: "Incorrect access token", + }.Write(w) + } + return +} + +// PutTransaction handles a /transactions PUT call from the homeserver. +func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { + if !as.CheckServerToken(w, r) { + return + } + + vars := mux.Vars(r) + txnID := vars["txnID"] + if len(txnID) == 0 { + Error{ + ErrorCode: ErrNoTransactionID, + HTTPStatus: http.StatusBadRequest, + Message: "Missing transaction ID", + }.Write(w) + return + } + defer r.Body.Close() + body, err := io.ReadAll(r.Body) + if err != nil || len(body) == 0 { + Error{ + ErrorCode: ErrNotJSON, + HTTPStatus: http.StatusBadRequest, + Message: "Missing request body", + }.Write(w) + return + } + log := as.Log.With().Str("transaction_id", txnID).Logger() + ctx := context.Background() + ctx = log.WithContext(ctx) + if as.txnIDC.IsProcessed(txnID) { + // Duplicate transaction ID: no-op + WriteBlankOK(w) + log.Debug().Msg("Ignoring duplicate transaction") + return + } + + var txn Transaction + err = json.Unmarshal(body, &txn) + if err != nil { + log.Error().Err(err).Msg("Failed to parse transaction content") + Error{ + ErrorCode: ErrBadJSON, + HTTPStatus: http.StatusBadRequest, + Message: "Failed to parse body JSON", + }.Write(w) + } else { + as.handleTransaction(ctx, txnID, &txn) + WriteBlankOK(w) + } +} + +func (as *AppService) handleTransaction(ctx context.Context, id string, txn *Transaction) { + log := zerolog.Ctx(ctx) + log.Debug().Object("content", txn).Msg("Starting handling of transaction") + if as.Registration.EphemeralEvents { + if txn.EphemeralEvents != nil { + as.handleEvents(ctx, txn.EphemeralEvents, event.EphemeralEventType) + } else if txn.MSC2409EphemeralEvents != nil { + as.handleEvents(ctx, txn.MSC2409EphemeralEvents, event.EphemeralEventType) + } + if txn.ToDeviceEvents != nil { + as.handleEvents(ctx, txn.ToDeviceEvents, event.ToDeviceEventType) + } else if txn.MSC2409ToDeviceEvents != nil { + as.handleEvents(ctx, txn.MSC2409ToDeviceEvents, event.ToDeviceEventType) + } + } + as.handleEvents(ctx, txn.Events, event.UnknownEventType) + if txn.DeviceLists != nil { + as.handleDeviceLists(ctx, txn.DeviceLists) + } else if txn.MSC3202DeviceLists != nil { + as.handleDeviceLists(ctx, txn.MSC3202DeviceLists) + } + if txn.DeviceOTKCount != nil { + as.handleOTKCounts(ctx, txn.DeviceOTKCount) + } else if txn.MSC3202DeviceOTKCount != nil { + as.handleOTKCounts(ctx, txn.MSC3202DeviceOTKCount) + } + as.txnIDC.MarkProcessed(id) + log.Debug().Msg("Finished dispatching events from transaction") +} + +func (as *AppService) handleOTKCounts(ctx context.Context, otks OTKCountMap) { + for userID, devices := range otks { + for deviceID, otkCounts := range devices { + otkCounts.UserID = userID + otkCounts.DeviceID = deviceID + select { + case as.OTKCounts <- &otkCounts: + default: + zerolog.Ctx(ctx).Warn(). + Str("user_id", userID.String()). + Msg("Dropped OTK count update for user because channel is full") + } + } + } +} + +func (as *AppService) handleDeviceLists(ctx context.Context, dl *mautrix.DeviceLists) { + select { + case as.DeviceLists <- dl: + default: + zerolog.Ctx(ctx).Warn().Msg("Dropped device list update because channel is full") + } +} + +func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, defaultTypeClass event.TypeClass) { + log := zerolog.Ctx(ctx) + for _, evt := range evts { + evt.Mautrix.ReceivedAt = time.Now() + if defaultTypeClass != event.UnknownEventType { + evt.Type.Class = defaultTypeClass + } else if evt.StateKey != nil { + evt.Type.Class = event.StateEventType + } else { + evt.Type.Class = event.MessageEventType + } + err := evt.Content.ParseRaw(evt.Type) + if errors.Is(err, event.ErrUnsupportedContentType) { + log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event") + } else if err != nil { + log.Warn().Err(err). + Str("event_id", evt.ID.String()). + Str("event_type", evt.Type.Type). + Str("event_type_class", evt.Type.Class.Name()). + Msg("Failed to parse content of event") + } + + if evt.Type.IsState() { + // TODO remove this check after making sure the log doesn't happen + historical, ok := evt.Content.Raw["org.matrix.msc2716.historical"].(bool) + if ok && historical { + log.Warn(). + Str("event_id", evt.ID.String()). + Str("event_type", evt.Type.Type). + Str("state_key", evt.GetStateKey()). + Msg("Received historical state event") + } else { + mautrix.UpdateStateStore(as.StateStore, evt) + } + } + var ch chan *event.Event + if evt.Type.Class == event.ToDeviceEventType { + ch = as.ToDeviceEvents + } else { + ch = as.Events + } + select { + case ch <- evt: + default: + log.Warn(). + Str("event_id", evt.ID.String()). + Str("event_type", evt.Type.Type). + Str("event_type_class", evt.Type.Class.Name()). + Msg("Event channel is full") + ch <- evt + } + } +} + +// GetRoom handles a /rooms GET call from the homeserver. +func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { + if !as.CheckServerToken(w, r) { + return + } + + vars := mux.Vars(r) + roomAlias := vars["roomAlias"] + ok := as.QueryHandler.QueryAlias(roomAlias) + if ok { + WriteBlankOK(w) + } else { + Error{ + ErrorCode: ErrUnknown, + HTTPStatus: http.StatusNotFound, + }.Write(w) + } +} + +// GetUser handles a /users GET call from the homeserver. +func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) { + if !as.CheckServerToken(w, r) { + return + } + + vars := mux.Vars(r) + userID := id.UserID(vars["userID"]) + ok := as.QueryHandler.QueryUser(userID) + if ok { + WriteBlankOK(w) + } else { + Error{ + ErrorCode: ErrUnknown, + HTTPStatus: http.StatusNotFound, + }.Write(w) + } +} + +func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) { + if !as.CheckServerToken(w, r) { + return + } + body, err := io.ReadAll(r.Body) + if err != nil || len(body) == 0 || !json.Valid(body) { + Error{ + ErrorCode: ErrNotJSON, + HTTPStatus: http.StatusBadRequest, + Message: "Missing request body", + }.Write(w) + return + } + + var txn mautrix.ReqAppservicePing + _ = json.Unmarshal(body, &txn) + as.Log.Debug().Str("txn_id", txn.TxnID).Msg("Received ping from homeserver") + + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte("{}")) +} + +func (as *AppService) GetLive(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + if as.Live { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusInternalServerError) + } + w.Write([]byte("{}")) +} + +func (as *AppService) GetReady(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + if as.Ready { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusInternalServerError) + } + w.Write([]byte("{}")) +} |