Fix context key mismatch - use typed contextKey consistently

This commit is contained in:
Leon Bösche
2026-01-09 20:26:55 +01:00
parent a9d205f454
commit 8114a3746b
3 changed files with 27 additions and 15 deletions

View File

@@ -6,7 +6,7 @@ import (
"net/http"
"os"
"github.com/go-chi/chi/v5/middleware"
chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
)
@@ -40,7 +40,7 @@ func WriteError(w http.ResponseWriter, code ErrorCode, message string, status in
// GetRequestID extracts the request ID from the request context
func GetRequestID(r *http.Request) string {
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
if reqID := chimiddleware.GetReqID(r.Context()); reqID != "" {
return reqID
}
return "unknown"
@@ -48,10 +48,10 @@ func GetRequestID(r *http.Request) string {
// GetUserID extracts user ID from context if available
func GetUserID(r *http.Request) string {
if userID := r.Context().Value("user"); userID != nil {
if uid, ok := userID.(string); ok {
return uid
}
// Use type contextKey matching middleware package
type contextKey string
if userID, ok := r.Context().Value(contextKey("user")).(string); ok && userID != "" {
return userID
}
return ""
}

View File

@@ -351,7 +351,7 @@ func listFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
}
func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
userIDStr := r.Context().Value("user").(string)
userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr)
orgID := r.Context().Value("org").(uuid.UUID)
fileId := chi.URLParam(r, "fileId")
@@ -382,7 +382,7 @@ func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi
}
func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
userIDStr := r.Context().Value("user").(string)
userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr)
orgID := r.Context().Value("org").(uuid.UUID)
fileId := chi.URLParam(r, "fileId")
@@ -405,7 +405,7 @@ func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi
}
func annotationsHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
userIDStr := r.Context().Value("user").(string)
userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr)
orgID := r.Context().Value("org").(uuid.UUID)
fileId := chi.URLParam(r, "fileId")
@@ -848,7 +848,7 @@ func passwordLoginHandler(w http.ResponseWriter, r *http.Request, db *database.D
// userFilesHandler returns files for the authenticated user's personal workspace.
func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
userIDStr, ok := r.Context().Value("user").(string)
userIDStr, ok := middleware.GetUserID(r.Context())
if !ok || userIDStr == "" {
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
return
@@ -899,7 +899,7 @@ func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
// createOrgFileHandler creates a file or folder record for an org workspace.
func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
orgID := r.Context().Value("org").(uuid.UUID)
userIDStr, _ := r.Context().Value("user").(string)
userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr)
var f *database.File
var err error
@@ -1039,7 +1039,7 @@ func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.D
// deleteOrgFileHandler deletes a file/folder in org workspace by path
func deleteOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
orgID := r.Context().Value("org").(uuid.UUID)
userIDStr, _ := r.Context().Value("user").(string)
userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr)
var req struct {
@@ -1084,7 +1084,7 @@ func deleteOrgFilePostHandler(w http.ResponseWriter, r *http.Request, db *databa
// createUserFileHandler creates a file or folder record for the authenticated user's personal workspace.
func createUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
userIDStr, ok := r.Context().Value("user").(string)
userIDStr, ok := middleware.GetUserID(r.Context())
if !ok || userIDStr == "" {
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
return
@@ -1216,7 +1216,7 @@ func deleteUserFilePostHandler(w http.ResponseWriter, r *http.Request, db *datab
// deleteUserFileHandler deletes a file/folder in user's personal workspace by path
func deleteUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
userIDStr, ok := r.Context().Value("user").(string)
userIDStr, ok := middleware.GetUserID(r.Context())
if !ok || userIDStr == "" {
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
return
@@ -1298,7 +1298,7 @@ func downloadOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database
// downloadUserFileHandler downloads a file from user's personal workspace
func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, storageClient *storage.WebDAVClient) {
userIDStr, ok := r.Context().Value("user").(string)
userIDStr, ok := middleware.GetUserID(r.Context())
if !ok || userIDStr == "" {
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
return

View File

@@ -54,6 +54,18 @@ const (
orgKey contextKey = "org"
)
// GetUserID retrieves the user ID from the request context
func GetUserID(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(userKey).(string)
return userID, ok
}
// GetSession retrieves the session from the request context
func GetSession(ctx context.Context) (*database.Session, bool) {
session, ok := ctx.Value(sessionKey).(*database.Session)
return session, ok
}
// Auth middleware
func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {