Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mysql.template.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@
"username": "user",
"password": "password",
"server": "localhost",
"database": "mydb"
"database": "mydb",
"tls": false,
"caCertPath": "/path/to/CA/certificate",
"clientCertPath": "/path/to/client/certificate",
"clientkeyPath": "/path/to/client/key"
}
66 changes: 57 additions & 9 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@ package utils

import (
"bufio"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/rs/zerolog"
"golang.org/x/sys/unix"
"io"
"os"
"path/filepath"
"pbench/log"
"reflect"
"strings"

"github.com/go-sql-driver/mysql"
"github.com/rs/zerolog"
"golang.org/x/sys/unix"
)

const (
Expand Down Expand Up @@ -64,6 +68,35 @@ func InitLogFile(logPath string) (finalizer func()) {
}
}

func createTLSConfig(caCertPath, clientCertPath, clientKeyPath string) (*tls.Config, error) {
rootCertPool := x509.NewCertPool()
pem, err := os.ReadFile(caCertPath)
if err != nil {
log.Error().Err(err).Msg("failed to read CA certificate")
return nil, err
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
log.Error().Msg("failed to append CA certificate")
return nil, err
}
tlsConfig := &tls.Config{
RootCAs: rootCertPool,
}

// Check if client certificate and key are provided for mutual TLS
if clientCertPath != "" && clientKeyPath != "" {
clientCert := make([]tls.Certificate, 0, 1)
certs, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
if err != nil {
log.Error().Err(err).Msg("failed to load client certificate or key")
return nil, err
}
clientCert = append(clientCert, certs)
tlsConfig.Certificates = clientCert
}
return tlsConfig, nil
}

func InitMySQLConnFromCfg(cfgPath string) *sql.DB {
if cfgPath == "" {
return nil
Expand All @@ -73,17 +106,32 @@ func InitMySQLConnFromCfg(cfgPath string) *sql.DB {
return nil
} else {
mySQLCfg := &struct {
Username string `json:"username"`
Password string `json:"password"`
Server string `json:"server"`
Database string `json:"database"`
}{}
Username string `json:"username"`
Password string `json:"password"`
Server string `json:"server"`
Database string `json:"database"`
TLS bool `json:"tls"`
CaCertPath string `json:"caCertPath"`
ClientCertPath string `json:"clientCertPath"`
ClientKeyPath string `json:"clientKeyPath"`
}{
TLS: false,
}
if err := json.Unmarshal(cfgBytes, mySQLCfg); err != nil {
log.Error().Err(err).Msg("failed to unmarshal MySQL connection config for the run recorder")
return nil
}
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true",
mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database)); err != nil {
tlsType := "false"
if mySQLCfg.TLS {
tlsType = "custom"
tlsConfig, err := createTLSConfig(mySQLCfg.CaCertPath, mySQLCfg.ClientCertPath, mySQLCfg.ClientKeyPath)
if err != nil {
log.Error().Err(err).Msg("TLS enabled but failed to load certificates")
}
mysql.RegisterTLSConfig(tlsType, tlsConfig)
}
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s&parseTime=true",
mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database, tlsType)); err != nil {
log.Error().Err(err).Msg("failed to initialize MySQL connection for the run recorder")
return nil
} else if err = db.Ping(); err != nil {
Expand Down