musiclink/internal/matrixbot/state_store.go

198 lines
4.7 KiB
Go

// Package matrixbot provides Matrix state storage utilities.
package matrixbot
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"sync"
"time"
_ "modernc.org/sqlite"
"maunium.net/go/mautrix/id"
)
const (
defaultEventTTL = 24 * time.Hour
)
// StateStore persists sync tokens and processed event IDs.
type StateStore struct {
db *sql.DB
mu sync.Mutex
dedupeTTL time.Duration
lastCleanup time.Time
}
// NewStateStore opens or creates the state store database.
func NewStateStore(path string) (*StateStore, error) {
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return nil, fmt.Errorf("create state store dir: %w", err)
}
db, err := sql.Open("sqlite", path)
if err != nil {
return nil, fmt.Errorf("open state store: %w", err)
}
store := &StateStore{db: db, dedupeTTL: defaultEventTTL}
if err := store.init(); err != nil {
_ = db.Close()
return nil, err
}
return store, nil
}
func (s *StateStore) init() error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS sync_tokens (
user_id TEXT PRIMARY KEY,
filter_id TEXT,
next_batch TEXT
);`,
`CREATE TABLE IF NOT EXISTS processed_events (
event_id TEXT PRIMARY KEY,
processed_at INTEGER NOT NULL
);`,
}
for _, stmt := range stmts {
if _, err := s.db.Exec(stmt); err != nil {
return fmt.Errorf("init state store: %w", err)
}
}
return nil
}
// Close closes the state store database.
func (s *StateStore) Close() error {
return s.db.Close()
}
// SaveFilterID stores the filter ID for the given user.
func (s *StateStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error {
_, err := s.db.ExecContext(ctx,
`INSERT INTO sync_tokens (user_id, filter_id) VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET filter_id = excluded.filter_id`,
userID.String(), filterID,
)
return err
}
// LoadFilterID loads the filter ID for the given user.
func (s *StateStore) LoadFilterID(ctx context.Context, userID id.UserID) (string, error) {
var filterID sql.NullString
err := s.db.QueryRowContext(ctx,
`SELECT filter_id FROM sync_tokens WHERE user_id = ?`,
userID.String(),
).Scan(&filterID)
if err == sql.ErrNoRows {
return "", nil
}
if err != nil {
return "", err
}
if filterID.Valid {
return filterID.String, nil
}
return "", nil
}
// SaveNextBatch stores the next batch token for the given user.
func (s *StateStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
_, err := s.db.ExecContext(ctx,
`INSERT INTO sync_tokens (user_id, next_batch) VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET next_batch = excluded.next_batch`,
userID.String(), nextBatchToken,
)
return err
}
// LoadNextBatch loads the next batch token for the given user.
func (s *StateStore) LoadNextBatch(ctx context.Context, userID id.UserID) (string, error) {
var nextBatch sql.NullString
err := s.db.QueryRowContext(ctx,
`SELECT next_batch FROM sync_tokens WHERE user_id = ?`,
userID.String(),
).Scan(&nextBatch)
if err == sql.ErrNoRows {
return "", nil
}
if err != nil {
return "", err
}
if nextBatch.Valid {
return nextBatch.String, nil
}
return "", nil
}
// WasEventProcessed checks whether an event ID has already been recorded.
func (s *StateStore) WasEventProcessed(ctx context.Context, eventID string) (bool, error) {
var exists bool
err := s.db.QueryRowContext(ctx,
`SELECT 1 FROM processed_events WHERE event_id = ?`,
eventID,
).Scan(&exists)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// MarkEventProcessed records a processed event. Returns true if newly recorded.
func (s *StateStore) MarkEventProcessed(ctx context.Context, eventID string) (bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
if s.lastCleanup.IsZero() || now.Sub(s.lastCleanup) > s.dedupeTTL {
if err := s.cleanupLocked(ctx, now); err != nil {
return false, err
}
s.lastCleanup = now
}
res, err := s.db.ExecContext(ctx,
`INSERT OR IGNORE INTO processed_events (event_id, processed_at) VALUES (?, ?)`,
eventID, now.Unix(),
)
if err != nil {
return false, err
}
affected, err := res.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// Cleanup prunes expired processed event IDs.
func (s *StateStore) Cleanup(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
if err := s.cleanupLocked(ctx, now); err != nil {
return err
}
s.lastCleanup = now
return nil
}
func (s *StateStore) cleanupLocked(ctx context.Context, now time.Time) error {
cutoff := now.Add(-s.dedupeTTL).Unix()
_, err := s.db.ExecContext(ctx,
`DELETE FROM processed_events WHERE processed_at < ?`,
cutoff,
)
return err
}