229 lines
4.9 KiB
Go
229 lines
4.9 KiB
Go
package repository
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
models "vibeStonk/server/models/v1"
|
|
)
|
|
|
|
var (
|
|
ErrStockNotFound = errors.New("stock not found")
|
|
)
|
|
|
|
func newSqliteStockRepo(db *sql.DB) (StockRepo, error) {
|
|
repo := &sqliteStockRepo{db: db}
|
|
if err := repo.initialize(); err != nil {
|
|
return nil, err
|
|
}
|
|
return repo, nil
|
|
}
|
|
|
|
type sqliteStockRepo struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// initialize creates the stocks table if it doesn't exist
|
|
func (s *sqliteStockRepo) initialize() error {
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS stocks (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
symbol TEXT UNIQUE NOT NULL,
|
|
name TEXT NOT NULL,
|
|
color TEXT
|
|
);
|
|
`
|
|
_, err := s.db.Exec(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create stocks table: %w", err)
|
|
}
|
|
|
|
query = `
|
|
CREATE INDEX IF NOT EXISTS stocks_symbols
|
|
ON stocks(symbol);
|
|
`
|
|
_, err = s.db.Exec(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create symbol index for stocks table: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *sqliteStockRepo) Create(stock *models.Stock) (*models.Stock, error) {
|
|
query := `
|
|
INSERT INTO stocks (symbol, name, color)
|
|
VALUES (?, ?, ?)
|
|
`
|
|
result, err := s.db.Exec(query, stock.Symbol, stock.Name, stock.Color)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create stock: %w", err)
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get last insert ID: %w", err)
|
|
}
|
|
|
|
stock.Id = id
|
|
return stock, nil
|
|
}
|
|
|
|
func (s *sqliteStockRepo) Get(id int64) (*models.Stock, error) {
|
|
query := `
|
|
SELECT id, symbol, name, color
|
|
FROM stocks
|
|
WHERE id = ?
|
|
`
|
|
row := s.db.QueryRow(query, id)
|
|
|
|
stock := &models.Stock{}
|
|
err := row.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrStockNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get stock: %w", err)
|
|
}
|
|
|
|
return stock, nil
|
|
}
|
|
|
|
func (s *sqliteStockRepo) GetBySymbol(symbol string) (*models.Stock, error) {
|
|
query := `
|
|
SELECT id, symbol, name, color
|
|
FROM stocks
|
|
WHERE symbol = ?
|
|
`
|
|
row := s.db.QueryRow(query, symbol)
|
|
|
|
stock := &models.Stock{}
|
|
err := row.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrStockNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get stock by symbol: %w", err)
|
|
}
|
|
|
|
return stock, nil
|
|
}
|
|
|
|
func (s *sqliteStockRepo) Update(stock *models.Stock) error {
|
|
query := `
|
|
UPDATE stocks
|
|
SET symbol = ?, name = ?, color = ?
|
|
WHERE id = ?
|
|
`
|
|
result, err := s.db.Exec(query, stock.Symbol, stock.Name, stock.Color, stock.Id)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update stock: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return ErrStockNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *sqliteStockRepo) Delete(stock *models.Stock) error {
|
|
query := `
|
|
DELETE FROM stocks
|
|
WHERE id = ?
|
|
`
|
|
result, err := s.db.Exec(query, stock.Id)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete stock: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
return ErrStockNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *sqliteStockRepo) List() ([]*models.Stock, error) {
|
|
query := `
|
|
SELECT id, symbol, name, color
|
|
FROM stocks
|
|
ORDER BY symbol
|
|
`
|
|
rows, err := s.db.Query(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list stocks: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var stocks []*models.Stock
|
|
for rows.Next() {
|
|
stock := &models.Stock{}
|
|
err := rows.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan stock row: %w", err)
|
|
}
|
|
stocks = append(stocks, stock)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating stock rows: %w", err)
|
|
}
|
|
|
|
return stocks, nil
|
|
}
|
|
|
|
func (s *sqliteStockRepo) GetByIDs(ids []int) ([]*models.Stock, error) {
|
|
query := buildGetByIDsQuery(ids)
|
|
params := make([]interface{}, 0, len(ids))
|
|
for _, id := range ids {
|
|
params = append(params, id)
|
|
}
|
|
|
|
rows, err := s.db.Query(query, params...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list stocks: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var stocks []*models.Stock
|
|
for rows.Next() {
|
|
stock := &models.Stock{}
|
|
err := rows.Scan(&stock.Id, &stock.Symbol, &stock.Name, &stock.Color)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan stock row: %w", err)
|
|
}
|
|
stocks = append(stocks, stock)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating stock rows: %w", err)
|
|
}
|
|
|
|
return stocks, nil
|
|
}
|
|
|
|
func buildGetByIDsQuery(ids []int) string {
|
|
sb := &strings.Builder{}
|
|
sb.WriteString("SELECT id, symbol, name, color FROM stocks WHERE id in (")
|
|
for i := range ids {
|
|
sb.WriteByte('?')
|
|
if i < len(ids)-1 {
|
|
sb.WriteByte(',')
|
|
}
|
|
}
|
|
sb.WriteString(") ORDER BY symbol")
|
|
return sb.String()
|
|
}
|