199 lines
4.8 KiB
Go
199 lines
4.8 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"math/big"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type User struct {
|
|
ID int `json:"id"`
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
IsVerified bool `json:"is_verified"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
type CreateUserRequest struct {
|
|
Username string `json:"username"`
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type LoginRequest struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type VerifyCodeRequest struct {
|
|
Email string `json:"email"`
|
|
Code string `json:"code"`
|
|
}
|
|
|
|
type VerificationCode struct {
|
|
ID int `json:"id"`
|
|
UserID int `json:"user_id"`
|
|
Code string `json:"code"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
Used bool `json:"used"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
func GenerateVerificationCode() (string, error) {
|
|
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
const length = 6
|
|
|
|
result := make([]byte, length)
|
|
for i := range result {
|
|
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
result[i] = charset[num.Int64()]
|
|
}
|
|
return string(result), nil
|
|
}
|
|
|
|
func CreateUser(ctx context.Context, req CreateUserRequest) (*User, error) {
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var user User
|
|
err = Pool.QueryRow(ctx,
|
|
"INSERT INTO users (username, email, password_hash) VALUES ($1, $2, $3) RETURNING id, username, email, is_verified, created_at",
|
|
req.Username, strings.ToLower(req.Email), string(hashedPassword),
|
|
).Scan(&user.ID, &user.Username, &user.Email, &user.IsVerified, &user.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func CreateVerificationCode(ctx context.Context, userID int) (*VerificationCode, error) {
|
|
code, err := GenerateVerificationCode()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
expiresAt := time.Now().Add(10 * time.Minute)
|
|
|
|
var verificationCode VerificationCode
|
|
err = Pool.QueryRow(ctx,
|
|
"INSERT INTO verification_codes (user_id, code, expires_at) VALUES ($1, $2, $3) RETURNING id, user_id, code, expires_at, used, created_at",
|
|
userID, code, expiresAt,
|
|
).Scan(&verificationCode.ID, &verificationCode.UserID, &verificationCode.Code, &verificationCode.ExpiresAt, &verificationCode.Used, &verificationCode.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &verificationCode, nil
|
|
}
|
|
|
|
func VerifyCode(ctx context.Context, email, code string) error {
|
|
tx, err := Pool.Begin(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
var userID int
|
|
var codeID int
|
|
var expiresAt time.Time
|
|
var used bool
|
|
|
|
err = tx.QueryRow(ctx, `
|
|
SELECT u.id, vc.id, vc.expires_at, vc.used
|
|
FROM users u
|
|
JOIN verification_codes vc ON u.id = vc.user_id
|
|
WHERE u.email = $1 AND vc.code = $2
|
|
ORDER BY vc.created_at DESC
|
|
LIMIT 1
|
|
`, strings.ToLower(email), code).Scan(&userID, &codeID, &expiresAt, &used)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("invalid verification code")
|
|
}
|
|
|
|
if used {
|
|
return fmt.Errorf("verification code already used")
|
|
}
|
|
|
|
if time.Now().After(expiresAt) {
|
|
return fmt.Errorf("verification code expired")
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, "UPDATE verification_codes SET used = true WHERE id = $1", codeID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, "UPDATE users SET is_verified = true WHERE id = $1", userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
func GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
|
var user User
|
|
err := Pool.QueryRow(ctx,
|
|
"SELECT id, username, email, is_verified, created_at FROM users WHERE username = $1",
|
|
username,
|
|
).Scan(&user.ID, &user.Username, &user.Email, &user.IsVerified, &user.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func GetUserByEmail(ctx context.Context, email string) (*User, error) {
|
|
var user User
|
|
err := Pool.QueryRow(ctx,
|
|
"SELECT id, username, email, is_verified, created_at FROM users WHERE email = $1",
|
|
strings.ToLower(email),
|
|
).Scan(&user.ID, &user.Username, &user.Email, &user.IsVerified, &user.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func ValidateUserCredentials(ctx context.Context, username, password string) (*User, error) {
|
|
var user User
|
|
var passwordHash string
|
|
err := Pool.QueryRow(ctx,
|
|
"SELECT id, username, email, password_hash, is_verified, created_at FROM users WHERE username = $1",
|
|
username,
|
|
).Scan(&user.ID, &user.Username, &user.Email, &passwordHash, &user.IsVerified, &user.CreatedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !user.IsVerified {
|
|
return nil, fmt.Errorf("account not verified")
|
|
}
|
|
|
|
err = bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|