Fix context key mismatch - use typed contextKey consistently
This commit is contained in:
@@ -6,7 +6,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ func WriteError(w http.ResponseWriter, code ErrorCode, message string, status in
|
||||
|
||||
// GetRequestID extracts the request ID from the request context
|
||||
func GetRequestID(r *http.Request) string {
|
||||
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
|
||||
if reqID := chimiddleware.GetReqID(r.Context()); reqID != "" {
|
||||
return reqID
|
||||
}
|
||||
return "unknown"
|
||||
@@ -48,10 +48,10 @@ func GetRequestID(r *http.Request) string {
|
||||
|
||||
// GetUserID extracts user ID from context if available
|
||||
func GetUserID(r *http.Request) string {
|
||||
if userID := r.Context().Value("user"); userID != nil {
|
||||
if uid, ok := userID.(string); ok {
|
||||
return uid
|
||||
}
|
||||
// Use type contextKey matching middleware package
|
||||
type contextKey string
|
||||
if userID, ok := r.Context().Value(contextKey("user")).(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -351,7 +351,7 @@ func listFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
|
||||
}
|
||||
|
||||
func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
|
||||
userIDStr := r.Context().Value("user").(string)
|
||||
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||
userID, _ := uuid.Parse(userIDStr)
|
||||
orgID := r.Context().Value("org").(uuid.UUID)
|
||||
fileId := chi.URLParam(r, "fileId")
|
||||
@@ -382,7 +382,7 @@ func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi
|
||||
}
|
||||
|
||||
func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
|
||||
userIDStr := r.Context().Value("user").(string)
|
||||
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||
userID, _ := uuid.Parse(userIDStr)
|
||||
orgID := r.Context().Value("org").(uuid.UUID)
|
||||
fileId := chi.URLParam(r, "fileId")
|
||||
@@ -405,7 +405,7 @@ func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi
|
||||
}
|
||||
|
||||
func annotationsHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
|
||||
userIDStr := r.Context().Value("user").(string)
|
||||
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||
userID, _ := uuid.Parse(userIDStr)
|
||||
orgID := r.Context().Value("org").(uuid.UUID)
|
||||
fileId := chi.URLParam(r, "fileId")
|
||||
@@ -848,7 +848,7 @@ func passwordLoginHandler(w http.ResponseWriter, r *http.Request, db *database.D
|
||||
|
||||
// userFilesHandler returns files for the authenticated user's personal workspace.
|
||||
func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
|
||||
userIDStr, ok := r.Context().Value("user").(string)
|
||||
userIDStr, ok := middleware.GetUserID(r.Context())
|
||||
if !ok || userIDStr == "" {
|
||||
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
@@ -899,7 +899,7 @@ func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
|
||||
// createOrgFileHandler creates a file or folder record for an org workspace.
|
||||
func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
||||
orgID := r.Context().Value("org").(uuid.UUID)
|
||||
userIDStr, _ := r.Context().Value("user").(string)
|
||||
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||
userID, _ := uuid.Parse(userIDStr)
|
||||
var f *database.File
|
||||
var err error
|
||||
@@ -1039,7 +1039,7 @@ func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.D
|
||||
// deleteOrgFileHandler deletes a file/folder in org workspace by path
|
||||
func deleteOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
||||
orgID := r.Context().Value("org").(uuid.UUID)
|
||||
userIDStr, _ := r.Context().Value("user").(string)
|
||||
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||
userID, _ := uuid.Parse(userIDStr)
|
||||
|
||||
var req struct {
|
||||
@@ -1084,7 +1084,7 @@ func deleteOrgFilePostHandler(w http.ResponseWriter, r *http.Request, db *databa
|
||||
|
||||
// createUserFileHandler creates a file or folder record for the authenticated user's personal workspace.
|
||||
func createUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
||||
userIDStr, ok := r.Context().Value("user").(string)
|
||||
userIDStr, ok := middleware.GetUserID(r.Context())
|
||||
if !ok || userIDStr == "" {
|
||||
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
@@ -1216,7 +1216,7 @@ func deleteUserFilePostHandler(w http.ResponseWriter, r *http.Request, db *datab
|
||||
|
||||
// deleteUserFileHandler deletes a file/folder in user's personal workspace by path
|
||||
func deleteUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
||||
userIDStr, ok := r.Context().Value("user").(string)
|
||||
userIDStr, ok := middleware.GetUserID(r.Context())
|
||||
if !ok || userIDStr == "" {
|
||||
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
@@ -1298,7 +1298,7 @@ func downloadOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database
|
||||
|
||||
// downloadUserFileHandler downloads a file from user's personal workspace
|
||||
func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, storageClient *storage.WebDAVClient) {
|
||||
userIDStr, ok := r.Context().Value("user").(string)
|
||||
userIDStr, ok := middleware.GetUserID(r.Context())
|
||||
if !ok || userIDStr == "" {
|
||||
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
|
||||
@@ -54,6 +54,18 @@ const (
|
||||
orgKey contextKey = "org"
|
||||
)
|
||||
|
||||
// GetUserID retrieves the user ID from the request context
|
||||
func GetUserID(ctx context.Context) (string, bool) {
|
||||
userID, ok := ctx.Value(userKey).(string)
|
||||
return userID, ok
|
||||
}
|
||||
|
||||
// GetSession retrieves the session from the request context
|
||||
func GetSession(ctx context.Context) (*database.Session, bool) {
|
||||
session, ok := ctx.Value(sessionKey).(*database.Session)
|
||||
return session, ok
|
||||
}
|
||||
|
||||
// Auth middleware
|
||||
func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
|
||||
Reference in New Issue
Block a user