tlsrp

TLS reverse proxy
git clone git://git.rr3.xyz/tlsrp
Log | Files | Refs | README | LICENSE

commit a31b76827ed7ed10637b7fe730277a3425071dfc
parent d257c068fa028e68e51aff4e61eeeb5978b51270
Author: Robert Russell <robertrussell.72001@gmail.com>
Date:   Sun, 14 Jul 2024 21:51:45 -0700

Add go mod files

Diffstat:
Ago.mod | 5+++++
Ago.sum | 2++
Mtlsrp.go | 103++++++++++++++++++++++++++++++++++++++-----------------------------------------
3 files changed, 57 insertions(+), 53 deletions(-)

diff --git a/go.mod b/go.mod @@ -0,0 +1,5 @@ +module rr3.xyz/tlsrp + +go 1.22.5 + +require golang.org/x/sys v0.22.0 // indirect diff --git a/go.sum b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/tlsrp.go b/tlsrp.go @@ -3,7 +3,6 @@ package main // TODO: check log.Printf newlines import ( - "bytes" "context" "crypto/tls" "errors" @@ -91,8 +90,9 @@ func parseHostname(s string) (hostname, error) { buf = append(buf, '.') } - first := true - for _, r := range label { + for i, r := range label { + first := i == 0 + last := i == len(label) - 1 switch { case 'A' <= r && r <= 'Z': r += 'a' - 'A' @@ -100,19 +100,16 @@ func parseHostname(s string) (hostname, error) { // Ok case '0' <= r && r <= '9': // Ok - case r == '-' && !first: + case r == '-' && (!first && !last): // Ok case r == '-' && first: return "", fmt.Errorf("illegal hostname: hyphen at start of label") + case r == '-' && last: + return "", fmt.Errorf("illegal hostname: hyphen at end of label") default: return "", fmt.Errorf("illegal hostname: illegal rune: %q", r) } buf = append(buf, byte(r)) - first = false - } - - if buf[len(buf)-1] == '-' { - return "", fmt.Errorf("illegal hostname: hyphen at end of label") } } @@ -126,11 +123,11 @@ func (pattern pattern) matches(hostname hostname) bool { return slices.Contains(pattern, hostname) } -func parsePattern(ss [][]byte) (pattern, error) { +func parsePattern(ss []string) (pattern, error) { pat := make(pattern, 0, len(ss)) for _, s := range ss { - hostname, err := parseHostname(string(s)) + hostname, err := parseHostname(s) if err != nil { return nil, err } @@ -140,9 +137,6 @@ func parsePattern(ss [][]byte) (pattern, error) { return pat, nil } -// TODO: Come up with better names for the types "sink" and "cert". -// These are really more like "sink specifications" and "cert specifications". - type sink struct { pattern pattern network string @@ -183,18 +177,17 @@ func loadConfig(path string) (config, error) { var cfg config - lines := bytes.Split(data, []byte{'\n'}) + lines := strings.Split(string(data), "\n") for _, line := range lines { - if len(line) == 0 { - continue + fields := strings.Fields(line) + if len(fields) == 0 { + continue // Empty line } - - fields := bytes.Fields(line) if len(fields) < 3 { return config{}, fmt.Errorf("illegal config: line with fewer than 3 fields") } - switch string(fields[0]) { + switch fields[0] { case "sink": pat, err := parsePattern(fields[3:]) if err != nil { @@ -203,15 +196,15 @@ func loadConfig(path string) (config, error) { sink := sink{ pattern: pat, - network: string(fields[1]), - address: string(fields[2]), + network: fields[1], + address: fields[2], } cfg.sinks = append(cfg.sinks, sink) case "cert": - crtPath := string(fields[1]) - keyPath := string(fields[2]) - c, err := tls.LoadX509KeyPair(crtPath, keyPath) + crtPath := fields[1] + keyPath := fields[2] + tlsCert, err := tls.LoadX509KeyPair(crtPath, keyPath) if err != nil { return config{}, err } @@ -225,7 +218,7 @@ func loadConfig(path string) (config, error) { pattern: pat, crtPath: crtPath, keyPath: keyPath, - cert: &c, + cert: &tlsCert, } cfg.certs = append(cfg.certs, cert) @@ -237,14 +230,14 @@ func loadConfig(path string) (config, error) { return cfg, nil } -func manageConfig(configPath string) { - cfg, err := loadConfig(configPath) +func manageConfig(cfgPath string) { + cfg, err := loadConfig(cfgPath) if err != nil { log.Printf("failed to load initial configuration: %s\n", err) + // Proceed with cfg == config{}, i.e., with no sinks and no certs, + // causing every client to be rejected. } - // Proceed with cfg == config{}, i.e., with no sinks and no certs, causing - // every client to be rejected. - + sigusr := make(chan os.Signal, 2) signal.Notify(sigusr, unix.SIGUSR1, unix.SIGUSR2) @@ -256,17 +249,19 @@ func manageConfig(configPath string) { log.Println("received SIGUSR1; reloading certificates") certs := cfg.certs for i := range certs { - c, err := tls.LoadX509KeyPair(certs[i].crtPath, certs[i].keyPath) + crtPath := certs[i].crtPath + keyPath := certs[i].keyPath + tlsCert, err := tls.LoadX509KeyPair(crtPath, keyPath) if err == nil { - certs[i].cert = &c + certs[i].cert = &tlsCert } else { - log.Printf("failed to reload certificates: %s\n", err) + log.Printf("failed to reload certificate (%s, %s): %s\n", crtPath, keyPath, err) } } case unix.SIGUSR2: log.Println("received SIGUSR2; reloading configuration") - newCfg, err := loadConfig(configPath) + newCfg, err := loadConfig(cfgPath) if err == nil { cfg = newCfg } else { @@ -417,22 +412,24 @@ func accept(listener net.Listener) { } func manageExit() { - softSigs := make(chan os.Signal, 1) - signal.Notify(softSigs, unix.SIGINT) - - hardSigs := make(chan os.Signal, 1) - signal.Notify(hardSigs, unix.SIGQUIT, unix.SIGTERM) - - select { - case <-softSigs: - log.Println("received SIGINT; exiting softly") - close(softExit) - <-hardSigs - log.Println("received SIGQUIT/SIGTERM; exiting harshly") - close(hardExit) - case <-hardSigs: - log.Println("received SIGQUIT/SIGTERM; exiting harshly") - close(hardExit) + sigs := make(chan os.Signal, 3) + signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM) + + softExiting := false + for sig := range sigs { + switch sig { + case unix.SIGINT: + if !softExiting { + log.Println("received SIGINT; exiting softly") + close(softExit) + softExiting = true + } + // Keep going in case we receive hard exit signal. + case unix.SIGQUIT, unix.SIGTERM: + log.Println("received SIGQUIT/SIGTERM; exiting harshly") + close(hardExit) + return + } } } @@ -498,7 +495,7 @@ func main() { log.Fatalln("expected 2 or more arguments") } - configPath := flag.Args()[0] + cfgPath := flag.Args()[0] listeners, err := listen(flag.Args()[1:]) if err != nil { @@ -506,7 +503,7 @@ func main() { } go manageExit() - go manageConfig(configPath) + go manageConfig(cfgPath) var wg sync.WaitGroup defer wg.Wait()