package repository import ( "database/sql" "errors" "fmt" "strings" models "vibeStonk/server/models/v1" ) var ( ErrStockNotFound = errors.New("stock not found") ) func newSqliteStockRepo(db *sql.DB) (StockRepo, error) { repo := &sqliteStockRepo{db: db} if err := repo.initialize(); err != nil { return nil, err } return repo, nil } type sqliteStockRepo struct { db *sql.DB } // initialize creates the stocks table if it doesn't exist func (s *sqliteStockRepo) initialize() error { query := ` CREATE TABLE IF NOT EXISTS stocks ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT UNIQUE NOT NULL, name TEXT NOT NULL, color TEXT ); ` _, err := s.db.Exec(query) if err != nil { return fmt.Errorf("failed to create stocks table: %w", err) } query = ` CREATE INDEX IF NOT EXISTS stocks_symbols ON stocks(symbol); ` _, err = s.db.Exec(query) if err != nil { return fmt.Errorf("failed to create symbol index for stocks table: %w", err) } return nil } func (s *sqliteStockRepo) Create(stock *models.Stock) (*models.Stock, error) { query := ` INSERT INTO stocks (symbol, name, color) VALUES (?, ?, ?) ` result, err := s.db.Exec(query, stock.Symbol, stock.Name, stock.Color) if err != nil { return nil, fmt.Errorf("failed to create stock: %w", err) } id, err := result.LastInsertId() if err != nil { return nil, fmt.Errorf("failed to get last insert ID: %w", err) } stock.Id = id return stock, nil } func (s *sqliteStockRepo) Get(id int64) (*models.Stock, error) { query := ` SELECT id, symbol, name, color FROM stocks WHERE id = ? ` row := s.db.QueryRow(query, id) stock := &models.Stock{} err := row.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrStockNotFound } return nil, fmt.Errorf("failed to get stock: %w", err) } return stock, nil } func (s *sqliteStockRepo) GetBySymbol(symbol string) (*models.Stock, error) { query := ` SELECT id, symbol, name, color FROM stocks WHERE symbol = ? ` row := s.db.QueryRow(query, symbol) stock := &models.Stock{} err := row.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrStockNotFound } return nil, fmt.Errorf("failed to get stock by symbol: %w", err) } return stock, nil } func (s *sqliteStockRepo) Update(stock *models.Stock) error { query := ` UPDATE stocks SET symbol = ?, name = ?, color = ? WHERE id = ? ` result, err := s.db.Exec(query, stock.Symbol, stock.Name, stock.Color, stock.Id) if err != nil { return fmt.Errorf("failed to update stock: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rowsAffected == 0 { return ErrStockNotFound } return nil } func (s *sqliteStockRepo) Delete(stock *models.Stock) error { query := ` DELETE FROM stocks WHERE id = ? ` result, err := s.db.Exec(query, stock.Id) if err != nil { return fmt.Errorf("failed to delete stock: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } if rowsAffected == 0 { return ErrStockNotFound } return nil } func (s *sqliteStockRepo) List() ([]*models.Stock, error) { query := ` SELECT id, symbol, name, color FROM stocks ORDER BY symbol ` rows, err := s.db.Query(query) if err != nil { return nil, fmt.Errorf("failed to list stocks: %w", err) } defer rows.Close() var stocks []*models.Stock for rows.Next() { stock := &models.Stock{} err := rows.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color) if err != nil { return nil, fmt.Errorf("failed to scan stock row: %w", err) } stocks = append(stocks, stock) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating stock rows: %w", err) } return stocks, nil } func (s *sqliteStockRepo) GetByIDs(ids []int) ([]*models.Stock, error) { query := buildGetByIDsQuery(ids) params := make([]interface{}, 0, len(ids)) for _, id := range ids { params = append(params, id) } rows, err := s.db.Query(query, params...) if err != nil { return nil, fmt.Errorf("failed to list stocks: %w", err) } defer rows.Close() var stocks []*models.Stock for rows.Next() { stock := &models.Stock{} err := rows.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color) if err != nil { return nil, fmt.Errorf("failed to scan stock row: %w", err) } stocks = append(stocks, stock) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating stock rows: %w", err) } return stocks, nil } func buildGetByIDsQuery(ids []int) string { sb := &strings.Builder{} sb.WriteString("SELECT id, symbol, name, color FROM stocks WHERE id in (") for i := range ids { sb.WriteByte('?') if i < len(ids)-1 { sb.WriteByte(',') } } sb.WriteString(") ORDER BY symbol") return sb.String() }