Files
b0esche_cloud/go_cloud/internal/middleware/middleware.go

337 lines
10 KiB
Go

package middleware
import (
"context"
"fmt"
"net/http"
"regexp"
"strings"
"sync"
"time"
"go.b0esche.cloud/backend/internal/audit"
"go.b0esche.cloud/backend/internal/database"
"go.b0esche.cloud/backend/internal/errors"
"go.b0esche.cloud/backend/internal/org"
"go.b0esche.cloud/backend/internal/permission"
"go.b0esche.cloud/backend/pkg/jwt"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
)
var RequestID = middleware.RequestID
var Logger = middleware.Logger
var Recoverer = middleware.Recoverer
// SecurityHeaders adds security-related HTTP headers
func SecurityHeaders() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// Prevent clickjacking - allow for WOPI routes
if !strings.HasPrefix(r.URL.Path, "/wopi") && !strings.HasPrefix(r.URL.Path, "/user/files/") && !strings.HasPrefix(r.URL.Path, "/orgs/") {
w.Header().Set("X-Frame-Options", "DENY")
}
// Enable XSS protection
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Referrer policy
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy - basic policy
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https://go.b0esche.cloud https://of.b0esche.cloud; frame-src 'self' https://of.b0esche.cloud;")
next.ServeHTTP(w, r)
})
}
}
// CORS middleware - accepts allowedOrigins comma-separated string
func CORS(allowedOrigins string) func(http.Handler) http.Handler {
allowedList, allowAll := compileAllowedOrigins(allowedOrigins)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" && isOriginAllowed(origin, allowedList) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Add("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Credentials", "true")
} else if allowAll {
w.Header().Set("Access-Control-Allow-Origin", "*")
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
allowHeaders := []string{"Content-Type", "Authorization", "Range", "Accept", "Origin", "X-Requested-With"}
if reqHeaders := r.Header.Get("Access-Control-Request-Headers"); reqHeaders != "" {
allowHeaders = append(allowHeaders, reqHeaders)
}
w.Header().Set("Access-Control-Allow-Headers", strings.Join(uniqueStrings(allowHeaders), ", "))
w.Header().Set("Access-Control-Expose-Headers", "Content-Length, Content-Type, Content-Disposition, Content-Range, Accept-Ranges")
w.Header().Set("Access-Control-Max-Age", "3600")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
}
func compileAllowedOrigins(origins string) ([]string, bool) {
var allowed []string
allowAll := false
for _, origin := range strings.Split(origins, ",") {
trimmed := strings.TrimSpace(origin)
if trimmed == "" {
continue
}
if trimmed == "*" {
allowAll = true
}
allowed = append(allowed, trimmed)
}
if len(allowed) == 0 && !allowAll {
allowAll = true
}
return allowed, allowAll
}
func uniqueStrings(values []string) []string {
seen := make(map[string]struct{})
var out []string
for _, v := range values {
trimmed := strings.TrimSpace(v)
if trimmed == "" {
continue
}
key := strings.ToLower(trimmed)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, trimmed)
}
return out
}
func isOriginAllowed(origin string, allowed []string) bool {
if origin == "" {
return false
}
for _, pattern := range allowed {
if originMatches(origin, pattern) {
return true
}
}
return false
}
func originMatches(origin, pattern string) bool {
if pattern == "*" {
return true
}
if !strings.Contains(pattern, "*") {
return strings.EqualFold(origin, pattern)
}
regexPattern := "(?i)^" + regexp.QuoteMeta(pattern) + "$"
regexPattern = strings.ReplaceAll(regexPattern, "\\*", ".*")
matched, err := regexp.MatchString(regexPattern, origin)
return err == nil && matched
}
// rateLimiter tracks request counts per IP address
type rateLimiter struct {
mu sync.RWMutex
requests map[string]*clientRequests
}
type clientRequests struct {
count int
resetTime time.Time
}
var limiter = &rateLimiter{
requests: make(map[string]*clientRequests),
}
// RateLimit implements a simple sliding window rate limiter
// Limits: 100 requests per minute per IP for general endpoints
// 10 requests per minute per IP for auth endpoints
var RateLimit = func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP (consider X-Forwarded-For from reverse proxy)
ip := r.RemoteAddr
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
ip = strings.Split(forwarded, ",")[0]
}
// Determine rate limit based on endpoint
limit := 100 // Default: 100 requests/minute
if strings.HasPrefix(r.URL.Path, "/auth/") {
limit = 10 // Auth endpoints: 10 requests/minute
}
limiter.mu.Lock()
client, exists := limiter.requests[ip]
now := time.Now()
if !exists || now.After(client.resetTime) {
// New window
limiter.requests[ip] = &clientRequests{
count: 1,
resetTime: now.Add(time.Minute),
}
limiter.mu.Unlock()
next.ServeHTTP(w, r)
return
}
if client.count >= limit {
limiter.mu.Unlock()
w.Header().Set("Retry-After", "60")
errors.WriteError(w, errors.CodeInvalidArgument, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests)
return
}
client.count++
limiter.mu.Unlock()
next.ServeHTTP(w, r)
})
}
type ContextKey string
const (
UserKey ContextKey = "user"
SessionKey ContextKey = "session"
TokenKey ContextKey = "token"
OrgKey ContextKey = "org"
)
// GetUserID retrieves the user ID from the request context
func GetUserID(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(UserKey).(string)
return userID, ok
}
// GetSession retrieves the session from the request context
func GetSession(ctx context.Context) (*database.Session, bool) {
session, ok := ctx.Value(SessionKey).(*database.Session)
return session, ok
}
// GetToken retrieves the JWT token from the request context
func GetToken(ctx context.Context) (string, bool) {
token, ok := ctx.Value(TokenKey).(string)
return token, ok
}
// Auth middleware
func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
var tokenString string
var tokenSource string
if strings.HasPrefix(authHeader, "Bearer ") {
tokenString = strings.TrimPrefix(authHeader, "Bearer ")
tokenSource = "header"
} else {
// Fallback to query parameter token (for viewers that cannot set headers)
qToken := r.URL.Query().Get("token")
if qToken == "" {
fmt.Printf("[AUTH-TOKEN] source=none, path=%s, statusCode=401\n", r.RequestURI)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
tokenString = qToken
tokenSource = "query"
}
fmt.Printf("[AUTH-TOKEN] source=%s, path=%s\n", tokenSource, r.RequestURI)
claims, session, err := jwtManager.ValidateWithSession(r.Context(), tokenString, db)
if err != nil {
fmt.Printf("[AUTH-TOKEN] validation_failed, source=%s, path=%s, error=%v\n", tokenSource, r.RequestURI, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
fmt.Printf("[AUTH-TOKEN] valid, source=%s, userId=%s\n", tokenSource, claims.UserID)
ctx := context.WithValue(r.Context(), UserKey, claims.UserID)
ctx = context.WithValue(ctx, SessionKey, session)
ctx = context.WithValue(ctx, TokenKey, tokenString)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Org middleware
func Org(db *database.DB, auditLogger *audit.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userIDStr := r.Context().Value(UserKey).(string)
userID, _ := uuid.Parse(userIDStr)
orgIDStr := r.Header.Get("X-Org-ID")
if orgIDStr == "" {
orgIDStr = chi.URLParam(r, "orgId")
}
orgID, err := uuid.Parse(orgIDStr)
if err != nil {
errors.WriteError(w, errors.CodeInvalidArgument, "Invalid org ID", http.StatusBadRequest)
return
}
_, err = org.CheckMembership(r.Context(), db, userID, orgID)
if err != nil {
auditLogger.Log(r.Context(), audit.Entry{
UserID: &userID,
Action: "org_access",
Success: false,
Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()},
})
errors.LogError(r, err, "Org access denied")
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
return
}
ctx := context.WithValue(r.Context(), OrgKey, orgID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// Permission middleware
func Permission(db *database.DB, auditLogger *audit.Logger, perm permission.Permission) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userIDStr := r.Context().Value(UserKey).(string)
userID, _ := uuid.Parse(userIDStr)
orgID := r.Context().Value(OrgKey).(uuid.UUID)
hasPerm, err := permission.HasPermission(r.Context(), db, userID, orgID, perm)
if err != nil || !hasPerm {
auditLogger.Log(r.Context(), audit.Entry{
UserID: &userID,
OrgID: &orgID,
Action: "permission_check",
Resource: &[]string{string(perm)}[0],
Success: false,
Metadata: map[string]interface{}{"permission": perm},
})
errors.LogError(r, err, "Permission denied")
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}