diff --git a/go_cloud/internal/errors/errors.go b/go_cloud/internal/errors/errors.go index 02289a3..6531473 100644 --- a/go_cloud/internal/errors/errors.go +++ b/go_cloud/internal/errors/errors.go @@ -6,7 +6,7 @@ import ( "net/http" "os" - "github.com/go-chi/chi/v5/middleware" + chimiddleware "github.com/go-chi/chi/v5/middleware" "github.com/google/uuid" ) @@ -40,7 +40,7 @@ func WriteError(w http.ResponseWriter, code ErrorCode, message string, status in // GetRequestID extracts the request ID from the request context func GetRequestID(r *http.Request) string { - if reqID := middleware.GetReqID(r.Context()); reqID != "" { + if reqID := chimiddleware.GetReqID(r.Context()); reqID != "" { return reqID } return "unknown" @@ -48,10 +48,10 @@ func GetRequestID(r *http.Request) string { // GetUserID extracts user ID from context if available func GetUserID(r *http.Request) string { - if userID := r.Context().Value("user"); userID != nil { - if uid, ok := userID.(string); ok { - return uid - } + // Use type contextKey matching middleware package + type contextKey string + if userID, ok := r.Context().Value(contextKey("user")).(string); ok && userID != "" { + return userID } return "" } diff --git a/go_cloud/internal/http/routes.go b/go_cloud/internal/http/routes.go index 4b92e3c..d12da6f 100644 --- a/go_cloud/internal/http/routes.go +++ b/go_cloud/internal/http/routes.go @@ -351,7 +351,7 @@ func listFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) { } func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) { - userIDStr := r.Context().Value("user").(string) + userIDStr, _ := middleware.GetUserID(r.Context()) userID, _ := uuid.Parse(userIDStr) orgID := r.Context().Value("org").(uuid.UUID) fileId := chi.URLParam(r, "fileId") @@ -382,7 +382,7 @@ func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi } func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) { - userIDStr := r.Context().Value("user").(string) + userIDStr, _ := middleware.GetUserID(r.Context()) userID, _ := uuid.Parse(userIDStr) orgID := r.Context().Value("org").(uuid.UUID) fileId := chi.URLParam(r, "fileId") @@ -405,7 +405,7 @@ func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi } func annotationsHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) { - userIDStr := r.Context().Value("user").(string) + userIDStr, _ := middleware.GetUserID(r.Context()) userID, _ := uuid.Parse(userIDStr) orgID := r.Context().Value("org").(uuid.UUID) fileId := chi.URLParam(r, "fileId") @@ -848,7 +848,7 @@ func passwordLoginHandler(w http.ResponseWriter, r *http.Request, db *database.D // userFilesHandler returns files for the authenticated user's personal workspace. func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) { - userIDStr, ok := r.Context().Value("user").(string) + userIDStr, ok := middleware.GetUserID(r.Context()) if !ok || userIDStr == "" { errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized) return @@ -899,7 +899,7 @@ func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) { // createOrgFileHandler creates a file or folder record for an org workspace. func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) { orgID := r.Context().Value("org").(uuid.UUID) - userIDStr, _ := r.Context().Value("user").(string) + userIDStr, _ := middleware.GetUserID(r.Context()) userID, _ := uuid.Parse(userIDStr) var f *database.File var err error @@ -1039,7 +1039,7 @@ func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.D // deleteOrgFileHandler deletes a file/folder in org workspace by path func deleteOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) { orgID := r.Context().Value("org").(uuid.UUID) - userIDStr, _ := r.Context().Value("user").(string) + userIDStr, _ := middleware.GetUserID(r.Context()) userID, _ := uuid.Parse(userIDStr) var req struct { @@ -1084,7 +1084,7 @@ func deleteOrgFilePostHandler(w http.ResponseWriter, r *http.Request, db *databa // createUserFileHandler creates a file or folder record for the authenticated user's personal workspace. func createUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) { - userIDStr, ok := r.Context().Value("user").(string) + userIDStr, ok := middleware.GetUserID(r.Context()) if !ok || userIDStr == "" { errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized) return @@ -1216,7 +1216,7 @@ func deleteUserFilePostHandler(w http.ResponseWriter, r *http.Request, db *datab // deleteUserFileHandler deletes a file/folder in user's personal workspace by path func deleteUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) { - userIDStr, ok := r.Context().Value("user").(string) + userIDStr, ok := middleware.GetUserID(r.Context()) if !ok || userIDStr == "" { errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized) return @@ -1298,7 +1298,7 @@ func downloadOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database // downloadUserFileHandler downloads a file from user's personal workspace func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, storageClient *storage.WebDAVClient) { - userIDStr, ok := r.Context().Value("user").(string) + userIDStr, ok := middleware.GetUserID(r.Context()) if !ok || userIDStr == "" { errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized) return diff --git a/go_cloud/internal/middleware/middleware.go b/go_cloud/internal/middleware/middleware.go index 91dcbb2..215a18c 100644 --- a/go_cloud/internal/middleware/middleware.go +++ b/go_cloud/internal/middleware/middleware.go @@ -54,6 +54,18 @@ const ( orgKey contextKey = "org" ) +// GetUserID retrieves the user ID from the request context +func GetUserID(ctx context.Context) (string, bool) { + userID, ok := ctx.Value(userKey).(string) + return userID, ok +} + +// GetSession retrieves the session from the request context +func GetSession(ctx context.Context) (*database.Session, bool) { + session, ok := ctx.Value(sessionKey).(*database.Session) + return session, ok +} + // Auth middleware func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler {