345 lines
8.9 KiB
Go
345 lines
8.9 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"math/big"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type User struct {
|
|
ID int `json:"id"`
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
IsVerified bool `json:"is_verified"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
type CreateUserRequest struct {
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type LoginRequest struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type VerifyCodeRequest struct {
|
|
Email string `json:"email"`
|
|
Code string `json:"code"`
|
|
}
|
|
|
|
type VerificationCode struct {
|
|
ID int `json:"id"`
|
|
UserID int `json:"user_id"`
|
|
Code string `json:"code"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
Used bool `json:"used"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
func GenerateVerificationCode() (string, error) {
|
|
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
const length = 6
|
|
|
|
result := make([]byte, length)
|
|
for i := range result {
|
|
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
result[i] = charset[num.Int64()]
|
|
}
|
|
return string(result), nil
|
|
}
|
|
|
|
func CreateUser(ctx context.Context, req CreateUserRequest) (*User, error) {
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var user User
|
|
err = Pool.QueryRow(ctx,
|
|
"INSERT INTO users (username, email, password_hash) VALUES ($1, $2, $3) RETURNING id, username, email, is_verified, created_at",
|
|
req.Username, strings.ToLower(req.Email), string(hashedPassword),
|
|
).Scan(&user.ID, &user.Username, &user.Email, &user.IsVerified, &user.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func CreateVerificationCode(ctx context.Context, userID int) (*VerificationCode, error) {
|
|
code, err := GenerateVerificationCode()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
expiresAt := time.Now().Add(10 * time.Minute)
|
|
|
|
var verificationCode VerificationCode
|
|
err = Pool.QueryRow(ctx,
|
|
"INSERT INTO verification_codes (user_id, code, expires_at) VALUES ($1, $2, $3) RETURNING id, user_id, code, expires_at, used, created_at",
|
|
userID, code, expiresAt,
|
|
).Scan(&verificationCode.ID, &verificationCode.UserID, &verificationCode.Code, &verificationCode.ExpiresAt, &verificationCode.Used, &verificationCode.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &verificationCode, nil
|
|
}
|
|
|
|
func VerifyCode(ctx context.Context, email, code string) error {
|
|
tx, err := Pool.Begin(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
var userID int
|
|
var codeID int
|
|
var expiresAt time.Time
|
|
var used bool
|
|
|
|
err = tx.QueryRow(ctx, `
|
|
SELECT u.id, vc.id, vc.expires_at, vc.used
|
|
FROM users u
|
|
JOIN verification_codes vc ON u.id = vc.user_id
|
|
WHERE u.email = $1 AND vc.code = $2
|
|
ORDER BY vc.created_at DESC
|
|
LIMIT 1
|
|
`, strings.ToLower(email), code).Scan(&userID, &codeID, &expiresAt, &used)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("invalid verification code")
|
|
}
|
|
|
|
if used {
|
|
return fmt.Errorf("verification code already used")
|
|
}
|
|
|
|
if time.Now().After(expiresAt) {
|
|
return fmt.Errorf("verification code expired")
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, "UPDATE verification_codes SET used = true WHERE id = $1", codeID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, "UPDATE users SET is_verified = true WHERE id = $1", userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
func GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
|
var user User
|
|
err := Pool.QueryRow(ctx,
|
|
"SELECT id, username, email, is_verified, created_at FROM users WHERE username = $1",
|
|
username,
|
|
).Scan(&user.ID, &user.Username, &user.Email, &user.IsVerified, &user.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func GetUserByEmail(ctx context.Context, email string) (*User, error) {
|
|
var user User
|
|
err := Pool.QueryRow(ctx,
|
|
"SELECT id, username, email, is_verified, created_at FROM users WHERE email = $1",
|
|
strings.ToLower(email),
|
|
).Scan(&user.ID, &user.Username, &user.Email, &user.IsVerified, &user.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
type Termail struct {
|
|
ID int `json:"id"`
|
|
SenderID int `json:"sender_id"`
|
|
ReceiverID int `json:"receiver_id"`
|
|
Subject string `json:"subject"`
|
|
Content string `json:"content"`
|
|
IsRead bool `json:"is_read"`
|
|
SentAt time.Time `json:"sent_at"`
|
|
Sender string `json:"sender,omitempty"`
|
|
Receiver string `json:"receiver,omitempty"`
|
|
}
|
|
|
|
type SendTermailRequest struct {
|
|
ReceiverUsername string `json:"receiver_username"`
|
|
Subject string `json:"subject"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
func GetTermail(ctx context.Context, termailID, userID int) (*Termail, error) {
|
|
var t Termail
|
|
err := Pool.QueryRow(ctx, `
|
|
SELECT t.id, t.sender_id, t.receiver_id, t.subject, t.content, t.is_read, t.sent_at, u.username as sender
|
|
FROM termails t
|
|
JOIN users u ON t.sender_id = u.id
|
|
WHERE t.id = $1 AND t.receiver_id = $2
|
|
`, termailID, userID).Scan(&t.ID, &t.SenderID, &t.ReceiverID, &t.Subject, &t.Content, &t.IsRead, &t.SentAt, &t.Sender)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &t, nil
|
|
}
|
|
|
|
func SendTermail(ctx context.Context, senderID int, req SendTermailRequest) (*Termail, error) {
|
|
receiver, err := GetUserByUsername(ctx, req.ReceiverUsername)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("user not found: %s", req.ReceiverUsername)
|
|
}
|
|
|
|
var termail Termail
|
|
err = Pool.QueryRow(ctx,
|
|
"INSERT INTO termails (sender_id, receiver_id, subject, content) VALUES ($1, $2, $3, $4) RETURNING id, sender_id, receiver_id, subject, content, is_read, sent_at",
|
|
senderID, receiver.ID, req.Subject, req.Content,
|
|
).Scan(&termail.ID, &termail.SenderID, &termail.ReceiverID, &termail.Subject, &termail.Content, &termail.IsRead, &termail.SentAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &termail, nil
|
|
}
|
|
|
|
func GetInbox(ctx context.Context, userID int, limit, offset int) ([]Termail, error) {
|
|
query := `
|
|
SELECT t.id, t.sender_id, t.receiver_id, t.subject, t.content, t.is_read, t.sent_at, u.username as sender
|
|
FROM termails t
|
|
JOIN users u ON t.sender_id = u.id
|
|
WHERE t.receiver_id = $1
|
|
ORDER BY t.sent_at DESC
|
|
LIMIT $2 OFFSET $3
|
|
`
|
|
|
|
rows, err := Pool.Query(ctx, query, userID, limit, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var termails []Termail
|
|
for rows.Next() {
|
|
var t Termail
|
|
err := rows.Scan(&t.ID, &t.SenderID, &t.ReceiverID, &t.Subject, &t.Content, &t.IsRead, &t.SentAt, &t.Sender)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
termails = append(termails, t)
|
|
}
|
|
|
|
return termails, nil
|
|
}
|
|
|
|
func MarkTermailAsRead(ctx context.Context, termailID, userID int) error {
|
|
_, err := Pool.Exec(ctx,
|
|
"UPDATE termails SET is_read = true WHERE id = $1 AND receiver_id = $2",
|
|
termailID, userID,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func DeleteTermail(ctx context.Context, termailID, userID int) error {
|
|
result, err := Pool.Exec(ctx,
|
|
"DELETE FROM termails WHERE id = $1 AND receiver_id = $2",
|
|
termailID, userID,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rowsAffected := result.RowsAffected()
|
|
if rowsAffected == 0 {
|
|
return fmt.Errorf("termail not found or access denied")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func SearchTermails(ctx context.Context, userID int, query string, limit, offset int) ([]Termail, error) {
|
|
sqlQuery := `
|
|
SELECT t.id, t.sender_id, t.receiver_id, t.subject, t.content, t.is_read, t.sent_at, u.username as sender
|
|
FROM termails t
|
|
JOIN users u ON t.sender_id = u.id
|
|
WHERE t.receiver_id = $1 AND (t.subject ILIKE $2 OR t.content ILIKE $2 OR u.username ILIKE $2)
|
|
ORDER BY t.sent_at DESC
|
|
LIMIT $3 OFFSET $4
|
|
`
|
|
|
|
searchPattern := "%" + query + "%"
|
|
rows, err := Pool.Query(ctx, sqlQuery, userID, searchPattern, limit, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var termails []Termail
|
|
for rows.Next() {
|
|
var t Termail
|
|
err := rows.Scan(&t.ID, &t.SenderID, &t.ReceiverID, &t.Subject, &t.Content, &t.IsRead, &t.SentAt, &t.Sender)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
termails = append(termails, t)
|
|
}
|
|
|
|
return termails, nil
|
|
}
|
|
|
|
func CleanupUnverifiedUsers(ctx context.Context) error {
|
|
_, err := Pool.Exec(ctx, `
|
|
DELETE FROM users
|
|
WHERE is_verified = false
|
|
AND created_at < NOW() - INTERVAL '1 hour'
|
|
`)
|
|
return err
|
|
}
|
|
|
|
func ValidateUserCredentials(ctx context.Context, username, password string) (*User, error) {
|
|
var user User
|
|
var passwordHash string
|
|
err := Pool.QueryRow(ctx,
|
|
"SELECT id, username, email, password_hash, is_verified, created_at FROM users WHERE username = $1",
|
|
username,
|
|
).Scan(&user.ID, &user.Username, &user.Email, &passwordHash, &user.IsVerified, &user.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !user.IsVerified {
|
|
return nil, fmt.Errorf("account not verified")
|
|
}
|
|
|
|
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|