From aea5ba9e58d9041da4b74cf9dd232bf580f5bf89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20B=C3=B6sche?= Date: Wed, 14 Jan 2026 12:02:20 +0100 Subject: [PATCH] Add functionality to download folders as ZIP archives for both org and user files --- go_cloud/internal/database/db.go | 180 +++++++++++++++++++++++++++++++ go_cloud/internal/http/routes.go | 152 ++++++++++++++++++++++++++ 2 files changed, 332 insertions(+) diff --git a/go_cloud/internal/database/db.go b/go_cloud/internal/database/db.go index 35564e6..c412a3b 100644 --- a/go_cloud/internal/database/db.go +++ b/go_cloud/internal/database/db.go @@ -373,6 +373,55 @@ func (db *DB) GetOrgFiles(ctx context.Context, orgID uuid.UUID, userID uuid.UUID return files, err } +// GetAllOrgFilesUnderPath returns all files recursively under the given path for an org +func (db *DB) GetAllOrgFilesUnderPath(ctx context.Context, orgID uuid.UUID, userID uuid.UUID, path string) ([]File, error) { + orgIDStr := orgID.String() + userIDStr := userID.String() + log.Printf("[DATA-ISOLATION] stage=before, action=list_recursive, orgId=%s, userId=%s, path=%s", orgIDStr, userIDStr, path) + + rows, err := db.QueryContext(ctx, ` + SELECT f.id, f.org_id::text, f.user_id::text, f.name, f.path, f.type, f.size, f.last_modified, f.created_at + FROM files f + WHERE f.org_id = $1 + AND EXISTS ( + SELECT 1 + FROM memberships m + WHERE m.org_id = $1 AND m.user_id = $2 + ) + AND f.path LIKE $3 || '%' + AND f.path != $3 + ORDER BY f.path + `, orgID, userID, path) + if err != nil { + return nil, err + } + defer rows.Close() + + var files []File + for rows.Next() { + var f File + var orgNull sql.NullString + var userNull sql.NullString + if err := rows.Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt); err != nil { + return nil, err + } + if orgNull.Valid { + oid, _ := uuid.Parse(orgNull.String) + f.OrgID = &oid + } + if userNull.Valid { + uid, _ := uuid.Parse(userNull.String) + f.UserID = &uid + } + files = append(files, f) + } + err = rows.Err() + if err == nil { + log.Printf("[DATA-ISOLATION] stage=after, action=list_recursive, orgId=%s, userId=%s, fileCount=%d, path=%s", orgIDStr, userIDStr, len(files), path) + } + return files, err +} + // GetUserFiles returns files for a user's personal workspace at a given path func (db *DB) GetUserFiles(ctx context.Context, userID uuid.UUID, path string, q string, page, pageSize int) ([]File, error) { if page <= 0 { @@ -429,6 +478,49 @@ func (db *DB) GetUserFiles(ctx context.Context, userID uuid.UUID, path string, q return files, err } +// GetAllUserFilesUnderPath returns all files recursively under the given path for a user +func (db *DB) GetAllUserFilesUnderPath(ctx context.Context, userID uuid.UUID, path string) ([]File, error) { + // Return all descendants of the given path + log.Printf("[DATA-ISOLATION] stage=before, action=list_recursive, orgId=, userId=%s, path=%s", userID.String(), path) + rows, err := db.QueryContext(ctx, ` + SELECT id, org_id::text, user_id::text, name, path, type, size, last_modified, created_at + FROM files + WHERE user_id = $1 + AND org_id IS NULL + AND path LIKE $2 || '%' + AND path != $2 + ORDER BY path + `, userID, path) + if err != nil { + return nil, err + } + defer rows.Close() + + var files []File + for rows.Next() { + var f File + var orgNull sql.NullString + var userNull sql.NullString + if err := rows.Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt); err != nil { + return nil, err + } + if orgNull.Valid { + oid, _ := uuid.Parse(orgNull.String) + f.OrgID = &oid + } + if userNull.Valid { + uid, _ := uuid.Parse(userNull.String) + f.UserID = &uid + } + files = append(files, f) + } + err = rows.Err() + if err == nil { + log.Printf("[DATA-ISOLATION] stage=after, action=list_recursive, orgId=, userId=%s, fileCount=%d, path=%s", userID.String(), len(files), path) + } + return files, err +} + // CreateFile inserts a file or folder record. orgID or userID may be nil. func (db *DB) CreateFile(ctx context.Context, orgID *uuid.UUID, userID *uuid.UUID, name, path, fileType string, size int64) (*File, error) { var f File @@ -502,6 +594,94 @@ func (db *DB) GetFileByID(ctx context.Context, fileID uuid.UUID) (*File, error) return &f, nil } +// GetOrgFileByPath returns a file by path for an org +func (db *DB) GetOrgFileByPath(ctx context.Context, orgID uuid.UUID, userID uuid.UUID, path string) (*File, error) { + var f File + var orgNull sql.NullString + var userNull sql.NullString + var modifiedByNull sql.NullString + var modifiedByNameNull sql.NullString + + err := db.QueryRowContext(ctx, ` + SELECT f.id, f.org_id::text, f.user_id::text, f.name, f.path, f.type, f.size, f.last_modified, f.created_at, + f.modified_by::text, u.display_name + FROM files f + LEFT JOIN users u ON f.modified_by = u.id + WHERE f.org_id = $1 + AND EXISTS ( + SELECT 1 + FROM memberships m + WHERE m.org_id = $1 AND m.user_id = $2 + ) + AND f.path = $3 + `, orgID, userID, path).Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt, + &modifiedByNull, &modifiedByNameNull) + + if err != nil { + return nil, err + } + + if orgNull.Valid { + oid, _ := uuid.Parse(orgNull.String) + f.OrgID = &oid + } + if userNull.Valid { + uid, _ := uuid.Parse(userNull.String) + f.UserID = &uid + } + if modifiedByNull.Valid { + mid, _ := uuid.Parse(modifiedByNull.String) + f.ModifiedBy = &mid + } + if modifiedByNameNull.Valid { + f.ModifiedByName = modifiedByNameNull.String + } + + return &f, nil +} + +// GetUserFileByPath returns a file by path for a user +func (db *DB) GetUserFileByPath(ctx context.Context, userID uuid.UUID, path string) (*File, error) { + var f File + var orgNull sql.NullString + var userNull sql.NullString + var modifiedByNull sql.NullString + var modifiedByNameNull sql.NullString + + err := db.QueryRowContext(ctx, ` + SELECT f.id, f.org_id::text, f.user_id::text, f.name, f.path, f.type, f.size, f.last_modified, f.created_at, + f.modified_by::text, u.display_name + FROM files f + LEFT JOIN users u ON f.modified_by = u.id + WHERE f.user_id = $1 + AND f.org_id IS NULL + AND f.path = $2 + `, userID, path).Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt, + &modifiedByNull, &modifiedByNameNull) + + if err != nil { + return nil, err + } + + if orgNull.Valid { + oid, _ := uuid.Parse(orgNull.String) + f.OrgID = &oid + } + if userNull.Valid { + uid, _ := uuid.Parse(userNull.String) + f.UserID = &uid + } + if modifiedByNull.Valid { + mid, _ := uuid.Parse(modifiedByNull.String) + f.ModifiedBy = &mid + } + if modifiedByNameNull.Valid { + f.ModifiedByName = modifiedByNameNull.String + } + + return &f, nil +} + // UpdateFileSize updates the size, modification time, and modifier of a file func (db *DB) UpdateFileSize(ctx context.Context, fileID uuid.UUID, size int64, modifiedBy *uuid.UUID) error { _, err := db.ExecContext(ctx, ` diff --git a/go_cloud/internal/http/routes.go b/go_cloud/internal/http/routes.go index a84d320..8fa25b3 100644 --- a/go_cloud/internal/http/routes.go +++ b/go_cloud/internal/http/routes.go @@ -1,6 +1,7 @@ package http import ( + "archive/zip" "bytes" "context" "encoding/json" @@ -1740,6 +1741,24 @@ func downloadOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database return } + // Check if it's a folder + file, err := db.GetOrgFileByPath(r.Context(), orgID, userID, filePath) + if err != nil && err.Error() != "sql: no rows in result set" { + errors.LogError(r, err, "Failed to get file info") + errors.WriteError(w, errors.CodeInternal, "Server error", http.StatusInternalServerError) + return + } + + if file != nil && file.Type == "folder" { + // Download folder as ZIP + err = downloadOrgFolderAsZip(w, r, db, cfg, orgID, userID, filePath, storageClient) + if err != nil { + errors.LogError(r, err, "Failed to download folder") + errors.WriteError(w, errors.CodeInternal, "Failed to download folder", http.StatusInternalServerError) + } + return + } + // Download from user's Nextcloud space under /orgs// rel := strings.TrimPrefix(filePath, "/") remotePath := path.Join("/orgs", orgID.String(), rel) @@ -1785,6 +1804,64 @@ func downloadOrgFileHandler(w http.ResponseWriter, r *http.Request, db *database } +// downloadOrgFolderAsZip downloads a folder as ZIP archive +func downloadOrgFolderAsZip(w http.ResponseWriter, r *http.Request, db *database.DB, cfg *config.Config, orgID, userID uuid.UUID, folderPath string, storageClient *storage.WebDAVClient) error { + // Get all files under the folder + files, err := db.GetAllOrgFilesUnderPath(r.Context(), orgID, userID, folderPath) + if err != nil { + return err + } + + // Filter only files, not folders + var fileList []database.File + for _, f := range files { + if f.Type == "file" { + fileList = append(fileList, f) + } + } + + // Set headers for ZIP download + folderName := path.Base(folderPath) + if folderName == "" || folderName == "/" { + folderName = "org_files" + } + w.Header().Set("Content-Type", "application/zip") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s.zip\"", folderName)) + + // Create ZIP writer + zipWriter := zip.NewWriter(w) + defer zipWriter.Close() + + // Add each file to ZIP + for _, file := range fileList { + // Calculate relative path in ZIP + relPath := strings.TrimPrefix(file.Path, folderPath) + if relPath[0] == '/' { + relPath = relPath[1:] + } + + // Download file from WebDAV + remoteRel := strings.TrimPrefix(file.Path, "/") + remotePath := path.Join("/orgs", orgID.String(), remoteRel) + resp, err := storageClient.Download(r.Context(), remotePath, "") + if err != nil { + continue // Skip files that can't be downloaded + } + defer resp.Body.Close() + + // Create ZIP entry + zipFile, err := zipWriter.Create(relPath) + if err != nil { + continue + } + + // Copy file content to ZIP + io.Copy(zipFile, resp.Body) + } + + return nil +} + // downloadUserFileHandler downloads a file from user's personal workspace func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *database.DB, cfg *config.Config) { // Try to get userID from context (Bearer token), fallback to query parameter @@ -1818,6 +1895,24 @@ func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *databas return } + // Check if it's a folder + file, err := db.GetUserFileByPath(r.Context(), userID, filePath) + if err != nil && err.Error() != "sql: no rows in result set" { + errors.LogError(r, err, "Failed to get file info") + errors.WriteError(w, errors.CodeInternal, "Server error", http.StatusInternalServerError) + return + } + + if file != nil && file.Type == "folder" { + // Download folder as ZIP + err = downloadUserFolderAsZip(w, r, db, cfg, userID, filePath, storageClient) + if err != nil { + errors.LogError(r, err, "Failed to download folder") + errors.WriteError(w, errors.CodeInternal, "Failed to download folder", http.StatusInternalServerError) + } + return + } + // Download from user's personal Nextcloud space remotePath := strings.TrimPrefix(filePath, "/") @@ -1863,6 +1958,63 @@ func downloadUserFileHandler(w http.ResponseWriter, r *http.Request, db *databas } +// downloadUserFolderAsZip downloads a folder as ZIP archive +func downloadUserFolderAsZip(w http.ResponseWriter, r *http.Request, db *database.DB, cfg *config.Config, userID uuid.UUID, folderPath string, storageClient *storage.WebDAVClient) error { + // Get all files under the folder + files, err := db.GetAllUserFilesUnderPath(r.Context(), userID, folderPath) + if err != nil { + return err + } + + // Filter only files, not folders + var fileList []database.File + for _, f := range files { + if f.Type == "file" { + fileList = append(fileList, f) + } + } + + // Set headers for ZIP download + folderName := path.Base(folderPath) + if folderName == "" || folderName == "/" { + folderName = "user_files" + } + w.Header().Set("Content-Type", "application/zip") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s.zip\"", folderName)) + + // Create ZIP writer + zipWriter := zip.NewWriter(w) + defer zipWriter.Close() + + // Add each file to ZIP + for _, file := range fileList { + // Calculate relative path in ZIP + relPath := strings.TrimPrefix(file.Path, folderPath) + if relPath[0] == '/' { + relPath = relPath[1:] + } + + // Download file from WebDAV + remotePath := strings.TrimPrefix(file.Path, "/") + resp, err := storageClient.Download(r.Context(), "/"+remotePath, "") + if err != nil { + continue // Skip files that can't be downloaded + } + defer resp.Body.Close() + + // Create ZIP entry + zipFile, err := zipWriter.Create(relPath) + if err != nil { + continue + } + + // Copy file content to ZIP + io.Copy(zipFile, resp.Body) + } + + return nil +} + // getMimeType returns the MIME type based on file extension func getMimeType(filename string) string { lower := strings.ToLower(filename)