From 8da1c9346bfb6ad871da9957e327fcc750397482 Mon Sep 17 00:00:00 2001 From: Quaternions <krakow20@gmail.com> Date: Tue, 25 Mar 2025 16:45:30 -0700 Subject: [PATCH 1/7] openapi: add session endpoints --- openapi.yaml | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/openapi.yaml b/openapi.yaml index 1c75f06..e086cfa 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -6,6 +6,8 @@ info: servers: - url: https://submissions.strafes.net/v1 tags: + - name: Session + description: Session operations - name: Submissions description: Submission operations - name: Scripts @@ -15,6 +17,63 @@ tags: security: - cookieAuth: [] paths: + /session/user: + get: + summary: Get information about the currently logged in user + operationId: sessionUser + tags: + - Session + responses: + "200": + description: Successful response + content: + application/json: + schema: + $ref: "#/components/schemas/User" + default: + description: General Error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /session/roles: + get: + summary: Get list of roles for the current session + operationId: sessionRoles + tags: + - Session + responses: + "200": + description: Successful response + content: + application/json: + schema: + $ref: "#/components/schemas/Roles" + default: + description: General Error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" + /session/validate: + get: + summary: Ask if the current session is valid + operationId: sessionValidate + tags: + - Session + responses: + "200": + description: Successful response + content: + application/json: + schema: + type: boolean + default: + description: General Error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" /submissions: get: summary: Get list of submissions @@ -619,6 +678,30 @@ components: ID: type: integer format: int64 + Roles: + required: + - Roles + type: object + properties: + Roles: + type: integer + format: int32 + User: + required: + - UserID + - Username + - AvatarURL + type: object + properties: + UserID: + type: integer + format: int64 + Username: + type: string + maxLength: 128 + AvatarURL: + type: string + maxLength: 256 Submission: required: - ID -- 2.47.1 From d7634de9ec14908e48288428f709513861cfd456 Mon Sep 17 00:00:00 2001 From: Quaternions <krakow20@gmail.com> Date: Tue, 25 Mar 2025 16:49:22 -0700 Subject: [PATCH 2/7] openapi: generate --- pkg/api/oas_client_gen.go | 337 ++++++++++++++++- pkg/api/oas_handlers_gen.go | 544 ++++++++++++++++++++++++++- pkg/api/oas_json_gen.go | 226 +++++++++++ pkg/api/oas_operations_gen.go | 3 + pkg/api/oas_response_decoders_gen.go | 260 +++++++++++++ pkg/api/oas_response_encoders_gen.go | 42 +++ pkg/api/oas_router_gen.go | 162 +++++++- pkg/api/oas_schemas_gen.go | 52 +++ pkg/api/oas_server_gen.go | 20 +- pkg/api/oas_unimplemented_gen.go | 29 +- pkg/api/oas_validators_gen.go | 50 +++ 11 files changed, 1718 insertions(+), 7 deletions(-) diff --git a/pkg/api/oas_client_gen.go b/pkg/api/oas_client_gen.go index e7b244d..ab4d7a2 100644 --- a/pkg/api/oas_client_gen.go +++ b/pkg/api/oas_client_gen.go @@ -149,9 +149,27 @@ type Invoker interface { // // POST /release-submissions ReleaseSubmissions(ctx context.Context, request []ReleaseInfo) error + // SessionRoles invokes sessionRoles operation. + // + // Get list of roles for the current session. + // + // GET /session/roles + SessionRoles(ctx context.Context) (*Roles, error) + // SessionUser invokes sessionUser operation. + // + // Get information about the currently logged in user. + // + // GET /session/user + SessionUser(ctx context.Context) (*User, error) + // SessionValidate invokes sessionValidate operation. + // + // Ask if the current session is valid. + // + // GET /session/validate + SessionValidate(ctx context.Context) (bool, error) // SetSubmissionCompleted invokes setSubmissionCompleted operation. // - // Retrieve map with ID. + // Called by maptest when a player completes the map. // // POST /submissions/{SubmissionID}/completed SetSubmissionCompleted(ctx context.Context, params SetSubmissionCompletedParams) error @@ -2861,9 +2879,324 @@ func (c *Client) sendReleaseSubmissions(ctx context.Context, request []ReleaseIn return result, nil } +// SessionRoles invokes sessionRoles operation. +// +// Get list of roles for the current session. +// +// GET /session/roles +func (c *Client) SessionRoles(ctx context.Context) (*Roles, error) { + res, err := c.sendSessionRoles(ctx) + return res, err +} + +func (c *Client) sendSessionRoles(ctx context.Context) (res *Roles, err error) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("sessionRoles"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/session/roles"), + } + + // Run stopwatch. + startTime := time.Now() + defer func() { + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedDuration := time.Since(startTime) + c.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + // Start a span for this request. + ctx, span := c.cfg.Tracer.Start(ctx, SessionRolesOperation, + trace.WithAttributes(otelAttrs...), + clientSpanKind, + ) + // Track stage for error reporting. + var stage string + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + span.End() + }() + + stage = "BuildURL" + u := uri.Clone(c.requestURL(ctx)) + var pathParts [1]string + pathParts[0] = "/session/roles" + uri.AddPathParts(u, pathParts[:]...) + + stage = "EncodeRequest" + r, err := ht.NewRequest(ctx, "GET", u) + if err != nil { + return res, errors.Wrap(err, "create request") + } + + { + type bitset = [1]uint8 + var satisfied bitset + { + stage = "Security:CookieAuth" + switch err := c.securityCookieAuth(ctx, SessionRolesOperation, r); { + case err == nil: // if NO error + satisfied[0] |= 1 << 0 + case errors.Is(err, ogenerrors.ErrSkipClientSecurity): + // Skip this security. + default: + return res, errors.Wrap(err, "security \"CookieAuth\"") + } + } + + if ok := func() bool { + nextRequirement: + for _, requirement := range []bitset{ + {0b00000001}, + } { + for i, mask := range requirement { + if satisfied[i]&mask != mask { + continue nextRequirement + } + } + return true + } + return false + }(); !ok { + return res, ogenerrors.ErrSecurityRequirementIsNotSatisfied + } + } + + stage = "SendRequest" + resp, err := c.cfg.Client.Do(r) + if err != nil { + return res, errors.Wrap(err, "do request") + } + defer resp.Body.Close() + + stage = "DecodeResponse" + result, err := decodeSessionRolesResponse(resp) + if err != nil { + return res, errors.Wrap(err, "decode response") + } + + return result, nil +} + +// SessionUser invokes sessionUser operation. +// +// Get information about the currently logged in user. +// +// GET /session/user +func (c *Client) SessionUser(ctx context.Context) (*User, error) { + res, err := c.sendSessionUser(ctx) + return res, err +} + +func (c *Client) sendSessionUser(ctx context.Context) (res *User, err error) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("sessionUser"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/session/user"), + } + + // Run stopwatch. + startTime := time.Now() + defer func() { + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedDuration := time.Since(startTime) + c.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + // Start a span for this request. + ctx, span := c.cfg.Tracer.Start(ctx, SessionUserOperation, + trace.WithAttributes(otelAttrs...), + clientSpanKind, + ) + // Track stage for error reporting. + var stage string + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + span.End() + }() + + stage = "BuildURL" + u := uri.Clone(c.requestURL(ctx)) + var pathParts [1]string + pathParts[0] = "/session/user" + uri.AddPathParts(u, pathParts[:]...) + + stage = "EncodeRequest" + r, err := ht.NewRequest(ctx, "GET", u) + if err != nil { + return res, errors.Wrap(err, "create request") + } + + { + type bitset = [1]uint8 + var satisfied bitset + { + stage = "Security:CookieAuth" + switch err := c.securityCookieAuth(ctx, SessionUserOperation, r); { + case err == nil: // if NO error + satisfied[0] |= 1 << 0 + case errors.Is(err, ogenerrors.ErrSkipClientSecurity): + // Skip this security. + default: + return res, errors.Wrap(err, "security \"CookieAuth\"") + } + } + + if ok := func() bool { + nextRequirement: + for _, requirement := range []bitset{ + {0b00000001}, + } { + for i, mask := range requirement { + if satisfied[i]&mask != mask { + continue nextRequirement + } + } + return true + } + return false + }(); !ok { + return res, ogenerrors.ErrSecurityRequirementIsNotSatisfied + } + } + + stage = "SendRequest" + resp, err := c.cfg.Client.Do(r) + if err != nil { + return res, errors.Wrap(err, "do request") + } + defer resp.Body.Close() + + stage = "DecodeResponse" + result, err := decodeSessionUserResponse(resp) + if err != nil { + return res, errors.Wrap(err, "decode response") + } + + return result, nil +} + +// SessionValidate invokes sessionValidate operation. +// +// Ask if the current session is valid. +// +// GET /session/validate +func (c *Client) SessionValidate(ctx context.Context) (bool, error) { + res, err := c.sendSessionValidate(ctx) + return res, err +} + +func (c *Client) sendSessionValidate(ctx context.Context) (res bool, err error) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("sessionValidate"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/session/validate"), + } + + // Run stopwatch. + startTime := time.Now() + defer func() { + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedDuration := time.Since(startTime) + c.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + // Start a span for this request. + ctx, span := c.cfg.Tracer.Start(ctx, SessionValidateOperation, + trace.WithAttributes(otelAttrs...), + clientSpanKind, + ) + // Track stage for error reporting. + var stage string + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + span.End() + }() + + stage = "BuildURL" + u := uri.Clone(c.requestURL(ctx)) + var pathParts [1]string + pathParts[0] = "/session/validate" + uri.AddPathParts(u, pathParts[:]...) + + stage = "EncodeRequest" + r, err := ht.NewRequest(ctx, "GET", u) + if err != nil { + return res, errors.Wrap(err, "create request") + } + + { + type bitset = [1]uint8 + var satisfied bitset + { + stage = "Security:CookieAuth" + switch err := c.securityCookieAuth(ctx, SessionValidateOperation, r); { + case err == nil: // if NO error + satisfied[0] |= 1 << 0 + case errors.Is(err, ogenerrors.ErrSkipClientSecurity): + // Skip this security. + default: + return res, errors.Wrap(err, "security \"CookieAuth\"") + } + } + + if ok := func() bool { + nextRequirement: + for _, requirement := range []bitset{ + {0b00000001}, + } { + for i, mask := range requirement { + if satisfied[i]&mask != mask { + continue nextRequirement + } + } + return true + } + return false + }(); !ok { + return res, ogenerrors.ErrSecurityRequirementIsNotSatisfied + } + } + + stage = "SendRequest" + resp, err := c.cfg.Client.Do(r) + if err != nil { + return res, errors.Wrap(err, "do request") + } + defer resp.Body.Close() + + stage = "DecodeResponse" + result, err := decodeSessionValidateResponse(resp) + if err != nil { + return res, errors.Wrap(err, "decode response") + } + + return result, nil +} + // SetSubmissionCompleted invokes setSubmissionCompleted operation. // -// Retrieve map with ID. +// Called by maptest when a player completes the map. // // POST /submissions/{SubmissionID}/completed func (c *Client) SetSubmissionCompleted(ctx context.Context, params SetSubmissionCompletedParams) error { diff --git a/pkg/api/oas_handlers_gen.go b/pkg/api/oas_handlers_gen.go index 4c4e4f5..ede4e74 100644 --- a/pkg/api/oas_handlers_gen.go +++ b/pkg/api/oas_handlers_gen.go @@ -3986,9 +3986,549 @@ func (s *Server) handleReleaseSubmissionsRequest(args [0]string, argsEscaped boo } } +// handleSessionRolesRequest handles sessionRoles operation. +// +// Get list of roles for the current session. +// +// GET /session/roles +func (s *Server) handleSessionRolesRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { + statusWriter := &codeRecorder{ResponseWriter: w} + w = statusWriter + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("sessionRoles"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/session/roles"), + } + + // Start a span for this request. + ctx, span := s.cfg.Tracer.Start(r.Context(), SessionRolesOperation, + trace.WithAttributes(otelAttrs...), + serverSpanKind, + ) + defer span.End() + + // Add Labeler to context. + labeler := &Labeler{attrs: otelAttrs} + ctx = contextWithLabeler(ctx, labeler) + + // Run stopwatch. + startTime := time.Now() + defer func() { + elapsedDuration := time.Since(startTime) + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + code := statusWriter.status + if code != 0 { + codeAttr := semconv.HTTPResponseStatusCode(code) + attrs = append(attrs, codeAttr) + span.SetAttributes(codeAttr) + } + attrOpt := metric.WithAttributes(attrs...) + + // Increment request counter. + s.requests.Add(ctx, 1, attrOpt) + + // Use floating point division here for higher precision (instead of Millisecond method). + s.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), attrOpt) + }() + + var ( + recordError = func(stage string, err error) { + span.RecordError(err) + + // https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status + // Span Status MUST be left unset if HTTP status code was in the 1xx, 2xx or 3xx ranges, + // unless there was another error (e.g., network error receiving the response body; or 3xx codes with + // max redirects exceeded), in which case status MUST be set to Error. + code := statusWriter.status + if code >= 100 && code < 500 { + span.SetStatus(codes.Error, stage) + } + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + if code != 0 { + attrs = append(attrs, semconv.HTTPResponseStatusCode(code)) + } + + s.errors.Add(ctx, 1, metric.WithAttributes(attrs...)) + } + err error + opErrContext = ogenerrors.OperationContext{ + Name: SessionRolesOperation, + ID: "sessionRoles", + } + ) + { + type bitset = [1]uint8 + var satisfied bitset + { + sctx, ok, err := s.securityCookieAuth(ctx, SessionRolesOperation, r) + if err != nil { + err = &ogenerrors.SecurityError{ + OperationContext: opErrContext, + Security: "CookieAuth", + Err: err, + } + if encodeErr := encodeErrorResponse(s.h.NewError(ctx, err), w, span); encodeErr != nil { + defer recordError("Security:CookieAuth", err) + } + return + } + if ok { + satisfied[0] |= 1 << 0 + ctx = sctx + } + } + + if ok := func() bool { + nextRequirement: + for _, requirement := range []bitset{ + {0b00000001}, + } { + for i, mask := range requirement { + if satisfied[i]&mask != mask { + continue nextRequirement + } + } + return true + } + return false + }(); !ok { + err = &ogenerrors.SecurityError{ + OperationContext: opErrContext, + Err: ogenerrors.ErrSecurityRequirementIsNotSatisfied, + } + if encodeErr := encodeErrorResponse(s.h.NewError(ctx, err), w, span); encodeErr != nil { + defer recordError("Security", err) + } + return + } + } + + var response *Roles + if m := s.cfg.Middleware; m != nil { + mreq := middleware.Request{ + Context: ctx, + OperationName: SessionRolesOperation, + OperationSummary: "Get list of roles for the current session", + OperationID: "sessionRoles", + Body: nil, + Params: middleware.Parameters{}, + Raw: r, + } + + type ( + Request = struct{} + Params = struct{} + Response = *Roles + ) + response, err = middleware.HookMiddleware[ + Request, + Params, + Response, + ]( + m, + mreq, + nil, + func(ctx context.Context, request Request, params Params) (response Response, err error) { + response, err = s.h.SessionRoles(ctx) + return response, err + }, + ) + } else { + response, err = s.h.SessionRoles(ctx) + } + if err != nil { + if errRes, ok := errors.Into[*ErrorStatusCode](err); ok { + if err := encodeErrorResponse(errRes, w, span); err != nil { + defer recordError("Internal", err) + } + return + } + if errors.Is(err, ht.ErrNotImplemented) { + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + if err := encodeErrorResponse(s.h.NewError(ctx, err), w, span); err != nil { + defer recordError("Internal", err) + } + return + } + + if err := encodeSessionRolesResponse(response, w, span); err != nil { + defer recordError("EncodeResponse", err) + if !errors.Is(err, ht.ErrInternalServerErrorResponse) { + s.cfg.ErrorHandler(ctx, w, r, err) + } + return + } +} + +// handleSessionUserRequest handles sessionUser operation. +// +// Get information about the currently logged in user. +// +// GET /session/user +func (s *Server) handleSessionUserRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { + statusWriter := &codeRecorder{ResponseWriter: w} + w = statusWriter + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("sessionUser"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/session/user"), + } + + // Start a span for this request. + ctx, span := s.cfg.Tracer.Start(r.Context(), SessionUserOperation, + trace.WithAttributes(otelAttrs...), + serverSpanKind, + ) + defer span.End() + + // Add Labeler to context. + labeler := &Labeler{attrs: otelAttrs} + ctx = contextWithLabeler(ctx, labeler) + + // Run stopwatch. + startTime := time.Now() + defer func() { + elapsedDuration := time.Since(startTime) + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + code := statusWriter.status + if code != 0 { + codeAttr := semconv.HTTPResponseStatusCode(code) + attrs = append(attrs, codeAttr) + span.SetAttributes(codeAttr) + } + attrOpt := metric.WithAttributes(attrs...) + + // Increment request counter. + s.requests.Add(ctx, 1, attrOpt) + + // Use floating point division here for higher precision (instead of Millisecond method). + s.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), attrOpt) + }() + + var ( + recordError = func(stage string, err error) { + span.RecordError(err) + + // https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status + // Span Status MUST be left unset if HTTP status code was in the 1xx, 2xx or 3xx ranges, + // unless there was another error (e.g., network error receiving the response body; or 3xx codes with + // max redirects exceeded), in which case status MUST be set to Error. + code := statusWriter.status + if code >= 100 && code < 500 { + span.SetStatus(codes.Error, stage) + } + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + if code != 0 { + attrs = append(attrs, semconv.HTTPResponseStatusCode(code)) + } + + s.errors.Add(ctx, 1, metric.WithAttributes(attrs...)) + } + err error + opErrContext = ogenerrors.OperationContext{ + Name: SessionUserOperation, + ID: "sessionUser", + } + ) + { + type bitset = [1]uint8 + var satisfied bitset + { + sctx, ok, err := s.securityCookieAuth(ctx, SessionUserOperation, r) + if err != nil { + err = &ogenerrors.SecurityError{ + OperationContext: opErrContext, + Security: "CookieAuth", + Err: err, + } + if encodeErr := encodeErrorResponse(s.h.NewError(ctx, err), w, span); encodeErr != nil { + defer recordError("Security:CookieAuth", err) + } + return + } + if ok { + satisfied[0] |= 1 << 0 + ctx = sctx + } + } + + if ok := func() bool { + nextRequirement: + for _, requirement := range []bitset{ + {0b00000001}, + } { + for i, mask := range requirement { + if satisfied[i]&mask != mask { + continue nextRequirement + } + } + return true + } + return false + }(); !ok { + err = &ogenerrors.SecurityError{ + OperationContext: opErrContext, + Err: ogenerrors.ErrSecurityRequirementIsNotSatisfied, + } + if encodeErr := encodeErrorResponse(s.h.NewError(ctx, err), w, span); encodeErr != nil { + defer recordError("Security", err) + } + return + } + } + + var response *User + if m := s.cfg.Middleware; m != nil { + mreq := middleware.Request{ + Context: ctx, + OperationName: SessionUserOperation, + OperationSummary: "Get information about the currently logged in user", + OperationID: "sessionUser", + Body: nil, + Params: middleware.Parameters{}, + Raw: r, + } + + type ( + Request = struct{} + Params = struct{} + Response = *User + ) + response, err = middleware.HookMiddleware[ + Request, + Params, + Response, + ]( + m, + mreq, + nil, + func(ctx context.Context, request Request, params Params) (response Response, err error) { + response, err = s.h.SessionUser(ctx) + return response, err + }, + ) + } else { + response, err = s.h.SessionUser(ctx) + } + if err != nil { + if errRes, ok := errors.Into[*ErrorStatusCode](err); ok { + if err := encodeErrorResponse(errRes, w, span); err != nil { + defer recordError("Internal", err) + } + return + } + if errors.Is(err, ht.ErrNotImplemented) { + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + if err := encodeErrorResponse(s.h.NewError(ctx, err), w, span); err != nil { + defer recordError("Internal", err) + } + return + } + + if err := encodeSessionUserResponse(response, w, span); err != nil { + defer recordError("EncodeResponse", err) + if !errors.Is(err, ht.ErrInternalServerErrorResponse) { + s.cfg.ErrorHandler(ctx, w, r, err) + } + return + } +} + +// handleSessionValidateRequest handles sessionValidate operation. +// +// Ask if the current session is valid. +// +// GET /session/validate +func (s *Server) handleSessionValidateRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { + statusWriter := &codeRecorder{ResponseWriter: w} + w = statusWriter + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("sessionValidate"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/session/validate"), + } + + // Start a span for this request. + ctx, span := s.cfg.Tracer.Start(r.Context(), SessionValidateOperation, + trace.WithAttributes(otelAttrs...), + serverSpanKind, + ) + defer span.End() + + // Add Labeler to context. + labeler := &Labeler{attrs: otelAttrs} + ctx = contextWithLabeler(ctx, labeler) + + // Run stopwatch. + startTime := time.Now() + defer func() { + elapsedDuration := time.Since(startTime) + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + code := statusWriter.status + if code != 0 { + codeAttr := semconv.HTTPResponseStatusCode(code) + attrs = append(attrs, codeAttr) + span.SetAttributes(codeAttr) + } + attrOpt := metric.WithAttributes(attrs...) + + // Increment request counter. + s.requests.Add(ctx, 1, attrOpt) + + // Use floating point division here for higher precision (instead of Millisecond method). + s.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), attrOpt) + }() + + var ( + recordError = func(stage string, err error) { + span.RecordError(err) + + // https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status + // Span Status MUST be left unset if HTTP status code was in the 1xx, 2xx or 3xx ranges, + // unless there was another error (e.g., network error receiving the response body; or 3xx codes with + // max redirects exceeded), in which case status MUST be set to Error. + code := statusWriter.status + if code >= 100 && code < 500 { + span.SetStatus(codes.Error, stage) + } + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + if code != 0 { + attrs = append(attrs, semconv.HTTPResponseStatusCode(code)) + } + + s.errors.Add(ctx, 1, metric.WithAttributes(attrs...)) + } + err error + opErrContext = ogenerrors.OperationContext{ + Name: SessionValidateOperation, + ID: "sessionValidate", + } + ) + { + type bitset = [1]uint8 + var satisfied bitset + { + sctx, ok, err := s.securityCookieAuth(ctx, SessionValidateOperation, r) + if err != nil { + err = &ogenerrors.SecurityError{ + OperationContext: opErrContext, + Security: "CookieAuth", + Err: err, + } + if encodeErr := encodeErrorResponse(s.h.NewError(ctx, err), w, span); encodeErr != nil { + defer recordError("Security:CookieAuth", err) + } + return + } + if ok { + satisfied[0] |= 1 << 0 + ctx = sctx + } + } + + if ok := func() bool { + nextRequirement: + for _, requirement := range []bitset{ + {0b00000001}, + } { + for i, mask := range requirement { + if satisfied[i]&mask != mask { + continue nextRequirement + } + } + return true + } + return false + }(); !ok { + err = &ogenerrors.SecurityError{ + OperationContext: opErrContext, + Err: ogenerrors.ErrSecurityRequirementIsNotSatisfied, + } + if encodeErr := encodeErrorResponse(s.h.NewError(ctx, err), w, span); encodeErr != nil { + defer recordError("Security", err) + } + return + } + } + + var response bool + if m := s.cfg.Middleware; m != nil { + mreq := middleware.Request{ + Context: ctx, + OperationName: SessionValidateOperation, + OperationSummary: "Ask if the current session is valid", + OperationID: "sessionValidate", + Body: nil, + Params: middleware.Parameters{}, + Raw: r, + } + + type ( + Request = struct{} + Params = struct{} + Response = bool + ) + response, err = middleware.HookMiddleware[ + Request, + Params, + Response, + ]( + m, + mreq, + nil, + func(ctx context.Context, request Request, params Params) (response Response, err error) { + response, err = s.h.SessionValidate(ctx) + return response, err + }, + ) + } else { + response, err = s.h.SessionValidate(ctx) + } + if err != nil { + if errRes, ok := errors.Into[*ErrorStatusCode](err); ok { + if err := encodeErrorResponse(errRes, w, span); err != nil { + defer recordError("Internal", err) + } + return + } + if errors.Is(err, ht.ErrNotImplemented) { + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + if err := encodeErrorResponse(s.h.NewError(ctx, err), w, span); err != nil { + defer recordError("Internal", err) + } + return + } + + if err := encodeSessionValidateResponse(response, w, span); err != nil { + defer recordError("EncodeResponse", err) + if !errors.Is(err, ht.ErrInternalServerErrorResponse) { + s.cfg.ErrorHandler(ctx, w, r, err) + } + return + } +} + // handleSetSubmissionCompletedRequest handles setSubmissionCompleted operation. // -// Retrieve map with ID. +// Called by maptest when a player completes the map. // // POST /submissions/{SubmissionID}/completed func (s *Server) handleSetSubmissionCompletedRequest(args [1]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { @@ -4122,7 +4662,7 @@ func (s *Server) handleSetSubmissionCompletedRequest(args [1]string, argsEscaped mreq := middleware.Request{ Context: ctx, OperationName: SetSubmissionCompletedOperation, - OperationSummary: "Retrieve map with ID", + OperationSummary: "Called by maptest when a player completes the map", OperationID: "setSubmissionCompleted", Body: nil, Params: middleware.Parameters{ diff --git a/pkg/api/oas_json_gen.go b/pkg/api/oas_json_gen.go index 4abaa1e..6f209be 100644 --- a/pkg/api/oas_json_gen.go +++ b/pkg/api/oas_json_gen.go @@ -440,6 +440,102 @@ func (s *ReleaseInfo) UnmarshalJSON(data []byte) error { return s.Decode(d) } +// Encode implements json.Marshaler. +func (s *Roles) Encode(e *jx.Encoder) { + e.ObjStart() + s.encodeFields(e) + e.ObjEnd() +} + +// encodeFields encodes fields. +func (s *Roles) encodeFields(e *jx.Encoder) { + { + e.FieldStart("Roles") + e.Int32(s.Roles) + } +} + +var jsonFieldsNameOfRoles = [1]string{ + 0: "Roles", +} + +// Decode decodes Roles from json. +func (s *Roles) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode Roles to nil") + } + var requiredBitSet [1]uint8 + + if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { + switch string(k) { + case "Roles": + requiredBitSet[0] |= 1 << 0 + if err := func() error { + v, err := d.Int32() + s.Roles = int32(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"Roles\"") + } + default: + return d.Skip() + } + return nil + }); err != nil { + return errors.Wrap(err, "decode Roles") + } + // Validate required fields. + var failures []validate.FieldError + for i, mask := range [1]uint8{ + 0b00000001, + } { + if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { + // Mask only required fields and check equality to mask using XOR. + // + // If XOR result is not zero, result is not equal to expected, so some fields are missed. + // Bits of fields which would be set are actually bits of missed fields. + missed := bits.OnesCount8(result) + for bitN := 0; bitN < missed; bitN++ { + bitIdx := bits.TrailingZeros8(result) + fieldIdx := i*8 + bitIdx + var name string + if fieldIdx < len(jsonFieldsNameOfRoles) { + name = jsonFieldsNameOfRoles[fieldIdx] + } else { + name = strconv.Itoa(fieldIdx) + } + failures = append(failures, validate.FieldError{ + Name: name, + Error: validate.ErrFieldRequired, + }) + // Reset bit. + result &^= 1 << bitIdx + } + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s *Roles) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *Roles) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} + // Encode implements json.Marshaler. func (s *Script) Encode(e *jx.Encoder) { e.ObjStart() @@ -1786,3 +1882,133 @@ func (s *SubmissionCreate) UnmarshalJSON(data []byte) error { d := jx.DecodeBytes(data) return s.Decode(d) } + +// Encode implements json.Marshaler. +func (s *User) Encode(e *jx.Encoder) { + e.ObjStart() + s.encodeFields(e) + e.ObjEnd() +} + +// encodeFields encodes fields. +func (s *User) encodeFields(e *jx.Encoder) { + { + e.FieldStart("UserID") + e.Int64(s.UserID) + } + { + e.FieldStart("Username") + e.Str(s.Username) + } + { + e.FieldStart("AvatarURL") + e.Str(s.AvatarURL) + } +} + +var jsonFieldsNameOfUser = [3]string{ + 0: "UserID", + 1: "Username", + 2: "AvatarURL", +} + +// Decode decodes User from json. +func (s *User) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode User to nil") + } + var requiredBitSet [1]uint8 + + if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { + switch string(k) { + case "UserID": + requiredBitSet[0] |= 1 << 0 + if err := func() error { + v, err := d.Int64() + s.UserID = int64(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"UserID\"") + } + case "Username": + requiredBitSet[0] |= 1 << 1 + if err := func() error { + v, err := d.Str() + s.Username = string(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"Username\"") + } + case "AvatarURL": + requiredBitSet[0] |= 1 << 2 + if err := func() error { + v, err := d.Str() + s.AvatarURL = string(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"AvatarURL\"") + } + default: + return d.Skip() + } + return nil + }); err != nil { + return errors.Wrap(err, "decode User") + } + // Validate required fields. + var failures []validate.FieldError + for i, mask := range [1]uint8{ + 0b00000111, + } { + if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { + // Mask only required fields and check equality to mask using XOR. + // + // If XOR result is not zero, result is not equal to expected, so some fields are missed. + // Bits of fields which would be set are actually bits of missed fields. + missed := bits.OnesCount8(result) + for bitN := 0; bitN < missed; bitN++ { + bitIdx := bits.TrailingZeros8(result) + fieldIdx := i*8 + bitIdx + var name string + if fieldIdx < len(jsonFieldsNameOfUser) { + name = jsonFieldsNameOfUser[fieldIdx] + } else { + name = strconv.Itoa(fieldIdx) + } + failures = append(failures, validate.FieldError{ + Name: name, + Error: validate.ErrFieldRequired, + }) + // Reset bit. + result &^= 1 << bitIdx + } + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s *User) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *User) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} diff --git a/pkg/api/oas_operations_gen.go b/pkg/api/oas_operations_gen.go index 4fa44c7..604c4ff 100644 --- a/pkg/api/oas_operations_gen.go +++ b/pkg/api/oas_operations_gen.go @@ -26,6 +26,9 @@ const ( ListScriptsOperation OperationName = "ListScripts" ListSubmissionsOperation OperationName = "ListSubmissions" ReleaseSubmissionsOperation OperationName = "ReleaseSubmissions" + SessionRolesOperation OperationName = "SessionRoles" + SessionUserOperation OperationName = "SessionUser" + SessionValidateOperation OperationName = "SessionValidate" SetSubmissionCompletedOperation OperationName = "SetSubmissionCompleted" UpdateScriptOperation OperationName = "UpdateScript" UpdateScriptPolicyOperation OperationName = "UpdateScriptPolicy" diff --git a/pkg/api/oas_response_decoders_gen.go b/pkg/api/oas_response_decoders_gen.go index 490c9ee..faa01e6 100644 --- a/pkg/api/oas_response_decoders_gen.go +++ b/pkg/api/oas_response_decoders_gen.go @@ -1452,6 +1452,266 @@ func decodeReleaseSubmissionsResponse(resp *http.Response) (res *ReleaseSubmissi return res, errors.Wrap(defRes, "error") } +func decodeSessionRolesResponse(resp *http.Response) (res *Roles, _ error) { + switch resp.StatusCode { + case 200: + // Code 200. + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response Roles + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return &response, nil + default: + return res, validate.InvalidContentType(ct) + } + } + // Convenient error response. + defRes, err := func() (res *ErrorStatusCode, err error) { + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response Error + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return &ErrorStatusCode{ + StatusCode: resp.StatusCode, + Response: response, + }, nil + default: + return res, validate.InvalidContentType(ct) + } + }() + if err != nil { + return res, errors.Wrapf(err, "default (code %d)", resp.StatusCode) + } + return res, errors.Wrap(defRes, "error") +} + +func decodeSessionUserResponse(resp *http.Response) (res *User, _ error) { + switch resp.StatusCode { + case 200: + // Code 200. + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response User + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } + return &response, nil + default: + return res, validate.InvalidContentType(ct) + } + } + // Convenient error response. + defRes, err := func() (res *ErrorStatusCode, err error) { + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response Error + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return &ErrorStatusCode{ + StatusCode: resp.StatusCode, + Response: response, + }, nil + default: + return res, validate.InvalidContentType(ct) + } + }() + if err != nil { + return res, errors.Wrapf(err, "default (code %d)", resp.StatusCode) + } + return res, errors.Wrap(defRes, "error") +} + +func decodeSessionValidateResponse(resp *http.Response) (res bool, _ error) { + switch resp.StatusCode { + case 200: + // Code 200. + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response bool + if err := func() error { + v, err := d.Bool() + response = bool(v) + if err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return response, nil + default: + return res, validate.InvalidContentType(ct) + } + } + // Convenient error response. + defRes, err := func() (res *ErrorStatusCode, err error) { + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response Error + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return &ErrorStatusCode{ + StatusCode: resp.StatusCode, + Response: response, + }, nil + default: + return res, validate.InvalidContentType(ct) + } + }() + if err != nil { + return res, errors.Wrapf(err, "default (code %d)", resp.StatusCode) + } + return res, errors.Wrap(defRes, "error") +} + func decodeSetSubmissionCompletedResponse(resp *http.Response) (res *SetSubmissionCompletedNoContent, _ error) { switch resp.StatusCode { case 204: diff --git a/pkg/api/oas_response_encoders_gen.go b/pkg/api/oas_response_encoders_gen.go index 99f0cb4..ad080a4 100644 --- a/pkg/api/oas_response_encoders_gen.go +++ b/pkg/api/oas_response_encoders_gen.go @@ -228,6 +228,48 @@ func encodeReleaseSubmissionsResponse(response *ReleaseSubmissionsCreated, w htt return nil } +func encodeSessionRolesResponse(response *Roles, w http.ResponseWriter, span trace.Span) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + span.SetStatus(codes.Ok, http.StatusText(200)) + + e := new(jx.Encoder) + response.Encode(e) + if _, err := e.WriteTo(w); err != nil { + return errors.Wrap(err, "write") + } + + return nil +} + +func encodeSessionUserResponse(response *User, w http.ResponseWriter, span trace.Span) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + span.SetStatus(codes.Ok, http.StatusText(200)) + + e := new(jx.Encoder) + response.Encode(e) + if _, err := e.WriteTo(w); err != nil { + return errors.Wrap(err, "write") + } + + return nil +} + +func encodeSessionValidateResponse(response bool, w http.ResponseWriter, span trace.Span) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + span.SetStatus(codes.Ok, http.StatusText(200)) + + e := new(jx.Encoder) + e.Bool(response) + if _, err := e.WriteTo(w); err != nil { + return errors.Wrap(err, "write") + } + + return nil +} + func encodeSetSubmissionCompletedResponse(response *SetSubmissionCompletedNoContent, w http.ResponseWriter, span trace.Span) error { w.WriteHeader(204) span.SetStatus(codes.Ok, http.StatusText(204)) diff --git a/pkg/api/oas_router_gen.go b/pkg/api/oas_router_gen.go index f25dd45..554e110 100644 --- a/pkg/api/oas_router_gen.go +++ b/pkg/api/oas_router_gen.go @@ -231,6 +231,80 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } + case 'e': // Prefix: "ession/" + + if l := len("ession/"); len(elem) >= l && elem[0:l] == "ession/" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + break + } + switch elem[0] { + case 'r': // Prefix: "roles" + + if l := len("roles"); len(elem) >= l && elem[0:l] == "roles" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "GET": + s.handleSessionRolesRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "GET") + } + + return + } + + case 'u': // Prefix: "user" + + if l := len("user"); len(elem) >= l && elem[0:l] == "user" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "GET": + s.handleSessionUserRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "GET") + } + + return + } + + case 'v': // Prefix: "validate" + + if l := len("validate"); len(elem) >= l && elem[0:l] == "validate" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "GET": + s.handleSessionValidateRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "GET") + } + + return + } + + } + case 'u': // Prefix: "ubmissions" if l := len("ubmissions"); len(elem) >= l && elem[0:l] == "ubmissions" { @@ -886,6 +960,92 @@ func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) { } + case 'e': // Prefix: "ession/" + + if l := len("ession/"); len(elem) >= l && elem[0:l] == "ession/" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + break + } + switch elem[0] { + case 'r': // Prefix: "roles" + + if l := len("roles"); len(elem) >= l && elem[0:l] == "roles" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch method { + case "GET": + r.name = SessionRolesOperation + r.summary = "Get list of roles for the current session" + r.operationID = "sessionRoles" + r.pathPattern = "/session/roles" + r.args = args + r.count = 0 + return r, true + default: + return + } + } + + case 'u': // Prefix: "user" + + if l := len("user"); len(elem) >= l && elem[0:l] == "user" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch method { + case "GET": + r.name = SessionUserOperation + r.summary = "Get information about the currently logged in user" + r.operationID = "sessionUser" + r.pathPattern = "/session/user" + r.args = args + r.count = 0 + return r, true + default: + return + } + } + + case 'v': // Prefix: "validate" + + if l := len("validate"); len(elem) >= l && elem[0:l] == "validate" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch method { + case "GET": + r.name = SessionValidateOperation + r.summary = "Ask if the current session is valid" + r.operationID = "sessionValidate" + r.pathPattern = "/session/validate" + r.args = args + r.count = 0 + return r, true + default: + return + } + } + + } + case 'u': // Prefix: "ubmissions" if l := len("ubmissions"); len(elem) >= l && elem[0:l] == "ubmissions" { @@ -974,7 +1134,7 @@ func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) { switch method { case "POST": r.name = SetSubmissionCompletedOperation - r.summary = "Retrieve map with ID" + r.summary = "Called by maptest when a player completes the map" r.operationID = "setSubmissionCompleted" r.pathPattern = "/submissions/{SubmissionID}/completed" r.args = args diff --git a/pkg/api/oas_schemas_gen.go b/pkg/api/oas_schemas_gen.go index 30ef7e4..5523081 100644 --- a/pkg/api/oas_schemas_gen.go +++ b/pkg/api/oas_schemas_gen.go @@ -290,6 +290,21 @@ func (s *ReleaseInfo) SetDate(val time.Time) { // ReleaseSubmissionsCreated is response for ReleaseSubmissions operation. type ReleaseSubmissionsCreated struct{} +// Ref: #/components/schemas/Roles +type Roles struct { + Roles int32 `json:"Roles"` +} + +// GetRoles returns the value of Roles. +func (s *Roles) GetRoles() int32 { + return s.Roles +} + +// SetRoles sets the value of Roles. +func (s *Roles) SetRoles(val int32) { + s.Roles = val +} + // Ref: #/components/schemas/Script type Script struct { ID int64 `json:"ID"` @@ -795,3 +810,40 @@ type UpdateScriptPolicyNoContent struct{} // UpdateSubmissionModelNoContent is response for UpdateSubmissionModel operation. type UpdateSubmissionModelNoContent struct{} + +// Ref: #/components/schemas/User +type User struct { + UserID int64 `json:"UserID"` + Username string `json:"Username"` + AvatarURL string `json:"AvatarURL"` +} + +// GetUserID returns the value of UserID. +func (s *User) GetUserID() int64 { + return s.UserID +} + +// GetUsername returns the value of Username. +func (s *User) GetUsername() string { + return s.Username +} + +// GetAvatarURL returns the value of AvatarURL. +func (s *User) GetAvatarURL() string { + return s.AvatarURL +} + +// SetUserID sets the value of UserID. +func (s *User) SetUserID(val int64) { + s.UserID = val +} + +// SetUsername sets the value of Username. +func (s *User) SetUsername(val string) { + s.Username = val +} + +// SetAvatarURL sets the value of AvatarURL. +func (s *User) SetAvatarURL(val string) { + s.AvatarURL = val +} diff --git a/pkg/api/oas_server_gen.go b/pkg/api/oas_server_gen.go index d0b85d6..505264f 100644 --- a/pkg/api/oas_server_gen.go +++ b/pkg/api/oas_server_gen.go @@ -128,9 +128,27 @@ type Handler interface { // // POST /release-submissions ReleaseSubmissions(ctx context.Context, req []ReleaseInfo) error + // SessionRoles implements sessionRoles operation. + // + // Get list of roles for the current session. + // + // GET /session/roles + SessionRoles(ctx context.Context) (*Roles, error) + // SessionUser implements sessionUser operation. + // + // Get information about the currently logged in user. + // + // GET /session/user + SessionUser(ctx context.Context) (*User, error) + // SessionValidate implements sessionValidate operation. + // + // Ask if the current session is valid. + // + // GET /session/validate + SessionValidate(ctx context.Context) (bool, error) // SetSubmissionCompleted implements setSubmissionCompleted operation. // - // Retrieve map with ID. + // Called by maptest when a player completes the map. // // POST /submissions/{SubmissionID}/completed SetSubmissionCompleted(ctx context.Context, params SetSubmissionCompletedParams) error diff --git a/pkg/api/oas_unimplemented_gen.go b/pkg/api/oas_unimplemented_gen.go index b6cb472..735853f 100644 --- a/pkg/api/oas_unimplemented_gen.go +++ b/pkg/api/oas_unimplemented_gen.go @@ -193,9 +193,36 @@ func (UnimplementedHandler) ReleaseSubmissions(ctx context.Context, req []Releas return ht.ErrNotImplemented } +// SessionRoles implements sessionRoles operation. +// +// Get list of roles for the current session. +// +// GET /session/roles +func (UnimplementedHandler) SessionRoles(ctx context.Context) (r *Roles, _ error) { + return r, ht.ErrNotImplemented +} + +// SessionUser implements sessionUser operation. +// +// Get information about the currently logged in user. +// +// GET /session/user +func (UnimplementedHandler) SessionUser(ctx context.Context) (r *User, _ error) { + return r, ht.ErrNotImplemented +} + +// SessionValidate implements sessionValidate operation. +// +// Ask if the current session is valid. +// +// GET /session/validate +func (UnimplementedHandler) SessionValidate(ctx context.Context) (r bool, _ error) { + return r, ht.ErrNotImplemented +} + // SetSubmissionCompleted implements setSubmissionCompleted operation. // -// Retrieve map with ID. +// Called by maptest when a player completes the map. // // POST /submissions/{SubmissionID}/completed func (UnimplementedHandler) SetSubmissionCompleted(ctx context.Context, params SetSubmissionCompletedParams) error { diff --git a/pkg/api/oas_validators_gen.go b/pkg/api/oas_validators_gen.go index 91e71c1..b8c4978 100644 --- a/pkg/api/oas_validators_gen.go +++ b/pkg/api/oas_validators_gen.go @@ -321,3 +321,53 @@ func (s *SubmissionCreate) Validate() error { } return nil } + +func (s *User) Validate() error { + if s == nil { + return validate.ErrNilPointer + } + + var failures []validate.FieldError + if err := func() error { + if err := (validate.String{ + MinLength: 0, + MinLengthSet: false, + MaxLength: 128, + MaxLengthSet: true, + Email: false, + Hostname: false, + Regex: nil, + }).Validate(string(s.Username)); err != nil { + return errors.Wrap(err, "string") + } + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: "Username", + Error: err, + }) + } + if err := func() error { + if err := (validate.String{ + MinLength: 0, + MinLengthSet: false, + MaxLength: 256, + MaxLengthSet: true, + Email: false, + Hostname: false, + Regex: nil, + }).Validate(string(s.AvatarURL)); err != nil { + return errors.Wrap(err, "string") + } + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: "AvatarURL", + Error: err, + }) + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + return nil +} -- 2.47.1 From 977d1d20c2700533099d087bfc9505e6eab1d43f Mon Sep 17 00:00:00 2001 From: Quaternions <krakow20@gmail.com> Date: Tue, 25 Mar 2025 17:36:36 -0700 Subject: [PATCH 3/7] submissions: rename UserInfo to UserInfoHandle --- pkg/service/script_policy.go | 8 ++++---- pkg/service/scripts.go | 8 ++++---- pkg/service/security.go | 18 +++++++++--------- pkg/service/submissions.go | 24 ++++++++++++------------ 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/pkg/service/script_policy.go b/pkg/service/script_policy.go index 20b070b..044a925 100644 --- a/pkg/service/script_policy.go +++ b/pkg/service/script_policy.go @@ -14,7 +14,7 @@ import ( // // POST /script-policy func (svc *Service) CreateScriptPolicy(ctx context.Context, req *api.ScriptPolicyCreate) (*api.ID, error) { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return nil, ErrUserInfo } @@ -99,7 +99,7 @@ func (svc *Service) ListScriptPolicy(ctx context.Context, params api.ListScriptP // // DELETE /script-policy/{ScriptPolicyID} func (svc *Service) DeleteScriptPolicy(ctx context.Context, params api.DeleteScriptPolicyParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -121,7 +121,7 @@ func (svc *Service) DeleteScriptPolicy(ctx context.Context, params api.DeleteScr // // GET /script-policy/{ScriptPolicyID} func (svc *Service) GetScriptPolicy(ctx context.Context, params api.GetScriptPolicyParams) (*api.ScriptPolicy, error) { - _, ok := ctx.Value("UserInfo").(UserInfo) + _, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return nil, ErrUserInfo } @@ -147,7 +147,7 @@ func (svc *Service) GetScriptPolicy(ctx context.Context, params api.GetScriptPol // // POST /script-policy/{ScriptPolicyID} func (svc *Service) UpdateScriptPolicy(ctx context.Context, req *api.ScriptPolicyUpdate, params api.UpdateScriptPolicyParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } diff --git a/pkg/service/scripts.go b/pkg/service/scripts.go index 0be1783..2a67e84 100644 --- a/pkg/service/scripts.go +++ b/pkg/service/scripts.go @@ -14,7 +14,7 @@ import ( // // POST /scripts func (svc *Service) CreateScript(ctx context.Context, req *api.ScriptCreate) (*api.ID, error) { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return nil, ErrUserInfo } @@ -95,7 +95,7 @@ func (svc *Service) ListScripts(ctx context.Context, params api.ListScriptsParam // // DELETE /scripts/{ScriptID} func (svc *Service) DeleteScript(ctx context.Context, params api.DeleteScriptParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -117,7 +117,7 @@ func (svc *Service) DeleteScript(ctx context.Context, params api.DeleteScriptPar // // GET /scripts/{ScriptID} func (svc *Service) GetScript(ctx context.Context, params api.GetScriptParams) (*api.Script, error) { - _, ok := ctx.Value("UserInfo").(UserInfo) + _, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return nil, ErrUserInfo } @@ -144,7 +144,7 @@ func (svc *Service) GetScript(ctx context.Context, params api.GetScriptParams) ( // // PATCH /scripts/{ScriptID} func (svc *Service) UpdateScript(ctx context.Context, req *api.ScriptUpdate, params api.UpdateScriptParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } diff --git a/pkg/service/security.go b/pkg/service/security.go index e48a574..ac70844 100644 --- a/pkg/service/security.go +++ b/pkg/service/security.go @@ -24,14 +24,14 @@ var ( RoleMapCouncil Role = 64 ) -type UserInfo struct { +type UserInfoHandle struct { // Would love to know a better way to do this svc *SecurityHandler ctx *context.Context sessionId string } -func (usr UserInfo) GetUserID() (uint64, error) { +func (usr UserInfoHandle) GetUserID() (uint64, error) { session, err := usr.svc.Client.GetSessionUser(*usr.ctx, &auth.IdMessage{ SessionID: usr.sessionId, }) @@ -40,14 +40,14 @@ func (usr UserInfo) GetUserID() (uint64, error) { } return session.UserID, nil } -func (usr UserInfo) IsSubmitter(submitter uint64) (bool, error) { +func (usr UserInfoHandle) IsSubmitter(submitter uint64) (bool, error) { userId, err := usr.GetUserID() if err != nil { return false, err } return userId == submitter, nil } -func (usr UserInfo) hasRole(role Role) (bool, error) { +func (usr UserInfoHandle) hasRole(role Role) (bool, error) { roles, err := usr.svc.Client.GetGroupRole(*usr.ctx, &auth.IdMessage{ SessionID: usr.sessionId, }) @@ -66,17 +66,17 @@ func (usr UserInfo) hasRole(role Role) (bool, error) { // RoleThumbnail // RoleMapDownload -func (usr UserInfo) HasRoleSubmissionRelease() (bool, error) { +func (usr UserInfoHandle) HasRoleSubmissionRelease() (bool, error) { return usr.hasRole(RoleMapAdmin) } -func (usr UserInfo) HasRoleSubmissionReview() (bool, error) { +func (usr UserInfoHandle) HasRoleSubmissionReview() (bool, error) { return usr.hasRole(RoleMapCouncil) } -func (usr UserInfo) HasRoleScriptWrite() (bool, error) { +func (usr UserInfoHandle) HasRoleScriptWrite() (bool, error) { return usr.hasRole(RoleQuat) } /// Not implemented -func (usr UserInfo) HasRoleMaptest() (bool, error) { +func (usr UserInfoHandle) HasRoleMaptest() (bool, error) { println("HasRoleMaptest is not implemented!") return false, nil } @@ -101,7 +101,7 @@ func (svc SecurityHandler) HandleCookieAuth(ctx context.Context, operationName a return nil, ErrInvalidSession } - newCtx := context.WithValue(ctx, "UserInfo", UserInfo{ + newCtx := context.WithValue(ctx, "UserInfo", UserInfoHandle{ svc: &svc, ctx: &ctx, sessionId: sessionId, diff --git a/pkg/service/submissions.go b/pkg/service/submissions.go index 3ca7029..97eaeb6 100644 --- a/pkg/service/submissions.go +++ b/pkg/service/submissions.go @@ -41,7 +41,7 @@ var ( // POST /submissions func (svc *Service) CreateSubmission(ctx context.Context, request *api.SubmissionCreate) (*api.ID, error) { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return nil, ErrUserInfo } @@ -205,7 +205,7 @@ func (svc *Service) ListSubmissions(ctx context.Context, params api.ListSubmissi // // POST /submissions/{SubmissionID}/completed func (svc *Service) SetSubmissionCompleted(ctx context.Context, params api.SetSubmissionCompletedParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -230,7 +230,7 @@ func (svc *Service) SetSubmissionCompleted(ctx context.Context, params api.SetSu // // POST /submissions/{SubmissionID}/model func (svc *Service) UpdateSubmissionModel(ctx context.Context, params api.UpdateSubmissionModelParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -265,7 +265,7 @@ func (svc *Service) UpdateSubmissionModel(ctx context.Context, params api.Update // // POST /submissions/{SubmissionID}/status/reject func (svc *Service) ActionSubmissionReject(ctx context.Context, params api.ActionSubmissionRejectParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -291,7 +291,7 @@ func (svc *Service) ActionSubmissionReject(ctx context.Context, params api.Actio // // POST /submissions/{SubmissionID}/status/request-changes func (svc *Service) ActionSubmissionRequestChanges(ctx context.Context, params api.ActionSubmissionRequestChangesParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -317,7 +317,7 @@ func (svc *Service) ActionSubmissionRequestChanges(ctx context.Context, params a // // POST /submissions/{SubmissionID}/status/revoke func (svc *Service) ActionSubmissionRevoke(ctx context.Context, params api.ActionSubmissionRevokeParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -349,7 +349,7 @@ func (svc *Service) ActionSubmissionRevoke(ctx context.Context, params api.Actio // // POST /submissions/{SubmissionID}/status/submit func (svc *Service) ActionSubmissionSubmit(ctx context.Context, params api.ActionSubmissionSubmitParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -381,7 +381,7 @@ func (svc *Service) ActionSubmissionSubmit(ctx context.Context, params api.Actio // // POST /submissions/{SubmissionID}/status/trigger-upload func (svc *Service) ActionSubmissionTriggerUpload(ctx context.Context, params api.ActionSubmissionTriggerUploadParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -446,7 +446,7 @@ func (svc *Service) ActionSubmissionTriggerUpload(ctx context.Context, params ap // // POST /submissions/{SubmissionID}/status/reset-uploading func (svc *Service) ActionSubmissionValidated(ctx context.Context, params api.ActionSubmissionValidatedParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -482,7 +482,7 @@ func (svc *Service) ActionSubmissionValidated(ctx context.Context, params api.Ac // // POST /submissions/{SubmissionID}/status/trigger-validate func (svc *Service) ActionSubmissionTriggerValidate(ctx context.Context, params api.ActionSubmissionTriggerValidateParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -542,7 +542,7 @@ func (svc *Service) ActionSubmissionTriggerValidate(ctx context.Context, params // // POST /submissions/{SubmissionID}/status/reset-validating func (svc *Service) ActionSubmissionAccepted(ctx context.Context, params api.ActionSubmissionAcceptedParams) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } @@ -579,7 +579,7 @@ func (svc *Service) ActionSubmissionAccepted(ctx context.Context, params api.Act // // POST /release-submissions func (svc *Service) ReleaseSubmissions(ctx context.Context, request []api.ReleaseInfo) error { - userInfo, ok := ctx.Value("UserInfo").(UserInfo) + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) if !ok { return ErrUserInfo } -- 2.47.1 From 783d0e843ce57fcfc4133049746ec2d7fb238020 Mon Sep 17 00:00:00 2001 From: Quaternions <krakow20@gmail.com> Date: Tue, 25 Mar 2025 17:38:19 -0700 Subject: [PATCH 4/7] submissions: refactor roles --- pkg/service/security.go | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/pkg/service/security.go b/pkg/service/security.go index ac70844..130f5b0 100644 --- a/pkg/service/security.go +++ b/pkg/service/security.go @@ -14,14 +14,31 @@ var ( ErrInvalidSession = errors.New("Session invalid") ) +// Submissions roles bitflag +type Roles int32 +var ( + RolesScriptWrite Roles = 8 + RolesSubmissionPublish Roles = 4 + RolesSubmissionReview Roles = 2 + RolesMapDownload Roles = 1 + RolesEmpty Roles = 0 +) + +// StrafesNET group roles type Role int32 var ( // has ScriptWrite RoleQuat Role = 240 + RolesQuat Roles = RolesScriptWrite|RolesSubmissionPublish|RolesSubmissionReview|RolesMapDownload // has SubmissionPublish RoleMapAdmin Role = 128 + RolesMapAdmin Roles = RolesSubmissionPublish|RolesSubmissionReview|RolesMapDownload // has SubmissionReview RoleMapCouncil Role = 64 + RolesMapCouncil Roles = RolesSubmissionReview|RolesMapDownload + // access to downloading maps + RoleMapAccess Role = 32 + RolesMapAccess Roles = RolesMapDownload ) type UserInfoHandle struct { @@ -62,7 +79,31 @@ func (usr UserInfoHandle) hasRole(role Role) (bool, error) { } return false, nil } +func (usr UserInfoHandle) GetRoles() (Roles, error) { + roles, err := usr.svc.Client.GetGroupRole(*usr.ctx, &auth.IdMessage{ + SessionID: usr.sessionId, + }) + var rolesBitflag = RolesEmpty; + if err != nil { + return rolesBitflag, err + } + + // map roles into bitflag + for _, r := range roles.Roles { + switch Role(r.Rank){ + case RoleQuat: + rolesBitflag|=RolesQuat; + case RoleMapAdmin: + rolesBitflag|=RolesMapAdmin; + case RoleMapCouncil: + rolesBitflag|=RolesMapCouncil; + case RoleMapAccess: + rolesBitflag|=RolesMapAccess; + } + } + return rolesBitflag, nil +} // RoleThumbnail // RoleMapDownload -- 2.47.1 From 7213948a26ce92fa15d8bb316394c0a707d77c2c Mon Sep 17 00:00:00 2001 From: Quaternions <krakow20@gmail.com> Date: Tue, 25 Mar 2025 17:38:31 -0700 Subject: [PATCH 5/7] submissions: add UserInfoHandle.GetUserInfo function --- pkg/service/security.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pkg/service/security.go b/pkg/service/security.go index 130f5b0..ade325e 100644 --- a/pkg/service/security.go +++ b/pkg/service/security.go @@ -47,7 +47,24 @@ type UserInfoHandle struct { ctx *context.Context sessionId string } +type UserInfo struct { + UserID uint64 + Username string + AvatarURL string +} +func (usr UserInfoHandle) GetUserInfo() (userInfo UserInfo, err error) { + session, err := usr.svc.Client.GetSessionUser(*usr.ctx, &auth.IdMessage{ + SessionID: usr.sessionId, + }) + if err != nil { + return userInfo, err + } + userInfo.UserID = session.UserID + userInfo.Username = session.Username + userInfo.AvatarURL = session.AvatarURL + return userInfo, nil +} func (usr UserInfoHandle) GetUserID() (uint64, error) { session, err := usr.svc.Client.GetSessionUser(*usr.ctx, &auth.IdMessage{ SessionID: usr.sessionId, -- 2.47.1 From 1feca92f7d863ae2e7cc113fa32387e2eac39bb3 Mon Sep 17 00:00:00 2001 From: Quaternions <krakow20@gmail.com> Date: Tue, 25 Mar 2025 17:43:51 -0700 Subject: [PATCH 6/7] submissions: add UserInfoHandle.Validate --- pkg/service/security.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pkg/service/security.go b/pkg/service/security.go index ade325e..3706da9 100644 --- a/pkg/service/security.go +++ b/pkg/service/security.go @@ -74,6 +74,15 @@ func (usr UserInfoHandle) GetUserID() (uint64, error) { } return session.UserID, nil } +func (usr UserInfoHandle) Validate() (bool, error) { + validate, err := usr.svc.Client.ValidateSession(*usr.ctx, &auth.IdMessage{ + SessionID: usr.sessionId, + }) + if err != nil { + return false, err + } + return validate.Valid, nil +} func (usr UserInfoHandle) IsSubmitter(submitter uint64) (bool, error) { userId, err := usr.GetUserID() if err != nil { -- 2.47.1 From 1af7d7e94165fa3d25328ff95d26ca85847aac47 Mon Sep 17 00:00:00 2001 From: Quaternions <krakow20@gmail.com> Date: Tue, 25 Mar 2025 17:37:33 -0700 Subject: [PATCH 7/7] submissions: implement session endpoints --- pkg/service/session.go | 68 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 pkg/service/session.go diff --git a/pkg/service/session.go b/pkg/service/session.go new file mode 100644 index 0000000..51fd5bd --- /dev/null +++ b/pkg/service/session.go @@ -0,0 +1,68 @@ +package service + +import ( + "context" + + "git.itzana.me/strafesnet/maps-service/pkg/api" +) + +// SessionRoles implements getSessionRoles operation. +// +// Get bitflags of permissions the currently logged in user has. +// +// GET /session/roles +func (svc *Service) SessionRoles(ctx context.Context) (*api.Roles, error) { + userInfo, ok := ctx.Value("UserInfo").(UserInfoHandle) + if !ok { + return nil, ErrUserInfo + } + + roles, err := userInfo.GetRoles(); + if err != nil { + return nil, err + } + + return &api.Roles{Roles: int32(roles)}, nil +} + +// SessionUser implements sessionUser operation. +// +// Get information about the currently logged in user. +// +// GET /session/roles +func (svc *Service) SessionUser(ctx context.Context) (*api.User, error) { + userInfoHandle, ok := ctx.Value("UserInfo").(UserInfoHandle) + if !ok { + return nil, ErrUserInfo + } + + userInfo, err := userInfoHandle.GetUserInfo(); + if err != nil { + return nil, err + } + + return &api.User{ + UserID:int64(userInfo.UserID), + Username:userInfo.Username, + AvatarURL:userInfo.AvatarURL, + }, nil +} + +// SessionUser implements sessionUser operation. +// +// Get information about the currently logged in user. +// +// GET /session/roles +func (svc *Service) SessionValidate(ctx context.Context) (bool, error) { + userInfoHandle, ok := ctx.Value("UserInfo").(UserInfoHandle) + if !ok { + return false, ErrUserInfo + } + + valid, err := userInfoHandle.Validate(); + if err != nil { + return false, err + } + + return valid, nil +} -- 2.47.1