package middleware import ( "context" "net/http" "strings" "go.b0esche.cloud/backend/internal/audit" "go.b0esche.cloud/backend/internal/database" "go.b0esche.cloud/backend/internal/errors" "go.b0esche.cloud/backend/internal/org" "go.b0esche.cloud/backend/internal/permission" "go.b0esche.cloud/backend/pkg/jwt" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/google/uuid" ) var RequestID = middleware.RequestID var Logger = middleware.Logger var Recoverer = middleware.Recoverer // CORS middleware - accepts allowedOrigins comma-separated string func CORS(allowedOrigins string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Check if origin is allowed if origin != "" { // Simple check - in production you'd want to parse allowedOrigins properly for _, allowed := range strings.Split(allowedOrigins, ",") { if strings.TrimSpace(allowed) == origin { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Credentials", "true") break } } } // Fallback to * if no credentials needed if w.Header().Get("Access-Control-Allow-Origin") == "" { w.Header().Set("Access-Control-Allow-Origin", "*") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") w.Header().Set("Access-Control-Max-Age", "3600") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } } // TODO: Implement rate limiter var RateLimit = func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Basic rate limiting logic here next.ServeHTTP(w, r) }) } type contextKey string const ( userKey contextKey = "user" sessionKey contextKey = "session" orgKey contextKey = "org" ) // GetUserID retrieves the user ID from the request context func GetUserID(ctx context.Context) (string, bool) { userID, ok := ctx.Value(userKey).(string) return userID, ok } // GetSession retrieves the session from the request context func GetSession(ctx context.Context) (*database.Session, bool) { session, ok := ctx.Value(sessionKey).(*database.Session) return session, ok } // Auth middleware func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if !strings.HasPrefix(authHeader, "Bearer ") { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } tokenString := strings.TrimPrefix(authHeader, "Bearer ") claims, session, err := jwtManager.ValidateWithSession(r.Context(), tokenString, db) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), userKey, claims.UserID) ctx = context.WithValue(ctx, sessionKey, session) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Org middleware func Org(db *database.DB, auditLogger *audit.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userIDStr := r.Context().Value(userKey).(string) userID, _ := uuid.Parse(userIDStr) orgIDStr := r.Header.Get("X-Org-ID") if orgIDStr == "" { orgIDStr = chi.URLParam(r, "orgId") } orgID, err := uuid.Parse(orgIDStr) if err != nil { errors.WriteError(w, errors.CodeInvalidArgument, "Invalid org ID", http.StatusBadRequest) return } _, err = org.CheckMembership(r.Context(), db, userID, orgID) if err != nil { auditLogger.Log(r.Context(), audit.Entry{ UserID: &userID, Action: "org_access", Success: false, Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()}, }) errors.LogError(r, err, "Org access denied") errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden) return } _, err = org.CheckMembership(r.Context(), db, userID, orgID) if err != nil { auditLogger.Log(r.Context(), audit.Entry{ UserID: &userID, Action: "org_access", Success: false, Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()}, }) errors.LogError(r, err, "Org access denied") errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden) return } ctx := context.WithValue(r.Context(), orgKey, orgID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // Permission middleware func Permission(db *database.DB, auditLogger *audit.Logger, perm permission.Permission) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userIDStr := r.Context().Value(userKey).(string) userID, _ := uuid.Parse(userIDStr) orgID := r.Context().Value(orgKey).(uuid.UUID) hasPerm, err := permission.HasPermission(r.Context(), db, userID, orgID, perm) if err != nil || !hasPerm { auditLogger.Log(r.Context(), audit.Entry{ UserID: &userID, OrgID: &orgID, Action: "permission_check", Resource: &[]string{string(perm)}[0], Success: false, Metadata: map[string]interface{}{"permission": perm}, }) errors.LogError(r, err, "Permission denied") errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } }