diff --git a/PKGBUILD b/PKGBUILD index 48e24d5..52a5275 100644 --- a/PKGBUILD +++ b/PKGBUILD @@ -1,5 +1,5 @@ pkgname=gonx -pkgver=0.0.2 +pkgver=0.0.3 pkgrel=1 pkgdesc='Simple reverse proxy server' arch=('x86_64' 'aarch64') diff --git a/app.go b/app.go new file mode 100644 index 0000000..745e85c --- /dev/null +++ b/app.go @@ -0,0 +1,77 @@ +package main + +import ( + "crypto/tls" + "fmt" + "log/slog" + "net" +) + +type App struct { + config *Config + tlsListener net.Listener +} + +func newApp(config *Config) (*App, error) { + err := config.initTls() + if err != nil { + return nil, fmt.Errorf("failed to load TLS keys: %v", err) + } + + app := &App{config: config} + + return app, nil +} + +func (app *App) reloadConfig(configFilePath string) error { + config, err := LoadConfig(configFilePath) + if err != nil { + return fmt.Errorf("failed to load config: %v", err) + } + + err = config.initTls() + if err != nil { + return fmt.Errorf("failed to load TLS keys: %v", err) + } + + app.config = config + + slog.Debug("Deactivating TLS listener", slog.String("addr", app.config.TlsListenAddr)) + err = app.tlsListener.Close() + if err != nil { + return fmt.Errorf("failed to close TLS listener: %v", err) + } + + err = app.restartTlsListener() + if err != nil { + return fmt.Errorf("failed to restart TLS listener: %v", err) + } + + return nil +} + +func (app *App) restartTlsListener() error { + slog.Debug("Starting TLS listener", slog.String("addr", app.config.TlsListenAddr)) + + tlsListener, err := tls.Listen("tcp", app.config.TlsListenAddr, app.config.tlsConfig) + if err != nil { + return fmt.Errorf("failed to open tls listener: %v", err) + } + + app.tlsListener = tlsListener + + go func() { + for { + conn, err := tlsListener.Accept() + if err != nil { + slog.Error("failed to receive connection", slog.String("err", err.Error())) // TODO: drop error on closing TLS listener + break + } + slog.Debug("incoming connection", slog.String("RemoteAddr", conn.RemoteAddr().String())) + + go func() { _ = handleTlsConn(conn.(*tls.Conn), app.config.proxyRules) }() + } + }() + + return nil +} diff --git a/config.go b/config.go index c6a2585..5671fe4 100644 --- a/config.go +++ b/config.go @@ -55,6 +55,8 @@ func LoadConfig(configFilePath string) (*Config, error) { } func (c *Config) initTls() error { + slog.Debug("Loading TLS keys") + c.tlsConfig = new(tls.Config) for hostName := range c.proxyRules { diff --git a/main.go b/main.go index d2792d2..ec3811d 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "crypto/tls" "log/slog" "net/http" "os" @@ -33,32 +32,17 @@ func main() { }})) slog.SetDefault(logger) - err = config.initTls() + app, err := newApp(config) if err != nil { - slog.Error("init tls error", slog.String("err", err.Error())) + slog.Error("Failed to start app", slog.String("err", err.Error())) os.Exit(1) } - go func() { - slog.Debug("Starting TLS listener", slog.String("addr", config.TlsListenAddr)) - - listener, err := tls.Listen("tcp", config.TlsListenAddr, config.tlsConfig) - if err != nil { - slog.Error("Failed to open tls listener", slog.String("err", err.Error())) - os.Exit(1) - } - - for { - conn, err := listener.Accept() - if err != nil { - slog.Debug("incoming connection failed", slog.String("err", err.Error())) - continue - } - slog.Debug("incoming connection", slog.String("RemoteAddr", conn.RemoteAddr().String())) - - go func() { _ = handleTlsConn(conn.(*tls.Conn), config.proxyRules) }() - } - }() + err = app.restartTlsListener() + if err != nil { + slog.Error("Failed to start TLS listener", slog.String("err", err.Error())) + os.Exit(1) + } go func() { slog.Debug("Starting HTTP listener", slog.String("addr", config.HttpListenAddr)) @@ -77,6 +61,21 @@ func main() { } }() + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGUSR1) + for { + <-c + slog.Debug("TLS keys reload requested") + + err := app.reloadConfig(configFilePath) + if err != nil { + slog.Error("failed to reload TLS keys", slog.String("err", err.Error())) + } + slog.Debug("Reloading TLS keys completed") + } + }() + c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) <-c