tlsrp

TLS reverse proxy
git clone git://git.rr3.xyz/tlsrp
Log | Files | Refs | README | LICENSE

commit 5f3c7bb1cdf0cd76d4716d158aa205bc5a3418e2
parent edd853eceb61e994d9726533072ed45fd8aad88c
Author: Robert Russell <robertrussell.72001@gmail.com>
Date:   Tue, 16 Jul 2024 15:00:09 -0700

Move hostname stuff to separate file

Diffstat:
Ahostname.go | 103+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mtlsrp.go | 262++++++++++++++++++++++++++-----------------------------------------------------
2 files changed, 187 insertions(+), 178 deletions(-)

diff --git a/hostname.go b/hostname.go @@ -0,0 +1,103 @@ +package main + +import ( + "fmt" + "slices" + "strings" +) + +type label string + +func parseLabel(s string) (label, error) { + if len(s) == 0 { + return "", fmt.Errorf("empty label") + } + + buf := make([]byte, 0, len(s)) + + for i, r := range s { + first := i == 0 + last := i == len(s) - 1 + + switch { + case 'A' <= r && r <= 'Z': + r += 'a' - 'A' + case 'a' <= r && r <= 'z': + // Ok + case '0' <= r && r <= '9': + // Ok + case r == '-' && (!first && !last): + // Ok + case r == '-' && first: + return "", fmt.Errorf("hyphen at start of label") + case r == '-' && last: + return "", fmt.Errorf("hyphen at end of label") + default: + return "", fmt.Errorf("illegal rune in label: %q", r) + } + + buf = append(buf, byte(r)) + } + + return label(buf), nil +} + +type hostname []label + +func (hostname hostname) String() string { + // Ughh, Go can't convert between hostname and []string, + // so we can't use strings.Join. + + var sb strings.Builder + for i, label := range hostname { + if i > 0 { + sb.WriteByte('.') + } + sb.WriteString(string(label)) + } + + return sb.String() +} + +func (hostname0 hostname) equal(hostname1 hostname) bool { + return slices.Equal(hostname0, hostname1) +} + +func parseHostname(s string) (hostname, error) { + if len(s) == 0 { + return nil, nil + } + + labelStrs := strings.Split(s, ".") + labels := make([]label, 0, len(labelStrs)) + + for _, labelStr := range labelStrs { + label, err := parseLabel(labelStr) + if err != nil { + return nil, fmt.Errorf("illegal hostname: %w", err) + } + labels = append(labels, label) + } + + return hostname(labels), nil +} + +type pattern []hostname + +func (pat pattern) matches(hostname hostname) bool { + return slices.ContainsFunc(pat, hostname.equal) +} + +func parsePattern(hostnameStrs []string) (pattern, error) { + pat := make(pattern, 0, len(hostnameStrs)) + + for _, hostnameStr := range hostnameStrs { + hostname, err := parseHostname(hostnameStr) + if err != nil { + return nil, err + } + pat = append(pat, hostname) + } + + return pat, nil +} diff --git a/tlsrp.go b/tlsrp.go @@ -13,7 +13,6 @@ import ( "net" "os" "os/signal" - "slices" "strings" "sync" "time" @@ -37,10 +36,18 @@ var softExit chan struct{} var hardExit chan struct{} var lookupSinkChan chan lookupSinkMsg +var lookupCertChan chan lookupCertMsg + type lookupSinkMsg struct { hostname hostname reply chan<- net.Addr } + +type lookupCertMsg struct { + hostname hostname + reply chan<- *tls.Certificate +} + func lookupSink(hostname hostname) (net.Addr, error) { reply := make(chan net.Addr, 1) lookupSinkChan <- lookupSinkMsg{ @@ -54,11 +61,6 @@ func lookupSink(hostname hostname) (net.Addr, error) { return sink, nil } -var lookupCertChan chan lookupCertMsg -type lookupCertMsg struct { - hostname hostname - reply chan<- *tls.Certificate -} func lookupCert(hostname hostname) (*tls.Certificate, error) { reply := make(chan *tls.Certificate, 1) lookupCertChan <- lookupCertMsg{ @@ -72,102 +74,6 @@ func lookupCert(hostname hostname) (*tls.Certificate, error) { return cert, nil } -type label string - -func parseLabel(s string) (label, error) { - if len(s) == 0 { - return "", fmt.Errorf("empty label") - } - - buf := make([]byte, 0, len(s)) - - for i, r := range s { - first := i == 0 - last := i == len(s) - 1 - - switch { - case 'A' <= r && r <= 'Z': - r += 'a' - 'A' - case 'a' <= r && r <= 'z': - // Ok - case '0' <= r && r <= '9': - // Ok - case r == '-' && (!first && !last): - // Ok - case r == '-' && first: - return "", fmt.Errorf("hyphen at start of label") - case r == '-' && last: - return "", fmt.Errorf("hyphen at end of label") - default: - return "", fmt.Errorf("illegal rune in label: %q", r) - } - - buf = append(buf, byte(r)) - } - - return label(buf), nil -} - -type hostname []label - -func (hostname hostname) String() string { - // Ughh, Go can't convert between hostname and []string, - // so we can't use strings.Join. - - var sb strings.Builder - for i, label := range hostname { - if i > 0 { - sb.WriteByte('.') - } - sb.WriteString(string(label)) - } - - return sb.String() -} - -func (hostname0 hostname) equal(hostname1 hostname) bool { - return slices.Equal(hostname0, hostname1) -} - -func parseHostname(s string) (hostname, error) { - if len(s) == 0 { - return nil, nil - } - - labelStrs := strings.Split(s, ".") - labels := make([]label, 0, len(labelStrs)) - - for _, labelStr := range labelStrs { - label, err := parseLabel(labelStr) - if err != nil { - return nil, fmt.Errorf("illegal hostname: %w", err) - } - labels = append(labels, label) - } - - return hostname(labels), nil -} - -type pattern []hostname - -func (pat pattern) matches(hostname hostname) bool { - return slices.ContainsFunc(pat, hostname.equal) -} - -func parsePattern(hostnameStrs []string) (pattern, error) { - pat := make(pattern, 0, len(hostnameStrs)) - - for _, hostnameStr := range hostnameStrs { - hostname, err := parseHostname(hostnameStr) - if err != nil { - return nil, err - } - pat = append(pat, hostname) - } - - return pat, nil -} - type sink struct { pattern pattern network string @@ -273,57 +179,6 @@ func loadConfig(path string) (config, error) { return cfg, nil } -func manageConfig(cfgPath string) { - cfg, err := loadConfig(cfgPath) - if err != nil { - log.Printf("failed to load initial configuration: %s\n", err) - // Proceed with cfg == config{}, i.e., with no sinks and no certs, - // causing every client to be rejected. - } - - sighup := make(chan os.Signal, 1) - signal.Notify(sighup, unix.SIGHUP) - - for { - select { - case <-sighup: - log.Println("received SIGHUP; reloading configuration") - newCfg, err := loadConfig(cfgPath) - if err == nil { - cfg = newCfg - } else { - log.Printf("failed to reload configuration: %s\n", err) - } - - case msg := <-lookupSinkChan: - if msg.hostname != nil { - for _, sink := range cfg.sinks { - if sink.pattern.matches(msg.hostname) { - msg.reply <- sink - break - } - } - } else if len(cfg.sinks) > 0 { - msg.reply <- cfg.sinks[0] - } - close(msg.reply) - - case msg := <-lookupCertChan: - if msg.hostname != nil { - for _, cert := range cfg.certs { - if cert.pattern.matches(msg.hostname) { - msg.reply <- cert.cert - break - } - } - } else if len(cfg.certs) > 0 { - msg.reply <- cfg.certs[0].cert - } - close(msg.reply) - } - } -} - func handshake(conn *tls.Conn) error { ctx, cancel := context.WithTimeout(context.Background(), handshakeTimeout) defer cancel() @@ -450,31 +305,6 @@ func accept(listener net.Listener) { } } -func manageExit() { - sigs := make(chan os.Signal, 3) - signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM) - - softExiting := false - for sig := range sigs { - switch sig { - case unix.SIGINT, unix.SIGQUIT: - if !softExiting { - log.Println("received SIGINT/SIGQUIT; exiting softly") - close(softExit) - softExiting = true - } else { - log.Println("received another SIGINT/SIGQUIT; exiting harshly") - close(hardExit) - return - } - case unix.SIGTERM: - log.Println("received SIGTERM; exiting harshly") - close(hardExit) - return - } - } -} - func listen(sources []string) ([]net.Listener, error) { tlsConfig := &tls.Config{ GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -510,6 +340,82 @@ func listen(sources []string) ([]net.Listener, error) { return listeners, nil } +func manageConfig(cfgPath string) { + cfg, err := loadConfig(cfgPath) + if err != nil { + log.Printf("failed to load initial configuration: %s\n", err) + // Proceed with cfg == config{}, i.e., with no sinks and no certs, + // causing every client to be rejected. + } + + sighup := make(chan os.Signal, 1) + signal.Notify(sighup, unix.SIGHUP) + + for { + select { + case <-sighup: + log.Println("received SIGHUP; reloading configuration") + newCfg, err := loadConfig(cfgPath) + if err == nil { + cfg = newCfg + } else { + log.Printf("failed to reload configuration: %s\n", err) + } + + case msg := <-lookupSinkChan: + if msg.hostname != nil { + for _, sink := range cfg.sinks { + if sink.pattern.matches(msg.hostname) { + msg.reply <- sink + break + } + } + } else if len(cfg.sinks) > 0 { + msg.reply <- cfg.sinks[0] + } + close(msg.reply) + + case msg := <-lookupCertChan: + if msg.hostname != nil { + for _, cert := range cfg.certs { + if cert.pattern.matches(msg.hostname) { + msg.reply <- cert.cert + break + } + } + } else if len(cfg.certs) > 0 { + msg.reply <- cfg.certs[0].cert + } + close(msg.reply) + } + } +} + +func manageExit() { + sigs := make(chan os.Signal, 3) + signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM) + + softExiting := false + for sig := range sigs { + switch sig { + case unix.SIGINT, unix.SIGQUIT: + if !softExiting { + log.Println("received SIGINT/SIGQUIT; exiting softly") + close(softExit) + softExiting = true + } else { + log.Println("received another SIGINT/SIGQUIT; exiting harshly") + close(hardExit) + return + } + case unix.SIGTERM: + log.Println("received SIGTERM; exiting harshly") + close(hardExit) + return + } + } +} + func init() { softExit = make(chan struct{}) hardExit = make(chan struct{})