166 lines
4.0 KiB
Go
166 lines
4.0 KiB
Go
package repository
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/google/uuid"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
"time"
|
|
models "vibeStonk/server/models/v1"
|
|
)
|
|
|
|
var (
|
|
ErrSessionNotFound = errors.New("session not found")
|
|
)
|
|
|
|
func newSqliteSessionRepo(db *sql.DB) SessionRepo {
|
|
repo := &sqliteSessionRepo{db: db}
|
|
if err := repo.initialize(); err != nil {
|
|
// Since we can't return an error from this function, we'll panic
|
|
// In a production environment, this should be handled differently
|
|
panic(fmt.Sprintf("failed to initialize session repository: %v", err))
|
|
}
|
|
return repo
|
|
}
|
|
|
|
type sqliteSessionRepo struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// initialize creates the sessions table if it doesn't exist
|
|
func (s *sqliteSessionRepo) initialize() error {
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
id TEXT PRIMARY KEY,
|
|
user_id TEXT NOT NULL,
|
|
revoked BOOLEAN NOT NULL DEFAULT 0,
|
|
token TEXT UNIQUE NOT NULL,
|
|
created TIMESTAMP NOT NULL,
|
|
expires TIMESTAMP NOT NULL,
|
|
FOREIGN KEY (user_id) REFERENCES users(id)
|
|
);
|
|
`
|
|
_, err := s.db.Exec(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create sessions table: %w", err)
|
|
}
|
|
|
|
query = `
|
|
CREATE INDEX IF NOT EXISTS sessions_token
|
|
ON sessions(token);
|
|
`
|
|
_, err = s.db.Exec(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create token index for sessions table: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *sqliteSessionRepo) Create(user *models.User) (*models.Session, error) {
|
|
// Generate a unique ID and token for the session
|
|
id := uuid.New().String()
|
|
token := uuid.New().String()
|
|
|
|
// Set creation time to now and expiration to 24 hours from now
|
|
now := time.Now()
|
|
expires := now.Add(24 * time.Hour)
|
|
|
|
// Create the session object
|
|
session := &models.Session{
|
|
Id: id,
|
|
UserID: user.Id,
|
|
Revoked: false,
|
|
Token: token,
|
|
Created: timestamppb.New(now),
|
|
Expires: timestamppb.New(expires),
|
|
}
|
|
|
|
// Insert the session into the database
|
|
query := `
|
|
INSERT INTO sessions (id, user_id, revoked, token, created, expires)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
`
|
|
_, err := s.db.Exec(query, session.Id, session.UserID, session.Revoked, session.Token,
|
|
session.Created.AsTime(), session.Expires.AsTime())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create session: %w", err)
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func (s *sqliteSessionRepo) Get(token string) (*models.Session, error) {
|
|
query := `
|
|
SELECT id, user_id, revoked, token, created, expires
|
|
FROM sessions
|
|
WHERE token = ?
|
|
`
|
|
row := s.db.QueryRow(query, token)
|
|
|
|
session := &models.Session{}
|
|
var createdTime, expiresTime time.Time
|
|
|
|
err := row.Scan(&session.Id, &session.UserID, &session.Revoked, &session.Token, &createdTime, &expiresTime)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrSessionNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get session: %w", err)
|
|
}
|
|
|
|
// Convert time.Time to protobuf Timestamp
|
|
session.Created = timestamppb.New(createdTime)
|
|
session.Expires = timestamppb.New(expiresTime)
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func (s *sqliteSessionRepo) Revoke(session *models.Session) error {
|
|
query := `
|
|
UPDATE sessions
|
|
SET revoked = 1
|
|
WHERE id = ?
|
|
`
|
|
result, err := s.db.Exec(query, session.Id)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to revoke session: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return ErrSessionNotFound
|
|
}
|
|
|
|
// BeginUpdate the session object to reflect the change
|
|
session.Revoked = true
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *sqliteSessionRepo) DeleteExpired() error {
|
|
query := `
|
|
DELETE FROM sessions
|
|
WHERE expires < ?
|
|
`
|
|
result, err := s.db.Exec(query, time.Now())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete expired sessions: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
// Log the number of deleted sessions (in a real application, you might want to use a logger)
|
|
fmt.Printf("Deleted %d expired sessions\n", rowsAffected)
|
|
|
|
return nil
|
|
}
|