tlsrp

TLS reverse proxy
git clone git://git.rr3.xyz/tlsrp
Log | Files | Refs | README | LICENSE

commit 0766dcb9d85f4636416d4c249da3997d5879504f
Author: Robert Russell <robertrussell.72001@gmail.com>
Date:   Sun, 14 Jul 2024 16:54:56 -0700

Initial commit

Diffstat:
ALICENSE | 16++++++++++++++++
Atlsrp.go | 528+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 544 insertions(+), 0 deletions(-)

diff --git a/LICENSE b/LICENSE @@ -0,0 +1,15 @@ +ISC License + +Copyright (c) 2024, Robert Russell + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +\ No newline at end of file diff --git a/tlsrp.go b/tlsrp.go @@ -0,0 +1,528 @@ +package main + +// TODO: check log.Printf newlines + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "flag" + "fmt" + "golang.org/x/sys/unix" + "log" + "net" + "os" + "os/signal" + "slices" + "strings" + "sync" + "time" +) + +// We only enforce a timeout on the handshake. After the handshake is complete, +// the sink is responsible for timing-out clients. +const handshakeTimeout = 10 * time.Second + +// softExit or hardExit is closed when exiting. During a soft exit, we +// accept no new clients, but existing clients should finish gracefully; +// during a hard exit, we accept no new clients, and exiting clients +// should be forcefully disconnected. Since we have two different exit +// modes, we don't use context.Context's to handle cancellation. +var softExit chan struct{} +var hardExit chan struct{} + +var lookupSinkChan chan lookupSinkMsg +type lookupSinkMsg struct { + hostname hostname + reply chan<- net.Addr +} +func lookupSink(hostname hostname) (net.Addr, error) { + reply := make(chan net.Addr, 1) + lookupSinkChan <- lookupSinkMsg{ + hostname: hostname, + reply: reply, + } + sink, ok := <-reply + if !ok { + return nil, fmt.Errorf("no sink for hostname %q", hostname) + } + return sink, nil +} + +var lookupCertChan chan lookupCertMsg +type lookupCertMsg struct { + hostname hostname + reply chan<- *tls.Certificate +} +func lookupCert(hostname hostname) (*tls.Certificate, error) { + reply := make(chan *tls.Certificate, 1) + lookupCertChan <- lookupCertMsg{ + hostname: hostname, + reply: reply, + } + cert, ok := <-reply + if !ok { + return nil, fmt.Errorf("no certificate for hostname %q", hostname) + } + return cert, nil +} + +var errHostnameEmpty = errors.New("empty hostname") + +type hostname string + +// XXX: We currently don't length check hostnames or the labels within. +func parseHostname(s string) (Hostname, error) { + if len(s) == 0 { + return "", errHostnameEmpty + } + + buf := make([]byte, 0, len(s)) + + labels := strings.Split(s, ".") + for _, label := range labels { + if len(label) == 0 { + return "", fmt.Errorf("illegal hostname: empty label") + } + + if len(buf) > 0 { + buf = append(buf, '.') + } + + first := true + for _, r := range label { + switch { + case 'A' <= r && r <= 'Z': + r += 'a' - 'A' + case 'a' <= r && r <= 'z': + // Ok + case '0' <= r && r <= '9': + // Ok + case '-' && !first: + // Ok + case '-' && first: + return "", fmt.Errorf("illegal hostname: hyphen at start of label") + default: + return "", fmt.Errorf("illegal hostname: illegal rune: %q", r) + } + buf = append(buf, r) + first := false + } + + if buf[len(buf)-1] == '-' { + return "", fmt.Errorf("illegal hostname: hyphen at end of label") + } + } + + return string(buf), nil +} + +// TODO: Add support for more than just alternation in patterns. +type pattern []hostname + +func (pattern pattern) matches(hostname hostname) bool { + return slices.Contains(pattern, hostname) +} + +func parsePattern(ss []string) (pattern, error) { + pat := make(pattern, 0, len(ss)) + + for _, s := range ss { + hostname, err := parseHostname(s) + if err != nil { + return nil, err + } + pat = append(pat, hostname) + } + + return pat, nil +} + +// TODO: Come up with better names for the types "sink" and "cert". +// These are really more like "sink specifications" and "cert specifications". + +type sink struct { + pattern Pattern + network string + address string +} + +func (sink sink) Network() string { + return sink.network +} + +func (sink sink) String() string { + return sink.address +} + +type cert struct { + pattern Pattern + crtPath string + keyPath string + cert *tls.Certificate +} + +type config struct { + sinks []Sink + certs []Cert +} + +func loadConfig(path string) (config, error) { + file, err := os.Open(path) + if err != nil { + return config{}, err + } + defer file.Close() + + data, err := io.ReadAll(file) + if err != nil { + return config{}, err + } + + var cfg config + + lines := bytes.Split(data, '\n') + for _, line := range lines { + if len(line) == 0 { + continue + } + + fields := bytes.Fields(line) + if len(fields) < 3 { + return &config{}, fmt.Errorf("illegal config: line with fewer than 3 fields") + } + + switch fields[0] { + case "sink": + pat, err := parsePattern(fields[3:]) + if err != nil { + return config{}, err + } + + sink := sink{ + pattern: pat, + network: fields[1], + address: fields[2], + } + config.sinks = append(config.sinks, sink) + + case "cert": + crtPath := fields[1] + keyPath := fields[2] + c, err := tls.LoadX509KeyPair(crtPath, keyPath) + if err != nil { + return config{}, err + } + + pat, err := parsePattern(fields[3:]) + if err != nil { + return config{}, err + } + + cert := cert{ + pattern: pat, + crtPath: fields[1], + keyPath: fields[2], + cert: c, + } + config.certs = append(config.certs, cert) + + default: + return config{}, fmt.Errorf("illegal config: expected \"sink\" or \"cert\" as first field") + } + } + + return cfg, nil +} + +func manageConfig(configPath string) { + cfg, err := loadConfig(configPath) + if err != nil { + log.Printf("failed to load initial configuration: %s\n", err) + } + // Proceed with cfg == config{}, i.e., with no sinks and no certs, causing + // every client to be rejected. + + sigusr := make(chan os.Signal, 2) + signal.Notify(sigusr, unix.SIGUSR1, unix.SIGUSR2) + + for { + select { + case sig := <-sigusr: + switch sig { + case unix.SIGUSR1: + log.Println("received SIGUSR1; reloading certificates") + certs := cfg.certs + for i := range certs { + c, err := tls.LoadX509KeyPair(certs[i].crtPath, certs[i].keyPath) + if err == nil { + certs[i].cert = c + } else { + log.Printf("failed to reload certificates: %s\n", err) + } + } + + case unix.SIGUSR2: + log.Println("received SIGUSR2; reloading configuration") + newCfg, err := loadConfig(configPath) + if err == nil { + cfg = newCfg + } else { + log.Printf("failed to reload configuration: %s\n", err) + } + } + + case msg := <-lookupSinkChan: + if msg.hostname != "" { + for _, sink := range cfg.sinks { + if sink.pattern.matches(msg.hostname) { + msg.reply <- sink + break + } + } + } else if len(cfg.sinks) > 0 { + msg.reply <- cfg.sinks[0] + } + close(msg.reply) + + case msg := <-lookupCertChan: + if msg.hostname != "" { + for _, cert := range cfg.certs { + if cert.pattern.matches(msg.hostname) { + msg.reply <- cert.cert + break + } + } + } else if len(cfg.certs) > 0 { + msg.reply <- cfg.certs[0].cert + } + close(msg.reply) + } + } +} + +func handshake(conn *tls.Conn) error { + handshakeErr := make(chan error, 1) + go func() { + ctx := context.WithTimeout(context.Background(), handshakeTimeout) + handshakeErr <- conn.HandshakeContext(ctx) + }() + + var err error + select { + case err = <-handshakeErr: + case <-hardExit: + conn.Close() + <-handshakeErr + } + + return err +} + +func splice(a, b *tls.Conn) error { + a2bErr := make(chan error, 1) + go func() { + _, err := io.Copy(b, a) + a2bErr <- err + }() + + b2aErr := make(chan error, 1) + go func() { + _, err := io.Copy(a, b) + b2aErr <- err + }() + + // In the first two cases, we call CloseWrite (not Close) on the + // corresponding destination connection, so that the other copy goroutine + // can continue reading from it (until it hits EOF). In the hard exit + // case, we want to exit ASAP, so we close both ends of both connections. + var err error + select { + case err = <-a2bErr: + b.CloseWrite() + <-b2aErr + case err = <-b2aErr: + a.CloseWrite() + <-a2bErr + case <-hardExit: + a.Close() + b.Close() + <-a2bErr + <-b2aErr + } + + return err +} + +func proxy(client *tls.Conn) { + logf := func(format string, a ...interface{}) { + log.Printf("client %s: %s\n", client.RemoteAddr(), fmt.Sprintf(format, a...)) + } + + logf("connected") + defer logf("disconnected") + defer client.Close() + + err := handshake(client) + if err != nil { + logf("handshake error: %s", err) + return + } + + hostname, err := parseHostname(client.ConnectionState().ServerName) + if err != nil && !errors.Is(err, errHostnameEmpty) { + logf("rejected: %s", err) + return + } + + sinkAddr, err := lookupSink(hostname) + if err != nil { + logf("rejected: %s", err) + return + } + + sink, err := net.Dial(sinkAddr.Network(), sinkAddr.String()) + if err != nil { + logf("dial error: %s", err) + return + } + defer sink.Close() + + err = splice(client, sink) + if err != nil { + logf("splice error: %s", err) + return + } +} + +func accept(listener net.Listener) { + var wg sync.WaitGroup + defer wg.Wait() + + for { + conn, err := listener.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + log.Fatalf("source %s error: %s\n", listener.Addr(), err) + } + return + } + + wg.Add(1) + go func() { + proxy(conn.(*tls.Conn)) + wg.Done() + }() + } +} + +func manageExit() { + softSigs := make(chan os.Signal, 1) + signal.Notify(softSigs, unix.SIGINT) + + hardSigs := make(chan os.Signal, 1) + signal.Notify(hardSigs, unix.SIGQUIT, unix.SIGTERM) + + select { + case <-softSigs: + log.Println("received SIGINT; exiting softly") + close(softExit) + <-hardSigs + fallthrough + case <-hardSigs: + log.Println("received SIGQUIT/SIGTERM; exiting harshly") + close(hardExit) + } +} + +func listen(sources []string) ([]net.Listener, error) { + tlsConfig := &tls.Config{ + GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + hostname, err := parseHostname(chi.ServerName) + if err != nil && !errors.Is(err, errHostnameEmpty) { + return nil, err + } + return lookupCert(hostname) + } + } + + listeners := make([]net.Listener, 0, len(sources)) + for _, s := range sources { + fields := strings.SplitN(s, ":", 2) + if len(fields) != 2 { + return nil, fmt.Errorf("invalid source: expected colon separating network type from address") + } + network := fields[0] + address := fields[1] + + switch network { + case "tcp", "unix": + l, err := tls.Listen(network, address, tlsConfig) + if err != nil { + return nil, err + } + listeners = append(listeners, l) + default: + return nil, fmt.Errorf("invalid source: expected network type of \"tcp\" or \"unix\"") + } + } + + return listeners, nil +} + +func usage() { + format := `Usage: %s CONFIG_PATH SOURCE0 SOURCE1 ... + SOURCEi = tcp:HOST:PORT | unix:PATH +tlsrp is a TLS reverse proxy. tlsrp accepts TLS-secured connections on one or +more source sockets and tunnels the decrypted bytes to one of many specified +sink sockets. The sink socket is chosen based on the hostname specified by the +client using the Server Name Indication TLS extension (RFC 3546). +` + // TODO: more doc + fmt.Fprintf(os.Stderr, format, os.Args[0]) +} + +func init() { + softExit = make(chan struct{}) + hardExit = make(chan struct{}) + lookupSink = make(chan lookupSinkMsg, 16) + lookupCert = make(chan lookupCertMsg, 16) +} + +func main() { + flag.Usage = usage + flag.Parse() + if flag.NArgs() < 2 { + log.Fatalln("expected 2 or more arguments") + } + + configPath := flag.Args()[0] + + listeners, err := listen(flags.Args()[1:]) + if err != nil { + log.Fatalln(err.Error()) + } + + go manageExit() + go manageConfig(configPath) + + var wg sync.WaitGroup + defer wg.Wait() + + for _, l := range listeners { + wg.Add(1) + go func (l net.Listener) { + accept(l) + wg.Done() + }(l) + } + + select { + case <-softExit: + case <-hardExit: + } + for _, l := range listeners { + l.Close() + } +}