musiclink/internal/matrixbot/bot.go

436 lines
11 KiB
Go

// Package matrixbot handles Matrix-native bot behavior.
package matrixbot
import (
"context"
"errors"
"fmt"
"html"
"log"
"strings"
"time"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"musiclink/pkg/config"
)
// TextHandler processes plain text messages and returns a response when applicable.
type TextHandler func(ctx context.Context, text, username string) (string, bool)
// Bot manages Matrix-native sync and message handling.
type Bot struct {
cfg config.MatrixConfig
client *mautrix.Client
handler TextHandler
allowedRooms map[id.RoomID]struct{}
encryptedRooms map[id.RoomID]struct{}
stateStore *StateStore
sendQueue chan sendRequest
stats botStats
}
type sendRequest struct {
roomID id.RoomID
event *event.Event
response string
}
// New creates a new Matrix-native bot instance.
func New(cfg config.MatrixConfig, handler TextHandler) (*Bot, error) {
client, err := mautrix.NewClient(cfg.Server, id.UserID(cfg.UserID), cfg.AccessToken)
if err != nil {
return nil, fmt.Errorf("create matrix client: %w", err)
}
allowed := make(map[id.RoomID]struct{}, len(cfg.Rooms))
for _, room := range cfg.Rooms {
allowed[id.RoomID(room)] = struct{}{}
}
store, err := NewStateStore(cfg.StateStorePath)
if err != nil {
return nil, fmt.Errorf("init state store: %w", err)
}
client.Store = store
log.Printf("Matrix state store: %s", cfg.StateStorePath)
bot := &Bot{
cfg: cfg,
client: client,
handler: handler,
allowedRooms: allowed,
encryptedRooms: make(map[id.RoomID]struct{}),
stateStore: store,
sendQueue: make(chan sendRequest, 100),
}
syncer := client.Syncer.(*mautrix.DefaultSyncer)
syncer.OnEventType(event.EventMessage, bot.onMessage)
syncer.OnEventType(event.StateMember, bot.onMember)
syncer.OnEventType(event.StateEncryption, bot.onEncryption)
syncer.OnSync(bot.onSync)
return bot, nil
}
// Run starts the sync loop.
func (b *Bot) Run(ctx context.Context) error {
mode := "active"
if b.cfg.Shadow {
mode = "shadow"
}
log.Printf("Matrix bot starting (%s mode, rooms: %d)", mode, len(b.allowedRooms))
log.Printf("Matrix allowlist rooms: %s", strings.Join(roomList(b.allowedRooms), ", "))
if b.cfg.HealthAddr != "" {
log.Printf("Matrix health server listening on %s", b.cfg.HealthAddr)
go b.startHealthServer(ctx)
}
if err := b.prefetchEncryptionState(ctx); err != nil {
log.Printf("Matrix encryption state check failed: %v", err)
}
go b.cleanupLoop(ctx)
go b.sendLoop(ctx)
return b.client.SyncWithContext(ctx)
}
// Close releases Matrix bot resources.
func (b *Bot) Close() error {
if b.stateStore != nil {
return b.stateStore.Close()
}
return nil
}
func (b *Bot) onMember(ctx context.Context, evt *event.Event) {
if evt == nil {
return
}
if evt.GetStateKey() != b.client.UserID.String() {
return
}
member := evt.Content.AsMember()
if member.Membership != event.MembershipInvite {
return
}
if isEncryptedInvite(evt) {
log.Printf("Matrix invite ignored (encrypted room): %s", evt.RoomID)
if _, err := b.client.LeaveRoom(ctx, evt.RoomID); err != nil {
log.Printf("Matrix invite leave failed for %s: %v", evt.RoomID, err)
}
return
}
if _, ok := b.allowedRooms[evt.RoomID]; ok {
if _, err := b.client.JoinRoomByID(ctx, evt.RoomID); err != nil {
log.Printf("Matrix invite join failed for %s: %v", evt.RoomID, err)
}
return
}
if _, err := b.client.LeaveRoom(ctx, evt.RoomID); err != nil {
log.Printf("Matrix invite leave failed for %s: %v", evt.RoomID, err)
}
}
func (b *Bot) onMessage(ctx context.Context, evt *event.Event) {
if evt == nil {
return
}
if _, ok := b.allowedRooms[evt.RoomID]; !ok {
return
}
if evt.Sender == b.client.UserID {
return
}
if b.isRoomEncrypted(evt.RoomID) {
b.stats.markEncryptedSkipped()
return
}
b.stats.markReceived()
if processed, err := b.stateStore.WasEventProcessed(ctx, evt.ID.String()); err != nil {
log.Printf("Matrix dedupe check failed (event %s): %v", evt.ID, err)
} else if processed {
return
}
content := evt.Content.AsMessage()
if content == nil {
return
}
if !content.MsgType.IsText() {
return
}
if content.RelatesTo != nil && content.RelatesTo.GetReplaceID() != "" {
return
}
if content.Body == "" {
return
}
response, ok := b.handler(ctx, content.Body, evt.Sender.String())
if !ok {
return
}
if b.cfg.Shadow {
log.Printf("Matrix shadow reply (room %s, event %s): %s", evt.RoomID, evt.ID, response)
b.stats.markResponded()
if _, err := b.stateStore.MarkEventProcessed(ctx, evt.ID.String()); err != nil {
log.Printf("Matrix dedupe record failed (event %s): %v", evt.ID, err)
}
return
}
select {
case b.sendQueue <- sendRequest{roomID: evt.RoomID, event: evt, response: response}:
default:
b.stats.markDropped()
if _, err := b.stateStore.MarkEventProcessed(ctx, evt.ID.String()); err != nil {
log.Printf("Matrix dedupe record failed after drop (event %s): %v", evt.ID, err)
}
log.Printf("Matrix send queue full; dropping response for event %s", evt.ID)
}
}
func (b *Bot) sendReply(ctx context.Context, evt *event.Event, response string) error {
body, formatted := formatNoPreview(response)
content := &event.MessageEventContent{
MsgType: event.MsgText,
Body: body,
BeeperLinkPreviews: []*event.BeeperLinkPreview{},
}
if formatted != "" {
content.Format = event.FormatHTML
content.FormattedBody = formatted
}
original := evt.Content.AsMessage()
if original != nil && original.RelatesTo != nil && original.RelatesTo.GetThreadParent() != "" {
content.SetThread(evt)
} else {
content.SetReply(evt)
}
_, err := b.client.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, content)
return err
}
func (b *Bot) sendLoop(ctx context.Context) {
for {
select {
case <-ctx.Done():
if pending := len(b.sendQueue); pending > 0 {
log.Printf("Matrix send queue pending on shutdown: %d", pending)
}
return
case req := <-b.sendQueue:
b.sendWithRetry(ctx, req)
}
}
}
func (b *Bot) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(6 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := b.stateStore.Cleanup(ctx); err != nil {
log.Printf("Matrix state store cleanup failed: %v", err)
}
}
}
}
func (b *Bot) sendWithRetry(ctx context.Context, req sendRequest) {
const maxAttempts = 5
backoff := time.Second
for attempt := 1; attempt <= maxAttempts; attempt++ {
if err := b.sendReply(ctx, req.event, req.response); err != nil {
delay, ok := rateLimitDelay(err)
if ok {
b.stats.markRateLimited()
b.stats.setLastSendError(err.Error())
if delay <= 0 {
delay = backoff
}
log.Printf("Matrix rate limited (event %s), retrying in %s", req.event.ID, delay)
if !sleepContext(ctx, delay) {
return
}
backoff = minDuration(backoff*2, 30*time.Second)
continue
}
b.stats.setLastSendError(err.Error())
log.Printf("Matrix send failed (room %s, event %s): %v", req.roomID, req.event.ID, err)
return
}
b.stats.markResponded()
if _, err := b.stateStore.MarkEventProcessed(ctx, req.event.ID.String()); err != nil {
log.Printf("Matrix dedupe record failed (event %s): %v", req.event.ID, err)
}
return
}
b.stats.setLastSendError("max retries exceeded")
log.Printf("Matrix send failed after retries (room %s, event %s)", req.roomID, req.event.ID)
}
func rateLimitDelay(err error) (time.Duration, bool) {
if !errors.Is(err, mautrix.MLimitExceeded) {
return 0, false
}
var httpErr mautrix.HTTPError
if errors.As(err, &httpErr) && httpErr.RespError != nil {
if retry, ok := httpErr.RespError.ExtraData["retry_after_ms"]; ok {
switch value := retry.(type) {
case float64:
return time.Duration(value) * time.Millisecond, true
case int64:
return time.Duration(value) * time.Millisecond, true
case int:
return time.Duration(value) * time.Millisecond, true
}
}
}
return 0, true
}
func sleepContext(ctx context.Context, delay time.Duration) bool {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}
func minDuration(a, b time.Duration) time.Duration {
if a < b {
return a
}
return b
}
func (b *Bot) onEncryption(ctx context.Context, evt *event.Event) {
if evt == nil {
return
}
if _, ok := b.allowedRooms[evt.RoomID]; !ok {
return
}
if !b.isRoomEncrypted(evt.RoomID) {
b.encryptedRooms[evt.RoomID] = struct{}{}
log.Printf("Matrix room marked encrypted; skipping messages: %s", evt.RoomID)
}
}
func (b *Bot) onSync(ctx context.Context, _ *mautrix.RespSync, _ string) bool {
b.stats.markSync()
return true
}
func (b *Bot) isRoomEncrypted(roomID id.RoomID) bool {
_, ok := b.encryptedRooms[roomID]
return ok
}
func (b *Bot) prefetchEncryptionState(ctx context.Context) error {
for roomID := range b.allowedRooms {
var content event.EncryptionEventContent
if err := b.client.StateEvent(ctx, roomID, event.StateEncryption, "", &content); err != nil {
if errors.Is(err, mautrix.MNotFound) {
continue
}
log.Printf("Matrix encryption state fetch failed for %s: %v", roomID, err)
continue
}
if !b.isRoomEncrypted(roomID) {
b.encryptedRooms[roomID] = struct{}{}
log.Printf("Matrix room marked encrypted (state fetch); skipping messages: %s", roomID)
}
}
return nil
}
func roomList(rooms map[id.RoomID]struct{}) []string {
list := make([]string, 0, len(rooms))
for roomID := range rooms {
list = append(list, roomID.String())
}
return list
}
func formatNoPreview(response string) (string, string) {
lines := strings.Split(response, "\n")
bodyLines := make([]string, 0, len(lines))
formattedLines := make([]string, 0, len(lines))
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
bodyLines = append(bodyLines, "")
formattedLines = append(formattedLines, "")
continue
}
label, url, ok := splitLinkLine(trimmed)
if ok {
bodyLines = append(bodyLines, fmt.Sprintf("%s: %s", label, obfuscateURL(url)))
formattedLines = append(formattedLines, fmt.Sprintf("%s: <a href=\"%s\">%s</a>", html.EscapeString(label), html.EscapeString(url), html.EscapeString(url)))
continue
}
bodyLines = append(bodyLines, line)
formattedLines = append(formattedLines, html.EscapeString(line))
}
body := strings.Join(bodyLines, "\n")
formatted := strings.Join(formattedLines, "<br/>")
return body, formatted
}
func splitLinkLine(line string) (string, string, bool) {
parts := strings.SplitN(line, ": ", 2)
if len(parts) != 2 {
return "", "", false
}
label := parts[0]
url := strings.TrimSpace(parts[1])
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
return "", "", false
}
return label, url, true
}
func obfuscateURL(url string) string {
return url
}
func isEncryptedInvite(evt *event.Event) bool {
if evt == nil || evt.Unsigned.InviteRoomState == nil {
return false
}
for _, state := range evt.Unsigned.InviteRoomState {
if state == nil {
continue
}
if state.Type == event.StateEncryption {
return true
}
}
return false
}