package repository import ( "database/sql" "errors" "fmt" "time" models "vibeStonk/server/models/v1" "google.golang.org/protobuf/types/known/timestamppb" ) var ( ErrSaleNotFound = errors.New("sale not found") ) func newSqliteSaleRepo(db *sql.DB) (SaleRepo, error) { repo := &sqliteSaleRepo{db: db} if err := repo.initialize(); err != nil { return nil, err } return repo, nil } type sqliteSaleRepo struct { db *sql.DB } // initialize creates the sales table if it doesn't exist func (s *sqliteSaleRepo) initialize() error { query := ` CREATE TABLE IF NOT EXISTS sales ( id INTEGER PRIMARY KEY AUTOINCREMENT, sale_date TIMESTAMP NOT NULL, stock_id INTEGER NOT NULL, qty REAL NOT NULL, price REAL NOT NULL ); ` _, err := s.db.Exec(query) if err != nil { return fmt.Errorf("failed to create sales table: %w", err) } query = ` CREATE INDEX IF NOT EXISTS sales_stock_id ON sales(stock_id); ` _, err = s.db.Exec(query) if err != nil { return fmt.Errorf("failed to create stock_id index for sales table: %w", err) } return nil } func (s *sqliteSaleRepo) BeginSale(stockID int, saleDate time.Time, qty, price float64) (*models.Sale, *sql.Tx, error) { query := ` INSERT INTO sales (sale_date, stock_id, qty, price) VALUES (?, ?, ?, ?) ` tx, err := s.db.Begin() if err != nil { return nil, nil, fmt.Errorf("failed to begin database tx: %w", err) } result, err := s.db.Exec(query, saleDate, stockID, qty, price) if err != nil { tx.Rollback() return nil, nil, fmt.Errorf("failed to create sale: %w", err) } id, err := result.LastInsertId() if err != nil { tx.Rollback() return nil, nil, fmt.Errorf("failed to get last insert ID: %w", err) } return &models.Sale{ Id: id, SaleDate: timestamppb.New(saleDate), StockID: int64(stockID), Qty: qty, Price: price, }, tx, nil } func (s *sqliteSaleRepo) Get(id int64) (*models.Sale, error) { query := ` SELECT id, sale_date, stock_id, qty, price FROM sales WHERE id = ? ` row := s.db.QueryRow(query, id) sale := &models.Sale{} var saleDate time.Time err := row.Scan(&sale.Id, &saleDate, &sale.StockID, &sale.Qty, &sale.Price) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrSaleNotFound } return nil, fmt.Errorf("failed to get sale: %w", err) } // Convert time.Time to protobuf timestamp sale.SaleDate = timestamppb.New(saleDate) return sale, nil } func (s *sqliteSaleRepo) GetByStockID(stockID int64) ([]*models.Sale, error) { query := ` SELECT id, sale_date, stock_id, qty, price FROM sales WHERE stock_id = ? ` rows, err := s.db.Query(query, stockID) if err != nil { return nil, fmt.Errorf("failed to get sales by stock ID: %w", err) } defer rows.Close() var sales []*models.Sale for rows.Next() { sale := &models.Sale{} var saleDate time.Time err := rows.Scan(&sale.Id, &saleDate, &sale.StockID, &sale.Qty, &sale.Price) if err != nil { return nil, fmt.Errorf("failed to scan sale row: %w", err) } // Convert time.Time to protobuf timestamp sale.SaleDate = timestamppb.New(saleDate) sales = append(sales, sale) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating sale rows: %w", err) } return sales, nil } func (s *sqliteSaleRepo) BeginDelete(sale *models.Sale) (*sql.Tx, error) { query := ` DELETE FROM sales WHERE id = ? ` tx, err := s.db.Begin() if err != nil { return nil, fmt.Errorf("failed to begin database transaction: %w", err) } result, err := s.db.Exec(query, sale.Id) if err != nil { tx.Rollback() return nil, fmt.Errorf("failed to delete sale: %w", err) } rowsAffected, err := result.RowsAffected() if err != nil { tx.Rollback() return nil, fmt.Errorf("failed to get rows affected: %w", err) } if rowsAffected == 0 { tx.Rollback() return nil, ErrSaleNotFound } return tx, nil } func (s *sqliteSaleRepo) List() ([]*models.Sale, error) { query := ` SELECT id, sale_date, stock_id, qty, price FROM sales ORDER BY sale_date DESC ` rows, err := s.db.Query(query) if err != nil { return nil, fmt.Errorf("failed to list sales: %w", err) } defer rows.Close() var sales []*models.Sale for rows.Next() { sale := &models.Sale{} var saleDate time.Time err := rows.Scan(&sale.Id, &saleDate, &sale.StockID, &sale.Qty, &sale.Price) if err != nil { return nil, fmt.Errorf("failed to scan sale row: %w", err) } // Convert time.Time to protobuf timestamp sale.SaleDate = timestamppb.New(saleDate) sales = append(sales, sale) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating sale rows: %w", err) } return sales, nil }