269 lines
6.7 KiB
Go
269 lines
6.7 KiB
Go
package routes
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/labstack/echo/v4"
|
|
"google.golang.org/protobuf/proto"
|
|
"io"
|
|
"net/http"
|
|
models "vibeStonk/server/models/v1"
|
|
"vibeStonk/server/services"
|
|
)
|
|
|
|
var (
|
|
errNoUsername = errors.New("missing username")
|
|
errNoPrefname = errors.New("missing preferred name")
|
|
errNoPassword = errors.New("missing password")
|
|
errMismatched = errors.New("passwords mismatched")
|
|
)
|
|
|
|
func NewUserRoute(system *services.SystemServices) UserRoute {
|
|
return &userRoute{system: system}
|
|
}
|
|
|
|
type UserRoute interface {
|
|
Provide(middlewares *SystemMiddleware) []*Route
|
|
}
|
|
|
|
type userRoute struct {
|
|
system *services.SystemServices
|
|
}
|
|
|
|
func (u *userRoute) Provide(middlewares *SystemMiddleware) []*Route {
|
|
return []*Route{
|
|
{endpoint("/user"), http.MethodGet, u.handleGet, []echo.MiddlewareFunc{middlewares.UserAuth}},
|
|
{endpoint("/user"), http.MethodPost, u.handlePost, nil},
|
|
{endpoint("/user/login"), http.MethodPost, u.handleLogin, nil},
|
|
}
|
|
}
|
|
|
|
// region GET
|
|
|
|
func (u *userRoute) handleGet(c echo.Context) error {
|
|
// by the time we get here, we are guaranteed to have an authenticated user
|
|
rCtx, err := GetRequestContext(c)
|
|
if err != nil {
|
|
return fmt.Errorf("userRoute handleGet failed to get request context: %w", err)
|
|
}
|
|
|
|
user := rCtx.User
|
|
user.Hash = ""
|
|
payload, err := proto.Marshal(user)
|
|
if err != nil {
|
|
return fmt.Errorf("userRoute handleGet failed to serialize response: %w", err)
|
|
}
|
|
|
|
c.Response().Header().Set("Content-Type", "application/json")
|
|
_, err = c.Response().Write(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("userRoute handleGet failed to write response: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// endregion
|
|
|
|
// region POST
|
|
|
|
func (u *userRoute) handlePost(c echo.Context) error {
|
|
registration := &models.UserRegistration{}
|
|
bodyReader := c.Request().Body
|
|
defer bodyReader.Close()
|
|
|
|
var err error
|
|
contentType := c.Request().Header.Get("Content-Type")
|
|
switch contentType {
|
|
case "application/json":
|
|
var bodyContent []byte
|
|
bodyContent, err = io.ReadAll(bodyReader)
|
|
if err != nil {
|
|
return fmt.Errorf("userRoute handlePost failed to read request body: %w", err)
|
|
}
|
|
|
|
err = json.Unmarshal(bodyContent, registration)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to deserialize proto request: %w", err)
|
|
}
|
|
case "application/x-protobuf":
|
|
var bodyContent []byte
|
|
bodyContent, err = io.ReadAll(bodyReader)
|
|
if err != nil {
|
|
return fmt.Errorf("userRoute handlePost failed to read request body: %w", err)
|
|
}
|
|
|
|
err = proto.Unmarshal(bodyContent, registration)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to deserialize proto request: %w", err)
|
|
}
|
|
|
|
case "multipart/form-data; boundary=X-INSOMNIA-BOUNDARY":
|
|
// Parse the multipart form
|
|
err = c.Request().ParseMultipartForm(10 << 20) // 10 MB max memory
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse multipart form: %w", err)
|
|
}
|
|
|
|
// Extract form fields
|
|
userName := c.FormValue("userName")
|
|
prefName := c.FormValue("prefName")
|
|
pswd := c.FormValue("pswd")
|
|
pswdConfirm := c.FormValue("pswdConfirm")
|
|
|
|
// Populate the registration object
|
|
registration.UserName = userName
|
|
registration.PrefName = prefName
|
|
registration.Pswd = []byte(pswd)
|
|
registration.PswdConfirm = []byte(pswdConfirm)
|
|
default:
|
|
return fmt.Errorf("did not understand content-type: %s", contentType)
|
|
}
|
|
|
|
err = validateRegistration(registration)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to validate registration: %w", err)
|
|
}
|
|
|
|
user, err := u.system.AuthService.RegisterUser(registration.UserName, registration.PrefName, registration.Pswd)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
|
|
session, err := u.system.AuthService.AuthenticateUser(user.UserName, registration.Pswd)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to authenticate newly-created user: %w", err)
|
|
}
|
|
|
|
resp := &models.UserRegistrationResponse{
|
|
User: user,
|
|
Token: session.Token,
|
|
}
|
|
|
|
payload, err := proto.Marshal(resp)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal response: %w", err)
|
|
}
|
|
|
|
_, err = c.Response().Write(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write response to client: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateRegistration(registration *models.UserRegistration) error {
|
|
if len(registration.UserName) == 0 {
|
|
return errNoUsername
|
|
}
|
|
|
|
if len(registration.PrefName) == 0 {
|
|
return errNoPrefname
|
|
}
|
|
|
|
if len(registration.Pswd) == 0 {
|
|
return errNoPassword
|
|
}
|
|
|
|
if len(registration.Pswd) != len(registration.PswdConfirm) || bytes.Compare(registration.Pswd, registration.PswdConfirm) != 0 {
|
|
return errMismatched
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// endregion
|
|
|
|
// region LOGIN
|
|
|
|
func (u *userRoute) handleLogin(c echo.Context) error {
|
|
login := &models.Login{}
|
|
bodyReader := c.Request().Body
|
|
defer bodyReader.Close()
|
|
|
|
var err error
|
|
contentType := c.Request().Header.Get("Content-Type")
|
|
switch contentType {
|
|
case "application/json":
|
|
var bodyContent []byte
|
|
bodyContent, err = io.ReadAll(bodyReader)
|
|
if err != nil {
|
|
return fmt.Errorf("userRoute handleLogin failed to read request body: %w", err)
|
|
}
|
|
|
|
err = json.Unmarshal(bodyContent, login)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to deserialize json request: %w", err)
|
|
}
|
|
|
|
case "multipart/form-data; boundary=X-INSOMNIA-BOUNDARY":
|
|
// Parse the multipart form
|
|
err = c.Request().ParseMultipartForm(10 << 20) // 10 MB max memory
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse multipart form: %w", err)
|
|
}
|
|
|
|
// Extract form fields
|
|
login.UserName = c.FormValue("userName")
|
|
login.Password = []byte(c.FormValue("password"))
|
|
case "application/x-protobuf":
|
|
payload, err := io.ReadAll(c.Request().Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read protobuf request body: %w", err)
|
|
}
|
|
|
|
err = proto.Unmarshal(payload, login)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to unmarshal protobuf payload: %w", err)
|
|
}
|
|
default:
|
|
return fmt.Errorf("did not understand content-type: %s", contentType)
|
|
}
|
|
|
|
// Validate login request
|
|
if len(login.UserName) == 0 {
|
|
return errNoUsername
|
|
}
|
|
if len(login.Password) == 0 {
|
|
return errNoPassword
|
|
}
|
|
|
|
// Authenticate user
|
|
session, err := u.system.AuthService.AuthenticateUser(login.UserName, []byte(login.Password))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to authenticate user: %w", err)
|
|
}
|
|
|
|
// Get user details
|
|
user, err := u.system.AuthService.AuthenticateToken(session.Token)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get user details: %w", err)
|
|
}
|
|
|
|
// Clear password hash for security
|
|
user.Hash = ""
|
|
|
|
// Create response
|
|
resp := &models.UserRegistrationResponse{
|
|
User: user,
|
|
Token: session.Token,
|
|
}
|
|
|
|
payload, err := proto.Marshal(resp)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal response: %w", err)
|
|
}
|
|
|
|
_, err = c.Response().Write(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write response to client: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// endregion
|