tlsrp

TLS reverse proxy
git clone git://git.rr3.xyz/tlsrp
Log | Files | Refs | README | LICENSE

commit 5cf5c48fd635559a2f65af47a1ca7c5e201469a0
parent 155ee4b00d9c733d2de6f02e62a7618e9f3f8beb
Author: Robert Russell <robertrussell.72001@gmail.com>
Date:   Tue, 16 Jul 2024 17:22:30 -0700

Clean up

Separate configuration-related stuff into config.go. Now
tlsrp.go has only the main TLS/proxy code.

Also, fix the disturbing "sink" and "cert" naming issues by
having the lookup stuff work with the sink and cert structs
rather than net.Addr and *tls.Certificate.

Diffstat:
Aconfig.go | 160+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mhostname.go | 5+++--
Mtlsrp.go | 213+++++++++++++++++--------------------------------------------------------------
3 files changed, 207 insertions(+), 171 deletions(-)

diff --git a/config.go b/config.go @@ -0,0 +1,160 @@ +package main + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "os" + "strings" +) + +var lookupSinkChan chan lookupSinkMsg +var lookupCertChan chan lookupCertMsg + +func init() { + lookupSinkChan = make(chan lookupSinkMsg, 16) + lookupCertChan = make(chan lookupCertMsg, 16) +} + +type sink struct { + pattern pattern + network string + address string +} + +type cert struct { + pattern pattern + crtPath string + keyPath string + cert *tls.Certificate +} + +type lookupSinkMsg struct { + hostname hostname + reply chan<- *sink +} + +type lookupCertMsg struct { + hostname hostname + reply chan<- *cert +} + +func lookupSink(hostname hostname) (*sink, error) { + reply := make(chan *sink, 1) + lookupSinkChan <- lookupSinkMsg{ + hostname: hostname, + reply: reply, + } + + sink, ok := <-reply + if !ok { + return nil, fmt.Errorf("no sink for hostname %s", hostname) + } + + return sink, nil +} + +func lookupCert(hostname hostname) (*cert, error) { + reply := make(chan *cert, 1) + lookupCertChan <- lookupCertMsg{ + hostname: hostname, + reply: reply, + } + + cert, ok := <-reply + if !ok { + return nil, fmt.Errorf("no cert for hostname %s", hostname) + } + + return cert, nil +} + +type config struct { + sinks []*sink + certs []*cert +} + +func loadTLSCert(crtPath, keyPath string) (*tls.Certificate, error) { + tlsCert, err := tls.LoadX509KeyPair(crtPath, keyPath) + if err != nil { + return nil, err + } + + // Parsing the leaf certificate in advance is recommended by crypto/tls to + // "reduce per-handshake processing". + tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0]) + if err != nil { + return nil, err + } + + return &tlsCert, nil +} + +func loadConfig(path string) (*config, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + data, err := io.ReadAll(file) + if err != nil { + return nil, err + } + + var cfg config + + lines := strings.Split(string(data), "\n") + for _, line := range lines { + fields := strings.Fields(line) + if len(fields) == 0 { + continue // Empty line + } + if len(fields) < 3 { + return nil, fmt.Errorf("illegal config: line with fewer than 3 fields") + } + + switch fields[0] { + case "sink": + pattern, err := parsePattern(fields[3:]) + if err != nil { + return nil, err + } + + sink := sink{ + pattern: pattern, + network: fields[1], + address: fields[2], + } + cfg.sinks = append(cfg.sinks, &sink) + + case "cert": + crtPath := fields[1] + keyPath := fields[2] + + tlsCert, err := loadTLSCert(crtPath, keyPath) + if err != nil { + return nil, err + } + + pattern, err := parsePattern(fields[3:]) + if err != nil { + return nil, err + } + + cert := cert{ + pattern: pattern, + crtPath: crtPath, + keyPath: keyPath, + cert: tlsCert, + } + cfg.certs = append(cfg.certs, &cert) + + default: + return nil, fmt.Errorf("illegal config: expected \"sink\" or \"cert\" as first field") + } + } + + return &cfg, nil +} diff --git a/hostname.go b/hostname.go @@ -8,6 +8,7 @@ import ( // TODO: Add support for more than just alternation in patterns. We should at // least support leading wildcards, but maybe we should support full regex. +// XXX: We currently don't length check hostnames or the labels within. type label string @@ -35,7 +36,7 @@ func parseLabel(labelStr string) (label, error) { } if buf[0] == '-' || buf[len(buf) - 1] == '-' { - return fmt.Errorf("hyphen at start or end of label") + return "", fmt.Errorf("hyphen at start or end of label") } return label(buf), nil @@ -44,7 +45,7 @@ func parseLabel(labelStr string) (label, error) { type hostname []label func (hostname hostname) String() string { - // Ughh, Go can't convert between hostname and []string, + // Ughhhh, Go can't convert between hostname and []string, // so we can't use strings.Join. var sb strings.Builder diff --git a/tlsrp.go b/tlsrp.go @@ -3,7 +3,6 @@ package main import ( "context" "crypto/tls" - "crypto/x509" "errors" "flag" "fmt" @@ -19,9 +18,6 @@ import ( ) // TODO: Scrutinize. In particular, compare with sltls. -// TODO: Add support for more than just alternation in patterns, like -// arbitrary regex. -// XXX: We currently don't length check hostnames or the labels within. // We only enforce a timeout on the handshake. After the handshake is complete, // the sink is responsible for timing-out clients. @@ -35,148 +31,28 @@ const handshakeTimeout = 10 * time.Second var softExit chan struct{} var hardExit chan struct{} -var lookupSinkChan chan lookupSinkMsg -var lookupCertChan chan lookupCertMsg - -type lookupSinkMsg struct { - hostname hostname - reply chan<- net.Addr -} - -type lookupCertMsg struct { - hostname hostname - reply chan<- *tls.Certificate -} - -func lookupSink(hostname hostname) (net.Addr, error) { - reply := make(chan net.Addr, 1) - lookupSinkChan <- lookupSinkMsg{ - hostname: hostname, - reply: reply, - } - sink, ok := <-reply - if !ok { - return nil, fmt.Errorf("no sink for hostname %s", hostname) - } - return sink, nil -} - -func lookupCert(hostname hostname) (*tls.Certificate, error) { - reply := make(chan *tls.Certificate, 1) - lookupCertChan <- lookupCertMsg{ - hostname: hostname, - reply: reply, - } - cert, ok := <-reply - if !ok { - return nil, fmt.Errorf("no certificate for hostname %s", hostname) - } - return cert, nil -} - -type sink struct { - pattern pattern - network string - address string -} - -func (sink sink) Network() string { - return sink.network -} - -func (sink sink) String() string { - return sink.address -} - -type cert struct { - pattern pattern - crtPath string - keyPath string - cert *tls.Certificate -} - -type config struct { - sinks []sink - certs []cert +func init() { + softExit = make(chan struct{}) + hardExit = make(chan struct{}) } -func loadCert(crtPath, keyPath string) (*tls.Certificate, error) { - tlsCert, err := tls.LoadX509KeyPair(crtPath, keyPath) - if err != nil { - return nil, err - } - tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0]) - if err != nil { - return nil, err - } - return &tlsCert, nil +type socket interface { + io.ReadWriteCloser + CloseWrite() error } -func loadConfig(path string) (config, error) { - file, err := os.Open(path) - if err != nil { - return config{}, err - } - defer file.Close() - - data, err := io.ReadAll(file) - if err != nil { - return config{}, err - } - - var cfg config - - lines := strings.Split(string(data), "\n") - for _, line := range lines { - fields := strings.Fields(line) - if len(fields) == 0 { - continue // Empty line - } - if len(fields) < 3 { - return config{}, fmt.Errorf("illegal config: line with fewer than 3 fields") - } - - switch fields[0] { - case "sink": - pat, err := parsePattern(fields[3:]) - if err != nil { - return config{}, err - } - - sink := sink{ - pattern: pat, - network: fields[1], - address: fields[2], - } - cfg.sinks = append(cfg.sinks, sink) +func parseSNI(sni string) (hostname, error) { + var hostname hostname - case "cert": - crtPath := fields[1] - keyPath := fields[2] - tlsCert, err := loadCert(crtPath, keyPath) - if err != nil { - return config{}, err - } - - pat, err := parsePattern(fields[3:]) - if err != nil { - return config{}, err - } - - cert := cert{ - pattern: pat, - crtPath: crtPath, - keyPath: keyPath, - cert: tlsCert, - } - cfg.certs = append(cfg.certs, cert) - - default: - return config{}, fmt.Errorf("illegal config: expected \"sink\" or \"cert\" as first field") + if sni != "" { + var err error + hostname, err = parseHostname(sni) + if err != nil { + return nil, err } } - return cfg, nil + return hostname, nil } func handshake(conn *tls.Conn) error { @@ -195,12 +71,7 @@ func handshake(conn *tls.Conn) error { return err } -type conn interface { - io.ReadWriteCloser - CloseWrite() error -} - -func splice(a, b conn) error { +func splice(a, b socket) error { a2bErr := make(chan error, 1) go func() { _, err := io.Copy(b, a) @@ -250,26 +121,26 @@ func proxy(client *tls.Conn) { return } - hostname, err := parseHostname(client.ConnectionState().ServerName) + hostname, err := parseSNI(client.ConnectionState().ServerName) if err != nil { logf("rejected: %s", err) return } - sinkAddr, err := lookupSink(hostname) + sink, err := lookupSink(hostname) if err != nil { logf("rejected: %s", err) return } - sink, err := net.Dial(sinkAddr.Network(), sinkAddr.String()) + server, err := net.Dial(sink.network, sink.address) if err != nil { logf("%s", err) return } - defer sink.Close() + defer server.Close() - err = splice(client, sink.(conn)) + err = splice(client, server.(socket)) if err != nil { logf("%s", err) return @@ -277,9 +148,6 @@ func proxy(client *tls.Conn) { } func accept(listener net.Listener) { - var wg sync.WaitGroup - defer wg.Wait() - logf := func(format string, a ...any) { log.Printf("source %s: %s\n", listener.Addr(), fmt.Sprintf(format, a...)) } @@ -287,12 +155,18 @@ func accept(listener net.Listener) { logf("accepting") defer logf("closing") + var wg sync.WaitGroup + defer wg.Wait() + for { conn, err := listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { logf("%s", err) os.Exit(1) + // XXX: Exiting here might be an over-reaction to the error. + // Although, keep in mind that tlsrp should be running under + // some service manager that should restart tlsrp if it exits. } return } @@ -312,26 +186,33 @@ func listen(sources []string) ([]net.Listener, error) { if err != nil { return nil, err } - return lookupCert(hostname) + + cert, err := lookupCert(hostname) + if err != nil { + return nil, err + } + + return cert.cert, nil }, } listeners := make([]net.Listener, 0, len(sources)) - for _, s := range sources { - fields := strings.SplitN(s, ":", 2) + for _, source := range sources { + fields := strings.SplitN(source, ":", 2) if len(fields) != 2 { return nil, fmt.Errorf("invalid source: expected colon separating network type from address") } + network := fields[0] address := fields[1] switch network { case "tcp", "unix": - l, err := tls.Listen(network, address, tlsConfig) + listener, err := tls.Listen(network, address, tlsConfig) if err != nil { return nil, err } - listeners = append(listeners, l) + listeners = append(listeners, listener) default: return nil, fmt.Errorf("invalid source: expected network type of \"tcp\" or \"unix\"") } @@ -344,8 +225,9 @@ func manageConfig(cfgPath string) { cfg, err := loadConfig(cfgPath) if err != nil { log.Printf("failed to load initial configuration: %s\n", err) - // Proceed with cfg == config{}, i.e., with no sinks and no certs, - // causing every client to be rejected. + // Proceed with an empty configuration (i.e., with no sinks and no + // certs), causing every client to be rejected. + cfg = &config{} } sighup := make(chan os.Signal, 1) @@ -379,18 +261,19 @@ func manageConfig(cfgPath string) { if msg.hostname != nil { for _, cert := range cfg.certs { if cert.pattern.matches(msg.hostname) { - msg.reply <- cert.cert + msg.reply <- cert break } } } else if len(cfg.certs) > 0 { - msg.reply <- cfg.certs[0].cert + msg.reply <- cfg.certs[0] } close(msg.reply) } } } +// TODO: Simplify this so all signals soft exit first, then hard exit second func manageExit() { sigs := make(chan os.Signal, 3) signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM) @@ -416,14 +299,6 @@ func manageExit() { } } -func init() { - softExit = make(chan struct{}) - hardExit = make(chan struct{}) - - lookupSinkChan = make(chan lookupSinkMsg, 16) - lookupCertChan = make(chan lookupCertMsg, 16) -} - func main() { flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s CONFIG_PATH SOURCE...", os.Args[0])