Files
b0esche_cloud/go_cloud/internal/database/db.go

679 lines
19 KiB
Go

package database
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"log"
"time"
"github.com/google/uuid"
)
type DB struct {
*sql.DB
}
func New(db *sql.DB) *DB {
return &DB{DB: db}
}
// StringArray handles nullable string arrays from PostgreSQL
type StringArray []string
// Scan handles NULL values properly
func (sa *StringArray) Scan(value interface{}) error {
if value == nil {
*sa = StringArray{}
return nil
}
// Handle byte slice from PostgreSQL array
if bytes, ok := value.([]byte); ok {
var arr []string
if err := json.Unmarshal(bytes, &arr); err != nil {
// If JSON parse fails, try as raw string
*sa = StringArray{string(bytes)}
return nil
}
*sa = StringArray(arr)
return nil
}
// Handle string directly
if str, ok := value.(string); ok {
if str == "" {
*sa = StringArray{}
return nil
}
*sa = StringArray{str}
return nil
}
return nil
}
// Value implements the driver.Valuer interface
func (sa StringArray) Value() (driver.Value, error) {
if len(sa) == 0 {
return nil, nil
}
return json.Marshal(sa)
}
type User struct {
ID uuid.UUID
Email string
Username string
DisplayName string
PasswordHash *string
CreatedAt time.Time
LastLoginAt *time.Time
}
type Credential struct {
ID string
UserID uuid.UUID
CredentialPublicKey []byte
CredentialID []byte
SignCount int64
CreatedAt time.Time
LastUsedAt *time.Time
Transports StringArray
}
type AuthChallenge struct {
ID uuid.UUID
UserID uuid.UUID
Challenge []byte
ChallengeType string
CreatedAt time.Time
ExpiresAt time.Time
UsedAt *time.Time
}
type Session struct {
ID uuid.UUID
UserID uuid.UUID
ExpiresAt time.Time
RevokedAt *time.Time
}
type Organization struct {
ID uuid.UUID `json:"id"`
OwnerID uuid.UUID `json:"ownerId"`
Name string `json:"name"`
Slug string `json:"slug"`
CreatedAt time.Time `json:"createdAt"`
}
type Membership struct {
UserID uuid.UUID
OrgID uuid.UUID
Role string
CreatedAt time.Time
}
type Activity struct {
ID uuid.UUID
UserID uuid.UUID
OrgID uuid.UUID
FileID *string
Action string
Metadata map[string]interface{}
Timestamp time.Time
}
type File struct {
ID uuid.UUID
OrgID *uuid.UUID
UserID *uuid.UUID
Name string
Path string
Type string
Size int64
LastModified time.Time
CreatedAt time.Time
}
func (db *DB) GetOrCreateUser(ctx context.Context, sub, email, name string) (*User, error) {
var user User
err := db.QueryRowContext(ctx, `
INSERT INTO users (id, email, display_name)
VALUES (gen_random_uuid(), $1, $2)
ON CONFLICT (email) DO UPDATE SET
display_name = EXCLUDED.display_name,
last_login_at = NOW()
RETURNING id, email, display_name, created_at, last_login_at
`, email, name).Scan(&user.ID, &user.Email, &user.DisplayName, &user.CreatedAt, &user.LastLoginAt)
if err != nil {
return nil, err
}
return &user, nil
}
func (db *DB) CreateSession(ctx context.Context, userID uuid.UUID, expiresAt time.Time) (*Session, error) {
var session Session
err := db.QueryRowContext(ctx, `
INSERT INTO sessions (user_id, expires_at)
VALUES ($1, $2)
RETURNING id, user_id, expires_at, revoked_at
`, userID, expiresAt).Scan(&session.ID, &session.UserID, &session.ExpiresAt, &session.RevokedAt)
if err != nil {
return nil, err
}
return &session, nil
}
func (db *DB) GetSession(ctx context.Context, sessionID uuid.UUID) (*Session, error) {
var session Session
err := db.QueryRowContext(ctx, `
SELECT id, user_id, expires_at, revoked_at
FROM sessions
WHERE id = $1
`, sessionID).Scan(&session.ID, &session.UserID, &session.ExpiresAt, &session.RevokedAt)
if err != nil {
return nil, err
}
return &session, nil
}
func (db *DB) RevokeSession(ctx context.Context, sessionID uuid.UUID) error {
_, err := db.ExecContext(ctx, `
UPDATE sessions
SET revoked_at = NOW()
WHERE id = $1 AND revoked_at IS NULL
`, sessionID)
return err
}
func (db *DB) GetUserOrganizations(ctx context.Context, userID uuid.UUID) ([]Organization, error) {
rows, err := db.QueryContext(ctx, `
SELECT o.id, o.owner_id, o.name, o.slug, o.created_at
FROM organizations o
JOIN memberships m ON o.id = m.org_id
WHERE m.user_id = $1
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var orgs []Organization
for rows.Next() {
var org Organization
if err := rows.Scan(&org.ID, &org.OwnerID, &org.Name, &org.Slug, &org.CreatedAt); err != nil {
return nil, err
}
orgs = append(orgs, org)
}
return orgs, rows.Err()
}
func (db *DB) GetUserMembership(ctx context.Context, userID, orgID uuid.UUID) (*Membership, error) {
var membership Membership
err := db.QueryRowContext(ctx, `
SELECT user_id, org_id, role, created_at
FROM memberships
WHERE user_id = $1 AND org_id = $2
`, userID, orgID).Scan(&membership.UserID, &membership.OrgID, &membership.Role, &membership.CreatedAt)
if err != nil {
return nil, err
}
return &membership, nil
}
// GetOrgMember is an alias for GetUserMembership - checks if user is a member of an org
func (db *DB) GetOrgMember(ctx context.Context, orgID, userID uuid.UUID) (*Membership, error) {
return db.GetUserMembership(ctx, userID, orgID)
}
func (db *DB) CreateOrg(ctx context.Context, ownerID uuid.UUID, name, slug string) (*Organization, error) {
var org Organization
err := db.QueryRowContext(ctx, `
INSERT INTO organizations (owner_id, name, slug)
VALUES ($1, $2, $3)
RETURNING id, owner_id, name, slug, created_at
`, ownerID, name, slug).Scan(&org.ID, &org.OwnerID, &org.Name, &org.Slug, &org.CreatedAt)
if err != nil {
return nil, err
}
return &org, nil
}
func (db *DB) AddMembership(ctx context.Context, userID, orgID uuid.UUID, role string) error {
_, err := db.ExecContext(ctx, `
INSERT INTO memberships (user_id, org_id, role)
VALUES ($1, $2, $3)
`, userID, orgID, role)
return err
}
func (db *DB) LogActivity(ctx context.Context, userID, orgID uuid.UUID, fileID *string, action string, metadata map[string]interface{}) error {
_, err := db.ExecContext(ctx, `
INSERT INTO activities (user_id, org_id, file_id, action, metadata)
VALUES ($1, $2, $3, $4, $5)
`, userID, orgID, fileID, action, metadata)
return err
}
func (db *DB) GetOrgActivities(ctx context.Context, orgID uuid.UUID, limit int) ([]Activity, error) {
rows, err := db.QueryContext(ctx, `
SELECT id, user_id, org_id, file_id, action, metadata, timestamp
FROM activities
WHERE org_id = $1
ORDER BY timestamp DESC
LIMIT $2
`, orgID, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var activities []Activity
for rows.Next() {
var a Activity
err := rows.Scan(&a.ID, &a.UserID, &a.OrgID, &a.FileID, &a.Action, &a.Metadata, &a.Timestamp)
if err != nil {
return nil, err
}
activities = append(activities, a)
}
return activities, rows.Err()
}
func (db *DB) GetOrgMembers(ctx context.Context, orgID uuid.UUID) ([]Membership, error) {
rows, err := db.QueryContext(ctx, `
SELECT user_id, org_id, role, created_at
FROM memberships
WHERE org_id = $1
`, orgID)
if err != nil {
return nil, err
}
defer rows.Close()
var memberships []Membership
for rows.Next() {
var m Membership
err := rows.Scan(&m.UserID, &m.OrgID, &m.Role, &m.CreatedAt)
if err != nil {
return nil, err
}
memberships = append(memberships, m)
}
return memberships, rows.Err()
}
// GetOrgFiles returns files for a given organization (top-level folder listing)
func (db *DB) GetOrgFiles(ctx context.Context, orgID uuid.UUID, userID uuid.UUID, path string, q string, page, pageSize int) ([]File, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 100
}
offset := (page - 1) * pageSize
orgIDStr := orgID.String()
userIDStr := userID.String()
log.Printf("[DATA-ISOLATION] stage=before, action=list, orgId=%s, userId=%s, fileCount=0, path=%s", orgIDStr, userIDStr, path)
// Basic search and pagination. Returns only direct children of the given path.
// For root ("/"), we want files where path doesn't contain "/" after the first character.
// For subdirs, we want files where path starts with parent but has no additional "/" after parent.
rows, err := db.QueryContext(ctx, `
SELECT f.id, f.org_id::text, f.user_id::text, f.name, f.path, f.type, f.size, f.last_modified, f.created_at
FROM files f
WHERE f.org_id = $1
AND EXISTS (
SELECT 1
FROM memberships m
WHERE m.org_id = $1 AND m.user_id = $2
)
AND f.path != $3
AND (
($3 = '/' AND f.path LIKE '/%' AND f.path NOT LIKE '/%/%')
OR ($3 != '/' AND f.path LIKE $3 || '/%' AND f.path NOT LIKE $3 || '/%/%')
)
AND ($4 = '' OR f.name ILIKE '%' || $4 || '%')
ORDER BY CASE WHEN f.type = 'folder' THEN 0 ELSE 1 END, f.name
LIMIT $5 OFFSET $6
`, orgID, userID, path, q, pageSize, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var files []File
for rows.Next() {
var f File
var orgNull sql.NullString
var userNull sql.NullString
if err := rows.Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt); err != nil {
return nil, err
}
if orgNull.Valid {
oid, _ := uuid.Parse(orgNull.String)
f.OrgID = &oid
}
if userNull.Valid {
uid, _ := uuid.Parse(userNull.String)
f.UserID = &uid
}
files = append(files, f)
}
err = rows.Err()
if err == nil {
log.Printf("[DATA-ISOLATION] stage=after, action=list, orgId=%s, userId=%s, fileCount=%d, path=%s", orgIDStr, userIDStr, len(files), path)
}
return files, err
}
// GetUserFiles returns files for a user's personal workspace at a given path
func (db *DB) GetUserFiles(ctx context.Context, userID uuid.UUID, path string, q string, page, pageSize int) ([]File, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 100
}
offset := (page - 1) * pageSize
// Return only direct children of the given path
log.Printf("[DATA-ISOLATION] stage=before, action=list, orgId=, userId=%s, fileCount=0, path=%s", userID.String(), path)
rows, err := db.QueryContext(ctx, `
SELECT id, org_id::text, user_id::text, name, path, type, size, last_modified, created_at
FROM files
WHERE user_id = $1
AND org_id IS NULL
AND path != $2
AND (
($2 = '/' AND path LIKE '/%' AND path NOT LIKE '/%/%')
OR ($2 != '/' AND path LIKE $2 || '/%' AND path NOT LIKE $2 || '/%/%')
)
AND ($3 = '' OR name ILIKE '%' || $3 || '%')
ORDER BY CASE WHEN type = 'folder' THEN 0 ELSE 1 END, name
LIMIT $4 OFFSET $5
`, userID, path, q, pageSize, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var files []File
for rows.Next() {
var f File
var orgNull sql.NullString
var userNull sql.NullString
if err := rows.Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt); err != nil {
return nil, err
}
if orgNull.Valid {
oid, _ := uuid.Parse(orgNull.String)
f.OrgID = &oid
}
if userNull.Valid {
uid, _ := uuid.Parse(userNull.String)
f.UserID = &uid
}
files = append(files, f)
}
err = rows.Err()
if err == nil {
log.Printf("[DATA-ISOLATION] stage=after, action=list, orgId=, userId=%s, fileCount=%d, path=%s", userID.String(), len(files), path)
}
return files, err
}
// CreateFile inserts a file or folder record. orgID or userID may be nil.
func (db *DB) CreateFile(ctx context.Context, orgID *uuid.UUID, userID *uuid.UUID, name, path, fileType string, size int64) (*File, error) {
var f File
var orgIDVal interface{}
var userIDVal interface{}
orgIDStr := ""
userIDStr := ""
if orgID != nil {
orgIDVal = *orgID
orgIDStr = orgID.String()
} else {
orgIDVal = nil
}
if userID != nil {
userIDVal = *userID
userIDStr = userID.String()
} else {
userIDVal = nil
}
log.Printf("[DATA-ISOLATION] stage=before, action=create, orgId=%s, userId=%s, fileCount=1, path=%s", orgIDStr, userIDStr, path)
err := db.QueryRowContext(ctx, `
INSERT INTO files (org_id, user_id, name, path, type, size)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, org_id::text, user_id::text, name, path, type, size, last_modified, created_at
`, orgIDVal, userIDVal, name, path, fileType, size).Scan(&f.ID, new(sql.NullString), new(sql.NullString), &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt)
if err != nil {
return nil, err
}
log.Printf("[DATA-ISOLATION] stage=after, action=create, orgId=%s, userId=%s, fileCount=1, path=%s", orgIDStr, userIDStr, f.Path)
return &f, nil
}
// GetFileByID retrieves a file by its ID
func (db *DB) GetFileByID(ctx context.Context, fileID uuid.UUID) (*File, error) {
var f File
var orgNull sql.NullString
var userNull sql.NullString
err := db.QueryRowContext(ctx, `
SELECT id, org_id::text, user_id::text, name, path, type, size, last_modified, created_at
FROM files
WHERE id = $1
`, fileID).Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt)
if err != nil {
return nil, err
}
if orgNull.Valid {
oid, _ := uuid.Parse(orgNull.String)
f.OrgID = &oid
}
if userNull.Valid {
uid, _ := uuid.Parse(userNull.String)
f.UserID = &uid
}
return &f, nil
}
// UpdateFileSize updates the size and modification time of a file
func (db *DB) UpdateFileSize(ctx context.Context, fileID uuid.UUID, size int64) error {
_, err := db.ExecContext(ctx, `
UPDATE files
SET size = $1, last_modified = NOW()
WHERE id = $2
`, size, fileID)
return err
}
// DeleteFileByPath removes a file or folder matching path for a given org or user
func (db *DB) DeleteFileByPath(ctx context.Context, orgID *uuid.UUID, userID *uuid.UUID, path string) error {
var res sql.Result
var err error
if orgID != nil {
res, err = db.ExecContext(ctx, `DELETE FROM files WHERE org_id = $1 AND path = $2`, *orgID, path)
} else if userID != nil {
res, err = db.ExecContext(ctx, `DELETE FROM files WHERE user_id = $1 AND path = $2`, *userID, path)
} else {
return nil
}
if err != nil {
return err
}
_, _ = res.RowsAffected()
return nil
}
func (db *DB) UpdateMemberRole(ctx context.Context, orgID, userID uuid.UUID, role string) error {
_, err := db.ExecContext(ctx, `
UPDATE memberships
SET role = $1
WHERE org_id = $2 AND user_id = $3
`, role, orgID, userID)
return err
}
// Passkey-related methods
func (db *DB) CreateUser(ctx context.Context, username, email, displayName string, passwordHash *string) (*User, error) {
var user User
err := db.QueryRowContext(ctx, `
INSERT INTO users (id, username, email, display_name, password_hash)
VALUES (gen_random_uuid(), $1, $2, $3, $4)
RETURNING id, username, email, display_name, password_hash, created_at, last_login_at
`, username, email, displayName, passwordHash).Scan(&user.ID, &user.Username, &user.Email, &user.DisplayName, &user.PasswordHash, &user.CreatedAt, &user.LastLoginAt)
if err != nil {
return nil, err
}
return &user, nil
}
func (db *DB) GetUserByUsername(ctx context.Context, username string) (*User, error) {
var user User
err := db.QueryRowContext(ctx, `
SELECT id, username, email, display_name, password_hash, created_at, last_login_at
FROM users
WHERE username = $1
`, username).Scan(&user.ID, &user.Username, &user.Email, &user.DisplayName, &user.PasswordHash, &user.CreatedAt, &user.LastLoginAt)
if err != nil {
return nil, err
}
return &user, nil
}
func (db *DB) GetUserByEmail(ctx context.Context, email string) (*User, error) {
var user User
err := db.QueryRowContext(ctx, `
SELECT id, username, email, display_name, password_hash, created_at, last_login_at
FROM users
WHERE email = $1
`, email).Scan(&user.ID, &user.Username, &user.Email, &user.DisplayName, &user.PasswordHash, &user.CreatedAt, &user.LastLoginAt)
if err != nil {
return nil, err
}
return &user, nil
}
func (db *DB) GetUserByID(ctx context.Context, userID uuid.UUID) (*User, error) {
var user User
err := db.QueryRowContext(ctx, `
SELECT id, username, email, display_name, password_hash, created_at, last_login_at
FROM users
WHERE id = $1
`, userID).Scan(&user.ID, &user.Username, &user.Email, &user.DisplayName, &user.PasswordHash, &user.CreatedAt, &user.LastLoginAt)
if err != nil {
return nil, err
}
return &user, nil
}
func (db *DB) UpdateUserLastLogin(ctx context.Context, userID uuid.UUID) error {
_, err := db.ExecContext(ctx, `
UPDATE users
SET last_login_at = NOW()
WHERE id = $1
`, userID)
return err
}
func (db *DB) CreateCredential(ctx context.Context, cred *Credential) error {
_, err := db.ExecContext(ctx, `
INSERT INTO credentials (id, user_id, credential_public_key, credential_id, sign_count, transports)
VALUES ($1, $2, $3, $4, $5, $6)
`, cred.ID, cred.UserID, cred.CredentialPublicKey, cred.CredentialID, cred.SignCount, cred.Transports)
return err
}
func (db *DB) GetCredentialByID(ctx context.Context, credentialID []byte) (*Credential, error) {
var cred Credential
err := db.QueryRowContext(ctx, `
SELECT id, user_id, credential_public_key, credential_id, sign_count, created_at, last_used_at, transports
FROM credentials
WHERE credential_id = $1
`, credentialID).Scan(&cred.ID, &cred.UserID, &cred.CredentialPublicKey, &cred.CredentialID, &cred.SignCount, &cred.CreatedAt, &cred.LastUsedAt, &cred.Transports)
if err != nil {
return nil, err
}
return &cred, nil
}
func (db *DB) GetUserCredentials(ctx context.Context, userID uuid.UUID) ([]Credential, error) {
rows, err := db.QueryContext(ctx, `
SELECT id, user_id, credential_public_key, credential_id, sign_count, created_at, last_used_at, transports
FROM credentials
WHERE user_id = $1
ORDER BY created_at DESC
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var credentials []Credential
for rows.Next() {
var cred Credential
err := rows.Scan(&cred.ID, &cred.UserID, &cred.CredentialPublicKey, &cred.CredentialID, &cred.SignCount, &cred.CreatedAt, &cred.LastUsedAt, &cred.Transports)
if err != nil {
return nil, err
}
credentials = append(credentials, cred)
}
return credentials, rows.Err()
}
func (db *DB) UpdateCredentialLastUsed(ctx context.Context, credentialID string) error {
_, err := db.ExecContext(ctx, `
UPDATE credentials
SET last_used_at = NOW()
WHERE id = $1
`, credentialID)
return err
}
func (db *DB) CreateAuthChallenge(ctx context.Context, userID uuid.UUID, challenge []byte, challengeType string) error {
_, err := db.ExecContext(ctx, `
INSERT INTO auth_challenges (user_id, challenge, challenge_type, expires_at)
VALUES ($1, $2, $3, NOW() + INTERVAL '15 minutes')
`, userID, challenge, challengeType)
return err
}
func (db *DB) VerifyAuthChallenge(ctx context.Context, userID uuid.UUID, challenge []byte, challengeType string) error {
var count int
err := db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_challenges
WHERE user_id = $1 AND challenge = $2 AND challenge_type = $3 AND expires_at > NOW() AND used_at IS NULL
`, userID, challenge, challengeType).Scan(&count)
if err != nil {
return err
}
if count == 0 {
return sql.ErrNoRows
}
return nil
}
func (db *DB) MarkChallengeUsed(ctx context.Context, challenge []byte) error {
_, err := db.ExecContext(ctx, `
UPDATE auth_challenges
SET used_at = NOW()
WHERE challenge = $1 AND used_at IS NULL
`, challenge)
return err
}
// UpdateFileSize updates the size and last_modified timestamp of a file