package middleware import ( "context" "net/http" "regexp" "strings" "go.b0esche.cloud/backend/internal/audit" "go.b0esche.cloud/backend/internal/database" "go.b0esche.cloud/backend/internal/errors" "go.b0esche.cloud/backend/internal/org" "go.b0esche.cloud/backend/internal/permission" "go.b0esche.cloud/backend/pkg/jwt" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/google/uuid" ) var RequestID = middleware.RequestID var Logger = middleware.Logger var Recoverer = middleware.Recoverer // CORS middleware - accepts allowedOrigins comma-separated string func CORS(allowedOrigins string) func(http.Handler) http.Handler { allowedList, allowAll := compileAllowedOrigins(allowedOrigins) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin != "" && isOriginAllowed(origin, allowedList) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Add("Vary", "Origin") w.Header().Set("Access-Control-Allow-Credentials", "true") } else if allowAll { w.Header().Set("Access-Control-Allow-Origin", "*") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") w.Header().Set("Access-Control-Expose-Headers", "Content-Length, Content-Type, Content-Disposition") w.Header().Set("Access-Control-Max-Age", "3600") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } } func compileAllowedOrigins(origins string) ([]string, bool) { var allowed []string allowAll := false for _, origin := range strings.Split(origins, ",") { trimmed := strings.TrimSpace(origin) if trimmed == "" { continue } if trimmed == "*" { allowAll = true } allowed = append(allowed, trimmed) } if len(allowed) == 0 && !allowAll { allowAll = true } return allowed, allowAll } func isOriginAllowed(origin string, allowed []string) bool { if origin == "" { return false } for _, pattern := range allowed { if originMatches(origin, pattern) { return true } } return false } func originMatches(origin, pattern string) bool { if pattern == "*" { return true } if !strings.Contains(pattern, "*") { return strings.EqualFold(origin, pattern) } regexPattern := "(?i)^" + regexp.QuoteMeta(pattern) + "$" regexPattern = strings.ReplaceAll(regexPattern, "\\*", ".*") matched, err := regexp.MatchString(regexPattern, origin) return err == nil && matched } // TODO: Implement rate limiter var RateLimit = func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Basic rate limiting logic here next.ServeHTTP(w, r) }) } type ContextKey string const ( UserKey ContextKey = "user" SessionKey ContextKey = "session" TokenKey ContextKey = "token" OrgKey ContextKey = "org" ) // GetUserID retrieves the user ID from the request context func GetUserID(ctx context.Context) (string, bool) { userID, ok := ctx.Value(UserKey).(string) return userID, ok } // GetSession retrieves the session from the request context func GetSession(ctx context.Context) (*database.Session, bool) { session, ok := ctx.Value(SessionKey).(*database.Session) return session, ok } // GetToken retrieves the JWT token from the request context func GetToken(ctx context.Context) (string, bool) { token, ok := ctx.Value(TokenKey).(string) return token, ok } // Auth middleware func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") var tokenString string if strings.HasPrefix(authHeader, "Bearer ") { tokenString = strings.TrimPrefix(authHeader, "Bearer ") } else { // Fallback to query parameter token (for viewers that cannot set headers) qToken := r.URL.Query().Get("token") if qToken == "" { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } tokenString = qToken } claims, session, err := jwtManager.ValidateWithSession(r.Context(), tokenString, db) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), UserKey, claims.UserID) ctx = context.WithValue(ctx, SessionKey, session) ctx = context.WithValue(ctx, TokenKey, tokenString) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Org middleware func Org(db *database.DB, auditLogger *audit.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userIDStr := r.Context().Value(UserKey).(string) userID, _ := uuid.Parse(userIDStr) orgIDStr := r.Header.Get("X-Org-ID") if orgIDStr == "" { orgIDStr = chi.URLParam(r, "orgId") } orgID, err := uuid.Parse(orgIDStr) if err != nil { errors.WriteError(w, errors.CodeInvalidArgument, "Invalid org ID", http.StatusBadRequest) return } _, err = org.CheckMembership(r.Context(), db, userID, orgID) if err != nil { auditLogger.Log(r.Context(), audit.Entry{ UserID: &userID, Action: "org_access", Success: false, Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()}, }) errors.LogError(r, err, "Org access denied") errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden) return } _, err = org.CheckMembership(r.Context(), db, userID, orgID) if err != nil { auditLogger.Log(r.Context(), audit.Entry{ UserID: &userID, Action: "org_access", Success: false, Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()}, }) errors.LogError(r, err, "Org access denied") errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden) return } ctx := context.WithValue(r.Context(), OrgKey, orgID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Permission middleware func Permission(db *database.DB, auditLogger *audit.Logger, perm permission.Permission) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userIDStr := r.Context().Value(UserKey).(string) userID, _ := uuid.Parse(userIDStr) orgID := r.Context().Value(OrgKey).(uuid.UUID) hasPerm, err := permission.HasPermission(r.Context(), db, userID, orgID, perm) if err != nil || !hasPerm { auditLogger.Log(r.Context(), audit.Entry{ UserID: &userID, OrgID: &orgID, Action: "permission_check", Resource: &[]string{string(perm)}[0], Success: false, Metadata: map[string]interface{}{"permission": perm}, }) errors.LogError(r, err, "Permission denied") errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } }