package routes import ( "bytes" "encoding/json" "errors" "fmt" "github.com/labstack/echo/v4" "google.golang.org/protobuf/proto" "io" "net/http" models "vibeStonk/server/models/v1" "vibeStonk/server/services" ) var ( errNoUsername = errors.New("missing username") errNoPrefname = errors.New("missing preferred name") errNoPassword = errors.New("missing password") errMismatched = errors.New("passwords mismatched") ) func NewUserRoute(system *services.SystemServices) UserRoute { return &userRoute{system: system} } type UserRoute interface { Provide(middlewares *SystemMiddleware) []*Route } type userRoute struct { system *services.SystemServices } func (u *userRoute) Provide(middlewares *SystemMiddleware) []*Route { return []*Route{ {endpoint("/user"), http.MethodGet, u.handleGet, []echo.MiddlewareFunc{middlewares.UserAuth}}, {endpoint("/user"), http.MethodPost, u.handlePost, nil}, {endpoint("/user/login"), http.MethodPost, u.handleLogin, nil}, } } // region GET func (u *userRoute) handleGet(c echo.Context) error { // by the time we get here, we are guaranteed to have an authenticated user rCtx, err := GetRequestContext(c) if err != nil { return fmt.Errorf("userRoute handleGet failed to get request context: %w", err) } user := rCtx.User user.Hash = "" payload, err := proto.Marshal(user) if err != nil { return fmt.Errorf("userRoute handleGet failed to serialize response: %w", err) } c.Response().Header().Set("Content-Type", "application/json") _, err = c.Response().Write(payload) if err != nil { return fmt.Errorf("userRoute handleGet failed to write response: %w", err) } return nil } // endregion // region POST func (u *userRoute) handlePost(c echo.Context) error { registration := &models.UserRegistration{} bodyReader := c.Request().Body defer bodyReader.Close() var err error contentType := c.Request().Header.Get("Content-Type") switch contentType { case "application/json": var bodyContent []byte bodyContent, err = io.ReadAll(bodyReader) if err != nil { return fmt.Errorf("userRoute handlePost failed to read request body: %w", err) } err = json.Unmarshal(bodyContent, registration) if err != nil { return fmt.Errorf("failed to deserialize proto request: %w", err) } case "application/x-protobuf": var bodyContent []byte bodyContent, err = io.ReadAll(bodyReader) if err != nil { return fmt.Errorf("userRoute handlePost failed to read request body: %w", err) } err = proto.Unmarshal(bodyContent, registration) if err != nil { return fmt.Errorf("failed to deserialize proto request: %w", err) } case "multipart/form-data; boundary=X-INSOMNIA-BOUNDARY": // Parse the multipart form err = c.Request().ParseMultipartForm(10 << 20) // 10 MB max memory if err != nil { return fmt.Errorf("failed to parse multipart form: %w", err) } // Extract form fields userName := c.FormValue("userName") prefName := c.FormValue("prefName") pswd := c.FormValue("pswd") pswdConfirm := c.FormValue("pswdConfirm") // Populate the registration object registration.UserName = userName registration.PrefName = prefName registration.Pswd = []byte(pswd) registration.PswdConfirm = []byte(pswdConfirm) default: return fmt.Errorf("did not understand content-type: %s", contentType) } err = validateRegistration(registration) if err != nil { return fmt.Errorf("failed to validate registration: %w", err) } user, err := u.system.AuthService.RegisterUser(registration.UserName, registration.PrefName, registration.Pswd) if err != nil { return fmt.Errorf("failed to create user: %w", err) } session, err := u.system.AuthService.AuthenticateUser(user.UserName, registration.Pswd) if err != nil { return fmt.Errorf("failed to authenticate newly-created user: %w", err) } resp := &models.UserRegistrationResponse{ User: user, Token: session.Token, } payload, err := proto.Marshal(resp) if err != nil { return fmt.Errorf("failed to marshal response: %w", err) } _, err = c.Response().Write(payload) if err != nil { return fmt.Errorf("failed to write response to client: %w", err) } return nil } func validateRegistration(registration *models.UserRegistration) error { if len(registration.UserName) == 0 { return errNoUsername } if len(registration.PrefName) == 0 { return errNoPrefname } if len(registration.Pswd) == 0 { return errNoPassword } if len(registration.Pswd) != len(registration.PswdConfirm) || bytes.Compare(registration.Pswd, registration.PswdConfirm) != 0 { return errMismatched } return nil } // endregion // region LOGIN func (u *userRoute) handleLogin(c echo.Context) error { login := &models.Login{} bodyReader := c.Request().Body defer bodyReader.Close() var err error contentType := c.Request().Header.Get("Content-Type") switch contentType { case "application/json": var bodyContent []byte bodyContent, err = io.ReadAll(bodyReader) if err != nil { return fmt.Errorf("userRoute handleLogin failed to read request body: %w", err) } err = json.Unmarshal(bodyContent, login) if err != nil { return fmt.Errorf("failed to deserialize json request: %w", err) } case "multipart/form-data; boundary=X-INSOMNIA-BOUNDARY": // Parse the multipart form err = c.Request().ParseMultipartForm(10 << 20) // 10 MB max memory if err != nil { return fmt.Errorf("failed to parse multipart form: %w", err) } // Extract form fields login.UserName = c.FormValue("userName") login.Password = []byte(c.FormValue("password")) case "application/x-protobuf": payload, err := io.ReadAll(c.Request().Body) if err != nil { return fmt.Errorf("failed to read protobuf request body: %w", err) } err = proto.Unmarshal(payload, login) if err != nil { return fmt.Errorf("failed to unmarshal protobuf payload: %w", err) } default: return fmt.Errorf("did not understand content-type: %s", contentType) } // Validate login request if len(login.UserName) == 0 { return errNoUsername } if len(login.Password) == 0 { return errNoPassword } // Authenticate user session, err := u.system.AuthService.AuthenticateUser(login.UserName, []byte(login.Password)) if err != nil { return fmt.Errorf("failed to authenticate user: %w", err) } // Get user details user, err := u.system.AuthService.AuthenticateToken(session.Token) if err != nil { return fmt.Errorf("failed to get user details: %w", err) } // Clear password hash for security user.Hash = "" // Create response resp := &models.UserRegistrationResponse{ User: user, Token: session.Token, } payload, err := proto.Marshal(resp) if err != nil { return fmt.Errorf("failed to marshal response: %w", err) } _, err = c.Response().Write(payload) if err != nil { return fmt.Errorf("failed to write response to client: %w", err) } return nil } // endregion