Fix context key mismatch - use typed contextKey consistently
This commit is contained in:
@@ -6,7 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ func WriteError(w http.ResponseWriter, code ErrorCode, message string, status in
|
|||||||
|
|
||||||
// GetRequestID extracts the request ID from the request context
|
// GetRequestID extracts the request ID from the request context
|
||||||
func GetRequestID(r *http.Request) string {
|
func GetRequestID(r *http.Request) string {
|
||||||
if reqID := middleware.GetReqID(r.Context()); reqID != "" {
|
if reqID := chimiddleware.GetReqID(r.Context()); reqID != "" {
|
||||||
return reqID
|
return reqID
|
||||||
}
|
}
|
||||||
return "unknown"
|
return "unknown"
|
||||||
@@ -48,10 +48,10 @@ func GetRequestID(r *http.Request) string {
|
|||||||
|
|
||||||
// GetUserID extracts user ID from context if available
|
// GetUserID extracts user ID from context if available
|
||||||
func GetUserID(r *http.Request) string {
|
func GetUserID(r *http.Request) string {
|
||||||
if userID := r.Context().Value("user"); userID != nil {
|
// Use type contextKey matching middleware package
|
||||||
if uid, ok := userID.(string); ok {
|
type contextKey string
|
||||||
return uid
|
if userID, ok := r.Context().Value(contextKey("user")).(string); ok && userID != "" {
|
||||||
}
|
return userID
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -351,7 +351,7 @@ func listFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
|
func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
|
||||||
userIDStr := r.Context().Value("user").(string)
|
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||||
userID, _ := uuid.Parse(userIDStr)
|
userID, _ := uuid.Parse(userIDStr)
|
||||||
orgID := r.Context().Value("org").(uuid.UUID)
|
orgID := r.Context().Value("org").(uuid.UUID)
|
||||||
fileId := chi.URLParam(r, "fileId")
|
fileId := chi.URLParam(r, "fileId")
|
||||||
@@ -382,7 +382,7 @@ func viewerHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi
|
|||||||
}
|
}
|
||||||
|
|
||||||
func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
|
func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
|
||||||
userIDStr := r.Context().Value("user").(string)
|
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||||
userID, _ := uuid.Parse(userIDStr)
|
userID, _ := uuid.Parse(userIDStr)
|
||||||
orgID := r.Context().Value("org").(uuid.UUID)
|
orgID := r.Context().Value("org").(uuid.UUID)
|
||||||
fileId := chi.URLParam(r, "fileId")
|
fileId := chi.URLParam(r, "fileId")
|
||||||
@@ -405,7 +405,7 @@ func editorHandler(w http.ResponseWriter, r *http.Request, db *database.DB, audi
|
|||||||
}
|
}
|
||||||
|
|
||||||
func annotationsHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
|
func annotationsHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger) {
|
||||||
userIDStr := r.Context().Value("user").(string)
|
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||||
userID, _ := uuid.Parse(userIDStr)
|
userID, _ := uuid.Parse(userIDStr)
|
||||||
orgID := r.Context().Value("org").(uuid.UUID)
|
orgID := r.Context().Value("org").(uuid.UUID)
|
||||||
fileId := chi.URLParam(r, "fileId")
|
fileId := chi.URLParam(r, "fileId")
|
||||||
@@ -848,7 +848,7 @@ func passwordLoginHandler(w http.ResponseWriter, r *http.Request, db *database.D
|
|||||||
|
|
||||||
// userFilesHandler returns files for the authenticated user's personal workspace.
|
// userFilesHandler returns files for the authenticated user's personal workspace.
|
||||||
func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
|
func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
|
||||||
userIDStr, ok := r.Context().Value("user").(string)
|
userIDStr, ok := middleware.GetUserID(r.Context())
|
||||||
if !ok || userIDStr == "" {
|
if !ok || userIDStr == "" {
|
||||||
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
@@ -899,7 +899,7 @@ func userFilesHandler(w http.ResponseWriter, r *http.Request, db *database.DB) {
|
|||||||
// createOrgFileHandler creates a file or folder record for an org workspace.
|
// createOrgFileHandler creates a file or folder record for an org workspace.
|
||||||
func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
||||||
orgID := r.Context().Value("org").(uuid.UUID)
|
orgID := r.Context().Value("org").(uuid.UUID)
|
||||||
userIDStr, _ := r.Context().Value("user").(string)
|
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||||
userID, _ := uuid.Parse(userIDStr)
|
userID, _ := uuid.Parse(userIDStr)
|
||||||
var f *database.File
|
var f *database.File
|
||||||
var err error
|
var err error
|
||||||
@@ -1039,7 +1039,7 @@ func createOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.D
|
|||||||
// deleteOrgFileHandler deletes a file/folder in org workspace by path
|
// deleteOrgFileHandler deletes a file/folder in org workspace by path
|
||||||
func deleteOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
func deleteOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
||||||
orgID := r.Context().Value("org").(uuid.UUID)
|
orgID := r.Context().Value("org").(uuid.UUID)
|
||||||
userIDStr, _ := r.Context().Value("user").(string)
|
userIDStr, _ := middleware.GetUserID(r.Context())
|
||||||
userID, _ := uuid.Parse(userIDStr)
|
userID, _ := uuid.Parse(userIDStr)
|
||||||
|
|
||||||
var req struct {
|
var req struct {
|
||||||
@@ -1084,7 +1084,7 @@ func deleteOrgFilePostHandler(w http.ResponseWriter, r *http.Request, db *databa
|
|||||||
|
|
||||||
// createUserFileHandler creates a file or folder record for the authenticated user's personal workspace.
|
// createUserFileHandler creates a file or folder record for the authenticated user's personal workspace.
|
||||||
func createUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
func createUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
||||||
userIDStr, ok := r.Context().Value("user").(string)
|
userIDStr, ok := middleware.GetUserID(r.Context())
|
||||||
if !ok || userIDStr == "" {
|
if !ok || userIDStr == "" {
|
||||||
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
@@ -1216,7 +1216,7 @@ func deleteUserFilePostHandler(w http.ResponseWriter, r *http.Request, db *datab
|
|||||||
|
|
||||||
// deleteUserFileHandler deletes a file/folder in user's personal workspace by path
|
// deleteUserFileHandler deletes a file/folder in user's personal workspace by path
|
||||||
func deleteUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
func deleteUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, auditLogger *audit.Logger, storageClient *storage.WebDAVClient) {
|
||||||
userIDStr, ok := r.Context().Value("user").(string)
|
userIDStr, ok := middleware.GetUserID(r.Context())
|
||||||
if !ok || userIDStr == "" {
|
if !ok || userIDStr == "" {
|
||||||
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
@@ -1298,7 +1298,7 @@ func downloadOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database
|
|||||||
|
|
||||||
// downloadUserFileHandler downloads a file from user's personal workspace
|
// downloadUserFileHandler downloads a file from user's personal workspace
|
||||||
func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, storageClient *storage.WebDAVClient) {
|
func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, storageClient *storage.WebDAVClient) {
|
||||||
userIDStr, ok := r.Context().Value("user").(string)
|
userIDStr, ok := middleware.GetUserID(r.Context())
|
||||||
if !ok || userIDStr == "" {
|
if !ok || userIDStr == "" {
|
||||||
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
errors.WriteError(w, errors.CodeUnauthenticated, "Unauthorized", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -54,6 +54,18 @@ const (
|
|||||||
orgKey contextKey = "org"
|
orgKey contextKey = "org"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetUserID retrieves the user ID from the request context
|
||||||
|
func GetUserID(ctx context.Context) (string, bool) {
|
||||||
|
userID, ok := ctx.Value(userKey).(string)
|
||||||
|
return userID, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSession retrieves the session from the request context
|
||||||
|
func GetSession(ctx context.Context) (*database.Session, bool) {
|
||||||
|
session, ok := ctx.Value(sessionKey).(*database.Session)
|
||||||
|
return session, ok
|
||||||
|
}
|
||||||
|
|
||||||
// Auth middleware
|
// Auth middleware
|
||||||
func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler {
|
func Auth(jwtManager *jwt.Manager, db *database.DB) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
|
|||||||
Reference in New Issue
Block a user