- Export ContextKey type and context keys from middleware package
- Use exported keys (UserKey, SessionKey, TokenKey, OrgKey) in handlers
- Fixes panic: interface conversion: interface {} is nil, not uuid.UUID
- The middleware was setting context with contextKey type but handlers
were retrieving with string type, causing nil value lookup failure
195 lines
6.1 KiB
Go
195 lines
6.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"go.b0esche.cloud/backend/internal/audit"
|
|
"go.b0esche.cloud/backend/internal/database"
|
|
"go.b0esche.cloud/backend/internal/errors"
|
|
"go.b0esche.cloud/backend/internal/org"
|
|
"go.b0esche.cloud/backend/internal/permission"
|
|
"go.b0esche.cloud/backend/pkg/jwt"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var RequestID = middleware.RequestID
|
|
var Logger = middleware.Logger
|
|
var Recoverer = middleware.Recoverer
|
|
|
|
// CORS middleware - accepts allowedOrigins comma-separated string
|
|
func CORS(allowedOrigins string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
// Check if origin is allowed
|
|
if origin != "" {
|
|
// Simple check - in production you'd want to parse allowedOrigins properly
|
|
for _, allowed := range strings.Split(allowedOrigins, ",") {
|
|
if strings.TrimSpace(allowed) == origin {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
break
|
|
}
|
|
}
|
|
}
|
|
// Fallback to * if no credentials needed
|
|
if w.Header().Get("Access-Control-Allow-Origin") == "" {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
}
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
|
w.Header().Set("Access-Control-Expose-Headers", "Content-Length, Content-Type, Content-Disposition")
|
|
w.Header().Set("Access-Control-Max-Age", "3600")
|
|
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TODO: Implement rate limiter
|
|
var RateLimit = func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Basic rate limiting logic here
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
type ContextKey string
|
|
|
|
const (
|
|
UserKey ContextKey = "user"
|
|
SessionKey ContextKey = "session"
|
|
TokenKey ContextKey = "token"
|
|
OrgKey ContextKey = "org"
|
|
)
|
|
|
|
// GetUserID retrieves the user ID from the request context
|
|
func GetUserID(ctx context.Context) (string, bool) {
|
|
userID, ok := ctx.Value(UserKey).(string)
|
|
return userID, ok
|
|
}
|
|
|
|
// GetSession retrieves the session from the request context
|
|
func GetSession(ctx context.Context) (*database.Session, bool) {
|
|
session, ok := ctx.Value(SessionKey).(*database.Session)
|
|
return session, ok
|
|
}
|
|
|
|
// GetToken retrieves the JWT token from the request context
|
|
func GetToken(ctx context.Context) (string, bool) {
|
|
token, ok := ctx.Value(TokenKey).(string)
|
|
return token, ok
|
|
}
|
|
|
|
// Auth middleware
|
|
func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
claims, session, err := jwtManager.ValidateWithSession(r.Context(), tokenString, db)
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), UserKey, claims.UserID)
|
|
ctx = context.WithValue(ctx, SessionKey, session)
|
|
ctx = context.WithValue(ctx, TokenKey, tokenString)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// Org middleware
|
|
func Org(db *database.DB, auditLogger *audit.Logger) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
userIDStr := r.Context().Value(UserKey).(string)
|
|
userID, _ := uuid.Parse(userIDStr)
|
|
|
|
orgIDStr := r.Header.Get("X-Org-ID")
|
|
if orgIDStr == "" {
|
|
orgIDStr = chi.URLParam(r, "orgId")
|
|
}
|
|
orgID, err := uuid.Parse(orgIDStr)
|
|
if err != nil {
|
|
errors.WriteError(w, errors.CodeInvalidArgument, "Invalid org ID", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
_, err = org.CheckMembership(r.Context(), db, userID, orgID)
|
|
if err != nil {
|
|
auditLogger.Log(r.Context(), audit.Entry{
|
|
UserID: &userID,
|
|
Action: "org_access",
|
|
Success: false,
|
|
Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()},
|
|
})
|
|
errors.LogError(r, err, "Org access denied")
|
|
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
_, err = org.CheckMembership(r.Context(), db, userID, orgID)
|
|
if err != nil {
|
|
auditLogger.Log(r.Context(), audit.Entry{
|
|
UserID: &userID,
|
|
Action: "org_access",
|
|
Success: false,
|
|
Metadata: map[string]interface{}{"org_id": orgID, "error": err.Error()},
|
|
})
|
|
errors.LogError(r, err, "Org access denied")
|
|
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), OrgKey, orgID)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// Permission middleware
|
|
func Permission(db *database.DB, auditLogger *audit.Logger, perm permission.Permission) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
userIDStr := r.Context().Value(UserKey).(string)
|
|
userID, _ := uuid.Parse(userIDStr)
|
|
orgID := r.Context().Value(OrgKey).(uuid.UUID)
|
|
|
|
hasPerm, err := permission.HasPermission(r.Context(), db, userID, orgID, perm)
|
|
if err != nil || !hasPerm {
|
|
auditLogger.Log(r.Context(), audit.Entry{
|
|
UserID: &userID,
|
|
OrgID: &orgID,
|
|
Action: "permission_check",
|
|
Resource: &[]string{string(perm)}[0],
|
|
Success: false,
|
|
Metadata: map[string]interface{}{"permission": perm},
|
|
})
|
|
errors.LogError(r, err, "Permission denied")
|
|
errors.WriteError(w, errors.CodePermissionDenied, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|