feat: add ownerId to Organization and update related database queries; enhance CORS middleware for origin validation
This commit is contained in:
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"go.b0esche.cloud/backend/internal/audit"
|
||||
@@ -23,22 +24,15 @@ var Recoverer = middleware.Recoverer
|
||||
|
||||
// CORS middleware - accepts allowedOrigins comma-separated string
|
||||
func CORS(allowedOrigins string) func(http.Handler) http.Handler {
|
||||
allowedList, allowAll := compileAllowedOrigins(allowedOrigins)
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
// Check if origin is allowed
|
||||
if origin != "" {
|
||||
// Simple check - in production you'd want to parse allowedOrigins properly
|
||||
for _, allowed := range strings.Split(allowedOrigins, ",") {
|
||||
if strings.TrimSpace(allowed) == origin {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback to * if no credentials needed
|
||||
if w.Header().Get("Access-Control-Allow-Origin") == "" {
|
||||
if origin != "" && isOriginAllowed(origin, allowedList) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Add("Vary", "Origin")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
} else if allowAll {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
@@ -56,6 +50,53 @@ func CORS(allowedOrigins string) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func compileAllowedOrigins(origins string) ([]string, bool) {
|
||||
var allowed []string
|
||||
allowAll := false
|
||||
|
||||
for _, origin := range strings.Split(origins, ",") {
|
||||
trimmed := strings.TrimSpace(origin)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if trimmed == "*" {
|
||||
allowAll = true
|
||||
}
|
||||
allowed = append(allowed, trimmed)
|
||||
}
|
||||
|
||||
if len(allowed) == 0 && !allowAll {
|
||||
allowAll = true
|
||||
}
|
||||
|
||||
return allowed, allowAll
|
||||
}
|
||||
|
||||
func isOriginAllowed(origin string, allowed []string) bool {
|
||||
if origin == "" {
|
||||
return false
|
||||
}
|
||||
for _, pattern := range allowed {
|
||||
if originMatches(origin, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func originMatches(origin, pattern string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
if !strings.Contains(pattern, "*") {
|
||||
return strings.EqualFold(origin, pattern)
|
||||
}
|
||||
regexPattern := "(?i)^" + regexp.QuoteMeta(pattern) + "$"
|
||||
regexPattern = strings.ReplaceAll(regexPattern, "\\*", ".*")
|
||||
matched, err := regexp.MatchString(regexPattern, origin)
|
||||
return err == nil && matched
|
||||
}
|
||||
|
||||
// TODO: Implement rate limiter
|
||||
var RateLimit = func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
Reference in New Issue
Block a user