Files
b0esche_cloud/go_cloud/internal/middleware/middleware.go
2026-01-10 19:16:23 +01:00

195 lines
6.1 KiB
Go

package middleware
import (
"context"
"net/http"
"strings"
"go.b0esche.cloud/backend/internal/audit"
"go.b0esche.cloud/backend/internal/database"
"go.b0esche.cloud/backend/internal/errors"
"go.b0esche.cloud/backend/internal/org"
"go.b0esche.cloud/backend/internal/permission"
"go.b0esche.cloud/backend/pkg/jwt"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
)
var RequestID = middleware.RequestID
var Logger = middleware.Logger
var Recoverer = middleware.Recoverer
// CORS middleware - accepts allowedOrigins comma-separated string
func CORS(allowedOrigins string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Check if origin is allowed
if origin != "" {
// Simple check - in production you'd want to parse allowedOrigins properly
for _, allowed := range strings.Split(allowedOrigins, ",") {
if strings.TrimSpace(allowed) == origin {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
break
}
}
}
// Fallback to * if no credentials needed
if w.Header().Get("Access-Control-Allow-Origin") == "" {
w.Header().Set("Access-Control-Allow-Origin", "*")
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.Header().Set("Access-Control-Expose-Headers", "Content-Length, Content-Type, Content-Disposition")
w.Header().Set("Access-Control-Max-Age", "3600")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
}
// TODO: Implement rate limiter
var RateLimit = func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Basic rate limiting logic here
next.ServeHTTP(w, r)
})
}
type contextKey string
const (
userKey contextKey = "user"
sessionKey contextKey = "session"
tokenKey contextKey = "token"
orgKey contextKey = "org"
)
// GetUserID retrieves the user ID from the request context
func GetUserID(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(userKey).(string)
return userID, ok
}
// GetSession retrieves the session from the request context
func GetSession(ctx context.Context) (*database.Session, bool) {
session, ok := ctx.Value(sessionKey).(*database.Session)
return session, ok
}
// GetToken retrieves the JWT token from the request context
func GetToken(ctx context.Context) (string, bool) {
token, ok := ctx.Value(tokenKey).(string)
return token, ok
}
// Auth middleware
func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if !strings.HasPrefix(authHeader, "Bearer ") {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
claims, session, err := jwtManager.ValidateWithSession(r.Context(), tokenString, db)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), userKey, claims.UserID)
ctx = context.WithValue(ctx, sessionKey, session)
ctx = context.WithValue(ctx, tokenKey, tokenString)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Org middleware
func Org(db *database.DB, auditLogger *audit.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userIDStr := r.Context().Value(userKey).(string)
userID, _ := uuid.Parse(userIDStr)
orgIDStr := r.Header.Get("X-Org-ID")
if orgIDStr == "" {
orgIDStr = chi.URLParam(r, "orgId")
}
orgID, err := uuid.Parse(orgIDStr)
if err != nil {
errors.WriteError(w, errors.CodeInvalidArgument, "Invalid org ID", http.StatusBadRequest)
return
}
_, err = org.CheckMembership(r.Context(), db, userID, orgID)
if err != nil {
auditLogger.Log(r.Context(), audit.Entry{
UserID: &userID,
Action: "org_access",
Success: false,
Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()},
})
errors.LogError(r, err, "Org access denied")
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
return
}
_, err = org.CheckMembership(r.Context(), db, userID, orgID)
if err != nil {
auditLogger.Log(r.Context(), audit.Entry{
UserID: &userID,
Action: "org_access",
Success: false,
Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()},
})
errors.LogError(r, err, "Org access denied")
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
return
}
ctx := context.WithValue(r.Context(), orgKey, orgID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Permission middleware
func Permission(db *database.DB, auditLogger *audit.Logger, perm permission.Permission) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userIDStr := r.Context().Value(userKey).(string)
userID, _ := uuid.Parse(userIDStr)
orgID := r.Context().Value(orgKey).(uuid.UUID)
hasPerm, err := permission.HasPermission(r.Context(), db, userID, orgID, perm)
if err != nil || !hasPerm {
auditLogger.Log(r.Context(), audit.Entry{
UserID: &userID,
OrgID: &orgID,
Action: "permission_check",
Resource: &[]string{string(perm)}[0],
Success: false,
Metadata: map[string]interface{}{"permission": perm},
})
errors.LogError(r, err, "Permission denied")
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}