vibeStonk/server/repository/sqliteSession.go
2025-06-12 16:57:42 -04:00

166 lines
4.0 KiB
Go

package repository
import (
"database/sql"
"errors"
"fmt"
"github.com/google/uuid"
"google.golang.org/protobuf/types/known/timestamppb"
"time"
models "vibeStonk/server/models/v1"
)
var (
ErrSessionNotFound = errors.New("session not found")
)
func newSqliteSessionRepo(db *sql.DB) SessionRepo {
repo := &sqliteSessionRepo{db: db}
if err := repo.initialize(); err != nil {
// Since we can't return an error from this function, we'll panic
// In a production environment, this should be handled differently
panic(fmt.Sprintf("failed to initialize session repository: %v", err))
}
return repo
}
type sqliteSessionRepo struct {
db *sql.DB
}
// initialize creates the sessions table if it doesn't exist
func (s *sqliteSessionRepo) initialize() error {
query := `
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
revoked BOOLEAN NOT NULL DEFAULT 0,
token TEXT UNIQUE NOT NULL,
created TIMESTAMP NOT NULL,
expires TIMESTAMP NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id)
);
`
_, err := s.db.Exec(query)
if err != nil {
return fmt.Errorf("failed to create sessions table: %w", err)
}
query = `
CREATE INDEX IF NOT EXISTS sessions_token
ON sessions(token);
`
_, err = s.db.Exec(query)
if err != nil {
return fmt.Errorf("failed to create token index for sessions table: %w", err)
}
return nil
}
func (s *sqliteSessionRepo) Create(user *models.User) (*models.Session, error) {
// Generate a unique ID and token for the session
id := uuid.New().String()
token := uuid.New().String()
// Set creation time to now and expiration to 24 hours from now
now := time.Now()
expires := now.Add(24 * time.Hour)
// Create the session object
session := &models.Session{
Id: id,
UserID: user.Id,
Revoked: false,
Token: token,
Created: timestamppb.New(now),
Expires: timestamppb.New(expires),
}
// Insert the session into the database
query := `
INSERT INTO sessions (id, user_id, revoked, token, created, expires)
VALUES (?, ?, ?, ?, ?, ?)
`
_, err := s.db.Exec(query, session.Id, session.UserID, session.Revoked, session.Token,
session.Created.AsTime(), session.Expires.AsTime())
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
return session, nil
}
func (s *sqliteSessionRepo) Get(token string) (*models.Session, error) {
query := `
SELECT id, user_id, revoked, token, created, expires
FROM sessions
WHERE token = ?
`
row := s.db.QueryRow(query, token)
session := &models.Session{}
var createdTime, expiresTime time.Time
err := row.Scan(&session.Id, &session.UserID, &session.Revoked, &session.Token, &createdTime, &expiresTime)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrSessionNotFound
}
return nil, fmt.Errorf("failed to get session: %w", err)
}
// Convert time.Time to protobuf Timestamp
session.Created = timestamppb.New(createdTime)
session.Expires = timestamppb.New(expiresTime)
return session, nil
}
func (s *sqliteSessionRepo) Revoke(session *models.Session) error {
query := `
UPDATE sessions
SET revoked = 1
WHERE id = ?
`
result, err := s.db.Exec(query, session.Id)
if err != nil {
return fmt.Errorf("failed to revoke session: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrSessionNotFound
}
// BeginUpdate the session object to reflect the change
session.Revoked = true
return nil
}
func (s *sqliteSessionRepo) DeleteExpired() error {
query := `
DELETE FROM sessions
WHERE expires < ?
`
result, err := s.db.Exec(query, time.Now())
if err != nil {
return fmt.Errorf("failed to delete expired sessions: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
// Log the number of deleted sessions (in a real application, you might want to use a logger)
fmt.Printf("Deleted %d expired sessions\n", rowsAffected)
return nil
}