Fix context key mismatch - use typed contextKey consistently

This commit is contained in:
Leon Bösche
2026-01-09 20:26:55 +01:00
parent a9d205f454
commit 8114a3746b
3 changed files with 27 additions and 15 deletions

View File

@@ -6,7 +6,7 @@ import (
"net/http" "net/http"
"os" "os"
"github.com/go-chi/chi/v5/middleware" chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -40,7 +40,7 @@ func WriteError(w http.ResponseWriter, code ErrorCode, message string, status in
// GetRequestID extracts the request ID from the request context // GetRequestID extracts the request ID from the request context
func GetRequestID(r *http.Request) string { func GetRequestID(r *http.Request) string {
if reqID := middleware.GetReqID(r.Context()); reqID != "" { if reqID := chimiddleware.GetReqID(r.Context()); reqID != "" {
return reqID return reqID
} }
return "unknown" return "unknown"
@@ -48,10 +48,10 @@ func GetRequestID(r *http.Request) string {
// GetUserID extracts user ID from context if available // GetUserID extracts user ID from context if available
func GetUserID(r *http.Request) string { func GetUserID(r *http.Request) string {
if userID := r.Context().Value("user"); userID != nil { // Use type contextKey matching middleware package
if uid, ok := userID.(string); ok { type contextKey string
return uid if userID, ok := r.Context().Value(contextKey("user")).(string); ok && userID != "" {
} return userID
} }
return "" return ""
} }

View File

@@ -351,7 +351,7 @@ func listFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
} }
func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) { func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
userIDStr := r.Context().Value("user").(string) userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr) userID, _ := uuid.Parse(userIDStr)
orgID := r.Context().Value("org").(uuid.UUID) orgID := r.Context().Value("org").(uuid.UUID)
fileId := chi.URLParam(r, "fileId") fileId := chi.URLParam(r, "fileId")
@@ -382,7 +382,7 @@ func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi
} }
func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) { func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
userIDStr := r.Context().Value("user").(string) userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr) userID, _ := uuid.Parse(userIDStr)
orgID := r.Context().Value("org").(uuid.UUID) orgID := r.Context().Value("org").(uuid.UUID)
fileId := chi.URLParam(r, "fileId") fileId := chi.URLParam(r, "fileId")
@@ -405,7 +405,7 @@ func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi
} }
func annotationsHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) { func annotationsHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
userIDStr := r.Context().Value("user").(string) userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr) userID, _ := uuid.Parse(userIDStr)
orgID := r.Context().Value("org").(uuid.UUID) orgID := r.Context().Value("org").(uuid.UUID)
fileId := chi.URLParam(r, "fileId") fileId := chi.URLParam(r, "fileId")
@@ -848,7 +848,7 @@ func passwordLoginHandler(w http.ResponseWriter, r *http.Request, db *database.D
// userFilesHandler returns files for the authenticated user's personal workspace. // userFilesHandler returns files for the authenticated user's personal workspace.
func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) { func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
userIDStr, ok := r.Context().Value("user").(string) userIDStr, ok := middleware.GetUserID(r.Context())
if !ok || userIDStr == "" { if !ok || userIDStr == "" {
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized) errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
return return
@@ -899,7 +899,7 @@ func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
// createOrgFileHandler creates a file or folder record for an org workspace. // createOrgFileHandler creates a file or folder record for an org workspace.
func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) { func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
orgID := r.Context().Value("org").(uuid.UUID) orgID := r.Context().Value("org").(uuid.UUID)
userIDStr, _ := r.Context().Value("user").(string) userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr) userID, _ := uuid.Parse(userIDStr)
var f *database.File var f *database.File
var err error var err error
@@ -1039,7 +1039,7 @@ func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.D
// deleteOrgFileHandler deletes a file/folder in org workspace by path // deleteOrgFileHandler deletes a file/folder in org workspace by path
func deleteOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) { func deleteOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
orgID := r.Context().Value("org").(uuid.UUID) orgID := r.Context().Value("org").(uuid.UUID)
userIDStr, _ := r.Context().Value("user").(string) userIDStr, _ := middleware.GetUserID(r.Context())
userID, _ := uuid.Parse(userIDStr) userID, _ := uuid.Parse(userIDStr)
var req struct { var req struct {
@@ -1084,7 +1084,7 @@ func deleteOrgFilePostHandler(w http.ResponseWriter, r *http.Request, db *databa
// createUserFileHandler creates a file or folder record for the authenticated user's personal workspace. // createUserFileHandler creates a file or folder record for the authenticated user's personal workspace.
func createUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) { func createUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
userIDStr, ok := r.Context().Value("user").(string) userIDStr, ok := middleware.GetUserID(r.Context())
if !ok || userIDStr == "" { if !ok || userIDStr == "" {
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized) errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
return return
@@ -1216,7 +1216,7 @@ func deleteUserFilePostHandler(w http.ResponseWriter, r *http.Request, db *datab
// deleteUserFileHandler deletes a file/folder in user's personal workspace by path // deleteUserFileHandler deletes a file/folder in user's personal workspace by path
func deleteUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) { func deleteUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
userIDStr, ok := r.Context().Value("user").(string) userIDStr, ok := middleware.GetUserID(r.Context())
if !ok || userIDStr == "" { if !ok || userIDStr == "" {
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized) errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
return return
@@ -1298,7 +1298,7 @@ func downloadOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database
// downloadUserFileHandler downloads a file from user's personal workspace // downloadUserFileHandler downloads a file from user's personal workspace
func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, storageClient *storage.WebDAVClient) { func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, storageClient *storage.WebDAVClient) {
userIDStr, ok := r.Context().Value("user").(string) userIDStr, ok := middleware.GetUserID(r.Context())
if !ok || userIDStr == "" { if !ok || userIDStr == "" {
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized) errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
return return

View File

@@ -54,6 +54,18 @@ const (
orgKey contextKey = "org" orgKey contextKey = "org"
) )
// GetUserID retrieves the user ID from the request context
func GetUserID(ctx context.Context) (string, bool) {
userID, ok := ctx.Value(userKey).(string)
return userID, ok
}
// GetSession retrieves the session from the request context
func GetSession(ctx context.Context) (*database.Session, bool) {
session, ok := ctx.Value(sessionKey).(*database.Session)
return session, ok
}
// Auth middleware // Auth middleware
func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler { func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {