207 lines
4.6 KiB
Go
207 lines
4.6 KiB
Go
package repository
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
models "vibeStonk/server/models/v1"
|
|
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
)
|
|
|
|
var (
|
|
ErrSaleNotFound = errors.New("sale not found")
|
|
)
|
|
|
|
func newSqliteSaleRepo(db *sql.DB) (SaleRepo, error) {
|
|
repo := &sqliteSaleRepo{db: db}
|
|
if err := repo.initialize(); err != nil {
|
|
return nil, err
|
|
}
|
|
return repo, nil
|
|
}
|
|
|
|
type sqliteSaleRepo struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// initialize creates the sales table if it doesn't exist
|
|
func (s *sqliteSaleRepo) initialize() error {
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS sales (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
sale_date TIMESTAMP NOT NULL,
|
|
stock_id INTEGER NOT NULL,
|
|
qty REAL NOT NULL,
|
|
price REAL NOT NULL
|
|
);
|
|
`
|
|
_, err := s.db.Exec(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create sales table: %w", err)
|
|
}
|
|
|
|
query = `
|
|
CREATE INDEX IF NOT EXISTS sales_stock_id
|
|
ON sales(stock_id);
|
|
`
|
|
_, err = s.db.Exec(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create stock_id index for sales table: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *sqliteSaleRepo) BeginSale(stockID int, saleDate time.Time, qty, price float64) (*models.Sale, *sql.Tx, error) {
|
|
query := `
|
|
INSERT INTO sales (sale_date, stock_id, qty, price)
|
|
VALUES (?, ?, ?, ?)
|
|
`
|
|
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to begin database tx: %w", err)
|
|
}
|
|
|
|
result, err := s.db.Exec(query, saleDate, stockID, qty, price)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return nil, nil, fmt.Errorf("failed to create sale: %w", err)
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return nil, nil, fmt.Errorf("failed to get last insert ID: %w", err)
|
|
}
|
|
|
|
return &models.Sale{
|
|
Id: id,
|
|
SaleDate: timestamppb.New(saleDate),
|
|
StockID: int64(stockID),
|
|
Qty: qty,
|
|
Price: price,
|
|
}, tx, nil
|
|
}
|
|
|
|
func (s *sqliteSaleRepo) Get(id int64) (*models.Sale, error) {
|
|
query := `
|
|
SELECT id, sale_date, stock_id, qty, price
|
|
FROM sales
|
|
WHERE id = ?
|
|
`
|
|
row := s.db.QueryRow(query, id)
|
|
|
|
sale := &models.Sale{}
|
|
var saleDate time.Time
|
|
err := row.Scan(&sale.Id, &saleDate, &sale.StockID, &sale.Qty, &sale.Price)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrSaleNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get sale: %w", err)
|
|
}
|
|
|
|
// Convert time.Time to protobuf timestamp
|
|
sale.SaleDate = timestamppb.New(saleDate)
|
|
|
|
return sale, nil
|
|
}
|
|
|
|
func (s *sqliteSaleRepo) GetByStockID(stockID int64) ([]*models.Sale, error) {
|
|
query := `
|
|
SELECT id, sale_date, stock_id, qty, price
|
|
FROM sales
|
|
WHERE stock_id = ?
|
|
`
|
|
rows, err := s.db.Query(query, stockID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get sales by stock ID: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sales []*models.Sale
|
|
for rows.Next() {
|
|
sale := &models.Sale{}
|
|
var saleDate time.Time
|
|
err := rows.Scan(&sale.Id, &saleDate, &sale.StockID, &sale.Qty, &sale.Price)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan sale row: %w", err)
|
|
}
|
|
// Convert time.Time to protobuf timestamp
|
|
sale.SaleDate = timestamppb.New(saleDate)
|
|
sales = append(sales, sale)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating sale rows: %w", err)
|
|
}
|
|
|
|
return sales, nil
|
|
}
|
|
|
|
func (s *sqliteSaleRepo) BeginDelete(sale *models.Sale) (*sql.Tx, error) {
|
|
query := `
|
|
DELETE FROM sales
|
|
WHERE id = ?
|
|
`
|
|
|
|
tx, err := s.db.Begin()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to begin database transaction: %w", err)
|
|
}
|
|
|
|
result, err := s.db.Exec(query, sale.Id)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return nil, fmt.Errorf("failed to delete sale: %w", err)
|
|
}
|
|
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return nil, fmt.Errorf("failed to get rows affected: %w", err)
|
|
}
|
|
|
|
if rowsAffected == 0 {
|
|
tx.Rollback()
|
|
return nil, ErrSaleNotFound
|
|
}
|
|
|
|
return tx, nil
|
|
}
|
|
|
|
func (s *sqliteSaleRepo) List() ([]*models.Sale, error) {
|
|
query := `
|
|
SELECT id, sale_date, stock_id, qty, price
|
|
FROM sales
|
|
ORDER BY sale_date DESC
|
|
`
|
|
rows, err := s.db.Query(query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list sales: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sales []*models.Sale
|
|
for rows.Next() {
|
|
sale := &models.Sale{}
|
|
var saleDate time.Time
|
|
err := rows.Scan(&sale.Id, &saleDate, &sale.StockID, &sale.Qty, &sale.Price)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan sale row: %w", err)
|
|
}
|
|
// Convert time.Time to protobuf timestamp
|
|
sale.SaleDate = timestamppb.New(saleDate)
|
|
sales = append(sales, sale)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("error iterating sale rows: %w", err)
|
|
}
|
|
|
|
return sales, nil
|
|
}
|