243 lines
7.1 KiB
Go
243 lines
7.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"go.b0esche.cloud/backend/internal/audit"
|
|
"go.b0esche.cloud/backend/internal/database"
|
|
"go.b0esche.cloud/backend/internal/errors"
|
|
"go.b0esche.cloud/backend/internal/org"
|
|
"go.b0esche.cloud/backend/internal/permission"
|
|
"go.b0esche.cloud/backend/pkg/jwt"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var RequestID = middleware.RequestID
|
|
var Logger = middleware.Logger
|
|
var Recoverer = middleware.Recoverer
|
|
|
|
// CORS middleware - accepts allowedOrigins comma-separated string
|
|
func CORS(allowedOrigins string) func(http.Handler) http.Handler {
|
|
allowedList, allowAll := compileAllowedOrigins(allowedOrigins)
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
if origin != "" && isOriginAllowed(origin, allowedList) {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Add("Vary", "Origin")
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
} else if allowAll {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
}
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
|
w.Header().Set("Access-Control-Expose-Headers", "Content-Length, Content-Type, Content-Disposition")
|
|
w.Header().Set("Access-Control-Max-Age", "3600")
|
|
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func compileAllowedOrigins(origins string) ([]string, bool) {
|
|
var allowed []string
|
|
allowAll := false
|
|
|
|
for _, origin := range strings.Split(origins, ",") {
|
|
trimmed := strings.TrimSpace(origin)
|
|
if trimmed == "" {
|
|
continue
|
|
}
|
|
if trimmed == "*" {
|
|
allowAll = true
|
|
}
|
|
allowed = append(allowed, trimmed)
|
|
}
|
|
|
|
if len(allowed) == 0 && !allowAll {
|
|
allowAll = true
|
|
}
|
|
|
|
return allowed, allowAll
|
|
}
|
|
|
|
func isOriginAllowed(origin string, allowed []string) bool {
|
|
if origin == "" {
|
|
return false
|
|
}
|
|
for _, pattern := range allowed {
|
|
if originMatches(origin, pattern) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func originMatches(origin, pattern string) bool {
|
|
if pattern == "*" {
|
|
return true
|
|
}
|
|
if !strings.Contains(pattern, "*") {
|
|
return strings.EqualFold(origin, pattern)
|
|
}
|
|
regexPattern := "(?i)^" + regexp.QuoteMeta(pattern) + "$"
|
|
regexPattern = strings.ReplaceAll(regexPattern, "\\*", ".*")
|
|
matched, err := regexp.MatchString(regexPattern, origin)
|
|
return err == nil && matched
|
|
}
|
|
|
|
// TODO: Implement rate limiter
|
|
var RateLimit = func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Basic rate limiting logic here
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
type ContextKey string
|
|
|
|
const (
|
|
UserKey ContextKey = "user"
|
|
SessionKey ContextKey = "session"
|
|
TokenKey ContextKey = "token"
|
|
OrgKey ContextKey = "org"
|
|
)
|
|
|
|
// GetUserID retrieves the user ID from the request context
|
|
func GetUserID(ctx context.Context) (string, bool) {
|
|
userID, ok := ctx.Value(UserKey).(string)
|
|
return userID, ok
|
|
}
|
|
|
|
// GetSession retrieves the session from the request context
|
|
func GetSession(ctx context.Context) (*database.Session, bool) {
|
|
session, ok := ctx.Value(SessionKey).(*database.Session)
|
|
return session, ok
|
|
}
|
|
|
|
// GetToken retrieves the JWT token from the request context
|
|
func GetToken(ctx context.Context) (string, bool) {
|
|
token, ok := ctx.Value(TokenKey).(string)
|
|
return token, ok
|
|
}
|
|
|
|
// Auth middleware
|
|
func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
var tokenString string
|
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
|
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
|
|
} else {
|
|
// Fallback to query parameter token (for viewers that cannot set headers)
|
|
qToken := r.URL.Query().Get("token")
|
|
if qToken == "" {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
tokenString = qToken
|
|
}
|
|
|
|
claims, session, err := jwtManager.ValidateWithSession(r.Context(), tokenString, db)
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), UserKey, claims.UserID)
|
|
ctx = context.WithValue(ctx, SessionKey, session)
|
|
ctx = context.WithValue(ctx, TokenKey, tokenString)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// Org middleware
|
|
func Org(db *database.DB, auditLogger *audit.Logger) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
userIDStr := r.Context().Value(UserKey).(string)
|
|
userID, _ := uuid.Parse(userIDStr)
|
|
|
|
orgIDStr := r.Header.Get("X-Org-ID")
|
|
if orgIDStr == "" {
|
|
orgIDStr = chi.URLParam(r, "orgId")
|
|
}
|
|
orgID, err := uuid.Parse(orgIDStr)
|
|
if err != nil {
|
|
errors.WriteError(w, errors.CodeInvalidArgument, "Invalid org ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
_, err = org.CheckMembership(r.Context(), db, userID, orgID)
|
|
if err != nil {
|
|
auditLogger.Log(r.Context(), audit.Entry{
|
|
UserID: &userID,
|
|
Action: "org_access",
|
|
Success: false,
|
|
Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()},
|
|
})
|
|
errors.LogError(r, err, "Org access denied")
|
|
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
_, err = org.CheckMembership(r.Context(), db, userID, orgID)
|
|
if err != nil {
|
|
auditLogger.Log(r.Context(), audit.Entry{
|
|
UserID: &userID,
|
|
Action: "org_access",
|
|
Success: false,
|
|
Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()},
|
|
})
|
|
errors.LogError(r, err, "Org access denied")
|
|
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), OrgKey, orgID)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// Permission middleware
|
|
func Permission(db *database.DB, auditLogger *audit.Logger, perm permission.Permission) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
userIDStr := r.Context().Value(UserKey).(string)
|
|
userID, _ := uuid.Parse(userIDStr)
|
|
orgID := r.Context().Value(OrgKey).(uuid.UUID)
|
|
|
|
hasPerm, err := permission.HasPermission(r.Context(), db, userID, orgID, perm)
|
|
if err != nil || !hasPerm {
|
|
auditLogger.Log(r.Context(), audit.Entry{
|
|
UserID: &userID,
|
|
OrgID: &orgID,
|
|
Action: "permission_check",
|
|
Resource: &[]string{string(perm)}[0],
|
|
Success: false,
|
|
Metadata: map[string]interface{}{"permission": perm},
|
|
})
|
|
errors.LogError(r, err, "Permission denied")
|
|
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|