feat: add ownerId to Organization and update related database queries; enhance CORS middleware for origin validation

This commit is contained in:
Leon Bösche
2026-01-11 05:33:16 +01:00
parent 619b2fe23c
commit 9d466fd63a
5 changed files with 92 additions and 22 deletions

View File

@@ -33,7 +33,7 @@ func Load() *Config {
NextcloudUser: os.Getenv("NEXTCLOUD_USER"), NextcloudUser: os.Getenv("NEXTCLOUD_USER"),
NextcloudPass: os.Getenv("NEXTCLOUD_PASSWORD"), NextcloudPass: os.Getenv("NEXTCLOUD_PASSWORD"),
NextcloudBase: getEnv("NEXTCLOUD_BASEPATH", "/"), NextcloudBase: getEnv("NEXTCLOUD_BASEPATH", "/"),
AllowedOrigins: getEnv("ALLOWED_ORIGINS", "https://b0esche.cloud,http://localhost:8080"), AllowedOrigins: getEnv("ALLOWED_ORIGINS", "https://b0esche.cloud,https://www.b0esche.cloud,https://*.b0esche.cloud,http://localhost:8080"),
} }
fmt.Printf("[CONFIG] Nextcloud URL: %q, User: %q, BasePath: %q\n", cfg.NextcloudURL, cfg.NextcloudUser, cfg.NextcloudBase) fmt.Printf("[CONFIG] Nextcloud URL: %q, User: %q, BasePath: %q\n", cfg.NextcloudURL, cfg.NextcloudUser, cfg.NextcloudBase)
return cfg return cfg

View File

