package repository import ( "database/sql" "errors" "fmt" "github.com/google/uuid" "google.golang.org/protobuf/types/known/timestamppb" "time" models "vibeStonk/server/models/v1" ) var ( ErrSessionNotFound = errors.New("session not found") ) func newSqliteSessionRepo(db *sql.DB) SessionRepo { repo := &sqliteSessionRepo{db: db} if err := repo.initialize(); err != nil { // Since we can't return an error from this function, we'll panic // In a production environment, this should be handled differently panic(fmt.Sprintf("failed to initialize session repository: %v", err)) } return repo } type sqliteSessionRepo struct { db *sql.DB } // initialize creates the sessions table if it doesn't exist func (s *sqliteSessionRepo) initialize() error { query := ` CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, revoked BOOLEAN NOT NULL DEFAULT 0, token TEXT UNIQUE NOT NULL, created TIMESTAMP NOT NULL, expires TIMESTAMP NOT NULL, FOREIGN KEY (user_id) REFERENCES users(id) ); ` _, err := s.db.Exec(query) if err != nil { return fmt.Errorf("failed to create sessions table: %w", err) } query = ` CREATE INDEX IF NOT EXISTS sessions_token ON sessions(token); ` _, err = s.db.Exec(query) if err != nil { return fmt.Errorf("failed to create token index for sessions table: %w", err) } return nil } func (s *sqliteSessionRepo) Create(user *models.User) (*models.Session, error) { // Generate a unique ID and token for the session id := uuid.New().String() token := uuid.New().String() // Set creation time to now and expiration to 24 hours from now now := time.Now() expires := now.Add(24 * time.Hour) // Create the session object session := &models.Session{ Id: id, UserID: user.Id, Revoked: false, Token: token, Created: timestamppb.New(now), Expires: timestamppb.New(expires), } // Insert the session into the database query := ` INSERT INTO sessions (id, user_id, revoked, token, created, expires) VALUES (?, ?, ?, ?, ?, ?) ` _, err := s.db.Exec(query, session.Id, session.UserID, session.Revoked, session.Token, session.Created.AsTime(), session.Expires.AsTime()) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } return session, nil } func (s *sqliteSessionRepo) Get(token string) (*models.Session, error) { query := ` SELECT id, user_id, revoked, token, created, expires FROM sessions WHERE token = ? ` row := s.db.QueryRow(query, token) session := &models.Session{} var createdTime, expiresTime time.Time err := row.Scan(&session.Id, &session.UserID, &session.Revoked, &session.Token, &createdTime, &expiresTime) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrSessionNotFound } return nil, fmt.Errorf("failed to get session: %w", err) } // Convert time.Time to protobuf Timestamp session.Created = timestamppb.New(createdTime) session.Expires = timestamppb.New(expiresTime) return session, nil } func (s *sqliteSessionRepo) Revoke(session *models.Session) error { query := ` UPDATE sessions SET revoked = 1 WHERE id = ? ` result, err := s.db.Exec(query, session.Id) if err != nil { return fmt.Errorf("failed to revoke session: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rowsAffected == 0 { return ErrSessionNotFound } // BeginUpdate the session object to reflect the change session.Revoked = true return nil } func (s *sqliteSessionRepo) DeleteExpired() error { query := ` DELETE FROM sessions WHERE expires < ? ` result, err := s.db.Exec(query, time.Now()) if err != nil { return fmt.Errorf("failed to delete expired sessions: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } // Log the number of deleted sessions (in a real application, you might want to use a logger) fmt.Printf("Deleted %d expired sessions\n", rowsAffected) return nil }