- Added modified_by column to files table - Updated WOPI PutFile to track who modified the file - Updated view handlers to return file metadata (name, size, lastModified, modifiedByName) - Updated Flutter models and UI to display last modified info
693 lines
20 KiB
Go
693 lines
20 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/json"
|
|
"log"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type DB struct {
|
|
*sql.DB
|
|
}
|
|
|
|
func New(db *sql.DB) *DB {
|
|
return &DB{DB: db}
|
|
}
|
|
|
|
// StringArray handles nullable string arrays from PostgreSQL
|
|
type StringArray []string
|
|
|
|
// Scan handles NULL values properly
|
|
func (sa *StringArray) Scan(value interface{}) error {
|
|
if value == nil {
|
|
*sa = StringArray{}
|
|
return nil
|
|
}
|
|
|
|
// Handle byte slice from PostgreSQL array
|
|
if bytes, ok := value.([]byte); ok {
|
|
var arr []string
|
|
if err := json.Unmarshal(bytes, &arr); err != nil {
|
|
// If JSON parse fails, try as raw string
|
|
*sa = StringArray{string(bytes)}
|
|
return nil
|
|
}
|
|
*sa = StringArray(arr)
|
|
return nil
|
|
}
|
|
|
|
// Handle string directly
|
|
if str, ok := value.(string); ok {
|
|
if str == "" {
|
|
*sa = StringArray{}
|
|
return nil
|
|
}
|
|
*sa = StringArray{str}
|
|
return nil
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Value implements the driver.Valuer interface
|
|
func (sa StringArray) Value() (driver.Value, error) {
|
|
if len(sa) == 0 {
|
|
return nil, nil
|
|
}
|
|
return json.Marshal(sa)
|
|
}
|
|
|
|
type User struct {
|
|
ID uuid.UUID
|
|
Email string
|
|
Username string
|
|
DisplayName string
|
|
PasswordHash *string
|
|
CreatedAt time.Time
|
|
LastLoginAt *time.Time
|
|
}
|
|
|
|
type Credential struct {
|
|
ID string
|
|
UserID uuid.UUID
|
|
CredentialPublicKey []byte
|
|
CredentialID []byte
|
|
SignCount int64
|
|
CreatedAt time.Time
|
|
LastUsedAt *time.Time
|
|
Transports StringArray
|
|
}
|
|
|
|
type AuthChallenge struct {
|
|
ID uuid.UUID
|
|
UserID uuid.UUID
|
|
Challenge []byte
|
|
ChallengeType string
|
|
CreatedAt time.Time
|
|
ExpiresAt time.Time
|
|
UsedAt *time.Time
|
|
}
|
|
|
|
type Session struct {
|
|
ID uuid.UUID
|
|
UserID uuid.UUID
|
|
ExpiresAt time.Time
|
|
RevokedAt *time.Time
|
|
}
|
|
|
|
type Organization struct {
|
|
ID uuid.UUID `json:"id"`
|
|
OwnerID uuid.UUID `json:"ownerId"`
|
|
Name string `json:"name"`
|
|
Slug string `json:"slug"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
}
|
|
|
|
type Membership struct {
|
|
UserID uuid.UUID
|
|
OrgID uuid.UUID
|
|
Role string
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
type Activity struct {
|
|
ID uuid.UUID
|
|
UserID uuid.UUID
|
|
OrgID uuid.UUID
|
|
FileID *string
|
|
Action string
|
|
Metadata map[string]interface{}
|
|
Timestamp time.Time
|
|
}
|
|
|
|
type File struct {
|
|
ID uuid.UUID
|
|
OrgID *uuid.UUID
|
|
UserID *uuid.UUID
|
|
Name string
|
|
Path string
|
|
Type string
|
|
Size int64
|
|
LastModified time.Time
|
|
CreatedAt time.Time
|
|
ModifiedBy *uuid.UUID
|
|
ModifiedByName string
|
|
}
|
|
|
|
func (db *DB) GetOrCreateUser(ctx context.Context, sub, email, name string) (*User, error) {
|
|
var user User
|
|
err := db.QueryRowContext(ctx, `
|
|
INSERT INTO users (id, email, display_name)
|
|
VALUES (gen_random_uuid(), $1, $2)
|
|
ON CONFLICT (email) DO UPDATE SET
|
|
display_name = EXCLUDED.display_name,
|
|
last_login_at = NOW()
|
|
RETURNING id, email, display_name, created_at, last_login_at
|
|
`, email, name).Scan(&user.ID, &user.Email, &user.DisplayName, &user.CreatedAt, &user.LastLoginAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (db *DB) CreateSession(ctx context.Context, userID uuid.UUID, expiresAt time.Time) (*Session, error) {
|
|
var session Session
|
|
err := db.QueryRowContext(ctx, `
|
|
INSERT INTO sessions (user_id, expires_at)
|
|
VALUES ($1, $2)
|
|
RETURNING id, user_id, expires_at, revoked_at
|
|
`, userID, expiresAt).Scan(&session.ID, &session.UserID, &session.ExpiresAt, &session.RevokedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
func (db *DB) GetSession(ctx context.Context, sessionID uuid.UUID) (*Session, error) {
|
|
var session Session
|
|
err := db.QueryRowContext(ctx, `
|
|
SELECT id, user_id, expires_at, revoked_at
|
|
FROM sessions
|
|
WHERE id = $1
|
|
`, sessionID).Scan(&session.ID, &session.UserID, &session.ExpiresAt, &session.RevokedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &session, nil
|
|
}
|
|
|
|
func (db *DB) RevokeSession(ctx context.Context, sessionID uuid.UUID) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
UPDATE sessions
|
|
SET revoked_at = NOW()
|
|
WHERE id = $1 AND revoked_at IS NULL
|
|
`, sessionID)
|
|
return err
|
|
}
|
|
|
|
func (db *DB) GetUserOrganizations(ctx context.Context, userID uuid.UUID) ([]Organization, error) {
|
|
rows, err := db.QueryContext(ctx, `
|
|
SELECT o.id, o.owner_id, o.name, o.slug, o.created_at
|
|
FROM organizations o
|
|
JOIN memberships m ON o.id = m.org_id
|
|
WHERE m.user_id = $1
|
|
`, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var orgs []Organization
|
|
for rows.Next() {
|
|
var org Organization
|
|
if err := rows.Scan(&org.ID, &org.OwnerID, &org.Name, &org.Slug, &org.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
orgs = append(orgs, org)
|
|
}
|
|
return orgs, rows.Err()
|
|
}
|
|
|
|
func (db *DB) GetUserMembership(ctx context.Context, userID, orgID uuid.UUID) (*Membership, error) {
|
|
var membership Membership
|
|
err := db.QueryRowContext(ctx, `
|
|
SELECT user_id, org_id, role, created_at
|
|
FROM memberships
|
|
WHERE user_id = $1 AND org_id = $2
|
|
`, userID, orgID).Scan(&membership.UserID, &membership.OrgID, &membership.Role, &membership.CreatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &membership, nil
|
|
}
|
|
|
|
// GetOrgMember is an alias for GetUserMembership - checks if user is a member of an org
|
|
func (db *DB) GetOrgMember(ctx context.Context, orgID, userID uuid.UUID) (*Membership, error) {
|
|
return db.GetUserMembership(ctx, userID, orgID)
|
|
}
|
|
|
|
func (db *DB) CreateOrg(ctx context.Context, ownerID uuid.UUID, name, slug string) (*Organization, error) {
|
|
var org Organization
|
|
err := db.QueryRowContext(ctx, `
|
|
INSERT INTO organizations (owner_id, name, slug)
|
|
VALUES ($1, $2, $3)
|
|
RETURNING id, owner_id, name, slug, created_at
|
|
`, ownerID, name, slug).Scan(&org.ID, &org.OwnerID, &org.Name, &org.Slug, &org.CreatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &org, nil
|
|
}
|
|
|
|
func (db *DB) AddMembership(ctx context.Context, userID, orgID uuid.UUID, role string) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
INSERT INTO memberships (user_id, org_id, role)
|
|
VALUES ($1, $2, $3)
|
|
`, userID, orgID, role)
|
|
return err
|
|
}
|
|
|
|
func (db *DB) LogActivity(ctx context.Context, userID, orgID uuid.UUID, fileID *string, action string, metadata map[string]interface{}) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
INSERT INTO activities (user_id, org_id, file_id, action, metadata)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
`, userID, orgID, fileID, action, metadata)
|
|
return err
|
|
}
|
|
|
|
func (db *DB) GetOrgActivities(ctx context.Context, orgID uuid.UUID, limit int) ([]Activity, error) {
|
|
rows, err := db.QueryContext(ctx, `
|
|
SELECT id, user_id, org_id, file_id, action, metadata, timestamp
|
|
FROM activities
|
|
WHERE org_id = $1
|
|
ORDER BY timestamp DESC
|
|
LIMIT $2
|
|
`, orgID, limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var activities []Activity
|
|
for rows.Next() {
|
|
var a Activity
|
|
err := rows.Scan(&a.ID, &a.UserID, &a.OrgID, &a.FileID, &a.Action, &a.Metadata, &a.Timestamp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
activities = append(activities, a)
|
|
}
|
|
return activities, rows.Err()
|
|
}
|
|
|
|
func (db *DB) GetOrgMembers(ctx context.Context, orgID uuid.UUID) ([]Membership, error) {
|
|
rows, err := db.QueryContext(ctx, `
|
|
SELECT user_id, org_id, role, created_at
|
|
FROM memberships
|
|
WHERE org_id = $1
|
|
`, orgID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var memberships []Membership
|
|
for rows.Next() {
|
|
var m Membership
|
|
err := rows.Scan(&m.UserID, &m.OrgID, &m.Role, &m.CreatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
memberships = append(memberships, m)
|
|
}
|
|
return memberships, rows.Err()
|
|
}
|
|
|
|
// GetOrgFiles returns files for a given organization (top-level folder listing)
|
|
func (db *DB) GetOrgFiles(ctx context.Context, orgID uuid.UUID, userID uuid.UUID, path string, q string, page, pageSize int) ([]File, error) {
|
|
if page <= 0 {
|
|
page = 1
|
|
}
|
|
if pageSize <= 0 {
|
|
pageSize = 100
|
|
}
|
|
offset := (page - 1) * pageSize
|
|
|
|
orgIDStr := orgID.String()
|
|
userIDStr := userID.String()
|
|
log.Printf("[DATA-ISOLATION] stage=before, action=list, orgId=%s, userId=%s, fileCount=0, path=%s", orgIDStr, userIDStr, path)
|
|
|
|
// Basic search and pagination. Returns only direct children of the given path.
|
|
// For root ("/"), we want files where path doesn't contain "/" after the first character.
|
|
// For subdirs, we want files where path starts with parent but has no additional "/" after parent.
|
|
rows, err := db.QueryContext(ctx, `
|
|
SELECT f.id, f.org_id::text, f.user_id::text, f.name, f.path, f.type, f.size, f.last_modified, f.created_at
|
|
FROM files f
|
|
WHERE f.org_id = $1
|
|
AND EXISTS (
|
|
SELECT 1
|
|
FROM memberships m
|
|
WHERE m.org_id = $1 AND m.user_id = $2
|
|
)
|
|
AND f.path != $3
|
|
AND (
|
|
($3 = '/' AND f.path LIKE '/%' AND f.path NOT LIKE '/%/%')
|
|
OR ($3 != '/' AND f.path LIKE $3 || '/%' AND f.path NOT LIKE $3 || '/%/%')
|
|
)
|
|
AND ($4 = '' OR f.name ILIKE '%' || $4 || '%')
|
|
ORDER BY CASE WHEN f.type = 'folder' THEN 0 ELSE 1 END, f.name
|
|
LIMIT $5 OFFSET $6
|
|
`, orgID, userID, path, q, pageSize, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var files []File
|
|
for rows.Next() {
|
|
var f File
|
|
var orgNull sql.NullString
|
|
var userNull sql.NullString
|
|
if err := rows.Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
if orgNull.Valid {
|
|
oid, _ := uuid.Parse(orgNull.String)
|
|
f.OrgID = &oid
|
|
}
|
|
if userNull.Valid {
|
|
uid, _ := uuid.Parse(userNull.String)
|
|
f.UserID = &uid
|
|
}
|
|
files = append(files, f)
|
|
}
|
|
err = rows.Err()
|
|
if err == nil {
|
|
log.Printf("[DATA-ISOLATION] stage=after, action=list, orgId=%s, userId=%s, fileCount=%d, path=%s", orgIDStr, userIDStr, len(files), path)
|
|
}
|
|
return files, err
|
|
}
|
|
|
|
// GetUserFiles returns files for a user's personal workspace at a given path
|
|
func (db *DB) GetUserFiles(ctx context.Context, userID uuid.UUID, path string, q string, page, pageSize int) ([]File, error) {
|
|
if page <= 0 {
|
|
page = 1
|
|
}
|
|
if pageSize <= 0 {
|
|
pageSize = 100
|
|
}
|
|
offset := (page - 1) * pageSize
|
|
|
|
// Return only direct children of the given path
|
|
log.Printf("[DATA-ISOLATION] stage=before, action=list, orgId=, userId=%s, fileCount=0, path=%s", userID.String(), path)
|
|
rows, err := db.QueryContext(ctx, `
|
|
SELECT id, org_id::text, user_id::text, name, path, type, size, last_modified, created_at
|
|
FROM files
|
|
WHERE user_id = $1
|
|
AND org_id IS NULL
|
|
AND path != $2
|
|
AND (
|
|
($2 = '/' AND path LIKE '/%' AND path NOT LIKE '/%/%')
|
|
OR ($2 != '/' AND path LIKE $2 || '/%' AND path NOT LIKE $2 || '/%/%')
|
|
)
|
|
AND ($3 = '' OR name ILIKE '%' || $3 || '%')
|
|
ORDER BY CASE WHEN type = 'folder' THEN 0 ELSE 1 END, name
|
|
LIMIT $4 OFFSET $5
|
|
`, userID, path, q, pageSize, offset)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var files []File
|
|
for rows.Next() {
|
|
var f File
|
|
var orgNull sql.NullString
|
|
var userNull sql.NullString
|
|
if err := rows.Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
if orgNull.Valid {
|
|
oid, _ := uuid.Parse(orgNull.String)
|
|
f.OrgID = &oid
|
|
}
|
|
if userNull.Valid {
|
|
uid, _ := uuid.Parse(userNull.String)
|
|
f.UserID = &uid
|
|
}
|
|
files = append(files, f)
|
|
}
|
|
err = rows.Err()
|
|
if err == nil {
|
|
log.Printf("[DATA-ISOLATION] stage=after, action=list, orgId=, userId=%s, fileCount=%d, path=%s", userID.String(), len(files), path)
|
|
}
|
|
return files, err
|
|
}
|
|
|
|
// CreateFile inserts a file or folder record. orgID or userID may be nil.
|
|
func (db *DB) CreateFile(ctx context.Context, orgID *uuid.UUID, userID *uuid.UUID, name, path, fileType string, size int64) (*File, error) {
|
|
var f File
|
|
var orgIDVal interface{}
|
|
var userIDVal interface{}
|
|
orgIDStr := ""
|
|
userIDStr := ""
|
|
if orgID != nil {
|
|
orgIDVal = *orgID
|
|
orgIDStr = orgID.String()
|
|
} else {
|
|
orgIDVal = nil
|
|
}
|
|
if userID != nil {
|
|
userIDVal = *userID
|
|
userIDStr = userID.String()
|
|
} else {
|
|
userIDVal = nil
|
|
}
|
|
log.Printf("[DATA-ISOLATION] stage=before, action=create, orgId=%s, userId=%s, fileCount=1, path=%s", orgIDStr, userIDStr, path)
|
|
|
|
err := db.QueryRowContext(ctx, `
|
|
INSERT INTO files (org_id, user_id, name, path, type, size)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
RETURNING id, org_id::text, user_id::text, name, path, type, size, last_modified, created_at
|
|
`, orgIDVal, userIDVal, name, path, fileType, size).Scan(&f.ID, new(sql.NullString), new(sql.NullString), &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
log.Printf("[DATA-ISOLATION] stage=after, action=create, orgId=%s, userId=%s, fileCount=1, path=%s", orgIDStr, userIDStr, f.Path)
|
|
return &f, nil
|
|
}
|
|
|
|
// GetFileByID retrieves a file by its ID
|
|
func (db *DB) GetFileByID(ctx context.Context, fileID uuid.UUID) (*File, error) {
|
|
var f File
|
|
var orgNull sql.NullString
|
|
var userNull sql.NullString
|
|
var modifiedByNull sql.NullString
|
|
var modifiedByNameNull sql.NullString
|
|
|
|
err := db.QueryRowContext(ctx, `
|
|
SELECT f.id, f.org_id::text, f.user_id::text, f.name, f.path, f.type, f.size, f.last_modified, f.created_at,
|
|
f.modified_by::text, u.display_name
|
|
FROM files f
|
|
LEFT JOIN users u ON f.modified_by = u.id
|
|
WHERE f.id = $1
|
|
`, fileID).Scan(&f.ID, &orgNull, &userNull, &f.Name, &f.Path, &f.Type, &f.Size, &f.LastModified, &f.CreatedAt,
|
|
&modifiedByNull, &modifiedByNameNull)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if orgNull.Valid {
|
|
oid, _ := uuid.Parse(orgNull.String)
|
|
f.OrgID = &oid
|
|
}
|
|
if userNull.Valid {
|
|
uid, _ := uuid.Parse(userNull.String)
|
|
f.UserID = &uid
|
|
}
|
|
if modifiedByNull.Valid {
|
|
mid, _ := uuid.Parse(modifiedByNull.String)
|
|
f.ModifiedBy = &mid
|
|
}
|
|
if modifiedByNameNull.Valid {
|
|
f.ModifiedByName = modifiedByNameNull.String
|
|
}
|
|
|
|
return &f, nil
|
|
}
|
|
|
|
// UpdateFileSize updates the size, modification time, and modifier of a file
|
|
func (db *DB) UpdateFileSize(ctx context.Context, fileID uuid.UUID, size int64, modifiedBy *uuid.UUID) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
UPDATE files
|
|
SET size = $1, last_modified = NOW(), modified_by = $3
|
|
WHERE id = $2
|
|
`, size, fileID, modifiedBy)
|
|
return err
|
|
}
|
|
|
|
// DeleteFileByPath removes a file or folder matching path for a given org or user
|
|
func (db *DB) DeleteFileByPath(ctx context.Context, orgID *uuid.UUID, userID *uuid.UUID, path string) error {
|
|
var res sql.Result
|
|
var err error
|
|
if orgID != nil {
|
|
res, err = db.ExecContext(ctx, `DELETE FROM files WHERE org_id = $1 AND path = $2`, *orgID, path)
|
|
} else if userID != nil {
|
|
res, err = db.ExecContext(ctx, `DELETE FROM files WHERE user_id = $1 AND path = $2`, *userID, path)
|
|
} else {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, _ = res.RowsAffected()
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) UpdateMemberRole(ctx context.Context, orgID, userID uuid.UUID, role string) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
UPDATE memberships
|
|
SET role = $1
|
|
WHERE org_id = $2 AND user_id = $3
|
|
`, role, orgID, userID)
|
|
return err
|
|
}
|
|
|
|
// Passkey-related methods
|
|
|
|
func (db *DB) CreateUser(ctx context.Context, username, email, displayName string, passwordHash *string) (*User, error) {
|
|
var user User
|
|
err := db.QueryRowContext(ctx, `
|
|
INSERT INTO users (id, username, email, display_name, password_hash)
|
|
VALUES (gen_random_uuid(), $1, $2, $3, $4)
|
|
RETURNING id, username, email, display_name, password_hash, created_at, last_login_at
|
|
`, username, email, displayName, passwordHash).Scan(&user.ID, &user.Username, &user.Email, &user.DisplayName, &user.PasswordHash, &user.CreatedAt, &user.LastLoginAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (db *DB) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
|
var user User
|
|
err := db.QueryRowContext(ctx, `
|
|
SELECT id, username, email, display_name, password_hash, created_at, last_login_at
|
|
FROM users
|
|
WHERE username = $1
|
|
`, username).Scan(&user.ID, &user.Username, &user.Email, &user.DisplayName, &user.PasswordHash, &user.CreatedAt, &user.LastLoginAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (db *DB) GetUserByEmail(ctx context.Context, email string) (*User, error) {
|
|
var user User
|
|
err := db.QueryRowContext(ctx, `
|
|
SELECT id, username, email, display_name, password_hash, created_at, last_login_at
|
|
FROM users
|
|
WHERE email = $1
|
|
`, email).Scan(&user.ID, &user.Username, &user.Email, &user.DisplayName, &user.PasswordHash, &user.CreatedAt, &user.LastLoginAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (db *DB) GetUserByID(ctx context.Context, userID uuid.UUID) (*User, error) {
|
|
var user User
|
|
err := db.QueryRowContext(ctx, `
|
|
SELECT id, username, email, display_name, password_hash, created_at, last_login_at
|
|
FROM users
|
|
WHERE id = $1
|
|
`, userID).Scan(&user.ID, &user.Username, &user.Email, &user.DisplayName, &user.PasswordHash, &user.CreatedAt, &user.LastLoginAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (db *DB) UpdateUserLastLogin(ctx context.Context, userID uuid.UUID) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
UPDATE users
|
|
SET last_login_at = NOW()
|
|
WHERE id = $1
|
|
`, userID)
|
|
return err
|
|
}
|
|
|
|
func (db *DB) CreateCredential(ctx context.Context, cred *Credential) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
INSERT INTO credentials (id, user_id, credential_public_key, credential_id, sign_count, transports)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
`, cred.ID, cred.UserID, cred.CredentialPublicKey, cred.CredentialID, cred.SignCount, cred.Transports)
|
|
return err
|
|
}
|
|
|
|
func (db *DB) GetCredentialByID(ctx context.Context, credentialID []byte) (*Credential, error) {
|
|
var cred Credential
|
|
err := db.QueryRowContext(ctx, `
|
|
SELECT id, user_id, credential_public_key, credential_id, sign_count, created_at, last_used_at, transports
|
|
FROM credentials
|
|
WHERE credential_id = $1
|
|
`, credentialID).Scan(&cred.ID, &cred.UserID, &cred.CredentialPublicKey, &cred.CredentialID, &cred.SignCount, &cred.CreatedAt, &cred.LastUsedAt, &cred.Transports)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &cred, nil
|
|
}
|
|
|
|
func (db *DB) GetUserCredentials(ctx context.Context, userID uuid.UUID) ([]Credential, error) {
|
|
rows, err := db.QueryContext(ctx, `
|
|
SELECT id, user_id, credential_public_key, credential_id, sign_count, created_at, last_used_at, transports
|
|
FROM credentials
|
|
WHERE user_id = $1
|
|
ORDER BY created_at DESC
|
|
`, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var credentials []Credential
|
|
for rows.Next() {
|
|
var cred Credential
|
|
err := rows.Scan(&cred.ID, &cred.UserID, &cred.CredentialPublicKey, &cred.CredentialID, &cred.SignCount, &cred.CreatedAt, &cred.LastUsedAt, &cred.Transports)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
credentials = append(credentials, cred)
|
|
}
|
|
return credentials, rows.Err()
|
|
}
|
|
|
|
func (db *DB) UpdateCredentialLastUsed(ctx context.Context, credentialID string) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
UPDATE credentials
|
|
SET last_used_at = NOW()
|
|
WHERE id = $1
|
|
`, credentialID)
|
|
return err
|
|
}
|
|
|
|
func (db *DB) CreateAuthChallenge(ctx context.Context, userID uuid.UUID, challenge []byte, challengeType string) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
INSERT INTO auth_challenges (user_id, challenge, challenge_type, expires_at)
|
|
VALUES ($1, $2, $3, NOW() + INTERVAL '15 minutes')
|
|
`, userID, challenge, challengeType)
|
|
return err
|
|
}
|
|
|
|
func (db *DB) VerifyAuthChallenge(ctx context.Context, userID uuid.UUID, challenge []byte, challengeType string) error {
|
|
var count int
|
|
err := db.QueryRowContext(ctx, `
|
|
SELECT COUNT(*)
|
|
FROM auth_challenges
|
|
WHERE user_id = $1 AND challenge = $2 AND challenge_type = $3 AND expires_at > NOW() AND used_at IS NULL
|
|
`, userID, challenge, challengeType).Scan(&count)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if count == 0 {
|
|
return sql.ErrNoRows
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) MarkChallengeUsed(ctx context.Context, challenge []byte) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
UPDATE auth_challenges
|
|
SET used_at = NOW()
|
|
WHERE challenge = $1 AND used_at IS NULL
|
|
`, challenge)
|
|
return err
|
|
}
|
|
|
|
// UpdateFileSize updates the size and last_modified timestamp of a file
|