diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8e7a349..dd80fb9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -20,7 +20,7 @@ jobs: arch: arm64 - os: linux runs-on: ubuntu-latest - arch: amd64 + arch: x86_64 steps: - name: Checkout code @@ -96,8 +96,8 @@ jobs: echo "GOARCH=amd64" >> $GITHUB_ENV echo "CGO_ENABLED=1" >> $GITHUB_ENV - - name: Set build environment (Linux) - if: matrix.os == 'linux' + - name: Set build environment (Linux x86_64) + if: matrix.os == 'linux' && matrix.arch == 'x86_64' run: | echo "GOOS=linux" >> $GITHUB_ENV echo "GOARCH=amd64" >> $GITHUB_ENV @@ -161,7 +161,7 @@ jobs: uses: softprops/action-gh-release@v1 with: files: | - sqlrsync-linux-amd64/sqlrsync-linux-amd64 + sqlrsync-linux-x86_64/sqlrsync-linux-x86_64 sqlrsync-darwin-amd64/sqlrsync-darwin-amd64 sqlrsync-darwin-arm64/sqlrsync-darwin-arm64 sqlrsync-windows-amd64/sqlrsync-windows-amd64.exe diff --git a/.gitignore b/.gitignore index 7fc3c42..248f074 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ CLAUDE.md **/CLAUDE.md tmp/ +client/sqlrsync +client/sqlrsync +client/sqlrsync_simple diff --git a/bridge/client.go b/bridge/client.go index 53e39a2..981b65a 100644 --- a/bridge/client.go +++ b/bridge/client.go @@ -73,14 +73,15 @@ func (c *Client) GetDatabaseInfo() (*DatabaseInfo, error) { func (c *Client) RunPushSync(readFunc ReadFunc, writeFunc WriteFunc) error { c.Logger.Info("Starting origin sync", zap.String("database", c.Config.DatabasePath)) + if c.Config.DryRun { + fmt.Println("Running in dry-run mode") + return nil + } + // Store I/O functions for callbacks c.ReadFunc = readFunc c.WriteFunc = writeFunc - if c.Config.DryRun { - c.Logger.Info("Running in dry-run mode") - } - c.Logger.Debug("Calling C sqlite_rsync_run_origin") // Run the origin synchronization via CGO bridge @@ -102,7 +103,8 @@ func (c *Client) RunPullSync(readFunc ReadFunc, writeFunc WriteFunc) error { c.WriteFunc = writeFunc if c.Config.DryRun { - c.Logger.Info("Running in dry-run mode") + fmt.Println("Running in dry-run mode. We should not have gotten here.") + return nil } c.Logger.Debug("Calling C sqlite_rsync_run_replica") @@ -120,12 +122,13 @@ func (c *Client) RunPullSync(readFunc ReadFunc, writeFunc WriteFunc) error { // RunDirectSync runs direct local synchronization between two SQLite files func (c *Client) RunDirectSync(replicaPath string) error { - c.Logger.Info("Starting direct local sync", + c.Logger.Info("Starting direct local sync", zap.String("origin", c.Config.DatabasePath), zap.String("replica", replicaPath)) if c.Config.DryRun { - c.Logger.Info("Running in dry-run mode") + fmt.Println("Running in dry-run mode. We should not have gotten here.") + return nil } verboseLevel := 0 diff --git a/client/auth/config.go b/client/auth/config.go new file mode 100644 index 0000000..76d241c --- /dev/null +++ b/client/auth/config.go @@ -0,0 +1,336 @@ +package auth + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/BurntSushi/toml" +) + +// DefaultsConfig represents ~/.config/sqlrsync/defaults.toml +type DefaultsConfig struct { + Defaults struct { + Server string `toml:"server"` + } `toml:"defaults"` +} + +// LocalSecretsConfig represents ~/.config/sqlrsync/local-secrets.toml +type LocalSecretsConfig struct { + SQLRsyncDatabases []SQLRsyncDatabase `toml:"sqlrsync-databases"` +} + +// SQLRsyncDatabase represents a configured database with auth info +type SQLRsyncDatabase struct { + LocalPath string `toml:"path,omitempty"` + Server string `toml:"server"` + CustomerSuppliedEncryptionKey string `toml:"customerSuppliedEncryptionKey,omitempty"` + ReplicaID string `toml:"replicaID"` + RemotePath string `toml:"remotePath,omitempty"` + PushKey string `toml:"pushKey,omitempty"` + LastPush time.Time `toml:"lastPush,omitempty"` +} + +// DashSQLRsync manages the -sqlrsync file for a database +type DashSQLRsync struct { + DatabasePath string + RemotePath string + PullKey string + Server string + ReplicaID string +} + +// GetConfigDir returns the sqlrsync config directory path +func GetConfigDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %w", err) + } + return filepath.Join(homeDir, ".config", "sqlrsync"), nil +} + +// GetDefaultsPath returns the path to defaults.toml +func GetDefaultsPath() (string, error) { + configDir, err := GetConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, "defaults.toml"), nil +} + +// GetLocalSecretsPath returns the path to local-secrets.toml +func GetLocalSecretsPath() (string, error) { + configDir, err := GetConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, "local-secrets.toml"), nil +} + +// LoadDefaultsConfig loads the defaults configuration +func LoadDefaultsConfig() (*DefaultsConfig, error) { + path, err := GetDefaultsPath() + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + // Return default config if file doesn't exist + config := &DefaultsConfig{} + config.Defaults.Server = "wss://sqlrsync.com" + return config, nil + } + return nil, fmt.Errorf("failed to read defaults config file %s: %w", path, err) + } + + var config DefaultsConfig + if _, err := toml.Decode(string(data), &config); err != nil { + return nil, fmt.Errorf("failed to parse TOML defaults config: %w", err) + } + + // Set default server if not specified + if config.Defaults.Server == "" { + config.Defaults.Server = "wss://sqlrsync.com" + } + + return &config, nil +} + +// SaveDefaultsConfig saves the defaults configuration +func SaveDefaultsConfig(config *DefaultsConfig) error { + path, err := GetDefaultsPath() + if err != nil { + return err + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to create defaults config file %s: %w", path, err) + } + defer file.Close() + + encoder := toml.NewEncoder(file) + if err := encoder.Encode(config); err != nil { + return fmt.Errorf("failed to write defaults config: %w", err) + } + + return nil +} + +// LoadLocalSecretsConfig loads the local secrets configuration +func LoadLocalSecretsConfig() (*LocalSecretsConfig, error) { + path, err := GetLocalSecretsPath() + if err != nil { + return nil, err + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + // Return empty config if file doesn't exist + return &LocalSecretsConfig{ + SQLRsyncDatabases: []SQLRsyncDatabase{}, + }, nil + } + return nil, fmt.Errorf("failed to read local-secrets config file %s: %w", path, err) + } + + var config LocalSecretsConfig + if _, err := toml.Decode(string(data), &config); err != nil { + return nil, fmt.Errorf("failed to parse TOML local-secrets config: %w", err) + } + + return &config, nil +} + +// SaveLocalSecretsConfig saves the local secrets configuration +func SaveLocalSecretsConfig(config *LocalSecretsConfig) error { + path, err := GetLocalSecretsPath() + if err != nil { + return err + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create directory %s: %w", dir, err) + } + + file, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to create local-secrets config file %s: %w", path, err) + } + defer file.Close() + + // Set file permissions to 0600 (read/write for owner only) + if err := file.Chmod(0600); err != nil { + return fmt.Errorf("failed to set permissions on local-secrets config file: %w", err) + } + + encoder := toml.NewEncoder(file) + if err := encoder.Encode(config); err != nil { + return fmt.Errorf("failed to write local-secrets config: %w", err) + } + + return nil +} + +// FindDatabaseByPath finds a database configuration by local path +func (c *LocalSecretsConfig) FindDatabaseByPath(path string) *SQLRsyncDatabase { + // Normalize the search path to absolute path + searchPath, err := filepath.Abs(path) + if err != nil { + // If we can't get absolute path, fall back to original comparison + for i := range c.SQLRsyncDatabases { + if c.SQLRsyncDatabases[i].LocalPath == path { + return &c.SQLRsyncDatabases[i] + } + } + return nil + } + + for i := range c.SQLRsyncDatabases { + // Normalize the stored path to absolute path for comparison + storedPath, err := filepath.Abs(c.SQLRsyncDatabases[i].LocalPath) + if err != nil { + // If we can't normalize stored path, compare as-is + if c.SQLRsyncDatabases[i].LocalPath == path || c.SQLRsyncDatabases[i].LocalPath == searchPath { + return &c.SQLRsyncDatabases[i] + } + } else { + // Compare normalized absolute paths + if storedPath == searchPath { + return &c.SQLRsyncDatabases[i] + } + } + } + return nil +} + +// UpdateOrAddDatabase updates an existing database or adds a new one +func (c *LocalSecretsConfig) UpdateOrAddDatabase(db SQLRsyncDatabase) { + for i := range c.SQLRsyncDatabases { + if c.SQLRsyncDatabases[i].LocalPath == db.LocalPath { + // Update existing database + c.SQLRsyncDatabases[i] = db + return + } + } + // Add new database + c.SQLRsyncDatabases = append(c.SQLRsyncDatabases, db) +} + +// RemoveDatabase removes a database configuration by local path +func (c *LocalSecretsConfig) RemoveDatabase(path string) { + for i, db := range c.SQLRsyncDatabases { + if db.LocalPath == path { + // Remove database from slice + c.SQLRsyncDatabases = append(c.SQLRsyncDatabases[:i], c.SQLRsyncDatabases[i+1:]...) + return + } + } +} + +// NewDashSQLRsync creates a new DashSQLRsync instance for the given database path +func NewDashSQLRsync(databasePath string) *DashSQLRsync { + if strings.Contains(databasePath, "@") { + databasePath = strings.Split(databasePath, "@")[0] + } + + return &DashSQLRsync{ + DatabasePath: databasePath, + } +} + +// FilePath returns the path to the -sqlrsync file +func (d *DashSQLRsync) FilePath() string { + return d.DatabasePath + "-sqlrsync" +} + +// Exists checks if the -sqlrsync file exists +func (d *DashSQLRsync) Exists() bool { + _, err := os.Stat(d.FilePath()) + return err == nil +} + +// Read reads the -sqlrsync file and populates the struct fields +func (d *DashSQLRsync) Read() error { + if !d.Exists() { + return fmt.Errorf("file does not exist: %s", d.FilePath()) + } + + file, err := os.Open(d.FilePath()) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if strings.HasPrefix(line, "#") || line == "" { + continue + } + + if strings.HasPrefix(line, "sqlrsync ") { + parts := strings.Fields(line) + if len(parts) >= 2 { + d.RemotePath = parts[1] + } + + for _, part := range parts { + if strings.HasPrefix(part, "--pullKey=") { + d.PullKey = strings.TrimPrefix(part, "--pullKey=") + } + if strings.HasPrefix(part, "--replicaID=") { + d.ReplicaID = strings.TrimPrefix(part, "--replicaID=") + } + if strings.HasPrefix(part, "--server=") { + d.Server = strings.TrimPrefix(part, "--server=") + } + } + break + } + } + + return scanner.Err() +} + +// Write writes the -sqlrsync file with the given remote path and pull key +func (d *DashSQLRsync) Write(remotePath string, localName string, replicaID string, pullKey string, serverURL string) error { + d.RemotePath = remotePath + d.PullKey = pullKey + + localNameTree := strings.Split(localName, "/") + localName = localNameTree[len(localNameTree)-1] + + content := fmt.Sprintf(`#!/bin/bash +# https://sqlrsync.com/help/dash-sqlrsync +sqlrsync %s %s --replicaID=%s --pullKey=%s --server=%s "$@" + +`, remotePath, localName, replicaID, pullKey, serverURL) + + if err := os.WriteFile(d.FilePath(), []byte(content), 0755); err != nil { + return fmt.Errorf("failed to write -sqlrsync file: %w", err) + } + + return nil +} + +// Remove removes the -sqlrsync file if it exists +func (d *DashSQLRsync) Remove() error { + if !d.Exists() { + return nil + } + return os.Remove(d.FilePath()) +} diff --git a/client/auth/resolver.go b/client/auth/resolver.go new file mode 100644 index 0000000..d807e2f --- /dev/null +++ b/client/auth/resolver.go @@ -0,0 +1,275 @@ +package auth + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strings" + + "go.uber.org/zap" +) + +// ResolveResult contains the resolved authentication information +type ResolveResult struct { + AuthToken string + ReplicaID string + ServerURL string + RemotePath string + LocalPath string + ShouldPrompt bool +} + +// ResolveRequest contains the parameters for authentication resolution +type ResolveRequest struct { + LocalPath string + RemotePath string + ServerURL string + ProvidedPullKey string + ProvidedPushKey string + ProvidedReplicaID string + Operation string // "pull", "push", "subscribe" + Logger *zap.Logger +} + +// Resolver handles authentication and configuration resolution +type Resolver struct { + logger *zap.Logger +} + +// NewResolver creates a new authentication resolver +func NewResolver(logger *zap.Logger) *Resolver { + return &Resolver{ + logger: logger, + } +} + +// Resolve determines the authentication method and configuration for an operation +func (r *Resolver) Resolve(req *ResolveRequest) (*ResolveResult, error) { + result := &ResolveResult{ + ServerURL: req.ServerURL, + LocalPath: req.LocalPath, + RemotePath: req.RemotePath, + } + + // 1. Try environment variable first + if token := os.Getenv("SQLRSYNC_AUTH_TOKEN"); token != "" { + r.logger.Debug("Using SQLRSYNC_AUTH_TOKEN from environment") + result.AuthToken = token + result.ReplicaID = req.ProvidedReplicaID + return result, nil + } + + // 2. Try explicitly provided keys + if req.ProvidedPullKey != "" { + r.logger.Debug("Using provided pull key") + result.AuthToken = req.ProvidedPullKey + result.ReplicaID = req.ProvidedReplicaID + return result, nil + } + + if req.ProvidedPushKey != "" { + r.logger.Debug("Using provided push key") + result.AuthToken = req.ProvidedPushKey + result.ReplicaID = req.ProvidedReplicaID + return result, nil + } + + // 3. For operations with local paths, check stored configurations + if req.LocalPath != "" { + // Get absolute path for lookups + absLocalPath, err := filepath.Abs(req.LocalPath) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path: %w", err) + } + + // If no server was explicitly provided (still default), try to get server from local config + if req.ServerURL == "wss://sqlrsync.com" { + if localSecretsConfig, err := LoadLocalSecretsConfig(); err == nil { + if dbConfig := localSecretsConfig.FindDatabaseByPath(absLocalPath); dbConfig != nil { + r.logger.Debug("Using server URL from local secrets config", + zap.String("configuredServer", dbConfig.Server), + zap.String("defaultServer", req.ServerURL)) + result.ServerURL = dbConfig.Server + } + } + } + + // Check local secrets config for push operations + if req.Operation == "push" { + if authResult, err := r.resolveFromLocalSecrets(absLocalPath, result.ServerURL, result); err == nil { + return authResult, nil + } + } + + // Check -sqlrsync file for pull/subscribe operations + if req.Operation == "pull" || req.Operation == "subscribe" { + if authResult, err := r.resolveFromDashFile(absLocalPath, result); err == nil { + return authResult, nil + } + } + } + + // 4. For push operations, check if we need to prompt for admin key + if req.Operation == "push" { + if os.Getenv("SQLRSYNC_ADMIN_KEY") != "" { + r.logger.Debug("Using SQLRSYNC_ADMIN_KEY from environment") + result.AuthToken = os.Getenv("SQLRSYNC_ADMIN_KEY") + result.ShouldPrompt = false + return result, nil + } + + // Need to prompt for admin key + result.ShouldPrompt = true + return result, nil + } + + // 5. No authentication found + return nil, fmt.Errorf("no authentication credentials found") +} + +// resolveFromLocalSecrets attempts to resolve auth from local-secrets.toml +func (r *Resolver) resolveFromLocalSecrets(absLocalPath, serverURL string, result *ResolveResult) (*ResolveResult, error) { + r.logger.Debug("Attempting to resolve from local secrets", zap.String("absLocalPath", absLocalPath), zap.String("serverURL", serverURL)) + + localSecretsConfig, err := LoadLocalSecretsConfig() + if err != nil { + r.logger.Debug("Failed to load local secrets config", zap.Error(err)) + return nil, fmt.Errorf("failed to load local secrets config: %w", err) + } + + r.logger.Debug("Loaded local secrets config", zap.Int("databaseCount", len(localSecretsConfig.SQLRsyncDatabases))) + for i, db := range localSecretsConfig.SQLRsyncDatabases { + r.logger.Debug("Checking database config", zap.Int("index", i), zap.String("storedPath", db.LocalPath), zap.String("server", db.Server)) + } + + dbConfig := localSecretsConfig.FindDatabaseByPath(absLocalPath) + if dbConfig == nil { + r.logger.Debug("No database configuration found", zap.String("searchPath", absLocalPath)) + return nil, fmt.Errorf("no database configuration found for path: %s", absLocalPath) + } + + if dbConfig.PushKey == "" { + r.logger.Debug("No push key found for database", zap.String("path", absLocalPath)) + return nil, fmt.Errorf("no push key found for database") + } + + if dbConfig.Server != serverURL { + r.logger.Debug("Server URL mismatch", + zap.String("configured", dbConfig.Server), + zap.String("requested", serverURL)) + return nil, fmt.Errorf("server URL mismatch: configured=%s, requested=%s", dbConfig.Server, serverURL) + } + + r.logger.Debug("Found authentication in local secrets config") + result.AuthToken = dbConfig.PushKey + result.ReplicaID = dbConfig.ReplicaID + result.RemotePath = dbConfig.RemotePath + result.ServerURL = dbConfig.Server + + return result, nil +} + +// resolveFromDashFile attempts to resolve auth from -sqlrsync file +func (r *Resolver) resolveFromDashFile(localPath string, result *ResolveResult) (*ResolveResult, error) { + dashSQLRsync := NewDashSQLRsync(localPath) + if !dashSQLRsync.Exists() { + return nil, fmt.Errorf("no -sqlrsync file found for: %s", localPath) + } + + if err := dashSQLRsync.Read(); err != nil { + return nil, fmt.Errorf("failed to read -sqlrsync file: %w", err) + } + + if dashSQLRsync.PullKey == "" { + return nil, fmt.Errorf("no pull key found in -sqlrsync file") + } + + r.logger.Debug("Found authentication in -sqlrsync file") + result.AuthToken = dashSQLRsync.PullKey + result.ReplicaID = dashSQLRsync.ReplicaID + result.RemotePath = dashSQLRsync.RemotePath + result.ServerURL = dashSQLRsync.Server + + return result, nil +} + +// PromptForAdminKey prompts the user for an admin key +func (r *Resolver) PromptForAdminKey(serverURL string) (string, error) { + httpServer := strings.Replace(serverURL, "ws", "http", 1) + fmt.Println("No Key provided. Creating a new Replica? Get a key at " + httpServer + "/namespaces") + fmt.Print(" Enter an Account Admin Key to create a new Replica: ") + + reader := bufio.NewReader(os.Stdin) + token, err := reader.ReadString('\n') + if err != nil { + return "", fmt.Errorf("failed to read admin key: %w", err) + } + + token = strings.TrimSpace(token) + if token == "" { + return "", fmt.Errorf("admin key cannot be empty") + } + + return token, nil +} + +// SavePushResult saves the result of a successful push operation +func (r *Resolver) SavePushResult(localPath, serverURL, remotePath, replicaID, pushKey string) error { + absLocalPath, err := filepath.Abs(localPath) + if err != nil { + return fmt.Errorf("failed to get absolute path: %w", err) + } + + localSecretsConfig, err := LoadLocalSecretsConfig() + if err != nil { + return fmt.Errorf("failed to load local secrets config: %w", err) + } + + dbConfig := SQLRsyncDatabase{ + LocalPath: absLocalPath, + Server: serverURL, + ReplicaID: replicaID, + RemotePath: remotePath, + PushKey: pushKey, + } + + localSecretsConfig.UpdateOrAddDatabase(dbConfig) + + if err := SaveLocalSecretsConfig(localSecretsConfig); err != nil { + return fmt.Errorf("failed to save local secrets config: %w", err) + } + + r.logger.Info("Saved push authentication to local secrets config") + return nil +} + +// SavePullResult saves the result of a successful pull operation +func (r *Resolver) SavePullResult(localPath, serverURL, remotePath, replicaID, pullKey string) error { + dashSQLRsync := NewDashSQLRsync(localPath) + + localNameTree := strings.Split(localPath, "/") + localName := localNameTree[len(localNameTree)-1] + + if err := dashSQLRsync.Write(remotePath, localName, replicaID, pullKey, serverURL); err != nil { + return fmt.Errorf("failed to create -sqlrsync file: %w", err) + } + + r.logger.Info("Created -sqlrsync file", zap.String("path", dashSQLRsync.FilePath())) + return nil +} + +// CheckNeedsDashFile determines if a -sqlrsync file should be created +func (r *Resolver) CheckNeedsDashFile(localPath, remotePath string) bool { + dashSQLRsync := NewDashSQLRsync(localPath) + if !dashSQLRsync.Exists() { + return true + } + + // Read existing file to check if remote path matches + if err := dashSQLRsync.Read(); err != nil { + return true // If we can't read it, recreate it + } + + return dashSQLRsync.RemotePath != remotePath +} \ No newline at end of file diff --git a/client/config.go b/client/config.go index 5b20ebe..77724b3 100644 --- a/client/config.go +++ b/client/config.go @@ -209,7 +209,7 @@ func (c *LocalSecretsConfig) RemoveDatabase(path string) { // NewDashSQLRsync creates a new DashSQLRsync instance for the given database path func NewDashSQLRsync(databasePath string) *DashSQLRsync { - if(strings.Contains(databasePath, "@")) { + if strings.Contains(databasePath, "@") { databasePath = strings.Split(databasePath, "@")[0] } @@ -282,8 +282,9 @@ func (d *DashSQLRsync) Write(remotePath string, localName string, replicaID stri localName = localNameTree[len(localNameTree)-1] content := fmt.Sprintf(`#!/bin/bash -# https://sqlrsync.com/docs/-sqlrsync -sqlrsync %s %s --replicaID=%s --pullKey=%s --server=%s +# https://sqlrsync.com/help/dash-sqlrsync +sqlrsync %s %s --replicaID=%s --pullKey=%s --server=%s "$@" + `, remotePath, localName, replicaID, pullKey, serverURL) if err := os.WriteFile(d.FilePath(), []byte(content), 0755); err != nil { diff --git a/client/main.go b/client/main.go index 08c34bb..048e766 100644 --- a/client/main.go +++ b/client/main.go @@ -1,122 +1,62 @@ package main import ( - "bufio" "fmt" "log" "os" "path/filepath" "strconv" "strings" - "time" - "github.com/fatih/color" "github.com/spf13/cobra" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "github.com/sqlrsync/sqlrsync.com/bridge" - "github.com/sqlrsync/sqlrsync.com/remote" + "github.com/sqlrsync/sqlrsync.com/sync" ) +var VERSION = "0.0.1" var ( - serverURL string - verbose bool - dryRun bool - setPublic bool - timeout int - logger *zap.Logger - inspectTraffic bool - inspectionDepth int - newReadToken bool - pullKey string - pushKey string - replicaID string + serverURL string + verbose bool + dryRun bool + SetPublic bool + SetUnlisted bool + subscribing bool + pullKey string + pushKey string + replicaID string + logger *zap.Logger + showVersion bool ) var rootCmd = &cobra.Command{ Use: "sqlrsync [ORIGIN] [REPLICA] or [LOCAL] or [REMOTE]", - Short: "SQLite Rsync - ", - Long: `A rsync-like utility built specifically to replicate SQLite databases -to sqlrsync.com for features such as backup, version control, and distribution. - -Using the page hashing algorithm designed by the authors of SQLite, only -changed pages are communicated between ORIGIN and REPLICA, allowing for -efficient synchronization. - -REPLICA becomes a copy of a snapshot of ORIGIN as it existed when the sqlrsync -command started. If other processes change the content of ORIGIN while this -command is running, those changes will be applied to ORIGIN, but they are not -transferred to REPLICA. Thus, REPLICA ends up as a fully-consistent snapshot -of ORIGIN at an instant in time. - -Learn about SQLite Pages: sqlite.org/fileformat2.html -Learn about sqlite3_rsync: sqlite.org/rsync.html - -This utility, a wrapper around sqlite3_rsync, uses sqlrsync.com as the REMOTE -server to allow specific benefits over simply using the utility the developers -of the SQLite project provide. - -ORIGIN and REPLICA can be LOCAL or REMOTE. Both cannot be REMOTE. - -LOCAL is this local machine in the current working directory (or prefixed with -./, ../, or /). -REMOTE is a database hosted on sqlrsync.com, and must have at least one / in its -path. - -If REPLICA does not already exist, it is created. - -Local databases may be "live" while this utility is running. Other programs can have -active connections to the local database (in either role) without any disruption. -Other programs can write to/read from ORIGIN, and can read from REPLICA while this -utility runs. - -All of the table (and index) content will be byte-for-byte identical in the -replica. However, there can be some minor changes in the database header. See -Limitations at sqlite.org/rsync.html - -A REMOTE ORIGIN database may be specified with an appended @, such as: - mynamespace/mydb.sqlite # Requests the latest uploaded version - mynamespace/mydb.sqlite@ # VERSION is a number greater than 0 and - identifies the nth version uploaded - mynamespace/mydb.sqlite@latest. # Redundant to leaving the value unspecified - mynamespace/mydb.sqlite@latest- # N is a number greater than 0 and - will cause the version N uploads prior to the latest version to be used. - -When ORIGIN is LOCAL and REPLICA is LOCAL, a local transfer (no network) causes -REPLICA to become a copy of ORIGIN. - -When ORIGIN is LOCAL and REPLICA is REMOTE, a secure websocket connects to -sqlrsync.com and then any pages REPLICA needs synchronized are transferred to -the remote database. - -When ORIGIN is LOCAL and REPLICA is unspecified, the remote REPLICA is created -at sqlrsync.com using the default namespace and database name derived from ORIGIN. - -When ORIGIN is REMOTE and REPLICA is LOCAL, the local REPLICA becomes a complete -copy of ORIGIN. - -When ORIGIN is REMOTE and REPLICA is unspecified, a local REPLICA is created -at using the database name derived from ORIGIN. + Short: "SQLRsync v" + VERSION, + Long: `SQLRsync v` + VERSION + ` +A web-enabled rsync-like utility for SQLite databases with subscription support. Usage modes: -1. Direct local sync: sqlrsync LOCAL LOCAL [OPTIONS] - Example: sqlrsync mydb.sqlite ./mydb2.sqlite - -2. Push to sqlrsync.com: sqlrsync LOCAL [REMOTE] [OPTIONS] - Example: sqlrsync mydb.sqlite mynamespace/mydb.sqlite - -3. Pull from sqlrsync.com: sqlrsync REMOTE [LOCAL] [OPTIONS] - or sqlrsync REMOTE@ [LOCAL] [OPTIONS] - Example: sqlrsync mynamespace/mydb.sqlite - Example: sqlrsync mynamespace/mydb.sqlite@latest-1 /overhere/mydb.sqlite - Example: sqlrsync mynamespace/mydb.sqlite@7 - -Eternal gratitude to the authors of the SQLite Project for their contributions -to the world of data storage. +1. Push to server: sqlrsync LOCAL [REMOTE] [OPTIONS] +2. Pull from server: sqlrsync REMOTE [LOCAL] [OPTIONS] +3. Pull with subscription: sqlrsync REMOTE [LOCAL] --subscribe [OPTIONS] +4. Local to local sync: sqlrsync LOCAL1 LOCAL2 [OPTIONS] + +Where: +- REMOTE is a path like namespace/db.sqlite (remote server) +- LOCAL is a local file path like ./db.sqlite or db.sqlite (local file) + +Limitations: +- Pushing to the server requires page size of 4096 (default for SQLite). + Check by querying "PRAGMA page_size;". + +Examples: + sqlrsync mydb.sqlite # Push local to remote + sqlrsync namespace/db.sqlite # Pull to local db.sqlite + sqlrsync namespace/db.sqlite --subscribe # Pull and watch for updates + sqlrsync mydb.sqlite mydb2.sqlite # Local to local sync `, - - Version: "1.0.0", + Version: VERSION, PreRun: func(cmd *cobra.Command, args []string) { setupLogger() }, @@ -125,491 +65,141 @@ to the world of data storage. SilenceUsage: true, } -func showLocalError(message string) { - fmt.Println(color.RedString("[error]"), message) -} - func runSync(cmd *cobra.Command, args []string) error { - // Determine the sync mode based on arguments - - // The two arguments are ORIGIN REPLICA. - // Either can be LOCAL or REMOTE. - // LOCAL is determined if the path begins with /, ./, or ../ OR doesn't have a / anywhere in it - // REMOTE is determined by !LOCAL - // DBNAME is string after the final / of the REMOTE path - // - // Examples: - // IF ORIGIN:LOCAL REPLICA:LOCAL - // runDirectSync(ORIGIN,REPLICA); - // IF ORIGIN:LOCAL REPLICA:REMOTE - // runPushSync(ORIGIN,REPLICA); - // IF ORIGIN:REMOTE REPLICA:LOCAL - // runPullSync(ORIGIN,REPLICA); - // IF ORIGIN:LOCAL (no REPLICA) - // runPushSync(ORIGIN, PREFIX/) - // IF (no ORIGIN) REPLICA:REMOTE - // runPullSync(REPLICA, ) - // IF (no ORIGIN) REPLICA:LOCAL - // This cannot happen - - isLocal := func(path string) bool { - return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") || strings.HasPrefix(path, "../") || strings.HasPrefix(path, "~/") || !strings.Contains(path, "/") - } - if len(args) == 0 { return cmd.Help() - } else if len(args) == 2 { - // Two arguments: ORIGIN REPLICA - origin, replica := args[0], args[1] - originLocal := isLocal(origin) - replicaLocal := isLocal(replica) - - if originLocal && replicaLocal { - // IF ORIGIN:LOCAL REPLICA:LOCAL - return runDirectSync(origin, replica) - } else if originLocal && !replicaLocal { - // IF ORIGIN:LOCAL REPLICA:REMOTE - return runPushSync(origin, replica) - } else if !originLocal && replicaLocal { - // IF ORIGIN:REMOTE REPLICA:LOCAL - return runPullSync(origin, replica) - } else { - return fmt.Errorf("remote to remote sync not supported") - } - } else if len(args) == 1 { - // One argument: either ORIGIN (push/pull depends on ~.config & -sqlrsync) or REPLICA (for pull) - path := args[0] - if isLocal(path) { - // IF ORIGIN:LOCAL (no REPLICA) - varies - localSecretsConfig, err := LoadLocalSecretsConfig() - if err != nil { - return fmt.Errorf("failed to load local secrets config: %w", err) - } - // Get absolute path for the local database - absPath, err := filepath.Abs(path) - if err == nil { - // If we have a push key for this database, use it to push - pushedDBInfo := localSecretsConfig.FindDatabaseByPath(absPath) - if pushedDBInfo != nil && pushedDBInfo.PushKey != "" && pushedDBInfo.Server == serverURL { - pushKey = pushedDBInfo.PushKey - return runPushSync(absPath, pushedDBInfo.RemotePath) - } - } - - // else if there is a -sqlrsync file, do a pull instead - dashSQLRsync := NewDashSQLRsync(path) - if dashSQLRsync.Exists() { - if err := dashSQLRsync.Read(); err != nil { - return fmt.Errorf("failed to read -sqlrsync file: %w", err) - } - if dashSQLRsync.RemotePath == "" { - return fmt.Errorf("invalid -sqlrsync file: missing remote path") - } - if dashSQLRsync.Server == serverURL { - localPath := "" - version := "latest" - localPath, version, _ = strings.Cut(path, "@") - - pullKey = dashSQLRsync.PullKey - replicaID = dashSQLRsync.ReplicaID - serverURL = dashSQLRsync.Server - return runPullSync(dashSQLRsync.RemotePath+"@"+version, localPath) - } - } - - // else push this file up - return runPushSync(path, "") - } else { - // IF REPLICA:REMOTE (no ORIGIN) - pull to default local name - dbname := filepath.Base(path) - return runPullSync(path, dbname) - } - - } else { - return fmt.Errorf("invalid arguments. Usage:\n1. Direct local sync: sqlrsync ORIGIN REPLICA [OPTIONS]\n2. Push to sqlrsync.com: sqlrsync ORIGIN [REPLICA] [OPTIONS]\n3. Pull from sqlrsync.com: sqlrsync REPLICA [OPTIONS] or sqlrsync REPLICA@ [OPTIONS]") - } -} - -func runDirectSync(originPath, replicaPath string) error { - // Validate that origin database file exists - if _, err := os.Stat(originPath); os.IsNotExist(err) { - return fmt.Errorf("origin database file does not exist: %s", originPath) - } - - logger.Info("Starting direct SQLite Rsync synchronization", - zap.String("origin", originPath), - zap.String("replica", replicaPath), - zap.Bool("dryRun", dryRun)) - - // Create local client for SQLite operations - localClient, err := bridge.New(&bridge.Config{ - DatabasePath: originPath, - DryRun: dryRun, - Logger: logger.Named("local"), - }) - if err != nil { - return fmt.Errorf("failed to create local client: %w", err) } - defer localClient.Close() - // Get database info - dbInfo, err := localClient.GetDatabaseInfo() - if err != nil { - return fmt.Errorf("failed to get database info: %w", err) - } - - logger.Info("Database information", - zap.Int("pageSize", dbInfo.PageSize), - zap.Int("pageCount", dbInfo.PageCount), - zap.String("journalMode", dbInfo.JournalMode)) - - // Perform direct sync - if err := localClient.RunDirectSync(replicaPath); err != nil { - return fmt.Errorf("direct synchronization failed: %w", err) - } + // Preprocess variables + serverURL = strings.TrimRight(serverURL, "/") - logger.Info("Direct synchronization completed successfully") - fmt.Println("✅ Locally replicated", originPath, "to", replicaPath+".") - return nil -} - -func runPushSync(localPath string, remotePath string) error { - logger.Info("Running a PUSH sync", - zap.String("local", localPath), - zap.String("remote", remotePath)) - // Validate that database file exists - if _, err := os.Stat(localPath); os.IsNotExist(err) { - return fmt.Errorf("database file does not exist: %s", localPath) - } - - // Load local secrets config - localSecretsConfig, err := LoadLocalSecretsConfig() + // Determine operation based on arguments and flags + operation, localPath, remotePath, err := determineOperation(args) if err != nil { - return fmt.Errorf("failed to load local secrets config: %w", err) + return err } - // Get absolute path for the local database - absLocalPath, err := filepath.Abs(localPath) - if err != nil { - return fmt.Errorf("failed to get absolute path: %w", err) - } - - // Find or create database entry - dbConfig := localSecretsConfig.FindDatabaseByPath(absLocalPath) - if dbConfig == nil { - // Create new database entry - dbConfig = &SQLRsyncDatabase{ - LocalPath: absLocalPath, - Server: serverURL, - } - } else { - if serverURL == "" { - serverURL = dbConfig.Server - } - if pushKey == "" { - pushKey = dbConfig.PushKey - } - if remotePath == "" { - remotePath = dbConfig.RemotePath - } - } - - if remotePath == "" { - // Check for -sqlrsync file - dashSQLRsync := NewDashSQLRsync(absLocalPath) - if !dashSQLRsync.Exists() { - fmt.Println("No -sqlrsync file found. This database hasn't been pushed to SQLRsync Server before.") - fmt.Println("No REMOTE name provided. Will use Account Admin Key's default Replica name.") - } else { - logger.Info("Found -sqlrsync file.") - } - } - - // Check if we have a push key for this database - if os.Getenv("SQLRSYNC_ADMIN_KEY") == "" && pushKey == "" { - httpServer := strings.Replace(serverURL, "ws", "http", 1) - fmt.Println("No Key provided. Creating a new Replica? Get a key at " + httpServer + "/namespaces") - fmt.Print(" Enter an Account Admin Key to create a new Replica: ") - reader := bufio.NewReader(os.Stdin) - token, err := reader.ReadString('\n') - if err != nil { - return fmt.Errorf("failed to read push key: %w", err) - } - token = strings.TrimSpace(token) - - if token == "" { - return fmt.Errorf("push key cannot be empty") + versionRaw := strings.SplitN(remotePath, "@", 2) + version := "latest" + if len(versionRaw) == 2 { + version = strings.TrimPrefix(strings.ToLower(versionRaw[1]), "v") + remotePath = versionRaw[0] + versionCheck, _ := strconv.Atoi(version) + if strings.HasPrefix(version, "latest") && versionCheck <= 0 { + return fmt.Errorf("invalid version specified: %s (must be `latest`, `latest-`, or `` where the number is greater than 0)", version) } - pushKey = token - fmt.Println() } - logger.Info("Starting push synchronization to sqlrsync.com", - zap.String("local", localPath), - zap.String("remote", remotePath), - zap.String("server", serverURL), - zap.Bool("dryRun", dryRun)) - - fmt.Println("PUSHing up to " + serverURL + " ...") - - // Create local client for SQLite operations - localClient, err := bridge.New(&bridge.Config{ - DatabasePath: localPath, - DryRun: dryRun, - Logger: logger.Named("local"), - }) - if err != nil { - return fmt.Errorf("failed to create local client: %w", err) - } - defer localClient.Close() - - localHostname, _ := os.Hostname() - - // Create remote client for WebSocket transport - remoteClient, err := remote.New(&remote.Config{ - ServerURL: serverURL + "/sapi/push/" + remotePath, - PingPong: false, - Timeout: timeout, - AuthToken: pushKey, - Logger: logger.Named("remote"), - EnableTrafficInspection: inspectTraffic, - LocalHostname: localHostname, - LocalAbsolutePath: absLocalPath, - InspectionDepth: inspectionDepth, - SendConfigCmd: needsToBuildDashSQLRSyncFile(localPath, remotePath), - SetPublic: setPublic, + visibility := 0 + if SetPublic && SetUnlisted { + return fmt.Errorf("cannot set both public and unlisted visibility") + } else if SetPublic { + visibility = 2 + } else if SetUnlisted { + visibility = 1 + } + + // Create sync coordinator + coordinator := sync.NewCoordinator(&sync.Config{ + ServerURL: serverURL, + ProvidedAuthToken: getAuthToken(), + ProvidedPullKey: pullKey, + ProvidedPushKey: pushKey, + ProvidedReplicaID: replicaID, + LocalPath: localPath, + RemotePath: remotePath, + ReplicaPath: remotePath, // For LOCAL TO LOCAL, remotePath is actually the replica path + Version: version, // Could be extended to parse @version syntax + Operation: operation, + SetVisibility: visibility, + DryRun: dryRun, + Logger: logger, + Verbose: verbose, }) - if err != nil { - return fmt.Errorf("failed to create remote client: %w", err) - } - defer remoteClient.Close() - - // Connect to remote server - if err := remoteClient.Connect(); err != nil { - return fmt.Errorf("%w", err) - } - - // Get database info - dbInfo, err := localClient.GetDatabaseInfo() - if err != nil { - return fmt.Errorf("failed to get database info: %w", err) - } - - logger.Info("Database information", - zap.Int("pageSize", dbInfo.PageSize), - zap.Int("pageCount", dbInfo.PageCount), - zap.String("journalMode", dbInfo.JournalMode)) - - // Perform the sync by bridging local and remote - if err := performPushSync(localClient, remoteClient); err != nil { - return fmt.Errorf("push synchronization failed: %w", err) - } - - logger.Info("Push synchronization completed successfully") - - dbConfig.LastPush = time.Now() - if remoteClient.GetNewPushKey() != "" { - fmt.Println("🔑 This database is now PUSH-enabled on this system.") - fmt.Println(" A new, replica-specific PUSH key has been stored at ~/.config/sqlrsync/local-secrets.toml") - dbConfig.ReplicaID = remoteClient.GetReplicaID() - dbConfig.RemotePath = remoteClient.GetReplicaPath() - dbConfig.PushKey = remoteClient.GetNewPushKey() - } - localSecretsConfig.UpdateOrAddDatabase(*dbConfig) - // Save the updated config - if err := SaveLocalSecretsConfig(localSecretsConfig); err != nil { - logger.Warn("Failed to save local secrets config", zap.Error(err)) - } + // Execute the operation + return coordinator.Execute() +} - if setPublic { - fmt.Println("🌐 This replica is now publicly accessible.") - fmt.Println(" Share this database with sqlrsync.com/" + remoteClient.GetReplicaPath()) +func determineOperation(args []string) (sync.Operation, string, string, error) { + isLocal := func(path string) bool { + return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./") || + strings.HasPrefix(path, "../") || strings.HasPrefix(path, "~/") || + !strings.Contains(path, "/") } - if needsToBuildDashSQLRSyncFile(localPath, remotePath) { - token := remoteClient.GetNewPullKey() - replicaID := remoteClient.GetReplicaID() - replicaPath := remoteClient.GetReplicaPath() - - dashSQLRsync := NewDashSQLRsync(localPath) - if err := dashSQLRsync.Write(replicaPath, localPath, replicaID, token, serverURL); err != nil { - return fmt.Errorf("failed to create shareable config file: %w", err) + if len(args) == 1 { + path := args[0] + if isLocal(path) { + // LOCAL -> push to default remote + return sync.OperationPush, path, "", nil + } else { + // REMOTE -> pull to default local + dbname := filepath.Base(path) + if subscribing { + return sync.OperationSubscribe, dbname, path, nil + } + return sync.OperationPull, dbname, path, nil } - fmt.Println("🔑 Shareable config file created:", dashSQLRsync.FilePath()) - fmt.Println(" Anyone with this file will be able to PULL any version of this database from sqlrsync.com") } - return nil -} - -func isValidVersion(version string) bool { - // Check if the version is a number - if num, err := strconv.Atoi(version); err == nil { - return num > 0 - } + if len(args) == 2 { + origin, replica := args[0], args[1] + originLocal := isLocal(origin) + replicaLocal := isLocal(replica) - // Check for "latest" or "latest-" - if version == "latest" || strings.HasPrefix(version, "latest-") { - _, after, _ := strings.Cut(version, "-") - if after != "" { - if num, err := strconv.Atoi(after); err == nil { - return num > 0 + if originLocal && !replicaLocal { + // LOCAL REMOTE -> push + return sync.OperationPush, origin, replica, nil + } else if !originLocal && replicaLocal { + // REMOTE LOCAL -> pull (or subscribe) + if subscribing { + return sync.OperationSubscribe, replica, origin, nil } + return sync.OperationPull, replica, origin, nil + } else if originLocal && replicaLocal { + // LOCAL LOCAL -> direct local sync + return sync.OperationLocalSync, origin, replica, nil } else { - return true + return sync.Operation(0), "", "", fmt.Errorf("remote to remote sync not supported") } } - return false -} - -func needsToBuildDashSQLRSyncFile(filepath string, remotePath string) bool { - if !newReadToken { - return false - } - - dashSQLRsync := NewDashSQLRsync(filepath) - dashSQLRsync.Read() - // check if the {path}-sqlrsync file exists - return !(dashSQLRsync.Exists() && dashSQLRsync.RemotePath == remotePath) + return sync.Operation(0), "", "", fmt.Errorf("invalid arguments") } -func runPullSync(remotePath string, localPath string) error { - logger.Info("Starting pull synchronization from sqlrsync.com", - zap.String("remote", remotePath), - zap.String("local", localPath), - zap.String("server", serverURL), - zap.Bool("dryRun", dryRun)) - - version := "latest" - // if remotePath has an @, then we want to pass that version through - if strings.Contains(remotePath, "@") { - remotePath, version, _ = strings.Cut(remotePath, "@") - if version == "" { - version = "latest" - } - - // if version is not a number, `latest`, or `latest-` then error - if !isValidVersion(version) { - return fmt.Errorf("invalid version format: %s", version) - } - } - - fmt.Println("PULLing down from " + serverURL + "/" + remotePath + "@" + version + " ...") - - // Create remote client for WebSocket transport - remoteClient, err := remote.New(&remote.Config{ - ServerURL: serverURL + "/sapi/pull/" + remotePath, - AuthToken: pullKey, - ReplicaID: replicaID, - Timeout: timeout, - PingPong: false, - Logger: logger.Named("remote"), - EnableTrafficInspection: inspectTraffic, - InspectionDepth: inspectionDepth, - Version: version, - SendConfigCmd: needsToBuildDashSQLRSyncFile(localPath, remotePath), - }) - if err != nil { - return fmt.Errorf("failed to create remote client: %w", err) - } - defer remoteClient.Close() - - // Connect to remote server - if err := remoteClient.Connect(); err != nil { - return fmt.Errorf("%w", err) - } - - // Create local client for SQLite operations - localClient, err := bridge.New(&bridge.Config{ - DatabasePath: localPath, - DryRun: dryRun, - Logger: logger.Named("local"), - }) - if err != nil { - return fmt.Errorf("failed to create local client: %w", err) - } - defer localClient.Close() - - // Perform the sync by bridging remote and local (reverse direction for pull) - if err := performPullSync(localClient, remoteClient); err != nil { - return fmt.Errorf("pull synchronization failed: %w", err) - } - - if needsToBuildDashSQLRSyncFile(localPath, remotePath) { - token := remoteClient.GetNewPullKey() - dashSQLRsync := NewDashSQLRsync(localPath) - replicaID := remoteClient.GetReplicaID() - if err := dashSQLRsync.Write(remotePath, localPath, replicaID, token, serverURL); err != nil { - return fmt.Errorf("failed to create shareable config file: %w", err) - } +func getAuthToken() string { + // Try environment variable first + if token := os.Getenv("SQLRSYNC_AUTH_TOKEN"); token != "" { + return token } - logger.Info("Pull synchronization completed successfully") - return nil -} - -func performPushSync(localClient *bridge.Client, remoteClient *remote.Client) error { - // Create I/O bridge between local and remote clients - readFunc := func(buffer []byte) (int, error) { - return remoteClient.Read(buffer) + // Try pull/push keys + if pullKey != "" { + return pullKey } - - writeFunc := func(data []byte) error { - return remoteClient.Write(data) + if pushKey != "" { + return pushKey } - // Run the origin sync through the bridge - err := localClient.RunPushSync(readFunc, writeFunc) - - // After sync completes, signal remote to close gracefully - // Give a moment for any final messages to be sent - time.Sleep(500 * time.Millisecond) + // TODO: Could try to load from config files here - return err + return "" } -func performPullSync(localClient *bridge.Client, remoteClient *remote.Client) error { - // Create I/O bridge between remote and local clients (reverse direction for pull) - readFunc := func(buffer []byte) (int, error) { - return remoteClient.Read(buffer) - } - - writeFunc := func(data []byte) error { - return remoteClient.Write(data) +func setupLogger() { + //config := zap.NewDevelopmentConfig() + config := zap.Config{ + Level: zap.NewAtomicLevelAt(zap.InfoLevel), + Development: false, + DisableStacktrace: true, // This disables stack traces + Encoding: "console", + EncoderConfig: zap.NewProductionEncoderConfig(), + OutputPaths: []string{"stdout"}, + ErrorOutputPaths: []string{"stderr"}, } - // Run the replica sync through the bridge (local acts as replica for pull) - err := localClient.RunPullSync(readFunc, writeFunc) - - // After sync completes, signal remote to close gracefully - // Give a moment for any final messages to be sent - time.Sleep(500 * time.Millisecond) - - return err -} - -func Execute() error { - return rootCmd.Execute() -} - -func init() { - rootCmd.Flags().StringVar(&pullKey, "pullKey", "", "Authentication key for pull operations") - rootCmd.Flags().StringVar(&pushKey, "pushKey", "", "Authentication key for push operations") - rootCmd.Flags().StringVar(&replicaID, "replicaID", "", "Replica ID for the remote database (overwrites the REMOTE path)") - rootCmd.Flags().StringVarP(&serverURL, "server", "s", "wss://sqlrsync.com", "Server URL for push/pull operations") - rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Enable verbose logging") - rootCmd.Flags().BoolVar(&setPublic, "public", false, "Enable public access to the replica (only for push operations)") - rootCmd.Flags().BoolVar(&newReadToken, "storeNewReadToken", true, "After syncing, the server creates a new read-only token that is stored in the -sqlrsync file adjacent to the local database") - rootCmd.Flags().BoolVar(&dryRun, "dry", false, "Perform a dry run without making changes") - rootCmd.Flags().IntVarP(&timeout, "timeout", "t", 8000, "Connection timeout in milliseconds (Max 10 seconds)") - rootCmd.Flags().BoolVar(&inspectTraffic, "inspect-traffic", false, "Enable traffic inspection between Go and Bridge layers") - rootCmd.Flags().IntVar(&inspectionDepth, "inspection-depth", 5, "Number of bytes to inspect from each message (default: 5)") -} - -func setupLogger() { - config := zap.NewDevelopmentConfig() + // zapcore Levels: DebugLevel, InfoLevel, WarnLevel, ErrorLevel, DPanicLevel, PanicLevel, FatalLevel if verbose { config.Level.SetLevel(zapcore.DebugLevel) @@ -627,9 +217,23 @@ func setupLogger() { } } +func init() { + rootCmd.Flags().StringVar(&pullKey, "pullKey", "", "Authentication key for PULL operations") + rootCmd.Flags().StringVar(&pushKey, "pushKey", "", "Authentication key for PUSH operations") + rootCmd.Flags().StringVar(&replicaID, "replicaID", "", "Replica ID for the remote database") + rootCmd.Flags().StringVarP(&serverURL, "server", "s", "wss://sqlrsync.com", "Server URL for operations") + rootCmd.Flags().BoolVar(&subscribing, "subscribe", false, "Enable subscription to PULL changes") + rootCmd.Flags().BoolVar(&verbose, "verbose", false, "Enable verbose logging") + rootCmd.Flags().BoolVar(&SetUnlisted, "unlisted", false, "Enable unlisted access to the replica (initial PUSH only)") + rootCmd.Flags().BoolVar(&SetPublic, "public", false, "Enable public access to the replica (initial PUSH only)") + rootCmd.Flags().BoolVar(&dryRun, "dry", false, "Perform a dry run without making changes") + rootCmd.Flags().BoolVarP(&showVersion, "version", "v", false, "Show version information") + +} + func main() { - if err := Execute(); err != nil { - showLocalError(err.Error()) + if err := rootCmd.Execute(); err != nil { + fmt.Printf("Error: %v\n", err) os.Exit(1) } } diff --git a/client/remote/client.go b/client/remote/client.go index 26aca7d..d94db75 100644 --- a/client/remote/client.go +++ b/client/remote/client.go @@ -2,22 +2,118 @@ package remote import ( "context" + "encoding/json" "fmt" "io" "net/http" "net/url" + "strconv" "strings" "sync" "time" + "github.com/fatih/color" "github.com/gorilla/websocket" "go.uber.org/zap" ) const ( - SQLRSYNC_CONFIG = 0x51 // Send to keys and replicaID + SQLRSYNC_CONFIG = 0x51 // Send to keys and replicaID + SQLRSYNC_NEWREPLICAVERSION = 0x52 // New version available + SQLRSYNC_KEYREQUEST = 0x53 // request keys ) +// ProgressPhase represents the current phase of the sync operation +type ProgressPhase int + +const ( + PhaseInitializing ProgressPhase = iota + PhaseNegotiating + PhaseTransferring + PhaseCompleting + PhaseCompleted +) + +// SyncDirection represents the direction of the sync operation +type SyncDirection int + +const ( + DirectionUnknown SyncDirection = iota + DirectionPush // Local → Remote (ORIGIN_* messages outbound) + DirectionPull // Remote → Local (REPLICA_* messages inbound) +) + +// ProgressEventType represents different types of progress events +type ProgressEventType int + +const ( + EventSyncStart ProgressEventType = iota + EventNegotiationComplete + EventPageSent + EventPageReceived + EventPageConfirmed + EventSyncComplete + EventError +) + +// SyncProgress tracks the current state of a sync operation +type SyncProgress struct { + // Basic metrics + TotalPages int + PageSize int + TotalBytes int64 + + // Progress counters + PagesSent int // ORIGIN_PAGE count + PagesReceived int // REPLICA_PAGE count + PagesConfirmed int // REPLICA_HASH/ORIGIN_HASH count + BytesTransferred int64 + + // Timing + StartTime time.Time + LastUpdate time.Time + + // State + Phase ProgressPhase + Direction SyncDirection + + // Calculated fields + PercentComplete float64 + EstimatedETA time.Duration + PagesPerSecond float64 +} + +// SyncProgressEvent represents a progress event +type SyncProgressEvent struct { + Type ProgressEventType + Progress *SyncProgress + Message string + Error error +} + +// ProgressFormat defines how progress should be displayed +type ProgressFormat int + +const ( + FormatSimple ProgressFormat = iota // Just percentage + FormatDetailed // Full progress bar with details + FormatJSON // Machine-readable JSON +) + +// ProgressConfig configures progress reporting behavior +type ProgressConfig struct { + Enabled bool + Format ProgressFormat + UpdateRate time.Duration // Minimum time between updates + ShowETA bool + ShowBytes bool + ShowPages bool + PagesPerUpdate int // Update every N pages (default: 10) +} + +// ProgressCallback is called when progress events occur +type ProgressCallback func(event SyncProgressEvent) + // TrafficInspector provides traffic inspection and protocol message detection type TrafficInspector struct { logger *zap.Logger @@ -101,6 +197,142 @@ func (t *TrafficInspector) LogWebSocketTraffic(data []byte, direction string, en zap.String("preview", fmt.Sprintf("%x", data[:inspectionSize]))) } +// InspectForProgress analyzes messages for progress tracking and calls the callback +func (t *TrafficInspector) InspectForProgress(data []byte, direction string, callback ProgressCallback, enableLogging bool) { + if len(data) == 0 || callback == nil { + return + } + + msgType := t.parseMessageType(data) + + // Log the message if enabled + if enableLogging { + inspectionSize := t.depth + if len(data) < inspectionSize { + inspectionSize = len(data) + } + header := data[:inspectionSize] + + t.logger.Debug(fmt.Sprintf("Progress inspection %s", direction), + zap.String("messageType", msgType), + zap.Int("totalBytes", len(data)), + zap.String("header", fmt.Sprintf("%x", header))) + } + + switch msgType { + case "ORIGIN_BEGIN": + if progress := t.parseBeginMessage(data, DirectionPush); progress != nil { + callback(SyncProgressEvent{ + Type: EventSyncStart, + Progress: progress, + Message: fmt.Sprintf("Starting sync: %d pages (%d bytes) to push", progress.TotalPages, progress.TotalBytes), + }) + } + case "REPLICA_BEGIN": + if progress := t.parseBeginMessage(data, DirectionPull); progress != nil { + callback(SyncProgressEvent{ + Type: EventSyncStart, + Progress: progress, + Message: fmt.Sprintf("Starting sync: %d pages (%d bytes) to pull", progress.TotalPages, progress.TotalBytes), + }) + } + case "ORIGIN_PAGE": + callback(SyncProgressEvent{ + Type: EventPageSent, + Message: "Page sent", + }) + case "REPLICA_PAGE": + callback(SyncProgressEvent{ + Type: EventPageReceived, + Message: "Page received", + }) + case "REPLICA_HASH", "ORIGIN_HASH": + callback(SyncProgressEvent{ + Type: EventPageConfirmed, + Message: "Page confirmed", + }) + case "ORIGIN_READY", "REPLICA_READY": + callback(SyncProgressEvent{ + Type: EventNegotiationComplete, + Message: "Protocol negotiation complete", + }) + case "ORIGIN_END", "REPLICA_END": + callback(SyncProgressEvent{ + Type: EventSyncComplete, + Message: "Sync operation completed", + }) + } +} + +// parseBeginMessage attempts to parse ORIGIN_BEGIN or REPLICA_BEGIN message payload +func (t *TrafficInspector) parseBeginMessage(data []byte, direction SyncDirection) *SyncProgress { + if len(data) < 9 { // Need at least message type + 8 bytes for basic info + return nil + } + + // Log the raw message bytes for debugging + minLen := len(data) + if minLen > 16 { + minLen = 16 + } + t.logger.Info("Parsing BEGIN message", + zap.String("direction", func() string { + if direction == DirectionPush { + return "PUSH" + } + return "PULL" + }()), + zap.Int("messageLength", len(data)), + zap.String("rawBytes", fmt.Sprintf("%x", data[:minLen]))) + + // SQLite rsync protocol structure for BEGIN messages (simplified parsing) + // This is a best-effort parse - exact structure may vary + // Byte 0: Message type (0x41 for ORIGIN_BEGIN, 0x61 for REPLICA_BEGIN) + // Bytes 1-4: Total pages (little-endian uint32) + // Bytes 5-8: Page size (little-endian uint32) + + totalPages := int(data[1]) | int(data[2])<<8 | int(data[3])<<16 | int(data[4])<<24 + pageSize := int(data[5]) | int(data[6])<<8 | int(data[7])<<16 | int(data[8])<<24 + + t.logger.Info("Parsed values from BEGIN message", + zap.Int("totalPages", totalPages), + zap.Int("pageSize", pageSize), + zap.String("bytes1-4", fmt.Sprintf("%02x %02x %02x %02x", data[1], data[2], data[3], data[4])), + zap.String("bytes5-8", fmt.Sprintf("%02x %02x %02x %02x", data[5], data[6], data[7], data[8]))) + + // Sanity check the parsed values - allow smaller page sizes like 4096 + if totalPages <= 0 || totalPages > 1000000 || pageSize <= 0 || pageSize > 65536 { + t.logger.Warn("Parsed BEGIN message with suspicious values", + zap.Int("totalPages", totalPages), + zap.Int("pageSize", pageSize)) + return nil + } + + progress := &SyncProgress{ + TotalPages: totalPages, + PageSize: pageSize, + TotalBytes: int64(totalPages) * int64(pageSize), + Phase: PhaseInitializing, + Direction: direction, + StartTime: time.Now(), + LastUpdate: time.Now(), + PercentComplete: 0.0, + } + + t.logger.Info("Parsed sync parameters", + zap.Int("totalPages", totalPages), + zap.Int("pageSize", pageSize), + zap.Int64("totalBytes", progress.TotalBytes), + zap.String("direction", func() string { + if direction == DirectionPush { + return "PUSH" + } + return "PULL" + }())) + + return progress +} + // parseMessageType attempts to identify the SQLite rsync message type func (t *TrafficInspector) parseMessageType(data []byte) string { if len(data) == 0 { @@ -144,6 +376,8 @@ func (t *TrafficInspector) parseMessageType(data []byte) string { return "REPLICA_CONFIG" case 0x51: return "SQLRSYNC_CONFIG" + case 0x52: + return "SQLRSYNC_NEWREPLICAVERSION" default: // For unknown messages, classify by first byte if firstByte >= 32 && firstByte <= 126 { @@ -158,16 +392,23 @@ type Config struct { ServerURL string Version string ReplicaID string - SetPublic bool // for PUSH - Timeout int // in milliseconds + Subscribe bool + SetVisibility int // for PUSH + Timeout int // in milliseconds Logger *zap.Logger EnableTrafficInspection bool // Enable detailed traffic logging InspectionDepth int // How many bytes to inspect (default: 32) PingPong bool AuthToken string - SendConfigCmd bool // the -sqlrsync file doesn't exist, so make a token - LocalHostname string - LocalAbsolutePath string + SendKeyRequest bool // the -sqlrsync file doesn't exist, so make a token + + SendConfigCmd bool // we don't have the version number or remote path + LocalHostname string + LocalAbsolutePath string + + // Progress tracking + ProgressConfig *ProgressConfig + ProgressCallback ProgressCallback } // Client handles WebSocket communication with the remote server @@ -203,11 +444,19 @@ type Client struct { syncMu sync.RWMutex // sqlrsync specific - NewPullKey string - NewPushKey string - ReplicaID string - ReplicaPath string - SetPublic bool + NewPullKey string + NewPushKey string + ReplicaID string + Version string + ReplicaPath string + SetVisibility int + newVersionChan chan struct{} + + // Progress tracking + progress *SyncProgress + progressMu sync.RWMutex + lastProgressSent time.Time + pagesSinceUpdate int } // New creates a new remote WebSocket client @@ -226,21 +475,191 @@ func New(config *Config) (*Client, error) { inspectionDepth = 32 } + // Set default progress config if progress is enabled but config is nil + if config.ProgressCallback != nil && config.ProgressConfig == nil { + config.ProgressConfig = &ProgressConfig{ + Enabled: true, + Format: FormatSimple, + UpdateRate: 500 * time.Millisecond, + ShowETA: true, + ShowBytes: true, + ShowPages: true, + PagesPerUpdate: 10, + } + } + ctx, cancel := context.WithCancel(context.Background()) client := &Client{ - config: config, - logger: config.Logger, - ctx: ctx, - cancel: cancel, - readQueue: make(chan []byte, 8202), - writeQueue: make(chan []byte, 100), - reconnectChan: make(chan struct{}, 1), - inspector: NewTrafficInspector(config.Logger, inspectionDepth), + config: config, + logger: config.Logger, + ctx: ctx, + cancel: cancel, + readQueue: make(chan []byte, 3), + writeQueue: make(chan []byte, 5), + reconnectChan: make(chan struct{}, 1), + inspector: NewTrafficInspector(config.Logger, inspectionDepth), + newVersionChan: make(chan struct{}, 1), } return client, nil } +// Progress tracking methods + +// initProgress initializes progress tracking with the given parameters +func (c *Client) initProgress(totalPages, pageSize int, direction SyncDirection) { + c.progressMu.Lock() + defer c.progressMu.Unlock() + + c.progress = &SyncProgress{ + TotalPages: totalPages, + PageSize: pageSize, + TotalBytes: int64(totalPages) * int64(pageSize), + Phase: PhaseInitializing, + Direction: direction, + StartTime: time.Now(), + LastUpdate: time.Now(), + PercentComplete: 0.0, + } + c.lastProgressSent = time.Now() + c.pagesSinceUpdate = 0 +} + +// updateProgress updates progress state and potentially calls the callback +func (c *Client) updateProgress(eventType ProgressEventType, message string) { + if c.config.ProgressCallback == nil || c.config.ProgressConfig == nil || !c.config.ProgressConfig.Enabled { + return + } + + c.progressMu.Lock() + defer c.progressMu.Unlock() + + if c.progress == nil { + return + } + + now := time.Now() + c.progress.LastUpdate = now + + // Update counters based on event type + switch eventType { + case EventPageSent: + c.progress.PagesSent++ + c.pagesSinceUpdate++ + case EventPageReceived: + c.progress.PagesReceived++ + c.pagesSinceUpdate++ + case EventPageConfirmed: + c.progress.PagesConfirmed++ + case EventNegotiationComplete: + c.progress.Phase = PhaseTransferring + case EventSyncComplete: + c.progress.Phase = PhaseCompleted + c.progress.PercentComplete = 100.0 + } + + // Calculate derived metrics + c.calculateProgressMetrics() + + // Determine if we should send an update + shouldUpdate := c.shouldSendProgressUpdate(eventType, now) + + if shouldUpdate { + c.sendProgressUpdate(eventType, message) + c.lastProgressSent = now + c.pagesSinceUpdate = 0 + } +} + +// calculateProgressMetrics updates calculated fields in progress +func (c *Client) calculateProgressMetrics() { + if c.progress == nil { + return + } + + // Calculate progress percentage + var completedPages int + if c.progress.Direction == DirectionPush { + completedPages = c.progress.PagesSent + } else { + completedPages = c.progress.PagesReceived + } + + if c.progress.TotalPages > 0 { + c.progress.PercentComplete = float64(completedPages) / float64(c.progress.TotalPages) * 100.0 + } + + // Calculate bytes transferred + c.progress.BytesTransferred = int64(completedPages) * int64(c.progress.PageSize) + + // Calculate speed and ETA + elapsed := time.Since(c.progress.StartTime) + if elapsed > 0 { + c.progress.PagesPerSecond = float64(completedPages) / elapsed.Seconds() + + if c.progress.PagesPerSecond > 0 { + remainingPages := c.progress.TotalPages - completedPages + c.progress.EstimatedETA = time.Duration(float64(remainingPages)/c.progress.PagesPerSecond) * time.Second + } + } +} + +// shouldSendProgressUpdate determines if a progress update should be sent +func (c *Client) shouldSendProgressUpdate(eventType ProgressEventType, now time.Time) bool { + // Always send for phase changes and completion + if eventType == EventSyncStart || eventType == EventNegotiationComplete || eventType == EventSyncComplete || eventType == EventError { + return true + } + + // Check rate limiting + if now.Sub(c.lastProgressSent) < c.config.ProgressConfig.UpdateRate { + // But still send if we've hit the page threshold + return c.pagesSinceUpdate >= c.config.ProgressConfig.PagesPerUpdate + } + + // Send if enough time has passed and we have updates + return c.pagesSinceUpdate > 0 +} + +// sendProgressUpdate calls the progress callback +func (c *Client) sendProgressUpdate(eventType ProgressEventType, message string) { + if c.config.ProgressCallback == nil || c.progress == nil { + return + } + + // Create a copy of progress to avoid race conditions + progressCopy := *c.progress + + event := SyncProgressEvent{ + Type: eventType, + Progress: &progressCopy, + Message: message, + } + + // Call the callback in a goroutine to avoid blocking + go func() { + defer func() { + if r := recover(); r != nil { + c.logger.Error("Progress callback panicked", zap.Any("panic", r)) + } + }() + c.config.ProgressCallback(event) + }() +} + +// getProgress returns a copy of the current progress (thread-safe) +func (c *Client) GetProgress() *SyncProgress { + c.progressMu.RLock() + defer c.progressMu.RUnlock() + + if c.progress == nil { + return nil + } + + progressCopy := *c.progress + return &progressCopy +} + // Connect establishes WebSocket connection to the remote server func (c *Client) Connect() error { c.logger.Info("Connecting to remote server", zap.String("url", c.config.ServerURL)) @@ -279,15 +698,18 @@ func (c *Client) Connect() error { if c.config.ReplicaID != "" { headers.Set("X-ReplicaID", c.config.ReplicaID) } - - if c.config.SetPublic { - headers.Set("X-SetPublic", fmt.Sprintf("%t", c.config.SetPublic)) + if c.config.SetVisibility != 0 { + headers.Set("X-Visibility", strconv.Itoa(c.config.SetVisibility)) } conn, response, err := dialer.DialContext(connectCtx, u.String(), headers) if err != nil { - respStr, _ := io.ReadAll(response.Body) - return fmt.Errorf("%s", respStr) + if response != nil && response.Body != nil { + respStr, _ := io.ReadAll(response.Body) + response.Body.Close() + return fmt.Errorf("%s", respStr) + } + return fmt.Errorf("failed to connect to WebSocket: %w", err) } defer response.Body.Close() @@ -300,7 +722,7 @@ func (c *Client) Connect() error { // Set up ping/pong handlers for connection health conn.SetPingHandler(func(data string) error { c.logger.Debug("Received ping from server") - return conn.WriteControl(websocket.PongMessage, []byte(data), time.Now().Add(8*time.Second)) + return conn.WriteControl(websocket.PongMessage, []byte(data), time.Now().Add(5*time.Second)) }) conn.SetPongHandler(func(data string) error { @@ -335,7 +757,18 @@ func (c *Client) Read(buffer []byte) (int, error) { return 0, nil } - // Check if we have a connection error first + // Check if connection is still alive first + if !c.isConnected() { + c.logger.Debug("Connection not active, returning immediately") + // If sync completed normally, return success + if c.isSyncCompleted() { + return 0, nil + } + // Otherwise return connection lost + return 0, fmt.Errorf("connection lost") + } + + // Check if we have a connection error if lastErr := c.GetLastError(); lastErr != nil { // If sync is completed and this is a normal closure, return immediately if c.isSyncCompleted() && websocket.IsCloseError(lastErr, websocket.CloseNoStatusReceived, websocket.CloseNormalClosure, websocket.CloseGoingAway) { @@ -375,24 +808,43 @@ func (c *Client) Read(buffer []byte) (int, error) { if isOriginEnd { c.logger.Info("ORIGIN_END received from server - sync completing") c.setSyncCompleted(true) + // In subscribe mode, continue reading for new version notifications + if c.config.Subscribe { + c.logger.Info("Subscribe mode: continuing to listen for new version notifications") + } + } + + // Handle progress tracking for inbound traffic + if c.config.ProgressCallback != nil { + c.inspector.InspectForProgress(data, "IN (Server → Client)", func(event SyncProgressEvent) { + c.handleProgressEvent(event) + }, c.config.EnableTrafficInspection) } return bytesRead, nil case <-time.After(func() time.Duration { - // Use a much shorter timeout if sync is completed + // In subscribe mode, use very long timeout to accommodate hibernated connections + if c.config.Subscribe { + return 1 * time.Hour + } + // Use a longer timeout if sync is completed to allow final transaction processing if c.isSyncCompleted() { - return 100 * time.Millisecond + return 2 * time.Second } - return 9 * time.Second + return 30 * time.Second }()): // Check if connection is still alive if !c.isConnected() { return 0, fmt.Errorf("connection lost") } - // If sync is completed, don't wait long + // If sync is completed and not in subscribe mode, don't wait long if c.isSyncCompleted() { return 0, nil } + // In subscribe mode, continue reading even after timeouts + if c.config.Subscribe { + return 0, nil // Return 0 bytes but no error, allowing caller to retry + } return 0, fmt.Errorf("read timeout") } } @@ -408,7 +860,7 @@ func (c *Client) setSyncCompleted(completed bool) { func (c *Client) isSyncCompleted() bool { c.syncMu.RLock() defer c.syncMu.RUnlock() - return c.syncCompleted + return c.syncCompleted && !c.config.Subscribe } // handleOutboundTraffic inspects outbound data and handles sync completion detection @@ -419,6 +871,36 @@ func (c *Client) handleOutboundTraffic(data []byte) { c.logger.Info("ORIGIN_END detected - sync completing") c.setSyncCompleted(true) } + + // Handle progress tracking + if c.config.ProgressCallback != nil { + c.inspector.InspectForProgress(data, "OUT (Client → Server)", func(event SyncProgressEvent) { + c.handleProgressEvent(event) + }, c.config.EnableTrafficInspection) + } +} + +// handleProgressEvent processes progress events from the traffic inspector +func (c *Client) handleProgressEvent(event SyncProgressEvent) { + switch event.Type { + case EventSyncStart: + if event.Progress != nil { + c.initProgress(event.Progress.TotalPages, event.Progress.PageSize, event.Progress.Direction) + } + c.updateProgress(EventSyncStart, event.Message) + case EventPageSent: + c.updateProgress(EventPageSent, event.Message) + case EventPageReceived: + c.updateProgress(EventPageReceived, event.Message) + case EventPageConfirmed: + c.updateProgress(EventPageConfirmed, event.Message) + case EventNegotiationComplete: + c.updateProgress(EventNegotiationComplete, event.Message) + case EventSyncComplete: + c.updateProgress(EventSyncComplete, event.Message) + case EventError: + c.updateProgress(EventError, event.Message) + } } // Write sends data to the remote server @@ -438,7 +920,7 @@ func (c *Client) Write(data []byte) error { case c.writeQueue <- dataCopy: c.logger.Debug("Data queued for writing", zap.Int("bytes", len(dataCopy))) return nil - case <-time.After(10 * time.Second): + case <-time.After(30 * time.Second): return fmt.Errorf("write queue timeout") } } @@ -458,13 +940,35 @@ func (c *Client) Close() { // Cancel context to signal all goroutines to stop c.cancel() - // Close the WebSocket connection + // Close the WebSocket connection gracefully c.mu.Lock() if c.conn != nil { // Send close message - c.conn.WriteControl(websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), - time.Now().Add(5*time.Second)) + closeMessage := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + err := c.conn.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(5*time.Second)) + if err != nil { + c.logger.Debug("Error sending close message", zap.Error(err)) + } else { + c.logger.Debug("Sent WebSocket close message") + } + + // Set a read deadline for the close handshake + c.conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + + // Wait for server's close response by reading until we get a close frame + for { + _, _, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { + c.logger.Debug("Received close acknowledgment from server") + } else { + c.logger.Debug("Connection closed during close handshake", zap.Error(err)) + } + break + } + // Keep reading until we get the close frame or timeout + } + c.conn.Close() c.conn = nil } @@ -560,7 +1064,14 @@ func (c *Client) pingLoop() { c.logger.Debug("Starting ping loop") defer c.logger.Debug("Ping loop terminated") - ticker := time.NewTicker(30 * time.Second) + // Use longer ping interval in subscribe mode to accommodate hibernated connections + pingInterval := 5 * time.Second + if c.config.Subscribe { + pingInterval = 25 * time.Minute + c.logger.Info("Subscribe mode: using 25-minute ping interval for hibernated connections") + } + + ticker := time.NewTicker(pingInterval) defer ticker.Stop() // Check connection status more frequently to exit quickly when disconnected @@ -586,14 +1097,25 @@ func (c *Client) pingLoop() { return } - err := conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(10*time.Second)) + // Use longer ping timeout in subscribe mode + pingTimeout := 10 * time.Second + if c.config.Subscribe { + pingTimeout = 30 * time.Second + } + + err := conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(pingTimeout)) if err != nil { c.logger.Error("Failed to send ping", zap.Error(err)) c.setError(err) c.setConnected(false) return } - c.logger.Debug("Sent ping to server") + + if c.config.Subscribe { + c.logger.Info("Sent ping to hibernated server connection") + } else { + c.logger.Debug("Sent ping to server") + } } } } @@ -644,37 +1166,65 @@ func (c *Client) readLoop() { conn := c.conn c.mu.RUnlock() - if conn == nil || c.isSyncCompleted() { + if conn == nil { + c.setConnected(false) + return + } + + // In subscribe mode, continue even after sync completion + if c.isSyncCompleted() && !c.config.Subscribe { + c.setConnected(false) + return + } + + // In subscribe mode, if connection failed, don't continue reading + if c.config.Subscribe && c.GetLastError() != nil { + c.logger.Debug("Connection has error in subscribe mode, exiting read loop") c.setConnected(false) return } - // Set read deadline - conn.SetReadDeadline(time.Now().Add(9 * time.Second)) + // Set read deadline - much longer timeout in subscribe mode for hibernated connections + timeout := 30 * time.Second + if c.config.Subscribe { + // In subscribe mode, use very long timeout (1 hour) to allow for hibernated connections + timeout = 1 * time.Hour + } + conn.SetReadDeadline(time.Now().Add(timeout)) messageType, data, err := conn.ReadMessage() if err != nil { + c.logger.Debug("ReadMessage error", zap.Error(err)) + // Check if this is an expected/normal connection closure if websocket.IsCloseError(err, websocket.CloseNoStatusReceived, // 1005 - normal close from server after ORIGIN_END websocket.CloseNormalClosure, // 1000 - normal closure websocket.CloseGoingAway) { // 1001 - endpoint going away c.logger.Info("WebSocket connection closed normally", zap.Error(err)) + } else if strings.Contains(err.Error(), "use of closed network connection") { + // This happens when we close the connection during shutdown - it's expected + c.logger.Debug("Connection closed during shutdown", zap.Error(err)) } else { // Any other error is unexpected - c.logger.Error("WebSocket read error", zap.Error(err)) + //c.logger.Error("WebSocket read error", zap.Error(err)) } c.setError(err) c.setConnected(false) - // If sync is completed and this is a normal closure, close read queue immediately - if c.isSyncCompleted() && websocket.IsCloseError(err, + // If this is a normal closure, close read queue immediately + if websocket.IsCloseError(err, websocket.CloseNoStatusReceived, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - c.logger.Debug("Sync completed - closing read queue immediately") + c.logger.Debug("Normal closure - closing read queue immediately") // Close the read queue to signal no more data - close(c.readQueue) + select { + case <-c.readQueue: + // Already closed + default: + close(c.readQueue) + } } // Only signal reconnection for truly unexpected errors (not normal closures) @@ -691,29 +1241,44 @@ func (c *Client) readLoop() { } if messageType == websocket.TextMessage { - accessKeyLength := 22 - replicaIDLength := 18 - readPullKeyResp := "NEWPULLKEY=" - readPushKeyResp := "NEWPUSHKEY=" - replicaIDResp := "REPLICAID=" - replicaPathResp := "REPLICAPATH=" + configMsgResp := "CONFIG=" + messageResp := "MESSAGE=" + abortResp := "ABORT=" // Handle text messages for NEWPULLKEY, NEWPUSHKEY, REPLICAID // Example: "NEWPULLKEY=xxxxxxxxxxxxxxxxxxxxxx" strData := string(data) - if (len(data) >= len(readPullKeyResp)+accessKeyLength) && strings.HasPrefix(strData, readPullKeyResp) { - c.NewPullKey = strData[len(readPullKeyResp):] - c.logger.Debug("đŸ“Ĩ Received new Pull Key:", zap.String("key", c.NewPullKey)) - } else if (len(data) >= len(readPushKeyResp)+accessKeyLength) && strings.HasPrefix(strData, readPushKeyResp) { - - c.NewPushKey = strData[len(readPushKeyResp):] - c.logger.Debug("đŸ“Ĩ Received new Push Key:", zap.String("key", c.NewPushKey)) - } else if (len(data) >= len(replicaIDResp)+replicaIDLength) && strings.HasPrefix(strData, replicaIDResp) { - c.ReplicaID = strData[len(replicaIDResp):] - c.logger.Debug("đŸ“Ĩ Received Replica ID:", zap.String("id", c.ReplicaID)) - } else if (len(data) >= len(replicaPathResp)) && strings.HasPrefix(strData, replicaPathResp) { - - c.ReplicaPath = strData[len(replicaPathResp):] - c.logger.Debug("đŸ“Ĩ Received new Replica Path:", zap.String("path", c.ReplicaPath)) + + if len(data) >= len(abortResp) && strings.HasPrefix(strData, abortResp) { + color.Red("❌ Server aborted connection: %s", strData[len(abortResp):]) + c.setConnected(false) + message := strData[len(abortResp):] + c.setError(fmt.Errorf("server aborted connection: %s", message)) + } else if (len(data) >= len(configMsgResp)) && strings.HasPrefix(strData, configMsgResp) { + // CONFIG={JSON} + jsonStr := strData[len(configMsgResp):] + var configMsg map[string]interface{} + err := json.Unmarshal([]byte(jsonStr), &configMsg) + if err != nil { + c.logger.Error("Failed to parse CONFIG JSON", zap.Error(err)) + continue + } + if configMsg["newPullKey"] != nil { + c.NewPullKey = configMsg["newPullKey"].(string) + } + if configMsg["newPushKey"] != nil { + c.NewPushKey = configMsg["newPushKey"].(string) + } + if configMsg["replicaID"] != nil { + c.ReplicaID = configMsg["replicaID"].(string) + } + if configMsg["replicaPath"] != nil { + c.ReplicaPath = configMsg["replicaPath"].(string) + } + if configMsg["committedVersionID"] != nil { + c.Version = configMsg["committedVersionID"].(string) + } + } else if (len(data) >= len(messageResp)) && strings.HasPrefix(strData, messageResp) { + fmt.Println(strData[len(messageResp):]) } continue } @@ -731,17 +1296,33 @@ func (c *Client) readLoop() { msgType := c.inspector.parseMessageType(data) if msgType == "ORIGIN_END" { c.logger.Info("ORIGIN_END detected in read loop - sync will complete") - c.setSyncCompleted(true) + // Don't mark as completed yet - let the C code process all remaining data first + // The Read method will mark it as completed when it actually receives ORIGIN_END + } else if msgType == "SQLRSYNC_NEWREPLICAVERSION" && c.config.Subscribe { + // Handle new version notification in subscribe mode + c.logger.Info("SQLRSYNC_NEWREPLICAVERSION (0x52) received - new version available!") + select { + case c.newVersionChan <- struct{}{}: + c.logger.Debug("New version notification sent to channel") + default: + c.logger.Debug("New version channel already has pending notification") + } + // Don't queue this message for normal reading + continue } + // Handle progress tracking in read loop + if c.config.ProgressCallback != nil { + c.inspector.InspectForProgress(data, "IN (Server → Client)", func(event SyncProgressEvent) { + c.handleProgressEvent(event) + }, c.config.EnableTrafficInspection) + } // Queue the data for reading select { case c.readQueue <- data: c.logger.Debug("Data queued for reading", zap.Int("bytes", len(data))) case <-c.ctx.Done(): return - default: - c.logger.Warn("Read queue full, dropping message", zap.Int("bytes", len(data))) } } } @@ -779,6 +1360,15 @@ func (c *Client) writeLoop() { // Set write deadline conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) + if c.config.SendConfigCmd { + conn.WriteMessage(websocket.BinaryMessage, []byte{SQLRSYNC_CONFIG}) + c.config.SendConfigCmd = false + } + if c.config.SendKeyRequest { + conn.WriteMessage(websocket.BinaryMessage, []byte{SQLRSYNC_KEYREQUEST}) + c.config.SendKeyRequest = false + } + // Inspect raw WebSocket outbound traffic c.inspector.LogWebSocketTraffic(data, "OUT (Client → Server)", c.config.EnableTrafficInspection) @@ -796,12 +1386,6 @@ func (c *Client) writeLoop() { return } - if c.config.SendConfigCmd { - conn.WriteMessage(websocket.BinaryMessage, []byte{SQLRSYNC_CONFIG}) - c.config.SendConfigCmd = false - c.logger.Debug("🔑 Also asked for keys and replicaID.") - } - c.logger.Debug("Sent message to remote", zap.Int("bytes", len(data))) } } @@ -821,3 +1405,67 @@ func (c *Client) GetReplicaID() string { func (c *Client) GetReplicaPath() string { return c.ReplicaPath } +func (c *Client) GetVersion() string { + return c.Version +} + +// WaitForNewVersion blocks until a new version notification is received (0x52) +// Returns nil when a new version is available, or an error if the connection is lost +func (c *Client) WaitForNewVersion() error { + if !c.config.Subscribe { + return fmt.Errorf("subscribe mode not enabled") + } + + c.logger.Info("Waiting for new version notification...") + + // Check if connection is still alive + if !c.isConnected() { + return fmt.Errorf("connection lost") + } + + // Check if there's already a notification pending + select { + case <-c.newVersionChan: + c.logger.Info("Found pending new version notification!") + return nil + default: + // No pending notification, continue to blocking wait + } + + c.logger.Debug("No pending notifications, blocking wait...") + + // Use a ticker to periodically check for context cancellation + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + c.logger.Debug("Context cancelled during new version wait") + return fmt.Errorf("client context cancelled") + case <-c.newVersionChan: + c.logger.Info("New version notification received from blocking wait!") + return nil + case <-ticker.C: + // Check if connection is still alive + if !c.isConnected() { + return fmt.Errorf("connection lost while waiting") + } + // Continue waiting + } + } +} + +// ResetForNewSync prepares the client for a new sync operation while maintaining the connection +func (c *Client) ResetForNewSync() { + c.setSyncCompleted(false) + // Clear any pending data in the read queue + for { + select { + case <-c.readQueue: + // Drain the queue + default: + return + } + } +} diff --git a/client/remote/progress.go b/client/remote/progress.go new file mode 100644 index 0000000..c0e9804 --- /dev/null +++ b/client/remote/progress.go @@ -0,0 +1,344 @@ +package remote + +import ( + "encoding/json" + "fmt" + "strings" + "time" +) + +// ProgressFormatter provides different output formats for progress events +type ProgressFormatter struct { + config *ProgressConfig +} + +// NewProgressFormatter creates a new progress formatter +func NewProgressFormatter(config *ProgressConfig) *ProgressFormatter { + if config == nil { + config = &ProgressConfig{ + Enabled: true, + Format: FormatSimple, + UpdateRate: 500 * time.Millisecond, + ShowETA: true, + ShowBytes: true, + ShowPages: true, + PagesPerUpdate: 10, + } + } + return &ProgressFormatter{config: config} +} + +// FormatProgress formats a progress event according to the configured format +func (f *ProgressFormatter) FormatProgress(event SyncProgressEvent) string { + if !f.config.Enabled { + return "" + } + + switch f.config.Format { + case FormatSimple: + return f.formatSimple(event) + case FormatDetailed: + return f.formatDetailed(event) + case FormatJSON: + return f.formatJSON(event) + default: + return f.formatSimple(event) + } +} + +// formatSimple creates a simple one-line progress display +func (f *ProgressFormatter) formatSimple(event SyncProgressEvent) string { + if event.Progress == nil { + return fmt.Sprintf("â€ĸ %s", event.Message) + } + + progress := event.Progress + direction := "Syncing" + if progress.Direction == DirectionPush { + direction = "Pushing" + } else if progress.Direction == DirectionPull { + direction = "Pulling" + } + + phase := f.getPhaseString(progress.Phase) + + switch event.Type { + case EventSyncStart: + return fmt.Sprintf("%s: Starting sync (%d pages, %s)", + direction, progress.TotalPages, f.formatBytes(progress.TotalBytes)) + + case EventNegotiationComplete: + return fmt.Sprintf("%s: %s - ready to transfer", direction, phase) + + case EventSyncComplete: + elapsed := time.Since(progress.StartTime) + return fmt.Sprintf("%s: ✅ Complete - %d pages in %s", + direction, progress.TotalPages, f.formatDuration(elapsed)) + + default: + // Active transfer progress + if progress.TotalPages > 0 { + var completedPages int + if progress.Direction == DirectionPush { + completedPages = progress.PagesSent + } else { + completedPages = progress.PagesReceived + } + + progressBar := f.createProgressBar(progress.PercentComplete, 20) + parts := []string{ + fmt.Sprintf("%s: %s %.1f%% (%d/%d pages)", + direction, progressBar, progress.PercentComplete, completedPages, progress.TotalPages), + } + + if f.config.ShowBytes { + parts = append(parts, f.formatBytes(progress.BytesTransferred)+"/"+f.formatBytes(progress.TotalBytes)) + } + + if f.config.ShowETA && progress.EstimatedETA > 0 { + parts = append(parts, "ETA: "+f.formatDuration(progress.EstimatedETA)) + } + + return strings.Join(parts, " â€ĸ ") + } + + return fmt.Sprintf("%s: %s", direction, event.Message) + } +} + +// formatDetailed creates a detailed multi-line progress display +func (f *ProgressFormatter) formatDetailed(event SyncProgressEvent) string { + if event.Progress == nil { + return fmt.Sprintf("â€ĸ %s", event.Message) + } + + progress := event.Progress + direction := "Database Sync" + if progress.Direction == DirectionPush { + direction = "Push to Remote" + } else if progress.Direction == DirectionPull { + direction = "Pull from Remote" + } + + phase := f.getPhaseString(progress.Phase) + + switch event.Type { + case EventSyncStart: + return fmt.Sprintf(`╭─ %s ─────────────────────────────────────╮ +│ Starting: %d pages (%s) │ +│ Phase: %s │ +╰──────────────────────────────────────────────────╯`, + direction, progress.TotalPages, f.formatBytes(progress.TotalBytes), phase) + + case EventSyncComplete: + elapsed := time.Since(progress.StartTime) + avgSpeed := float64(progress.TotalPages) / elapsed.Seconds() + return fmt.Sprintf(`╭─ %s Complete ───────────────────────────╮ +│ ✅ Successfully synced %d pages │ +│ Time: %s │ +│ Average speed: %.1f pages/sec │ +╰──────────────────────────────────────────────────╯`, + direction, progress.TotalPages, f.formatDuration(elapsed), avgSpeed) + + default: + // Active transfer progress + if progress.TotalPages > 0 { + var completedPages int + if progress.Direction == DirectionPush { + completedPages = progress.PagesSent + } else { + completedPages = progress.PagesReceived + } + + progressBar := f.createProgressBar(progress.PercentComplete, 40) + + lines := []string{ + fmt.Sprintf("╭─ %s Progress ───────────────────────────╮", direction), + fmt.Sprintf("│ %s %.1f%% │", progressBar, progress.PercentComplete), + "│ │", + } + + if f.config.ShowPages { + lines = append(lines, fmt.Sprintf("│ Pages: %d/%d (sent: %d, confirmed: %d) │", + completedPages, progress.TotalPages, progress.PagesSent, progress.PagesConfirmed)) + } + + if f.config.ShowBytes { + lines = append(lines, fmt.Sprintf("│ Data: %s/%s │", + f.formatBytes(progress.BytesTransferred), f.formatBytes(progress.TotalBytes))) + } + + if progress.PagesPerSecond > 0 { + lines = append(lines, fmt.Sprintf("│ Speed: %.1f pages/sec │", progress.PagesPerSecond)) + } + + if f.config.ShowETA && progress.EstimatedETA > 0 { + lines = append(lines, fmt.Sprintf("│ ETA: %s │", f.formatDuration(progress.EstimatedETA))) + } + + lines = append(lines, "╰──────────────────────────────────────────────────╯") + return strings.Join(lines, "\n") + } + + return fmt.Sprintf("â€ĸ %s: %s", direction, event.Message) + } +} + +// formatJSON creates a JSON representation of the progress +func (f *ProgressFormatter) formatJSON(event SyncProgressEvent) string { + data := map[string]interface{}{ + "timestamp": time.Now().Unix(), + "event": f.getEventTypeString(event.Type), + "message": event.Message, + } + + if event.Progress != nil { + progress := event.Progress + + var completedPages int + if progress.Direction == DirectionPush { + completedPages = progress.PagesSent + } else { + completedPages = progress.PagesReceived + } + + data["progress"] = map[string]interface{}{ + "phase": f.getPhaseString(progress.Phase), + "direction": f.getDirectionString(progress.Direction), + "percent_complete": progress.PercentComplete, + "pages": map[string]interface{}{ + "completed": completedPages, + "sent": progress.PagesSent, + "received": progress.PagesReceived, + "confirmed": progress.PagesConfirmed, + "total": progress.TotalPages, + }, + "bytes": map[string]interface{}{ + "transferred": progress.BytesTransferred, + "total": progress.TotalBytes, + "page_size": progress.PageSize, + }, + "timing": map[string]interface{}{ + "start_time": progress.StartTime.Unix(), + "last_update": progress.LastUpdate.Unix(), + "elapsed_seconds": time.Since(progress.StartTime).Seconds(), + "eta_seconds": progress.EstimatedETA.Seconds(), + "pages_per_sec": progress.PagesPerSecond, + }, + } + } + + jsonData, _ := json.Marshal(data) + return string(jsonData) +} + +// Helper methods + +func (f *ProgressFormatter) createProgressBar(percent float64, width int) string { + filled := int(percent / 100.0 * float64(width)) + if filled > width { + filled = width + } + + bar := strings.Repeat("█", filled) + strings.Repeat("░", width-filled) + return fmt.Sprintf("[%s]", bar) +} + +func (f *ProgressFormatter) formatBytes(bytes int64) string { + if bytes < 1024 { + return fmt.Sprintf("%d B", bytes) + } else if bytes < 1024*1024 { + return fmt.Sprintf("%.1f KB", float64(bytes)/1024) + } else if bytes < 1024*1024*1024 { + return fmt.Sprintf("%.1f MB", float64(bytes)/(1024*1024)) + } else { + return fmt.Sprintf("%.1f GB", float64(bytes)/(1024*1024*1024)) + } +} + +func (f *ProgressFormatter) formatDuration(d time.Duration) string { + if d < time.Second { + return fmt.Sprintf("%dms", d.Milliseconds()) + } else if d < time.Minute { + return fmt.Sprintf("%.1fs", d.Seconds()) + } else if d < time.Hour { + return fmt.Sprintf("%dm%ds", int(d.Minutes()), int(d.Seconds())%60) + } else { + return fmt.Sprintf("%dh%dm", int(d.Hours()), int(d.Minutes())%60) + } +} + +func (f *ProgressFormatter) getPhaseString(phase ProgressPhase) string { + switch phase { + case PhaseInitializing: + return "Initializing" + case PhaseNegotiating: + return "Negotiating" + case PhaseTransferring: + return "Transferring" + case PhaseCompleting: + return "Completing" + case PhaseCompleted: + return "Completed" + default: + return "Unknown" + } +} + +func (f *ProgressFormatter) getDirectionString(direction SyncDirection) string { + switch direction { + case DirectionPush: + return "push" + case DirectionPull: + return "pull" + default: + return "unknown" + } +} + +func (f *ProgressFormatter) getEventTypeString(eventType ProgressEventType) string { + switch eventType { + case EventSyncStart: + return "sync_start" + case EventNegotiationComplete: + return "negotiation_complete" + case EventPageSent: + return "page_sent" + case EventPageReceived: + return "page_received" + case EventPageConfirmed: + return "page_confirmed" + case EventSyncComplete: + return "sync_complete" + case EventError: + return "error" + default: + return "unknown" + } +} + +// DefaultProgressCallback provides a simple callback that prints to stdout +func DefaultProgressCallback(format ProgressFormat) ProgressCallback { + formatter := NewProgressFormatter(&ProgressConfig{ + Enabled: true, + Format: format, + UpdateRate: 500 * time.Millisecond, + ShowETA: true, + ShowBytes: true, + ShowPages: true, + PagesPerUpdate: 10, + }) + + return func(event SyncProgressEvent) { + output := formatter.FormatProgress(event) + if output != "" { + // For detailed format, clear previous lines + if format == FormatDetailed && event.Type != EventSyncStart && event.Type != EventSyncComplete { + // Simple clear - in real implementation might want to use terminal escape codes + fmt.Print("\r") + } + + fmt.Println(output) + } + } +} diff --git a/client/subscription/manager.go b/client/subscription/manager.go new file mode 100644 index 0000000..304af46 --- /dev/null +++ b/client/subscription/manager.go @@ -0,0 +1,519 @@ +package subscription + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +// Message types for subscription communication +const ( + MsgTypeLatestVersion = "LATEST_VERSION" + MsgTypePing = "PING" + MsgTypePong = "PONG" + MsgTypeSubscribe = "SUBSCRIBE" + MsgTypeSubscribed = "SUBSCRIBED" + MsgTypeUnsubscribe = "UNSUBSCRIBE" + MsgTypeError = "ERROR" +) +const PING_INTERVAL = 1 * time.Hour + +// Message represents a subscription control message +type Message struct { + Type string `json:"type"` + Data map[string]interface{} `json:"data,omitempty"` + Timestamp time.Time `json:"timestamp"` +} + +// Config holds subscription manager configuration +type Config struct { + ServerURL string + ReplicaPath string + AuthToken string + ReplicaID string + Logger *zap.Logger + MaxReconnectAttempts int // Maximum number of reconnect attempts (0 = infinite) + InitialReconnectDelay time.Duration // Initial delay before first reconnect + MaxReconnectDelay time.Duration // Maximum delay between reconnect attempts +} + +// Manager handles WebSocket subscriptions for new version notifications +// with automatic reconnection using exponential backoff when connections are lost. +// +// The manager will automatically attempt to reconnect when WebSocket errors occur, +// with delays starting at InitialReconnectDelay and doubling until +// MaxReconnectDelay is reached. Reconnection attempts continue indefinitely unless +// MaxReconnectAttempts is set to a positive value. +type Manager struct { + config *Config + logger *zap.Logger + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + mu sync.RWMutex + connected bool + + // Reconnection state + reconnectAttempts int + lastConnectTime time.Time + reconnecting bool + + // Event channels + newVersionChan chan string + errorChan chan error +} + +// NewManager creates a new subscription manager +func NewManager(config *Config) *Manager { + ctx, cancel := context.WithCancel(context.Background()) + + // Set default reconnection parameters if not provided + if config.MaxReconnectAttempts == 0 { + config.MaxReconnectAttempts = -1 // Infinite reconnect attempts by default + } + if config.InitialReconnectDelay == 0 { + config.InitialReconnectDelay = 3 * time.Second + } + if config.MaxReconnectDelay == 0 { + config.MaxReconnectDelay = 300 * time.Second // 5 minutes max + } + + return &Manager{ + config: config, + logger: config.Logger, + ctx: ctx, + cancel: cancel, + newVersionChan: make(chan string, 1), + errorChan: make(chan error, 1), + } +} + +// Connect establishes WebSocket connection for subscription with automatic reconnection. +// If the initial connection fails, it will retry according to the configured reconnection parameters. +// Once connected, the manager will automatically handle reconnection if the connection is lost. +func (m *Manager) Connect() error { + return m.connectWithRetry(false) +} + +// connectWithRetry handles connection with optional retry logic +func (m *Manager) connectWithRetry(isReconnect bool) error { + m.mu.Lock() + if isReconnect { + m.reconnecting = true + } + m.mu.Unlock() + + defer func() { + m.mu.Lock() + m.reconnecting = false + m.mu.Unlock() + }() + + var lastErr error + currentDelay := m.config.InitialReconnectDelay + + for attempt := 0; m.config.MaxReconnectAttempts < 0 || attempt < m.config.MaxReconnectAttempts; attempt++ { + // Update reconnect attempts counter + m.mu.Lock() + m.reconnectAttempts = attempt + 1 + m.mu.Unlock() + + // Wait for delay on retry attempts (but not the very first attempt) + if attempt > 0 { + if isReconnect { + m.logger.Info("Waiting before reconnect attempt", + zap.Duration("delay", currentDelay), + zap.Int("attempt", attempt+1)) + } else { + m.logger.Info("Waiting before connection retry", + zap.Duration("delay", currentDelay), + zap.Int("attempt", attempt+1)) + } + + select { + case <-m.ctx.Done(): + return fmt.Errorf("connection cancelled during backoff") + case <-time.After(currentDelay): + } + } + + if err := m.doConnect(); err != nil { + lastErr = err + // Calculate next delay with exponential backoff for the following attempt + currentDelay = m.calculateNextDelay(currentDelay) + if isReconnect { + m.logger.Warn("Reconnection attempt failed", + zap.Error(err), + zap.Int("attempt", attempt+1), + zap.Duration("next_retry_in", currentDelay)) + } else { + m.logger.Warn("Connection attempt failed", + zap.Error(err), + zap.Int("attempt", attempt+1), + zap.Duration("next_retry_in", currentDelay)) + } + + continue + } + + // Connection successful + m.mu.Lock() + m.reconnectAttempts = 0 + m.lastConnectTime = time.Now() + m.mu.Unlock() + + if isReconnect { + m.logger.Info("Successfully reconnected to subscription service", + zap.Int("attempts", attempt+1)) + } else { + m.logger.Info("Successfully connected to subscription service") + } + return nil + } + + return fmt.Errorf("failed to connect after %d attempts: %w", + m.config.MaxReconnectAttempts, lastErr) +} + +// doConnect performs the actual WebSocket connection +func (m *Manager) doConnect() error { + m.logger.Info("Connecting to subscription service", zap.String("url", m.config.ServerURL)) + + u, err := url.Parse(m.config.ServerURL) + if err != nil { + return fmt.Errorf("invalid server URL: %w", err) + } + + // Add subscription endpoint + u.Path = strings.TrimSuffix(u.Path, "/") + "/sapi/subscribe/" + m.config.ReplicaPath + + headers := http.Header{} + headers.Set("Authorization", m.config.AuthToken) + if m.config.ReplicaID != "" { + headers.Set("X-ReplicaID", m.config.ReplicaID) + } + + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + + m.logger.Debug("Dialing WebSocket", zap.String("url", u.String())) + + conn, _, err := dialer.DialContext(m.ctx, u.String(), headers) + if err != nil { + return fmt.Errorf("failed to connect to subscription service: %w", err) + } + + m.mu.Lock() + m.conn = conn + m.connected = true + m.mu.Unlock() + + // Start message handling loops + go m.readLoop() + go m.pingLoop() + + // Send subscription message + if err := m.sendMessage(Message{ + Type: MsgTypeSubscribe, + Data: map[string]interface{}{ + "replicaID": m.config.ReplicaID, + }, + Timestamp: time.Now(), + }); err != nil { + return fmt.Errorf("failed to send subscribe message: %w", err) + } + + return nil +} + +// WaitForNewVersionMsg blocks until a new version is available +func (m *Manager) WaitForNewVersionMsg() (string, error) { + m.logger.Info("Waiting for a new version notification...") + + for { + select { + case <-m.ctx.Done(): + return "", fmt.Errorf("subscription cancelled") + case err := <-m.errorChan: + // Check if this is a reconnection failure + if strings.Contains(err.Error(), "reconnection failed") { + m.logger.Error("Reconnection failed permanently", zap.Error(err)) + return "", err + } + // For other errors, log and continue waiting (reconnection might be in progress) + m.logger.Warn("Temporary subscription error", zap.Error(err)) + continue + case version := <-m.newVersionChan: + m.logger.Info("Latest Version message received", zap.String("version", version)) + return version, nil + } + } +} + +// Close cleanly shuts down the subscription manager +func (m *Manager) Close() error { + m.logger.Info("Closing subscription manager") + + // Cancel context to stop all operations + m.cancel() + + m.mu.Lock() + if m.conn != nil { + // Send unsubscribe message (best effort) + m.sendMessage(Message{ + Type: MsgTypeUnsubscribe, + Timestamp: time.Now(), + }) + + // Close connection gracefully + m.conn.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + time.Now().Add(5*time.Second)) + m.conn.Close() + m.conn = nil + } + m.connected = false + m.reconnecting = false + m.mu.Unlock() + + return nil +} + +// IsConnected returns current connection status +func (m *Manager) IsConnected() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.connected +} + +// IsReconnecting returns whether the manager is currently attempting to reconnect +func (m *Manager) IsReconnecting() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.reconnecting +} + +// GetConnectionStatus returns detailed connection status information +func (m *Manager) GetConnectionStatus() (connected bool, reconnecting bool, lastConnectTime time.Time, reconnectAttempts int) { + m.mu.RLock() + defer m.mu.RUnlock() + return m.connected, m.reconnecting, m.lastConnectTime, m.reconnectAttempts +} + +// calculateNextDelay calculates the next delay using exponential backoff +func (m *Manager) calculateNextDelay(currentDelay time.Duration) time.Duration { + nextDelay := time.Duration(float64(currentDelay) * 2.0) + if nextDelay > m.config.MaxReconnectDelay { + return m.config.MaxReconnectDelay + } + return nextDelay +} + +// sendMessage sends a message to the server +func (m *Manager) sendMessage(msg Message) error { + m.mu.RLock() + conn := m.conn + m.mu.RUnlock() + + if conn == nil { + return fmt.Errorf("not connected") + } + + var data []byte + var err error + + if msg.Type == MsgTypePing { + data = []byte("PING") + } else if msg.Type == MsgTypePong { + data = []byte("PONG") + } else { + data, err = json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + } + + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return conn.WriteMessage(websocket.TextMessage, data) +} + +// readLoop handles incoming messages +func (m *Manager) readLoop() { + defer func() { + m.mu.Lock() + if m.conn != nil { + m.conn.Close() + m.conn = nil + } + m.connected = false + m.mu.Unlock() + }() + + for { + select { + case <-m.ctx.Done(): + return + default: + } + + m.mu.RLock() + conn := m.conn + m.mu.RUnlock() + + if conn == nil { + return + } + + conn.SetReadDeadline(time.Now().Add(PING_INTERVAL + 2*time.Minute)) // Longer than ping interval + + messageType, data, err := conn.ReadMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + m.logger.Info("WebSocket connection closed normally") + return + } + + m.logger.Error("WebSocket read error", zap.Error(err)) + + // Mark as disconnected + m.mu.Lock() + if m.conn != nil { + m.conn.Close() + m.conn = nil + } + m.connected = false + wasReconnecting := m.reconnecting + m.mu.Unlock() + + // Only attempt reconnection if we're not already reconnecting + // and the context hasn't been cancelled + select { + case <-m.ctx.Done(): + return + default: + if !wasReconnecting { + m.logger.Info("Attempting to reconnect after connection loss") + go m.attemptReconnect() + } + } + return + } + + if messageType != websocket.TextMessage { + continue + } + + var msg Message + if err := json.Unmarshal(data, &msg); err != nil { + m.logger.Warn("Failed to unmarshal message", zap.Error(err)) + continue + } + + m.handleMessage(msg) + } +} + +// attemptReconnect handles automatic reconnection logic +func (m *Manager) attemptReconnect() { + fmt.Println("🔄 Connection lost. Attempting to reconnect...") + + if err := m.connectWithRetry(true); err != nil { + m.logger.Error("Failed to reconnect to subscription service", zap.Error(err)) + fmt.Printf("❌ Reconnection failed: %v\n", err) + // Send error to error channel for coordinator to handle + select { + case m.errorChan <- fmt.Errorf("reconnection failed: %w", err): + default: + } + } else { + fmt.Println("✅ Reconnected successfully! Continuing to watch for updates...") + } +} + +// handleMessage processes incoming subscription messages +func (m *Manager) handleMessage(msg Message) { + // Only log debug for non-PONG messages + if false && msg.Type != MsgTypePong { + m.logger.Debug("Received message", zap.String("type", msg.Type)) + } + + switch msg.Type { + case MsgTypeLatestVersion, MsgTypeSubscribed: + latestVersion, ok := msg.Data["version"].(string) + if !ok { + actualValue := msg.Data["version"] + m.logger.Error("Invalid LATEST_VERSION message format: version field is not a string", + zap.Any("data", msg.Data), + zap.Any("actualVersion", actualValue), + zap.String("actualType", fmt.Sprintf("%T", actualValue))) + latestVersion = "latest" + } + select { + case m.newVersionChan <- latestVersion: + default: + // Channel full, version notification already pending + } + + case MsgTypePing: + // Respond to server ping + m.sendMessage(Message{ + Type: MsgTypePong, + Timestamp: time.Now(), + }) + + case MsgTypePong: + // Server responded to our ping - connection is healthy + m.logger.Info("Received PONG - connection healthy") + + case MsgTypeError: + errMsg, _ := msg.Data["message"].(string) + err := fmt.Errorf("subscription error: %s", errMsg) + select { + case m.errorChan <- err: + default: + } + default: + m.logger.Debug("Unknown message type", zap.String("type", msg.Type)) + } +} + +// pingLoop sends periodic ping messages +func (m *Manager) pingLoop() { + ticker := time.NewTicker(PING_INTERVAL) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + // Check if we're connected before trying to ping + m.mu.RLock() + connected := m.connected && m.conn != nil + m.mu.RUnlock() + + if !connected { + // Connection lost, stop ping loop - readLoop will handle reconnection + return + } + + if err := m.sendMessage(Message{ + Type: MsgTypePing, + Timestamp: time.Now(), + }); err != nil { + m.logger.Error("Failed to send ping", zap.Error(err)) + // Don't send to error channel here - let readLoop handle the disconnection + return + } + } + } +} diff --git a/client/sync/coordinator.go b/client/sync/coordinator.go new file mode 100644 index 0000000..faa75cd --- /dev/null +++ b/client/sync/coordinator.go @@ -0,0 +1,655 @@ +package sync + +import ( + "context" + "fmt" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + "time" + + "go.uber.org/zap" + + "github.com/fatih/color" + "github.com/sqlrsync/sqlrsync.com/auth" + "github.com/sqlrsync/sqlrsync.com/bridge" + "github.com/sqlrsync/sqlrsync.com/remote" + "github.com/sqlrsync/sqlrsync.com/subscription" +) + +// Operation represents a sync operation type +type Operation int + +const ( + OperationPull Operation = iota + OperationPush + OperationSubscribe + OperationLocalSync +) + +// Config holds sync coordinator configuration +type Config struct { + ServerURL string + ProvidedAuthToken string // Explicitly provided auth token + ProvidedPullKey string // Explicitly provided pull key + ProvidedPushKey string // Explicitly provided push key + ProvidedReplicaID string // Explicitly provided replica ID + LocalPath string + RemotePath string + ReplicaPath string // For LOCAL TO LOCAL sync + Version string + Operation Operation + SetVisibility int + DryRun bool + Logger *zap.Logger + Verbose bool +} + +// Coordinator manages sync operations and subscriptions +type Coordinator struct { + config *Config + logger *zap.Logger + authResolver *auth.Resolver + subManager *subscription.Manager + ctx context.Context + cancel context.CancelFunc +} + +// NewCoordinator creates a new sync coordinator +func NewCoordinator(config *Config) *Coordinator { + ctx, cancel := context.WithCancel(context.Background()) + + return &Coordinator{ + config: config, + logger: config.Logger, + authResolver: auth.NewResolver(config.Logger), + ctx: ctx, + cancel: cancel, + } +} + +// Execute runs the sync operation +func (c *Coordinator) Execute() error { + // Set up signal handling for graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + go func() { + <-sigChan + fmt.Println("\nShutting down...") + c.cancel() + // Force exit after 2 seconds if graceful shutdown fails + go func() { + time.Sleep(2 * time.Second) + fmt.Println("Force exiting...") + os.Exit(0) + }() + }() + + switch c.config.Operation { + case OperationPull: + return c.executePull(false) + case OperationPush: + return c.executePush() + case OperationSubscribe: + return c.executeSubscribe() + case OperationLocalSync: + return c.executeLocalSync() + default: + return fmt.Errorf("unknown operation") + } +} + +// displayDryRunInfo displays dry run information for different operations +func (c *Coordinator) displayDryRunInfo(operation string, authResult *auth.ResolveResult, absLocalPath, serverURL, remotePath, localHostname string) { + fmt.Println("SQLRsync Dry Run:") + + switch operation { + case "push": + fmt.Printf(" - Mode: %s the LOCAL ORIGIN file up to the REMOTE REPLICA\n", color.YellowString("PUSHing")) + fmt.Printf(" - LOCAL ORIGIN: %s\n", color.GreenString(absLocalPath)) + if remotePath == "" { + fmt.Println(" - REMOTE REPLICA: " + color.YellowString("(None - the server will assign a path using this hostname)")) + fmt.Printf(" - Hostname: %s\n", color.GreenString(localHostname)) + } else { + fmt.Printf(" - REMOTE REPLICA: %s\n", color.GreenString(remotePath)) + } + case "pull": + fmt.Printf(" - Mode: %s the REMOTE ORIGIN file down to LOCAL REPLICA\n", color.YellowString("PULLing")) + fmt.Printf(" - REMOTE ORIGIN: %s\n", color.GreenString(remotePath)) + fmt.Printf(" - LOCAL REPLICA: %s\n", color.GreenString(absLocalPath)) + case "subscribe": + fmt.Printf(" - Mode: %s to REMOTE ORIGIN to PULL down current and future updates\n", color.YellowString("SUBSCRIBing")) + fmt.Printf(" - REMOTE ORIGIN: %s\n", color.GreenString(remotePath)) + fmt.Printf(" - LOCAL REPLICA: %s\n", color.GreenString(absLocalPath)) + case "local": + fmt.Printf(" - Mode: %s between two databases\n", color.YellowString("LOCAL ONLY")) + } + + if operation != "local" { + fmt.Printf(" - Server: %s\n", color.GreenString(serverURL)) + + fmt.Printf(" - Auth Token: %s\n", color.GreenString(authResult.AuthToken)) + + if operation == "push" { + switch c.config.SetVisibility { + case 0: + fmt.Println(" - Visibility: " + color.YellowString("PRIVATE") + " (only accessible with access key)") + case 1: + fmt.Println(" - Visibility: " + color.YellowString("UNLISTED") + " (anyone with the link can access)") + case 2: + fmt.Println(" - Visibility: " + color.GreenString("PUBLIC") + " (anyone can access)") + } + } + + if c.authResolver.CheckNeedsDashFile(c.config.LocalPath, remotePath) { + fmt.Println(" - A shareable config (the -sqlrsync file) " + color.GreenString("WILL BE") + " created for future PULLs and SUBSCRIBEs") + } else { + fmt.Println(" - A shareable config (the -sqlrsync file) will " + color.RedString("NOT") + " be created") + } + } else { + // For local sync, show the replica path + if c.config.ReplicaPath != "" { + absReplicaPath, _ := filepath.Abs(c.config.ReplicaPath) + fmt.Printf(" - LOCAL ORIGIN: %s\n", color.GreenString(absLocalPath)) + fmt.Printf(" - LOCAL REPLICA: %s\n", color.GreenString(absReplicaPath)) + } + } + fmt.Println("\nAfter running this command, REPLICA will become a copy of ORIGIN at the moment the command begins.") +} + +// resolveAuth resolves authentication for the given operation +func (c *Coordinator) resolveAuth(operation string) (*auth.ResolveResult, error) { + req := &auth.ResolveRequest{ + LocalPath: c.config.LocalPath, + RemotePath: c.config.RemotePath, + ServerURL: c.config.ServerURL, + ProvidedPullKey: c.config.ProvidedPullKey, + ProvidedPushKey: c.config.ProvidedPushKey, + ProvidedReplicaID: c.config.ProvidedReplicaID, + Operation: operation, + Logger: c.logger, + } + + // Try explicit auth token first + if c.config.ProvidedAuthToken != "" { + return &auth.ResolveResult{ + AuthToken: c.config.ProvidedAuthToken, + ReplicaID: c.config.ProvidedReplicaID, + ServerURL: c.config.ServerURL, + RemotePath: c.config.RemotePath, + LocalPath: c.config.LocalPath, + }, nil + } + + result, err := c.authResolver.Resolve(req) + if err != nil { + return nil, err + } + + // If prompting is needed for push operations + if result.ShouldPrompt && operation == "push" { + token, err := c.authResolver.PromptForAdminKey(c.config.ServerURL) + if err != nil { + return nil, err + } + result.AuthToken = token + result.ShouldPrompt = false + } + + return result, nil +} + +// executeSubscribe runs pull sync with subscription for new versions +func (c *Coordinator) executeSubscribe() error { + fmt.Println("📡 Subscribe mode enabled - will watch for new versions...") + fmt.Println(" Press Ctrl+C to stop watching...") + + // Resolve authentication for subscription + authResult, err := c.resolveAuth("subscribe") + if err != nil { + return fmt.Errorf("authentication failed: %w", err) + } + + // Check for dry run mode + if c.config.DryRun { + absLocalPath, err := filepath.Abs(c.config.LocalPath) + if err != nil { + return fmt.Errorf("failed to get absolute path: %w", err) + } + localHostname, _ := os.Hostname() + + serverURL := authResult.ServerURL + if c.config.ServerURL != "" && c.config.ServerURL != "wss://sqlrsync.com" { + serverURL = c.config.ServerURL + } + + remotePath := authResult.RemotePath + if c.config.RemotePath != "" { + remotePath = c.config.RemotePath + } + + c.displayDryRunInfo("subscribe", authResult, absLocalPath, serverURL, remotePath, localHostname) + return nil + } + + // Create subscription manager with reconnection configuration + c.subManager = subscription.NewManager(&subscription.Config{ + ServerURL: authResult.ServerURL, + ReplicaPath: authResult.RemotePath, + AuthToken: authResult.AuthToken, + ReplicaID: authResult.ReplicaID, + Logger: c.logger.Named("subscription"), + MaxReconnectAttempts: 20, // Infinite reconnect attempts + InitialReconnectDelay: 5 * time.Second, // Start with 5 seconds delay + MaxReconnectDelay: 5 * time.Minute, // Cap at 5 minutes + }) + + c.logger.Info("Starting subscription service", zap.String("server", authResult.ServerURL)) + + // Connect to subscription service + if err := c.subManager.Connect(); err != nil { + return fmt.Errorf("failed to connect to subscription service2: %w", err) + } + defer c.subManager.Close() + + syncCount := 0 + for { + syncCount++ + fmt.Printf("🔄 Starting sync...\n") + + // Perform pull sync + if err := c.executePull(true); err != nil { + c.logger.Error("Sync failed", zap.Error(err), zap.Int("syncCount", syncCount)) + return fmt.Errorf("sync #%d failed: %w", syncCount, err) + } + + fmt.Printf("✅ Sync complete. Waiting for new version...\n") + + // Wait for new version or shutdown + select { + case <-c.ctx.Done(): + fmt.Println("Subscription stopped by user.") + return nil + default: + } + + // Wait for new version notification + var version string + for { + version, err = c.subManager.WaitForNewVersionMsg() + if err != nil { + // Check if this is a cancellation (graceful shutdown) + if strings.Contains(err.Error(), "cancelled") { + fmt.Println("Subscription stopped by user.") + return nil + } + + // Check if this is a permanent reconnection failure + if strings.Contains(err.Error(), "reconnection failed") { + fmt.Printf("❌ Failed to maintain connection to subscription service: %v\n", err) + fmt.Println(" Please check your network connection and try again later.") + return fmt.Errorf("subscription connection lost: %w", err) + } + + c.logger.Error("Subscription error", zap.Error(err)) + return fmt.Errorf("subscription error: %w", err) + } + if c.config.Version == version { + fmt.Printf("â„šī¸ Already at version %s, waiting for next update...\n", version) + continue + } else { + break + } + } + + fmt.Printf("🔄 New version %s announced!\n", version) + // Update version for next sync + if version != "latest" { + c.config.Version = version + } + + } +} + +// executePull performs a single pull sync operation +func (c *Coordinator) executePull(isSubscription bool) error { + // Resolve authentication + authResult, err := c.resolveAuth("pull") + if err != nil { + return fmt.Errorf("authentication failed: %w", err) + } + + version := c.config.Version + if version == "" { + version = "latest" + } + + // Use resolved values, with config overrides + serverURL := authResult.ServerURL + // Only override if user explicitly provided a different server (not just using the default) + if c.config.ServerURL != "" && c.config.ServerURL != "wss://sqlrsync.com" { + serverURL = c.config.ServerURL + } + + remotePath := authResult.RemotePath + if c.config.RemotePath != "" { + remotePath = c.config.RemotePath + } + + // Get absolute path and hostname for dry run display + absLocalPath, err := filepath.Abs(c.config.LocalPath) + if err != nil { + return fmt.Errorf("failed to get absolute path: %w", err) + } + localHostname, _ := os.Hostname() + + if c.config.DryRun { + c.displayDryRunInfo("pull", authResult, absLocalPath, serverURL, remotePath, localHostname) + return nil + } + + if !isSubscription { + fmt.Printf("PULLing down from %s/%s@%s ...\n", serverURL, remotePath, version) + } + + // Create remote client for WebSocket transport + remoteClient, err := remote.New(&remote.Config{ + ServerURL: serverURL + "/sapi/pull/" + remotePath, + AuthToken: authResult.AuthToken, + ReplicaID: authResult.ReplicaID, + Timeout: 8000, + PingPong: false, // No ping/pong needed for single sync + Logger: c.logger.Named("remote"), + Subscribe: false, // Subscription handled separately + EnableTrafficInspection: c.config.Verbose, + InspectionDepth: 5, + Version: version, + SendConfigCmd: true, + SendKeyRequest: c.authResolver.CheckNeedsDashFile(c.config.LocalPath, remotePath), + //ProgressCallback: remote.DefaultProgressCallback(remote.FormatSimple), + ProgressCallback: nil, + ProgressConfig: &remote.ProgressConfig{ + Enabled: true, + Format: remote.FormatSimple, + UpdateRate: 500 * time.Millisecond, + ShowETA: true, + ShowBytes: true, + ShowPages: true, + PagesPerUpdate: 10, + }, + }) + if err != nil { + return fmt.Errorf("failed to create remote client: %w", err) + } + defer remoteClient.Close() + + // Connect to remote server + if err := remoteClient.Connect(); err != nil { + return fmt.Errorf("failed to connect to server: %w", err) + } + + // Create local client for SQLite operations + localClient, err := bridge.New(&bridge.Config{ + DatabasePath: c.config.LocalPath, + DryRun: c.config.DryRun, + Logger: c.logger.Named("local"), + }) + if err != nil { + return fmt.Errorf("failed to create local client: %w", err) + } + defer localClient.Close() + + // Perform the sync + if err := c.performPullSync(localClient, remoteClient); err != nil { + return fmt.Errorf("pull synchronization failed: %w", err) + } + c.config.Version = remoteClient.GetVersion() + // Save pull result if needed + if remoteClient.GetNewPullKey() != "" && c.authResolver.CheckNeedsDashFile(c.config.LocalPath, remotePath) { + if err := c.authResolver.SavePullResult( + c.config.LocalPath, + serverURL, + remoteClient.GetReplicaPath(), + remoteClient.GetReplicaID(), + remoteClient.GetNewPullKey(), + ); err != nil { + c.logger.Warn("Failed to save pull result", zap.Error(err)) + } else { + fmt.Println("🔑 Shareable config file created for future pulls") + } + } + if !isSubscription { + c.logger.Info("Pull synchronization completed successfully") + } + return nil +} + +// executePush performs a push sync operation +func (c *Coordinator) executePush() error { + // Validate that database file exists + if _, err := os.Stat(c.config.LocalPath); os.IsNotExist(err) { + return fmt.Errorf("database file does not exist: %s", c.config.LocalPath) + } + + // Resolve authentication + authResult, err := c.resolveAuth("push") + if err != nil { + return fmt.Errorf("authentication failed: %w", err) + } + + // Use resolved values, with config overrides + serverURL := authResult.ServerURL + // Only override if user explicitly provided a different server (not just using the default) + if c.config.ServerURL != "" && c.config.ServerURL != "wss://sqlrsync.com" { + serverURL = c.config.ServerURL + } + + remotePath := authResult.RemotePath + if c.config.RemotePath != "" { + remotePath = c.config.RemotePath + } + + // Create local client for SQLite operations + localClient, err := bridge.New(&bridge.Config{ + DatabasePath: c.config.LocalPath, + DryRun: c.config.DryRun, + Logger: c.logger.Named("local"), + }) + if err != nil { + return fmt.Errorf("failed to create local client: %w", err) + } + defer localClient.Close() + + // Get absolute path for the local database + absLocalPath, err := filepath.Abs(c.config.LocalPath) + if err != nil { + return fmt.Errorf("failed to get absolute path: %w", err) + } + + localHostname, _ := os.Hostname() + + if c.config.DryRun { + c.displayDryRunInfo("push", authResult, absLocalPath, serverURL, remotePath, localHostname) + return nil + } + + fmt.Printf("PUSHing up to %s/%s ...\n", serverURL, remotePath) + + // Create remote client for WebSocket transport + remoteClient, err := remote.New(&remote.Config{ + ServerURL: serverURL + "/sapi/push/" + remotePath, + PingPong: true, + Timeout: 15000, + AuthToken: authResult.AuthToken, + Logger: c.logger.Named("remote"), + EnableTrafficInspection: c.config.Verbose, + LocalHostname: localHostname, + LocalAbsolutePath: absLocalPath, + InspectionDepth: 5, + SendKeyRequest: c.authResolver.CheckNeedsDashFile(c.config.LocalPath, remotePath), + SendConfigCmd: true, + SetVisibility: c.config.SetVisibility, + ProgressCallback: nil, //remote.DefaultProgressCallback(remote.FormatSimple), + ProgressConfig: &remote.ProgressConfig{ + Enabled: true, + Format: remote.FormatSimple, + UpdateRate: 500 * time.Millisecond, + ShowETA: true, + ShowBytes: true, + ShowPages: true, + PagesPerUpdate: 10, + }, + }) + if err != nil { + return fmt.Errorf("failed to create remote client: %w", err) + } + defer remoteClient.Close() + + // Connect to remote server + if err := remoteClient.Connect(); err != nil { + return fmt.Errorf("failed to connect to server: %w", err) + } + + // Perform the sync + if err := c.performPushSync(localClient, remoteClient); err != nil { + return fmt.Errorf("push synchronization failed: %w", err) + } + + // Save push result if we got new keys + if remoteClient.GetNewPushKey() != "" { + if err := c.authResolver.SavePushResult( + absLocalPath, + serverURL, + remoteClient.GetReplicaPath(), + remoteClient.GetReplicaID(), + remoteClient.GetNewPushKey(), + ); err != nil { + c.logger.Warn("Failed to save push result", zap.Error(err)) + } else { + fmt.Println("🔑 A new PUSH access key was stored at ~/.config/sqlrsync/ for ") + fmt.Println(" revokable permission to push updates in the future. Just") + fmt.Println(" use `sqlrsync " + absLocalPath + "` to push again.") + } + } + + // Create -sqlrsync file for sharing if needed + if c.authResolver.CheckNeedsDashFile(c.config.LocalPath, remotePath) && remoteClient.GetNewPullKey() != "" { + if err := c.authResolver.SavePullResult( + c.config.LocalPath, + serverURL, + remoteClient.GetReplicaPath(), + remoteClient.GetReplicaID(), + remoteClient.GetNewPullKey(), + ); err != nil { + c.logger.Warn("Failed to create shareable config file", zap.Error(err)) + } else { + fmt.Println("🔑 A new PULL access key was created: " + c.config.LocalPath + "-sqlrsync") + fmt.Println(" Share this file to allow others to download or subscribe") + fmt.Println(" to this database.") + } + } + + c.logger.Info("Push synchronization completed successfully") + return nil +} + +// performPullSync executes the pull synchronization +func (c *Coordinator) performPullSync(localClient *bridge.Client, remoteClient *remote.Client) error { + // Create I/O bridge between remote and local clients + readFunc := func(buffer []byte) (int, error) { + return remoteClient.Read(buffer) + } + + writeFunc := func(data []byte) error { + return remoteClient.Write(data) + } + + /* + progress := remoteClient.GetProgress() + if progress != nil { + fmt.Printf("Current progress: %.1f%% (%d/%d pages)\n", + progress.PercentComplete, progress.PagesSent, progress.TotalPages) + }*/ + + // Run the replica sync through the bridge + return localClient.RunPullSync(readFunc, writeFunc) +} + +// performPushSync executes the push synchronization +func (c *Coordinator) performPushSync(localClient *bridge.Client, remoteClient *remote.Client) error { + // Create I/O bridge between local and remote clients + readFunc := func(buffer []byte) (int, error) { + return remoteClient.Read(buffer) + } + + writeFunc := func(data []byte) error { + return remoteClient.Write(data) + } + /* + progress := remoteClient.GetProgress() + if progress != nil { + fmt.Printf("Current progress: %.1f%% (%d/%d pages)\n", + progress.PercentComplete, progress.PagesSent, progress.TotalPages) + }*/ + + // Run the origin sync through the bridge + return localClient.RunPushSync(readFunc, writeFunc) +} + +// executeLocalSync performs a direct local-to-local sync operation +func (c *Coordinator) executeLocalSync() error { + // Validate that both database files exist + if _, err := os.Stat(c.config.LocalPath); os.IsNotExist(err) { + return fmt.Errorf("origin database file does not exist: %s", c.config.LocalPath) + } + + // For replica, it's okay if it doesn't exist - it will be created + absOriginPath, err := filepath.Abs(c.config.LocalPath) + if err != nil { + return fmt.Errorf("failed to get absolute path for origin: %w", err) + } + + absReplicaPath, err := filepath.Abs(c.config.ReplicaPath) + if err != nil { + return fmt.Errorf("failed to get absolute path for replica: %w", err) + } + + if c.config.DryRun { + c.displayDryRunInfo("local", nil, absOriginPath, "", "", "") + return nil + } + + fmt.Printf("Syncing LOCAL to LOCAL: %s → %s\n", absOriginPath, absReplicaPath) + + // Create local client for SQLite operations + localClient, err := bridge.New(&bridge.Config{ + DatabasePath: absOriginPath, + DryRun: c.config.DryRun, + Logger: c.logger.Named("local"), + }) + if err != nil { + return fmt.Errorf("failed to create local client: %w", err) + } + defer localClient.Close() + + // Perform direct sync + if err := localClient.RunDirectSync(absReplicaPath); err != nil { + return fmt.Errorf("local-to-local synchronization failed: %w", err) + } + + c.logger.Info("Local-to-local synchronization completed successfully") + fmt.Println("✅ Local sync completed") + return nil +} + +// Close cleanly shuts down the coordinator +func (c *Coordinator) Close() error { + c.cancel() + if c.subManager != nil { + return c.subManager.Close() + } + return nil +}