termcloud/internal/db/account.go

435 lines
13 KiB
Go

package db
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"time"
"git.keircn.com/keiran/termcloud/internal/config"
"github.com/jackc/pgx/v5/pgxpool"
)
type Account struct {
ID int64 `json:"id"`
AccountNumber string `json:"accountNumber"`
AccessToken string `json:"accessToken,omitempty"`
BalanceUSD float64 `json:"balanceUsd"`
IsActive bool `json:"isActive"`
CreatedAt time.Time `json:"createdAt"`
ActivatedAt *time.Time `json:"activatedAt,omitempty"`
LastBillingDate time.Time `json:"lastBillingDate"`
}
type Payment struct {
ID int64 `json:"id"`
AccountID int64 `json:"accountId"`
PaymentType string `json:"paymentType"`
BTCAddress string `json:"btcAddress,omitempty"`
BTCAmount float64 `json:"btcAmount,omitempty"`
USDAmount float64 `json:"usdAmount"`
Confirmations int `json:"confirmations"`
TxHash string `json:"txHash,omitempty"`
Status string `json:"status"`
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
CreatedAt time.Time `json:"createdAt"`
ConfirmedAt *time.Time `json:"confirmedAt,omitempty"`
}
type CryptoRate struct {
Currency string `json:"currency"`
USDRate float64 `json:"usdRate"`
UpdatedAt time.Time `json:"updatedAt"`
}
type UsageRecord struct {
ID int64 `json:"id"`
AccountID int64 `json:"accountId"`
BillingPeriodStart time.Time `json:"billingPeriodStart"`
BillingPeriodEnd time.Time `json:"billingPeriodEnd"`
MaxStorageBytes int64 `json:"maxStorageBytes"`
ChargeUSD float64 `json:"chargeUsd"`
ChargedAt *time.Time `json:"chargedAt,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
type AccountService struct {
pool *pgxpool.Pool
pricePerGB float64
bitcoinMasterAddr string
}
func NewAccountService(pool *pgxpool.Pool, cfg *config.Config) *AccountService {
pricePerGB := cfg.PricePerGBUSD
if pricePerGB == 0 {
pricePerGB = 0.50
}
return &AccountService{
pool: pool,
pricePerGB: pricePerGB,
bitcoinMasterAddr: cfg.BitcoinMasterAddress,
}
}
func (s *AccountService) GenerateAccountNumber() string {
for {
num, _ := rand.Int(rand.Reader, big.NewInt(9999999999999999))
accountNumber := fmt.Sprintf("%016d", num)
var exists bool
s.pool.QueryRow(context.Background(),
"SELECT EXISTS(SELECT 1 FROM accounts WHERE account_number = $1)",
accountNumber).Scan(&exists)
if !exists {
return accountNumber
}
}
}
func (s *AccountService) GenerateAccessToken() string {
bytes := make([]byte, 32)
rand.Read(bytes)
return fmt.Sprintf("%x", bytes)
}
func (s *AccountService) CreateAccount(ctx context.Context) (*Account, error) {
accountNumber := s.GenerateAccountNumber()
accessToken := s.GenerateAccessToken()
var account Account
err := s.pool.QueryRow(ctx, `
INSERT INTO accounts (account_number, access_token)
VALUES ($1, $2)
RETURNING id, account_number, balance_usd, is_active, created_at, last_billing_date`,
accountNumber, accessToken).Scan(
&account.ID, &account.AccountNumber, &account.BalanceUSD,
&account.IsActive, &account.CreatedAt, &account.LastBillingDate)
if err != nil {
return nil, fmt.Errorf("failed to create account: %w", err)
}
account.AccessToken = accessToken
return &account, nil
}
func (s *AccountService) GetAccountByToken(ctx context.Context, token string) (*Account, error) {
var account Account
var activatedAt *time.Time
err := s.pool.QueryRow(ctx, `
SELECT id, account_number, balance_usd, is_active, created_at, activated_at, last_billing_date
FROM accounts WHERE access_token = $1`, token).Scan(
&account.ID, &account.AccountNumber, &account.BalanceUSD,
&account.IsActive, &account.CreatedAt, &activatedAt, &account.LastBillingDate)
if err != nil {
return nil, fmt.Errorf("account not found: %w", err)
}
account.ActivatedAt = activatedAt
return &account, nil
}
func (s *AccountService) GetAccountByNumber(ctx context.Context, accountNumber string) (*Account, error) {
var account Account
var activatedAt *time.Time
err := s.pool.QueryRow(ctx, `
SELECT id, account_number, access_token, balance_usd, is_active, created_at, activated_at, last_billing_date
FROM accounts WHERE account_number = $1`, accountNumber).Scan(
&account.ID, &account.AccountNumber, &account.AccessToken, &account.BalanceUSD,
&account.IsActive, &account.CreatedAt, &activatedAt, &account.LastBillingDate)
if err != nil {
return nil, fmt.Errorf("account not found: %w", err)
}
account.ActivatedAt = activatedAt
return &account, nil
}
func (s *AccountService) CreatePayment(ctx context.Context, accountID int64, usdAmount float64) (*Payment, error) {
btcPrice, err := s.getBTCPrice(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BTC price: %w", err)
}
btcAmount := usdAmount / btcPrice
btcAddress := s.generateBTCAddress()
expiresAt := time.Now().Add(24 * time.Hour)
var payment Payment
err = s.pool.QueryRow(ctx, `
INSERT INTO payments (account_id, payment_type, btc_address, btc_amount, usd_amount, status, expires_at)
VALUES ($1, 'bitcoin', $2, $3, $4, 'pending', $5)
RETURNING id, account_id, payment_type, btc_address, btc_amount, usd_amount, confirmations, status, expires_at, created_at`,
accountID, btcAddress, btcAmount, usdAmount, expiresAt).Scan(
&payment.ID, &payment.AccountID, &payment.PaymentType, &payment.BTCAddress,
&payment.BTCAmount, &payment.USDAmount, &payment.Confirmations, &payment.Status, &payment.ExpiresAt, &payment.CreatedAt)
if err != nil {
return nil, fmt.Errorf("failed to create payment: %w", err)
}
return &payment, nil
}
func (s *AccountService) ConfirmPayment(ctx context.Context, paymentID int64, txHash string) error {
tx, err := s.pool.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
var payment Payment
err = tx.QueryRow(ctx, `
SELECT account_id, usd_amount, status
FROM payments WHERE id = $1`, paymentID).Scan(
&payment.AccountID, &payment.USDAmount, &payment.Status)
if err != nil {
return fmt.Errorf("payment not found: %w", err)
}
if payment.Status != "pending" {
return fmt.Errorf("payment already processed")
}
now := time.Now()
_, err = tx.Exec(ctx, `
UPDATE payments SET status = 'confirmed', tx_hash = $1, confirmations = 6, confirmed_at = $2
WHERE id = $3`, txHash, now, paymentID)
if err != nil {
return err
}
_, err = tx.Exec(ctx, `
UPDATE accounts SET balance_usd = balance_usd + $1, is_active = TRUE, activated_at = COALESCE(activated_at, $2)
WHERE id = $3`, payment.USDAmount, now, payment.AccountID)
if err != nil {
return err
}
return tx.Commit(ctx)
}
func (s *AccountService) RecordUsage(ctx context.Context, accountID int64, storageBytes int64) error {
now := time.Now()
billingStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
billingEnd := billingStart.AddDate(0, 1, 0).Add(-time.Second)
_, err := s.pool.Exec(ctx, `
INSERT INTO usage_records (account_id, billing_period_start, billing_period_end, max_storage_bytes)
VALUES ($1, $2, $3, $4)
ON CONFLICT (account_id, billing_period_start)
DO UPDATE SET max_storage_bytes = GREATEST(usage_records.max_storage_bytes, $4)`,
accountID, billingStart, billingEnd, storageBytes)
return err
}
func (s *AccountService) ProcessMonthlyBilling(ctx context.Context) error {
now := time.Now()
lastMonth := now.AddDate(0, -1, 0)
billingStart := time.Date(lastMonth.Year(), lastMonth.Month(), 1, 0, 0, 0, 0, lastMonth.Location())
rows, err := s.pool.Query(ctx, `
SELECT account_id, max_storage_bytes
FROM usage_records
WHERE billing_period_start = $1 AND charged_at IS NULL`,
billingStart)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var accountID int64
var maxStorageBytes int64
if err := rows.Scan(&accountID, &maxStorageBytes); err != nil {
continue
}
storageGB := float64(maxStorageBytes) / (1024 * 1024 * 1024)
charge := storageGB * s.pricePerGB
if charge > 0 {
err = s.chargeAccount(ctx, accountID, charge, billingStart)
if err != nil {
continue
}
}
_, err = s.pool.Exec(ctx, `
UPDATE usage_records
SET charge_usd = $1, charged_at = $2
WHERE account_id = $3 AND billing_period_start = $4`,
charge, now, accountID, billingStart)
}
return nil
}
func (s *AccountService) chargeAccount(ctx context.Context, accountID int64, amount float64, billingPeriod time.Time) error {
result, err := s.pool.Exec(ctx, `
UPDATE accounts
SET balance_usd = balance_usd - $1, last_billing_date = $2, is_active = (balance_usd - $1 >= 0)
WHERE id = $3 AND balance_usd >= $1`,
amount, billingPeriod, accountID)
if err != nil {
return err
}
rowsAffected := result.RowsAffected()
if rowsAffected == 0 {
_, err = s.pool.Exec(ctx, `
UPDATE accounts SET is_active = FALSE WHERE id = $1`, accountID)
return fmt.Errorf("insufficient balance for account %d", accountID)
}
return nil
}
func (s *AccountService) getBTCPrice(ctx context.Context) (float64, error) {
var rate float64
err := s.pool.QueryRow(ctx, `
SELECT usd_rate FROM crypto_rates WHERE currency = 'BTC'
AND updated_at > NOW() - INTERVAL '1 hour'`).Scan(&rate)
if err != nil {
rate = 45000.0
_, err = s.pool.Exec(ctx, `
INSERT INTO crypto_rates (currency, usd_rate) VALUES ('BTC', $1)
ON CONFLICT (currency) DO UPDATE SET usd_rate = $1, updated_at = NOW()`,
rate)
if err != nil {
return 45000.0, nil
}
}
return rate, nil
}
func (s *AccountService) CheckPaymentStatus(ctx context.Context, paymentID int64) (*Payment, error) {
var payment Payment
var expiresAt *time.Time
var confirmedAt *time.Time
err := s.pool.QueryRow(ctx, `
SELECT id, account_id, payment_type, btc_address, btc_amount, usd_amount,
confirmations, tx_hash, status, expires_at, created_at, confirmed_at
FROM payments WHERE id = $1`, paymentID).Scan(
&payment.ID, &payment.AccountID, &payment.PaymentType, &payment.BTCAddress,
&payment.BTCAmount, &payment.USDAmount, &payment.Confirmations, &payment.TxHash,
&payment.Status, &expiresAt, &payment.CreatedAt, &confirmedAt)
if err != nil {
return nil, fmt.Errorf("payment not found: %w", err)
}
payment.ExpiresAt = expiresAt
payment.ConfirmedAt = confirmedAt
if payment.Status == "pending" && expiresAt != nil && time.Now().After(*expiresAt) {
_, err = s.pool.Exec(ctx, `UPDATE payments SET status = 'expired' WHERE id = $1`, paymentID)
if err == nil {
payment.Status = "expired"
}
}
return &payment, nil
}
func (s *AccountService) GetAccountPayments(ctx context.Context, accountID int64) ([]Payment, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, account_id, payment_type, btc_address, btc_amount, usd_amount,
confirmations, tx_hash, status, expires_at, created_at, confirmed_at
FROM payments WHERE account_id = $1 ORDER BY created_at DESC`, accountID)
if err != nil {
return nil, err
}
defer rows.Close()
var payments []Payment
for rows.Next() {
var payment Payment
var expiresAt *time.Time
var confirmedAt *time.Time
err := rows.Scan(&payment.ID, &payment.AccountID, &payment.PaymentType,
&payment.BTCAddress, &payment.BTCAmount, &payment.USDAmount,
&payment.Confirmations, &payment.TxHash, &payment.Status,
&expiresAt, &payment.CreatedAt, &confirmedAt)
if err != nil {
continue
}
payment.ExpiresAt = expiresAt
payment.ConfirmedAt = confirmedAt
payments = append(payments, payment)
}
return payments, nil
}
func (s *AccountService) GetUsageRecords(ctx context.Context, accountID int64) ([]UsageRecord, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, account_id, billing_period_start, billing_period_end,
max_storage_bytes, charge_usd, charged_at, created_at
FROM usage_records WHERE account_id = $1 ORDER BY billing_period_start DESC`, accountID)
if err != nil {
return nil, err
}
defer rows.Close()
var records []UsageRecord
for rows.Next() {
var record UsageRecord
var chargedAt *time.Time
err := rows.Scan(&record.ID, &record.AccountID, &record.BillingPeriodStart,
&record.BillingPeriodEnd, &record.MaxStorageBytes, &record.ChargeUSD,
&chargedAt, &record.CreatedAt)
if err != nil {
continue
}
record.ChargedAt = chargedAt
records = append(records, record)
}
return records, nil
}
func (s *AccountService) CheckResourceLimits(ctx context.Context, accountID int64, requestedBytes int64) error {
var balance float64
err := s.pool.QueryRow(ctx, `SELECT balance_usd FROM accounts WHERE id = $1`, accountID).Scan(&balance)
if err != nil {
return fmt.Errorf("account not found: %w", err)
}
if balance <= 0 {
return fmt.Errorf("insufficient account balance")
}
maxBytesAffordable := int64(balance / s.pricePerGB * 1024 * 1024 * 1024)
if requestedBytes > maxBytesAffordable {
return fmt.Errorf("requested storage (%d bytes) exceeds affordable limit (%d bytes)",
requestedBytes, maxBytesAffordable)
}
return nil
}
func (s *AccountService) generateBTCAddress() string {
if s.bitcoinMasterAddr != "" {
return s.bitcoinMasterAddr
}
bytes := make([]byte, 20)
rand.Read(bytes)
return fmt.Sprintf("bc1q%x", bytes)
}