package middleware import ( "context" "fmt" "net/http" "regexp" "strings" "sync" "time" "go.b0esche.cloud/backend/internal/audit" "go.b0esche.cloud/backend/internal/database" "go.b0esche.cloud/backend/internal/errors" "go.b0esche.cloud/backend/internal/org" "go.b0esche.cloud/backend/internal/permission" "go.b0esche.cloud/backend/pkg/jwt" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/google/uuid" ) var RequestID = middleware.RequestID var Logger = middleware.Logger var Recoverer = middleware.Recoverer // SecurityHeaders adds security-related HTTP headers func SecurityHeaders() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Prevent MIME type sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // Prevent clickjacking - allow for WOPI routes if !strings.HasPrefix(r.URL.Path, "/wopi") && !strings.HasPrefix(r.URL.Path, "/user/files/") && !strings.HasPrefix(r.URL.Path, "/orgs/") { w.Header().Set("X-Frame-Options", "DENY") } // Enable XSS protection w.Header().Set("X-XSS-Protection", "1; mode=block") // Referrer policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Content Security Policy - basic policy w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https://go.b0esche.cloud https://of.b0esche.cloud; frame-src 'self' https://of.b0esche.cloud;") next.ServeHTTP(w, r) }) } } // CORS middleware - accepts allowedOrigins comma-separated string func CORS(allowedOrigins string) func(http.Handler) http.Handler { allowedList, allowAll := compileAllowedOrigins(allowedOrigins) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin != "" && isOriginAllowed(origin, allowedList) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Add("Vary", "Origin") w.Header().Set("Access-Control-Allow-Credentials", "true") } else if allowAll { w.Header().Set("Access-Control-Allow-Origin", "*") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") allowHeaders := []string{"Content-Type", "Authorization", "Range", "Accept", "Origin", "X-Requested-With"} if reqHeaders := r.Header.Get("Access-Control-Request-Headers"); reqHeaders != "" { allowHeaders = append(allowHeaders, reqHeaders) } w.Header().Set("Access-Control-Allow-Headers", strings.Join(uniqueStrings(allowHeaders), ", ")) w.Header().Set("Access-Control-Expose-Headers", "Content-Length, Content-Type, Content-Disposition, Content-Range, Accept-Ranges") w.Header().Set("Access-Control-Max-Age", "3600") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } } func compileAllowedOrigins(origins string) ([]string, bool) { var allowed []string allowAll := false for _, origin := range strings.Split(origins, ",") { trimmed := strings.TrimSpace(origin) if trimmed == "" { continue } if trimmed == "*" { allowAll = true } allowed = append(allowed, trimmed) } if len(allowed) == 0 && !allowAll { allowAll = true } return allowed, allowAll } func uniqueStrings(values []string) []string { seen := make(map[string]struct{}) var out []string for _, v := range values { trimmed := strings.TrimSpace(v) if trimmed == "" { continue } key := strings.ToLower(trimmed) if _, ok := seen[key]; ok { continue } seen[key] = struct{}{} out = append(out, trimmed) } return out } func isOriginAllowed(origin string, allowed []string) bool { if origin == "" { return false } for _, pattern := range allowed { if originMatches(origin, pattern) { return true } } return false } func originMatches(origin, pattern string) bool { if pattern == "*" { return true } if !strings.Contains(pattern, "*") { return strings.EqualFold(origin, pattern) } regexPattern := "(?i)^" + regexp.QuoteMeta(pattern) + "$" regexPattern = strings.ReplaceAll(regexPattern, "\\*", ".*") matched, err := regexp.MatchString(regexPattern, origin) return err == nil && matched } // rateLimiter tracks request counts per IP address type rateLimiter struct { mu sync.RWMutex requests map[string]*clientRequests } type clientRequests struct { count int resetTime time.Time } var limiter = &rateLimiter{ requests: make(map[string]*clientRequests), } // RateLimit implements a simple sliding window rate limiter // Limits: 100 requests per minute per IP for general endpoints // 10 requests per minute per IP for auth endpoints var RateLimit = func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get client IP (consider X-Forwarded-For from reverse proxy) ip := r.RemoteAddr if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { ip = strings.Split(forwarded, ",")[0] } // Determine rate limit based on endpoint limit := 100 // Default: 100 requests/minute if strings.HasPrefix(r.URL.Path, "/auth/") { limit = 10 // Auth endpoints: 10 requests/minute } limiter.mu.Lock() client, exists := limiter.requests[ip] now := time.Now() if !exists || now.After(client.resetTime) { // New window limiter.requests[ip] = &clientRequests{ count: 1, resetTime: now.Add(time.Minute), } limiter.mu.Unlock() next.ServeHTTP(w, r) return } if client.count >= limit { limiter.mu.Unlock() w.Header().Set("Retry-After", "60") errors.WriteError(w, errors.CodeInvalidArgument, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests) return } client.count++ limiter.mu.Unlock() next.ServeHTTP(w, r) }) } type ContextKey string const ( UserKey ContextKey = "user" SessionKey ContextKey = "session" TokenKey ContextKey = "token" OrgKey ContextKey = "org" ) // GetUserID retrieves the user ID from the request context func GetUserID(ctx context.Context) (string, bool) { userID, ok := ctx.Value(UserKey).(string) return userID, ok } // GetSession retrieves the session from the request context func GetSession(ctx context.Context) (*database.Session, bool) { session, ok := ctx.Value(SessionKey).(*database.Session) return session, ok } // GetToken retrieves the JWT token from the request context func GetToken(ctx context.Context) (string, bool) { token, ok := ctx.Value(TokenKey).(string) return token, ok } // Auth middleware func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") var tokenString string var tokenSource string if strings.HasPrefix(authHeader, "Bearer ") { tokenString = strings.TrimPrefix(authHeader, "Bearer ") tokenSource = "header" } else { // Fallback to query parameter token (for viewers that cannot set headers) qToken := r.URL.Query().Get("token") if qToken == "" { fmt.Printf("[AUTH-TOKEN] source=none, path=%s, statusCode=401\n", r.RequestURI) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } tokenString = qToken tokenSource = "query" } fmt.Printf("[AUTH-TOKEN] source=%s, path=%s\n", tokenSource, r.RequestURI) claims, session, err := jwtManager.ValidateWithSession(r.Context(), tokenString, db) if err != nil { fmt.Printf("[AUTH-TOKEN] validation_failed, source=%s, path=%s, error=%v\n", tokenSource, r.RequestURI, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } fmt.Printf("[AUTH-TOKEN] valid, source=%s, userId=%s\n", tokenSource, claims.UserID) ctx := context.WithValue(r.Context(), UserKey, claims.UserID) ctx = context.WithValue(ctx, SessionKey, session) ctx = context.WithValue(ctx, TokenKey, tokenString) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Org middleware func Org(db *database.DB, auditLogger *audit.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userIDStr := r.Context().Value(UserKey).(string) userID, _ := uuid.Parse(userIDStr) orgIDStr := r.Header.Get("X-Org-ID") if orgIDStr == "" { orgIDStr = chi.URLParam(r, "orgId") } orgID, err := uuid.Parse(orgIDStr) if err != nil { errors.WriteError(w, errors.CodeInvalidArgument, "Invalid org ID", http.StatusBadRequest) return } _, err = org.CheckMembership(r.Context(), db, userID, orgID) if err != nil { auditLogger.Log(r.Context(), audit.Entry{ UserID: &userID, Action: "org_access", Success: false, Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()}, }) errors.LogError(r, err, "Org access denied") errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden) return } ctx := context.WithValue(r.Context(), OrgKey, orgID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Permission middleware func Permission(db *database.DB, auditLogger *audit.Logger, perm permission.Permission) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userIDStr := r.Context().Value(UserKey).(string) userID, _ := uuid.Parse(userIDStr) orgID := r.Context().Value(OrgKey).(uuid.UUID) hasPerm, err := permission.HasPermission(r.Context(), db, userID, orgID, perm) if err != nil || !hasPerm { auditLogger.Log(r.Context(), audit.Entry{ UserID: &userID, OrgID: &orgID, Action: "permission_check", Resource: &[]string{string(perm)}[0], Success: false, Metadata: map[string]interface{}{"permission": perm}, }) errors.LogError(r, err, "Permission denied") errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } }