vibeStonk/server/repository/sqliteStock.go
2025-06-12 16:57:42 -04:00

229 lines
4.9 KiB
Go

package repository
import (
"database/sql"
"errors"
"fmt"
"strings"
models "vibeStonk/server/models/v1"
)
var (
ErrStockNotFound = errors.New("stock not found")
)
func newSqliteStockRepo(db *sql.DB) (StockRepo, error) {
repo := &sqliteStockRepo{db: db}
if err := repo.initialize(); err != nil {
return nil, err
}
return repo, nil
}
type sqliteStockRepo struct {
db *sql.DB
}
// initialize creates the stocks table if it doesn't exist
func (s *sqliteStockRepo) initialize() error {
query := `
CREATE TABLE IF NOT EXISTS stocks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT UNIQUE NOT NULL,
name TEXT NOT NULL,
color TEXT
);
`
_, err := s.db.Exec(query)
if err != nil {
return fmt.Errorf("failed to create stocks table: %w", err)
}
query = `
CREATE INDEX IF NOT EXISTS stocks_symbols
ON stocks(symbol);
`
_, err = s.db.Exec(query)
if err != nil {
return fmt.Errorf("failed to create symbol index for stocks table: %w", err)
}
return nil
}
func (s *sqliteStockRepo) Create(stock *models.Stock) (*models.Stock, error) {
query := `
INSERT INTO stocks (symbol, name, color)
VALUES (?, ?, ?)
`
result, err := s.db.Exec(query, stock.Symbol, stock.Name, stock.Color)
if err != nil {
return nil, fmt.Errorf("failed to create stock: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("failed to get last insert ID: %w", err)
}
stock.Id = id
return stock, nil
}
func (s *sqliteStockRepo) Get(id int64) (*models.Stock, error) {
query := `
SELECT id, symbol, name, color
FROM stocks
WHERE id = ?
`
row := s.db.QueryRow(query, id)
stock := &models.Stock{}
err := row.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrStockNotFound
}
return nil, fmt.Errorf("failed to get stock: %w", err)
}
return stock, nil
}
func (s *sqliteStockRepo) GetBySymbol(symbol string) (*models.Stock, error) {
query := `
SELECT id, symbol, name, color
FROM stocks
WHERE symbol = ?
`
row := s.db.QueryRow(query, symbol)
stock := &models.Stock{}
err := row.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrStockNotFound
}
return nil, fmt.Errorf("failed to get stock by symbol: %w", err)
}
return stock, nil
}
func (s *sqliteStockRepo) Update(stock *models.Stock) error {
query := `
UPDATE stocks
SET symbol = ?, name = ?, color = ?
WHERE id = ?
`
result, err := s.db.Exec(query, stock.Symbol, stock.Name, stock.Color, stock.Id)
if err != nil {
return fmt.Errorf("failed to update stock: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrStockNotFound
}
return nil
}
func (s *sqliteStockRepo) Delete(stock *models.Stock) error {
query := `
DELETE FROM stocks
WHERE id = ?
`
result, err := s.db.Exec(query, stock.Id)
if err != nil {
return fmt.Errorf("failed to delete stock: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrStockNotFound
}
return nil
}
func (s *sqliteStockRepo) List() ([]*models.Stock, error) {
query := `
SELECT id, symbol, name, color
FROM stocks
ORDER BY symbol
`
rows, err := s.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to list stocks: %w", err)
}
defer rows.Close()
var stocks []*models.Stock
for rows.Next() {
stock := &models.Stock{}
err := rows.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color)
if err != nil {
return nil, fmt.Errorf("failed to scan stock row: %w", err)
}
stocks = append(stocks, stock)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating stock rows: %w", err)
}
return stocks, nil
}
func (s *sqliteStockRepo) GetByIDs(ids []int) ([]*models.Stock, error) {
query := buildGetByIDsQuery(ids)
params := make([]interface{}, 0, len(ids))
for _, id := range ids {
params = append(params, id)
}
rows, err := s.db.Query(query, params...)
if err != nil {
return nil, fmt.Errorf("failed to list stocks: %w", err)
}
defer rows.Close()
var stocks []*models.Stock
for rows.Next() {
stock := &models.Stock{}
err := rows.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color)
if err != nil {
return nil, fmt.Errorf("failed to scan stock row: %w", err)
}
stocks = append(stocks, stock)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating stock rows: %w", err)
}
return stocks, nil
}
func buildGetByIDsQuery(ids []int) string {
sb := &strings.Builder{}
sb.WriteString("SELECT id, symbol, name, color FROM stocks WHERE id in (")
for i := range ids {
sb.WriteByte('?')
if i < len(ids)-1 {
sb.WriteByte(',')
}
}
sb.WriteString(") ORDER BY symbol")
return sb.String()
}