@@ -57,6 +57,7 @@ type Session struct {
type Organization struct { type Organization struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
OwnerID uuid.UUID `json:"ownerId"`
Name string `json:"name"` Name string `json:"name"`
Slug string `json:"slug"` Slug string `json:"slug"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
@@ -144,7 +145,7 @@ func (db *DB) RevokeSession(ctx context.Context, sessionID uuid.UUID) error {
func (db *DB) GetUserOrganizations(ctx context.Context, userID uuid.UUID) ([]Organization, error) { func (db *DB) GetUserOrganizations(ctx context.Context, userID uuid.UUID) ([]Organization, error) {
rows, err := db.QueryContext(ctx, ` rows, err := db.QueryContext(ctx, `
SELECT o.id, o.name, o.slug, o.created_at SELECT o.id, o.owner_id, o.name, o.slug, o.created_at
FROM organizations o FROM organizations o
JOIN memberships m ON o.id = m.org_id JOIN memberships m ON o.id = m.org_id
WHERE m.user_id = $1 WHERE m.user_id = $1
@@ -157,7 +158,7 @@ func (db *DB) GetUserOrganizations(ctx context.Context, userID uuid.UUID) ([]Org
var orgs []Organization var orgs []Organization
for rows.Next() { for rows.Next() {
var org Organization var org Organization
if err := rows.Scan(&org.ID, &org.Name, &org.Slug, &org.CreatedAt); err != nil { if err := rows.Scan(&org.ID, &org.OwnerID, &org.Name, &org.Slug, &org.CreatedAt); err != nil {
return nil, err return nil, err
} }
orgs = append(orgs, org) orgs = append(orgs, org)
@@ -178,13 +179,13 @@ func (db *DB) GetUserMembership(ctx context.Context, userID, orgID uuid.UUID) (*
return &membership, nil return &membership, nil
} }
func (db *DB) CreateOrg(ctx context.Context, name, slug string) (*Organization, error) { func (db *DB) CreateOrg(ctx context.Context, ownerID uuid.UUID, name, slug string) (*Organization, error) {
var org Organization var org Organization
err := db.QueryRowContext(ctx, ` err := db.QueryRowContext(ctx, `
INSERT INTO organizations (name, slug) INSERT INTO organizations (owner_id, name, slug)
VALUES ($1, $2) VALUES ($1, $2, $3)
RETURNING id, name, slug, created_at RETURNING id, owner_id, name, slug, created_at
`, name, slug).Scan(&org.ID, &org.Name, &org.Slug, &org.CreatedAt) `, ownerID, name, slug).Scan(&org.ID, &org.OwnerID, &org.Name, &org.Slug, &org.CreatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -3,6 +3,7 @@ package middleware
import ( import (
"context" "context"
"net/http" "net/http"
"regexp"
"strings" "strings"
"go.b0esche.cloud/backend/internal/audit" "go.b0esche.cloud/backend/internal/audit"
@@ -23,22 +24,15 @@ var Recoverer = middleware.Recoverer
// CORS middleware - accepts allowedOrigins comma-separated string // CORS middleware - accepts allowedOrigins comma-separated string
func CORS(allowedOrigins string) func(http.Handler) http.Handler { func CORS(allowedOrigins string) func(http.Handler) http.Handler {
allowedList, allowAll := compileAllowedOrigins(allowedOrigins)
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
// Check if origin is allowed if origin != "" && isOriginAllowed(origin, allowedList) {
if origin != "" { w.Header().Set("Access-Control-Allow-Origin", origin)
// Simple check - in production you'd want to parse allowedOrigins properly w.Header().Add("Vary", "Origin")
for _, allowed := range strings.Split(allowedOrigins, ",") { w.Header().Set("Access-Control-Allow-Credentials", "true")
if strings.TrimSpace(allowed) == origin { } else if allowAll {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
break
}
}
}
// Fallback to * if no credentials needed
if w.Header().Get("Access-Control-Allow-Origin") == "" {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
} }
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
@@ -56,6 +50,53 @@ func CORS(allowedOrigins string) func(http.Handler) http.Handler {
} }
} }
func compileAllowedOrigins(origins string) ([]string, bool) {
var allowed []string
allowAll := false
for _, origin := range strings.Split(origins, ",") {
trimmed := strings.TrimSpace(origin)
if trimmed == "" {
continue
}
if trimmed == "*" {
allowAll = true
}
allowed = append(allowed, trimmed)
}
if len(allowed) == 0 && !allowAll {
allowAll = true
}
return allowed, allowAll
}
func isOriginAllowed(origin string, allowed []string) bool {
if origin == "" {
return false
}
for _, pattern := range allowed {
if originMatches(origin, pattern) {
return true
}
}
return false
}
func originMatches(origin, pattern string) bool {
if pattern == "*" {
return true
}
if !strings.Contains(pattern, "*") {
return strings.EqualFold(origin, pattern)
}
regexPattern := "(?i)^" + regexp.QuoteMeta(pattern) + "$"
regexPattern = strings.ReplaceAll(regexPattern, "\\*", ".*")
matched, err := regexp.MatchString(regexPattern, origin)
return err == nil && matched
}
// TODO: Implement rate limiter // TODO: Implement rate limiter
var RateLimit = func(next http.Handler) http.Handler { var RateLimit = func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@@ -49,7 +49,7 @@ func CreateOrg(ctx context.Context, db *database.DB, userID uuid.UUID, name, slu
if i > 0 { if i > 0 {
candidate = fmt.Sprintf("%s-%d", baseSlug, i+1) candidate = fmt.Sprintf("%s-%d", baseSlug, i+1)
} }
org, err = db.CreateOrg(ctx, trimmedName, candidate) org, err = db.CreateOrg(ctx, userID, trimmedName, candidate)
if err != nil { if err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
// Unique violation; try next suffix // Unique violation; try next suffix

View File

@@ -0,0 +1,28 @@
-- Scope organization slugs per owner instead of globally unique
ALTER TABLE organizations ADD COLUMN owner_id UUID REFERENCES users(id);
WITH first_owner AS (
SELECT DISTINCT ON (org_id) org_id, user_id
FROM memberships
WHERE role = 'owner'
ORDER BY org_id, created_at
)
UPDATE organizations o
SET owner_id = fo.user_id
FROM first_owner fo
WHERE o.id = fo.org_id;
WITH first_member AS (
SELECT DISTINCT ON (org_id) org_id, user_id
FROM memberships
ORDER BY org_id, created_at
)
UPDATE organizations o
SET owner_id = fm.user_id
FROM first_member fm
WHERE o.owner_id IS NULL
AND o.id = fm.org_id;
ALTER TABLE organizations ALTER COLUMN owner_id SET NOT NULL;
ALTER TABLE organizations DROP CONSTRAINT organizations_slug_key;
CREATE UNIQUE INDEX organizations_owner_slug_key ON organizations(owner_id, slug);