diff --git a/mysql.template.json b/mysql.template.json index 9cf36a21..ce79887b 100644 --- a/mysql.template.json +++ b/mysql.template.json @@ -2,5 +2,9 @@ "username": "user", "password": "password", "server": "localhost", - "database": "mydb" + "database": "mydb", + "tls": false, + "caCertPath": "/path/to/CA/certificate", + "clientCertPath": "/path/to/client/certificate", + "clientkeyPath": "/path/to/client/key" } diff --git a/utils/utils.go b/utils/utils.go index 8556cf3f..e3521d79 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,18 +2,22 @@ package utils import ( "bufio" + "crypto/tls" + "crypto/x509" "database/sql" "encoding/json" "errors" "fmt" - "github.com/rs/zerolog" - "golang.org/x/sys/unix" "io" "os" "path/filepath" "pbench/log" "reflect" "strings" + + "github.com/go-sql-driver/mysql" + "github.com/rs/zerolog" + "golang.org/x/sys/unix" ) const ( @@ -64,6 +68,35 @@ func InitLogFile(logPath string) (finalizer func()) { } } +func createTLSConfig(caCertPath, clientCertPath, clientKeyPath string) (*tls.Config, error) { + rootCertPool := x509.NewCertPool() + pem, err := os.ReadFile(caCertPath) + if err != nil { + log.Error().Err(err).Msg("failed to read CA certificate") + return nil, err + } + if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { + log.Error().Msg("failed to append CA certificate") + return nil, err + } + tlsConfig := &tls.Config{ + RootCAs: rootCertPool, + } + + // Check if client certificate and key are provided for mutual TLS + if clientCertPath != "" && clientKeyPath != "" { + clientCert := make([]tls.Certificate, 0, 1) + certs, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath) + if err != nil { + log.Error().Err(err).Msg("failed to load client certificate or key") + return nil, err + } + clientCert = append(clientCert, certs) + tlsConfig.Certificates = clientCert + } + return tlsConfig, nil +} + func InitMySQLConnFromCfg(cfgPath string) *sql.DB { if cfgPath == "" { return nil @@ -73,17 +106,32 @@ func InitMySQLConnFromCfg(cfgPath string) *sql.DB { return nil } else { mySQLCfg := &struct { - Username string `json:"username"` - Password string `json:"password"` - Server string `json:"server"` - Database string `json:"database"` - }{} + Username string `json:"username"` + Password string `json:"password"` + Server string `json:"server"` + Database string `json:"database"` + TLS bool `json:"tls"` + CaCertPath string `json:"caCertPath"` + ClientCertPath string `json:"clientCertPath"` + ClientKeyPath string `json:"clientKeyPath"` + }{ + TLS: false, + } if err := json.Unmarshal(cfgBytes, mySQLCfg); err != nil { log.Error().Err(err).Msg("failed to unmarshal MySQL connection config for the run recorder") return nil } - if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true", - mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database)); err != nil { + tlsType := "false" + if mySQLCfg.TLS { + tlsType = "custom" + tlsConfig, err := createTLSConfig(mySQLCfg.CaCertPath, mySQLCfg.ClientCertPath, mySQLCfg.ClientKeyPath) + if err != nil { + log.Error().Err(err).Msg("TLS enabled but failed to load certificates") + } + mysql.RegisterTLSConfig(tlsType, tlsConfig) + } + if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s&parseTime=true", + mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database, tlsType)); err != nil { log.Error().Err(err).Msg("failed to initialize MySQL connection for the run recorder") return nil } else if err = db.Ping(); err != nil {