429 lines
11 KiB
Go
429 lines
11 KiB
Go
// Package matrixbot handles Matrix-native bot behavior.
|
|
package matrixbot
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"html"
|
|
"log"
|
|
"strings"
|
|
"time"
|
|
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
|
|
"musiclink/pkg/config"
|
|
)
|
|
|
|
// TextHandler processes plain text messages and returns a response when applicable.
|
|
type TextHandler func(ctx context.Context, text, username string) (string, bool)
|
|
|
|
// Bot manages Matrix-native sync and message handling.
|
|
type Bot struct {
|
|
cfg config.MatrixConfig
|
|
client *mautrix.Client
|
|
handler TextHandler
|
|
allowedRooms map[id.RoomID]struct{}
|
|
encryptedRooms map[id.RoomID]struct{}
|
|
stateStore *StateStore
|
|
sendQueue chan sendRequest
|
|
stats botStats
|
|
}
|
|
|
|
type sendRequest struct {
|
|
roomID id.RoomID
|
|
event *event.Event
|
|
response string
|
|
}
|
|
|
|
// New creates a new Matrix-native bot instance.
|
|
func New(cfg config.MatrixConfig, handler TextHandler) (*Bot, error) {
|
|
client, err := mautrix.NewClient(cfg.Server, id.UserID(cfg.UserID), cfg.AccessToken)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create matrix client: %w", err)
|
|
}
|
|
|
|
allowed := make(map[id.RoomID]struct{}, len(cfg.Rooms))
|
|
for _, room := range cfg.Rooms {
|
|
allowed[id.RoomID(room)] = struct{}{}
|
|
}
|
|
|
|
store, err := NewStateStore(cfg.StateStorePath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("init state store: %w", err)
|
|
}
|
|
client.Store = store
|
|
log.Printf("Matrix state store: %s", cfg.StateStorePath)
|
|
|
|
bot := &Bot{
|
|
cfg: cfg,
|
|
client: client,
|
|
handler: handler,
|
|
allowedRooms: allowed,
|
|
encryptedRooms: make(map[id.RoomID]struct{}),
|
|
stateStore: store,
|
|
sendQueue: make(chan sendRequest, 100),
|
|
}
|
|
|
|
syncer := client.Syncer.(*mautrix.DefaultSyncer)
|
|
syncer.OnEventType(event.EventMessage, bot.onMessage)
|
|
syncer.OnEventType(event.StateMember, bot.onMember)
|
|
syncer.OnEventType(event.StateEncryption, bot.onEncryption)
|
|
syncer.OnSync(bot.onSync)
|
|
|
|
return bot, nil
|
|
}
|
|
|
|
// Run starts the sync loop.
|
|
func (b *Bot) Run(ctx context.Context) error {
|
|
mode := "active"
|
|
if b.cfg.Shadow {
|
|
mode = "shadow"
|
|
}
|
|
log.Printf("Matrix bot starting (%s mode, rooms: %d)", mode, len(b.allowedRooms))
|
|
log.Printf("Matrix allowlist rooms: %s", strings.Join(roomList(b.allowedRooms), ", "))
|
|
if b.cfg.HealthAddr != "" {
|
|
log.Printf("Matrix health server listening on %s", b.cfg.HealthAddr)
|
|
go b.startHealthServer(ctx)
|
|
}
|
|
if err := b.prefetchEncryptionState(ctx); err != nil {
|
|
log.Printf("Matrix encryption state check failed: %v", err)
|
|
}
|
|
go b.cleanupLoop(ctx)
|
|
go b.sendLoop(ctx)
|
|
return b.client.SyncWithContext(ctx)
|
|
}
|
|
|
|
// Close releases Matrix bot resources.
|
|
func (b *Bot) Close() error {
|
|
if b.stateStore != nil {
|
|
return b.stateStore.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *Bot) onMember(ctx context.Context, evt *event.Event) {
|
|
if evt == nil {
|
|
return
|
|
}
|
|
if evt.GetStateKey() != b.client.UserID.String() {
|
|
return
|
|
}
|
|
member := evt.Content.AsMember()
|
|
if member.Membership != event.MembershipInvite {
|
|
return
|
|
}
|
|
if isEncryptedInvite(evt) {
|
|
log.Printf("Matrix invite ignored (encrypted room): %s", evt.RoomID)
|
|
if _, err := b.client.LeaveRoom(ctx, evt.RoomID); err != nil {
|
|
log.Printf("Matrix invite leave failed for %s: %v", evt.RoomID, err)
|
|
}
|
|
return
|
|
}
|
|
if _, ok := b.allowedRooms[evt.RoomID]; ok {
|
|
if _, err := b.client.JoinRoomByID(ctx, evt.RoomID); err != nil {
|
|
log.Printf("Matrix invite join failed for %s: %v", evt.RoomID, err)
|
|
}
|
|
return
|
|
}
|
|
if _, err := b.client.LeaveRoom(ctx, evt.RoomID); err != nil {
|
|
log.Printf("Matrix invite leave failed for %s: %v", evt.RoomID, err)
|
|
}
|
|
}
|
|
|
|
func (b *Bot) onMessage(ctx context.Context, evt *event.Event) {
|
|
if evt == nil {
|
|
return
|
|
}
|
|
if _, ok := b.allowedRooms[evt.RoomID]; !ok {
|
|
return
|
|
}
|
|
if evt.Sender == b.client.UserID {
|
|
return
|
|
}
|
|
if b.isRoomEncrypted(evt.RoomID) {
|
|
b.stats.markEncryptedSkipped()
|
|
return
|
|
}
|
|
|
|
b.stats.markReceived()
|
|
|
|
if processed, err := b.stateStore.WasEventProcessed(ctx, evt.ID.String()); err != nil {
|
|
log.Printf("Matrix dedupe check failed (event %s): %v", evt.ID, err)
|
|
} else if processed {
|
|
return
|
|
}
|
|
|
|
content := evt.Content.AsMessage()
|
|
if content == nil {
|
|
return
|
|
}
|
|
if !content.MsgType.IsText() {
|
|
return
|
|
}
|
|
if content.RelatesTo != nil && content.RelatesTo.GetReplaceID() != "" {
|
|
return
|
|
}
|
|
if content.Body == "" {
|
|
return
|
|
}
|
|
|
|
response, ok := b.handler(ctx, content.Body, evt.Sender.String())
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if b.cfg.Shadow {
|
|
log.Printf("Matrix shadow reply (room %s, event %s): %s", evt.RoomID, evt.ID, response)
|
|
b.stats.markResponded()
|
|
if _, err := b.stateStore.MarkEventProcessed(ctx, evt.ID.String()); err != nil {
|
|
log.Printf("Matrix dedupe record failed (event %s): %v", evt.ID, err)
|
|
}
|
|
return
|
|
}
|
|
|
|
select {
|
|
case b.sendQueue <- sendRequest{roomID: evt.RoomID, event: evt, response: response}:
|
|
default:
|
|
b.stats.markDropped()
|
|
if _, err := b.stateStore.MarkEventProcessed(ctx, evt.ID.String()); err != nil {
|
|
log.Printf("Matrix dedupe record failed after drop (event %s): %v", evt.ID, err)
|
|
}
|
|
log.Printf("Matrix send queue full; dropping response for event %s", evt.ID)
|
|
}
|
|
}
|
|
|
|
func (b *Bot) sendReply(ctx context.Context, evt *event.Event, response string) error {
|
|
body, formatted := formatNoPreview(response)
|
|
content := &event.MessageEventContent{
|
|
MsgType: event.MsgText,
|
|
Body: body,
|
|
BeeperLinkPreviews: []*event.BeeperLinkPreview{},
|
|
}
|
|
if formatted != "" {
|
|
content.Format = event.FormatHTML
|
|
content.FormattedBody = formatted
|
|
}
|
|
|
|
_, err := b.client.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, content)
|
|
return err
|
|
}
|
|
|
|
func (b *Bot) sendLoop(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
if pending := len(b.sendQueue); pending > 0 {
|
|
log.Printf("Matrix send queue pending on shutdown: %d", pending)
|
|
}
|
|
return
|
|
case req := <-b.sendQueue:
|
|
b.sendWithRetry(ctx, req)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *Bot) cleanupLoop(ctx context.Context) {
|
|
ticker := time.NewTicker(6 * time.Hour)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := b.stateStore.Cleanup(ctx); err != nil {
|
|
log.Printf("Matrix state store cleanup failed: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *Bot) sendWithRetry(ctx context.Context, req sendRequest) {
|
|
const maxAttempts = 5
|
|
backoff := time.Second
|
|
|
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
|
if err := b.sendReply(ctx, req.event, req.response); err != nil {
|
|
delay, ok := rateLimitDelay(err)
|
|
if ok {
|
|
b.stats.markRateLimited()
|
|
b.stats.setLastSendError(err.Error())
|
|
if delay <= 0 {
|
|
delay = backoff
|
|
}
|
|
log.Printf("Matrix rate limited (event %s), retrying in %s", req.event.ID, delay)
|
|
if !sleepContext(ctx, delay) {
|
|
return
|
|
}
|
|
backoff = minDuration(backoff*2, 30*time.Second)
|
|
continue
|
|
}
|
|
b.stats.setLastSendError(err.Error())
|
|
log.Printf("Matrix send failed (room %s, event %s): %v", req.roomID, req.event.ID, err)
|
|
return
|
|
}
|
|
|
|
b.stats.markResponded()
|
|
if _, err := b.stateStore.MarkEventProcessed(ctx, req.event.ID.String()); err != nil {
|
|
log.Printf("Matrix dedupe record failed (event %s): %v", req.event.ID, err)
|
|
}
|
|
return
|
|
}
|
|
|
|
b.stats.setLastSendError("max retries exceeded")
|
|
log.Printf("Matrix send failed after retries (room %s, event %s)", req.roomID, req.event.ID)
|
|
}
|
|
|
|
func rateLimitDelay(err error) (time.Duration, bool) {
|
|
if !errors.Is(err, mautrix.MLimitExceeded) {
|
|
return 0, false
|
|
}
|
|
|
|
var httpErr mautrix.HTTPError
|
|
if errors.As(err, &httpErr) && httpErr.RespError != nil {
|
|
if retry, ok := httpErr.RespError.ExtraData["retry_after_ms"]; ok {
|
|
switch value := retry.(type) {
|
|
case float64:
|
|
return time.Duration(value) * time.Millisecond, true
|
|
case int64:
|
|
return time.Duration(value) * time.Millisecond, true
|
|
case int:
|
|
return time.Duration(value) * time.Millisecond, true
|
|
}
|
|
}
|
|
}
|
|
|
|
return 0, true
|
|
}
|
|
|
|
func sleepContext(ctx context.Context, delay time.Duration) bool {
|
|
timer := time.NewTimer(delay)
|
|
defer timer.Stop()
|
|
select {
|
|
case <-ctx.Done():
|
|
return false
|
|
case <-timer.C:
|
|
return true
|
|
}
|
|
}
|
|
|
|
func minDuration(a, b time.Duration) time.Duration {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
func (b *Bot) onEncryption(ctx context.Context, evt *event.Event) {
|
|
if evt == nil {
|
|
return
|
|
}
|
|
if _, ok := b.allowedRooms[evt.RoomID]; !ok {
|
|
return
|
|
}
|
|
if !b.isRoomEncrypted(evt.RoomID) {
|
|
b.encryptedRooms[evt.RoomID] = struct{}{}
|
|
log.Printf("Matrix room marked encrypted; skipping messages: %s", evt.RoomID)
|
|
}
|
|
}
|
|
|
|
func (b *Bot) onSync(ctx context.Context, _ *mautrix.RespSync, _ string) bool {
|
|
b.stats.markSync()
|
|
return true
|
|
}
|
|
|
|
func (b *Bot) isRoomEncrypted(roomID id.RoomID) bool {
|
|
_, ok := b.encryptedRooms[roomID]
|
|
return ok
|
|
}
|
|
|
|
func (b *Bot) prefetchEncryptionState(ctx context.Context) error {
|
|
for roomID := range b.allowedRooms {
|
|
var content event.EncryptionEventContent
|
|
if err := b.client.StateEvent(ctx, roomID, event.StateEncryption, "", &content); err != nil {
|
|
if errors.Is(err, mautrix.MNotFound) {
|
|
continue
|
|
}
|
|
log.Printf("Matrix encryption state fetch failed for %s: %v", roomID, err)
|
|
continue
|
|
}
|
|
if !b.isRoomEncrypted(roomID) {
|
|
b.encryptedRooms[roomID] = struct{}{}
|
|
log.Printf("Matrix room marked encrypted (state fetch); skipping messages: %s", roomID)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func roomList(rooms map[id.RoomID]struct{}) []string {
|
|
list := make([]string, 0, len(rooms))
|
|
for roomID := range rooms {
|
|
list = append(list, roomID.String())
|
|
}
|
|
return list
|
|
}
|
|
|
|
func formatNoPreview(response string) (string, string) {
|
|
lines := strings.Split(response, "\n")
|
|
bodyLines := make([]string, 0, len(lines))
|
|
formattedLines := make([]string, 0, len(lines))
|
|
|
|
for _, line := range lines {
|
|
trimmed := strings.TrimSpace(line)
|
|
if trimmed == "" {
|
|
bodyLines = append(bodyLines, "")
|
|
formattedLines = append(formattedLines, "")
|
|
continue
|
|
}
|
|
|
|
label, url, ok := splitLinkLine(trimmed)
|
|
if ok {
|
|
bodyLines = append(bodyLines, fmt.Sprintf("%s: %s", label, obfuscateURL(url)))
|
|
formattedLines = append(formattedLines, fmt.Sprintf("%s: <a href=\"%s\">%s</a>", html.EscapeString(label), html.EscapeString(url), html.EscapeString(url)))
|
|
continue
|
|
}
|
|
|
|
bodyLines = append(bodyLines, line)
|
|
formattedLines = append(formattedLines, html.EscapeString(line))
|
|
}
|
|
|
|
body := strings.Join(bodyLines, "\n")
|
|
formatted := strings.Join(formattedLines, "<br/>")
|
|
return body, formatted
|
|
}
|
|
|
|
func splitLinkLine(line string) (string, string, bool) {
|
|
parts := strings.SplitN(line, ": ", 2)
|
|
if len(parts) != 2 {
|
|
return "", "", false
|
|
}
|
|
label := parts[0]
|
|
url := strings.TrimSpace(parts[1])
|
|
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
|
|
return "", "", false
|
|
}
|
|
return label, url, true
|
|
}
|
|
|
|
func obfuscateURL(url string) string {
|
|
return url
|
|
}
|
|
|
|
func isEncryptedInvite(evt *event.Event) bool {
|
|
if evt == nil || evt.Unsigned.InviteRoomState == nil {
|
|
return false
|
|
}
|
|
for _, state := range evt.Unsigned.InviteRoomState {
|
|
if state == nil {
|
|
continue
|
|
}
|
|
if state.Type == event.StateEncryption {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|