From 9d466fd63a079a145aa628cd523d438209618931 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20B=C3=B6sche?= Date: Sun, 11 Jan 2026 05:33:16 +0100 Subject: [PATCH] feat: add ownerId to Organization and update related database queries; enhance CORS middleware for origin validation --- go_cloud/internal/config/config.go | 2 +- go_cloud/internal/database/db.go | 15 ++--- go_cloud/internal/middleware/middleware.go | 67 +++++++++++++++++---- go_cloud/internal/org/org.go | 2 +- go_cloud/migrations/0004_org_owner_slug.sql | 28 +++++++++ 5 files changed, 92 insertions(+), 22 deletions(-) create mode 100644 go_cloud/migrations/0004_org_owner_slug.sql diff --git a/go_cloud/internal/config/config.go b/go_cloud/internal/config/config.go index 060a013..71bf55c 100644 --- a/go_cloud/internal/config/config.go +++ b/go_cloud/internal/config/config.go @@ -33,7 +33,7 @@ func Load() *Config { NextcloudUser: os.Getenv("NEXTCLOUD_USER"), NextcloudPass: os.Getenv("NEXTCLOUD_PASSWORD"), NextcloudBase: getEnv("NEXTCLOUD_BASEPATH", "/"), - AllowedOrigins: getEnv("ALLOWED_ORIGINS", "https://b0esche.cloud,http://localhost:8080"), + AllowedOrigins: getEnv("ALLOWED_ORIGINS", "https://b0esche.cloud,https://www.b0esche.cloud,https://*.b0esche.cloud,http://localhost:8080"), } fmt.Printf("[CONFIG] Nextcloud URL: %q, User: %q, BasePath: %q\n", cfg.NextcloudURL, cfg.NextcloudUser, cfg.NextcloudBase) return cfg diff --git a/go_cloud/internal/database/db.go b/go_cloud/internal/database/db.go index c802741..97fe538 100644 --- a/go_cloud/internal/database/db.go +++ b/go_cloud/internal/database/db.go @@ -57,6 +57,7 @@ type Session struct { type Organization struct { ID uuid.UUID `json:"id"` + OwnerID uuid.UUID `json:"ownerId"` Name string `json:"name"` Slug string `json:"slug"` CreatedAt time.Time `json:"createdAt"` @@ -144,7 +145,7 @@ func (db *DB) RevokeSession(ctx context.Context, sessionID uuid.UUID) error { func (db *DB) GetUserOrganizations(ctx context.Context, userID uuid.UUID) ([]Organization, error) { rows, err := db.QueryContext(ctx, ` - SELECT o.id, o.name, o.slug, o.created_at + SELECT o.id, o.owner_id, o.name, o.slug, o.created_at FROM organizations o JOIN memberships m ON o.id = m.org_id WHERE m.user_id = $1 @@ -157,7 +158,7 @@ func (db *DB) GetUserOrganizations(ctx context.Context, userID uuid.UUID) ([]Org var orgs []Organization for rows.Next() { var org Organization - if err := rows.Scan(&org.ID, &org.Name, &org.Slug, &org.CreatedAt); err != nil { + if err := rows.Scan(&org.ID, &org.OwnerID, &org.Name, &org.Slug, &org.CreatedAt); err != nil { return nil, err } orgs = append(orgs, org) @@ -178,13 +179,13 @@ func (db *DB) GetUserMembership(ctx context.Context, userID, orgID uuid.UUID) (* return &membership, nil } -func (db *DB) CreateOrg(ctx context.Context, name, slug string) (*Organization, error) { +func (db *DB) CreateOrg(ctx context.Context, ownerID uuid.UUID, name, slug string) (*Organization, error) { var org Organization err := db.QueryRowContext(ctx, ` - INSERT INTO organizations (name, slug) - VALUES ($1, $2) - RETURNING id, name, slug, created_at - `, name, slug).Scan(&org.ID, &org.Name, &org.Slug, &org.CreatedAt) + INSERT INTO organizations (owner_id, name, slug) + VALUES ($1, $2, $3) + RETURNING id, owner_id, name, slug, created_at + `, ownerID, name, slug).Scan(&org.ID, &org.OwnerID, &org.Name, &org.Slug, &org.CreatedAt) if err != nil { return nil, err } diff --git a/go_cloud/internal/middleware/middleware.go b/go_cloud/internal/middleware/middleware.go index e6e29e4..70d25e9 100644 --- a/go_cloud/internal/middleware/middleware.go +++ b/go_cloud/internal/middleware/middleware.go @@ -3,6 +3,7 @@ package middleware import ( "context" "net/http" + "regexp" "strings" "go.b0esche.cloud/backend/internal/audit" @@ -23,22 +24,15 @@ var Recoverer = middleware.Recoverer // CORS middleware - accepts allowedOrigins comma-separated string func CORS(allowedOrigins string) func(http.Handler) http.Handler { + allowedList, allowAll := compileAllowedOrigins(allowedOrigins) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") - // Check if origin is allowed - if origin != "" { - // Simple check - in production you'd want to parse allowedOrigins properly - for _, allowed := range strings.Split(allowedOrigins, ",") { - if strings.TrimSpace(allowed) == origin { - w.Header().Set("Access-Control-Allow-Origin", origin) - w.Header().Set("Access-Control-Allow-Credentials", "true") - break - } - } - } - // Fallback to * if no credentials needed - if w.Header().Get("Access-Control-Allow-Origin") == "" { + if origin != "" && isOriginAllowed(origin, allowedList) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Add("Vary", "Origin") + w.Header().Set("Access-Control-Allow-Credentials", "true") + } else if allowAll { w.Header().Set("Access-Control-Allow-Origin", "*") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") @@ -56,6 +50,53 @@ func CORS(allowedOrigins string) func(http.Handler) http.Handler { } } +func compileAllowedOrigins(origins string) ([]string, bool) { + var allowed []string + allowAll := false + + for _, origin := range strings.Split(origins, ",") { + trimmed := strings.TrimSpace(origin) + if trimmed == "" { + continue + } + if trimmed == "*" { + allowAll = true + } + allowed = append(allowed, trimmed) + } + + if len(allowed) == 0 && !allowAll { + allowAll = true + } + + return allowed, allowAll +} + +func isOriginAllowed(origin string, allowed []string) bool { + if origin == "" { + return false + } + for _, pattern := range allowed { + if originMatches(origin, pattern) { + return true + } + } + return false +} + +func originMatches(origin, pattern string) bool { + if pattern == "*" { + return true + } + if !strings.Contains(pattern, "*") { + return strings.EqualFold(origin, pattern) + } + regexPattern := "(?i)^" + regexp.QuoteMeta(pattern) + "$" + regexPattern = strings.ReplaceAll(regexPattern, "\\*", ".*") + matched, err := regexp.MatchString(regexPattern, origin) + return err == nil && matched +} + // TODO: Implement rate limiter var RateLimit = func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/go_cloud/internal/org/org.go b/go_cloud/internal/org/org.go index 4d180ea..f40ac9d 100644 --- a/go_cloud/internal/org/org.go +++ b/go_cloud/internal/org/org.go @@ -49,7 +49,7 @@ func CreateOrg(ctx context.Context, db *database.DB, userID uuid.UUID, name, slu if i > 0 { candidate = fmt.Sprintf("%s-%d", baseSlug, i+1) } - org, err = db.CreateOrg(ctx, trimmedName, candidate) + org, err = db.CreateOrg(ctx, userID, trimmedName, candidate) if err != nil { if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { // Unique violation; try next suffix diff --git a/go_cloud/migrations/0004_org_owner_slug.sql b/go_cloud/migrations/0004_org_owner_slug.sql new file mode 100644 index 0000000..929bacd --- /dev/null +++ b/go_cloud/migrations/0004_org_owner_slug.sql @@ -0,0 +1,28 @@ +-- Scope organization slugs per owner instead of globally unique +ALTER TABLE organizations ADD COLUMN owner_id UUID REFERENCES users(id); + +WITH first_owner AS ( + SELECT DISTINCT ON (org_id) org_id, user_id + FROM memberships + WHERE role = 'owner' + ORDER BY org_id, created_at +) +UPDATE organizations o +SET owner_id = fo.user_id +FROM first_owner fo +WHERE o.id = fo.org_id; + +WITH first_member AS ( + SELECT DISTINCT ON (org_id) org_id, user_id + FROM memberships + ORDER BY org_id, created_at +) +UPDATE organizations o +SET owner_id = fm.user_id +FROM first_member fm +WHERE o.owner_id IS NULL + AND o.id = fm.org_id; + +ALTER TABLE organizations ALTER COLUMN owner_id SET NOT NULL; +ALTER TABLE organizations DROP CONSTRAINT organizations_slug_key; +CREATE UNIQUE INDEX organizations_owner_slug_key ON organizations(owner_id, slug);