198 lines
4.7 KiB
Go
198 lines
4.7 KiB
Go
// Package matrixbot provides Matrix state storage utilities.
|
|
package matrixbot
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
const (
|
|
defaultEventTTL = 24 * time.Hour
|
|
)
|
|
|
|
// StateStore persists sync tokens and processed event IDs.
|
|
type StateStore struct {
|
|
db *sql.DB
|
|
mu sync.Mutex
|
|
dedupeTTL time.Duration
|
|
lastCleanup time.Time
|
|
}
|
|
|
|
// NewStateStore opens or creates the state store database.
|
|
func NewStateStore(path string) (*StateStore, error) {
|
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
return nil, fmt.Errorf("create state store dir: %w", err)
|
|
}
|
|
|
|
db, err := sql.Open("sqlite", path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open state store: %w", err)
|
|
}
|
|
|
|
store := &StateStore{db: db, dedupeTTL: defaultEventTTL}
|
|
if err := store.init(); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return store, nil
|
|
}
|
|
|
|
func (s *StateStore) init() error {
|
|
stmts := []string{
|
|
`CREATE TABLE IF NOT EXISTS sync_tokens (
|
|
user_id TEXT PRIMARY KEY,
|
|
filter_id TEXT,
|
|
next_batch TEXT
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS processed_events (
|
|
event_id TEXT PRIMARY KEY,
|
|
processed_at INTEGER NOT NULL
|
|
);`,
|
|
}
|
|
for _, stmt := range stmts {
|
|
if _, err := s.db.Exec(stmt); err != nil {
|
|
return fmt.Errorf("init state store: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Close closes the state store database.
|
|
func (s *StateStore) Close() error {
|
|
return s.db.Close()
|
|
}
|
|
|
|
// SaveFilterID stores the filter ID for the given user.
|
|
func (s *StateStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error {
|
|
_, err := s.db.ExecContext(ctx,
|
|
`INSERT INTO sync_tokens (user_id, filter_id) VALUES (?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET filter_id = excluded.filter_id`,
|
|
userID.String(), filterID,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// LoadFilterID loads the filter ID for the given user.
|
|
func (s *StateStore) LoadFilterID(ctx context.Context, userID id.UserID) (string, error) {
|
|
var filterID sql.NullString
|
|
err := s.db.QueryRowContext(ctx,
|
|
`SELECT filter_id FROM sync_tokens WHERE user_id = ?`,
|
|
userID.String(),
|
|
).Scan(&filterID)
|
|
if err == sql.ErrNoRows {
|
|
return "", nil
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if filterID.Valid {
|
|
return filterID.String, nil
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
// SaveNextBatch stores the next batch token for the given user.
|
|
func (s *StateStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
|
|
_, err := s.db.ExecContext(ctx,
|
|
`INSERT INTO sync_tokens (user_id, next_batch) VALUES (?, ?)
|
|
ON CONFLICT(user_id) DO UPDATE SET next_batch = excluded.next_batch`,
|
|
userID.String(), nextBatchToken,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// LoadNextBatch loads the next batch token for the given user.
|
|
func (s *StateStore) LoadNextBatch(ctx context.Context, userID id.UserID) (string, error) {
|
|
var nextBatch sql.NullString
|
|
err := s.db.QueryRowContext(ctx,
|
|
`SELECT next_batch FROM sync_tokens WHERE user_id = ?`,
|
|
userID.String(),
|
|
).Scan(&nextBatch)
|
|
if err == sql.ErrNoRows {
|
|
return "", nil
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if nextBatch.Valid {
|
|
return nextBatch.String, nil
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
// WasEventProcessed checks whether an event ID has already been recorded.
|
|
func (s *StateStore) WasEventProcessed(ctx context.Context, eventID string) (bool, error) {
|
|
var exists bool
|
|
err := s.db.QueryRowContext(ctx,
|
|
`SELECT 1 FROM processed_events WHERE event_id = ?`,
|
|
eventID,
|
|
).Scan(&exists)
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
// MarkEventProcessed records a processed event. Returns true if newly recorded.
|
|
func (s *StateStore) MarkEventProcessed(ctx context.Context, eventID string) (bool, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
if s.lastCleanup.IsZero() || now.Sub(s.lastCleanup) > s.dedupeTTL {
|
|
if err := s.cleanupLocked(ctx, now); err != nil {
|
|
return false, err
|
|
}
|
|
s.lastCleanup = now
|
|
}
|
|
|
|
res, err := s.db.ExecContext(ctx,
|
|
`INSERT OR IGNORE INTO processed_events (event_id, processed_at) VALUES (?, ?)`,
|
|
eventID, now.Unix(),
|
|
)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
affected, err := res.RowsAffected()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return affected > 0, nil
|
|
}
|
|
|
|
// Cleanup prunes expired processed event IDs.
|
|
func (s *StateStore) Cleanup(ctx context.Context) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
if err := s.cleanupLocked(ctx, now); err != nil {
|
|
return err
|
|
}
|
|
s.lastCleanup = now
|
|
return nil
|
|
}
|
|
|
|
func (s *StateStore) cleanupLocked(ctx context.Context, now time.Time) error {
|
|
cutoff := now.Add(-s.dedupeTTL).Unix()
|
|
_, err := s.db.ExecContext(ctx,
|
|
`DELETE FROM processed_events WHERE processed_at < ?`,
|
|
cutoff,
|
|
)
|
|
return err
|
|
}
|