commit 0766dcb9d85f4636416d4c249da3997d5879504f
Author: Robert Russell <robertrussell.72001@gmail.com>
Date: Sun, 14 Jul 2024 16:54:56 -0700
Initial commit
Diffstat:
| A | LICENSE | | | 16 | ++++++++++++++++ |
| A | tlsrp.go | | | 528 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
2 files changed, 544 insertions(+), 0 deletions(-)
diff --git a/LICENSE b/LICENSE
@@ -0,0 +1,15 @@
+ISC License
+
+Copyright (c) 2024, Robert Russell
+
+Permission to use, copy, modify, and/or distribute this software for any
+purpose with or without fee is hereby granted, provided that the above
+copyright notice and this permission notice appear in all copies.
+
+THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+\ No newline at end of file
diff --git a/tlsrp.go b/tlsrp.go
@@ -0,0 +1,528 @@
+package main
+
+// TODO: check log.Printf newlines
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "errors"
+ "flag"
+ "fmt"
+ "golang.org/x/sys/unix"
+ "log"
+ "net"
+ "os"
+ "os/signal"
+ "slices"
+ "strings"
+ "sync"
+ "time"
+)
+
+// We only enforce a timeout on the handshake. After the handshake is complete,
+// the sink is responsible for timing-out clients.
+const handshakeTimeout = 10 * time.Second
+
+// softExit or hardExit is closed when exiting. During a soft exit, we
+// accept no new clients, but existing clients should finish gracefully;
+// during a hard exit, we accept no new clients, and exiting clients
+// should be forcefully disconnected. Since we have two different exit
+// modes, we don't use context.Context's to handle cancellation.
+var softExit chan struct{}
+var hardExit chan struct{}
+
+var lookupSinkChan chan lookupSinkMsg
+type lookupSinkMsg struct {
+ hostname hostname
+ reply chan<- net.Addr
+}
+func lookupSink(hostname hostname) (net.Addr, error) {
+ reply := make(chan net.Addr, 1)
+ lookupSinkChan <- lookupSinkMsg{
+ hostname: hostname,
+ reply: reply,
+ }
+ sink, ok := <-reply
+ if !ok {
+ return nil, fmt.Errorf("no sink for hostname %q", hostname)
+ }
+ return sink, nil
+}
+
+var lookupCertChan chan lookupCertMsg
+type lookupCertMsg struct {
+ hostname hostname
+ reply chan<- *tls.Certificate
+}
+func lookupCert(hostname hostname) (*tls.Certificate, error) {
+ reply := make(chan *tls.Certificate, 1)
+ lookupCertChan <- lookupCertMsg{
+ hostname: hostname,
+ reply: reply,
+ }
+ cert, ok := <-reply
+ if !ok {
+ return nil, fmt.Errorf("no certificate for hostname %q", hostname)
+ }
+ return cert, nil
+}
+
+var errHostnameEmpty = errors.New("empty hostname")
+
+type hostname string
+
+// XXX: We currently don't length check hostnames or the labels within.
+func parseHostname(s string) (Hostname, error) {
+ if len(s) == 0 {
+ return "", errHostnameEmpty
+ }
+
+ buf := make([]byte, 0, len(s))
+
+ labels := strings.Split(s, ".")
+ for _, label := range labels {
+ if len(label) == 0 {
+ return "", fmt.Errorf("illegal hostname: empty label")
+ }
+
+ if len(buf) > 0 {
+ buf = append(buf, '.')
+ }
+
+ first := true
+ for _, r := range label {
+ switch {
+ case 'A' <= r && r <= 'Z':
+ r += 'a' - 'A'
+ case 'a' <= r && r <= 'z':
+ // Ok
+ case '0' <= r && r <= '9':
+ // Ok
+ case '-' && !first:
+ // Ok
+ case '-' && first:
+ return "", fmt.Errorf("illegal hostname: hyphen at start of label")
+ default:
+ return "", fmt.Errorf("illegal hostname: illegal rune: %q", r)
+ }
+ buf = append(buf, r)
+ first := false
+ }
+
+ if buf[len(buf)-1] == '-' {
+ return "", fmt.Errorf("illegal hostname: hyphen at end of label")
+ }
+ }
+
+ return string(buf), nil
+}
+
+// TODO: Add support for more than just alternation in patterns.
+type pattern []hostname
+
+func (pattern pattern) matches(hostname hostname) bool {
+ return slices.Contains(pattern, hostname)
+}
+
+func parsePattern(ss []string) (pattern, error) {
+ pat := make(pattern, 0, len(ss))
+
+ for _, s := range ss {
+ hostname, err := parseHostname(s)
+ if err != nil {
+ return nil, err
+ }
+ pat = append(pat, hostname)
+ }
+
+ return pat, nil
+}
+
+// TODO: Come up with better names for the types "sink" and "cert".
+// These are really more like "sink specifications" and "cert specifications".
+
+type sink struct {
+ pattern Pattern
+ network string
+ address string
+}
+
+func (sink sink) Network() string {
+ return sink.network
+}
+
+func (sink sink) String() string {
+ return sink.address
+}
+
+type cert struct {
+ pattern Pattern
+ crtPath string
+ keyPath string
+ cert *tls.Certificate
+}
+
+type config struct {
+ sinks []Sink
+ certs []Cert
+}
+
+func loadConfig(path string) (config, error) {
+ file, err := os.Open(path)
+ if err != nil {
+ return config{}, err
+ }
+ defer file.Close()
+
+ data, err := io.ReadAll(file)
+ if err != nil {
+ return config{}, err
+ }
+
+ var cfg config
+
+ lines := bytes.Split(data, '\n')
+ for _, line := range lines {
+ if len(line) == 0 {
+ continue
+ }
+
+ fields := bytes.Fields(line)
+ if len(fields) < 3 {
+ return &config{}, fmt.Errorf("illegal config: line with fewer than 3 fields")
+ }
+
+ switch fields[0] {
+ case "sink":
+ pat, err := parsePattern(fields[3:])
+ if err != nil {
+ return config{}, err
+ }
+
+ sink := sink{
+ pattern: pat,
+ network: fields[1],
+ address: fields[2],
+ }
+ config.sinks = append(config.sinks, sink)
+
+ case "cert":
+ crtPath := fields[1]
+ keyPath := fields[2]
+ c, err := tls.LoadX509KeyPair(crtPath, keyPath)
+ if err != nil {
+ return config{}, err
+ }
+
+ pat, err := parsePattern(fields[3:])
+ if err != nil {
+ return config{}, err
+ }
+
+ cert := cert{
+ pattern: pat,
+ crtPath: fields[1],
+ keyPath: fields[2],
+ cert: c,
+ }
+ config.certs = append(config.certs, cert)
+
+ default:
+ return config{}, fmt.Errorf("illegal config: expected \"sink\" or \"cert\" as first field")
+ }
+ }
+
+ return cfg, nil
+}
+
+func manageConfig(configPath string) {
+ cfg, err := loadConfig(configPath)
+ if err != nil {
+ log.Printf("failed to load initial configuration: %s\n", err)
+ }
+ // Proceed with cfg == config{}, i.e., with no sinks and no certs, causing
+ // every client to be rejected.
+
+ sigusr := make(chan os.Signal, 2)
+ signal.Notify(sigusr, unix.SIGUSR1, unix.SIGUSR2)
+
+ for {
+ select {
+ case sig := <-sigusr:
+ switch sig {
+ case unix.SIGUSR1:
+ log.Println("received SIGUSR1; reloading certificates")
+ certs := cfg.certs
+ for i := range certs {
+ c, err := tls.LoadX509KeyPair(certs[i].crtPath, certs[i].keyPath)
+ if err == nil {
+ certs[i].cert = c
+ } else {
+ log.Printf("failed to reload certificates: %s\n", err)
+ }
+ }
+
+ case unix.SIGUSR2:
+ log.Println("received SIGUSR2; reloading configuration")
+ newCfg, err := loadConfig(configPath)
+ if err == nil {
+ cfg = newCfg
+ } else {
+ log.Printf("failed to reload configuration: %s\n", err)
+ }
+ }
+
+ case msg := <-lookupSinkChan:
+ if msg.hostname != "" {
+ for _, sink := range cfg.sinks {
+ if sink.pattern.matches(msg.hostname) {
+ msg.reply <- sink
+ break
+ }
+ }
+ } else if len(cfg.sinks) > 0 {
+ msg.reply <- cfg.sinks[0]
+ }
+ close(msg.reply)
+
+ case msg := <-lookupCertChan:
+ if msg.hostname != "" {
+ for _, cert := range cfg.certs {
+ if cert.pattern.matches(msg.hostname) {
+ msg.reply <- cert.cert
+ break
+ }
+ }
+ } else if len(cfg.certs) > 0 {
+ msg.reply <- cfg.certs[0].cert
+ }
+ close(msg.reply)
+ }
+ }
+}
+
+func handshake(conn *tls.Conn) error {
+ handshakeErr := make(chan error, 1)
+ go func() {
+ ctx := context.WithTimeout(context.Background(), handshakeTimeout)
+ handshakeErr <- conn.HandshakeContext(ctx)
+ }()
+
+ var err error
+ select {
+ case err = <-handshakeErr:
+ case <-hardExit:
+ conn.Close()
+ <-handshakeErr
+ }
+
+ return err
+}
+
+func splice(a, b *tls.Conn) error {
+ a2bErr := make(chan error, 1)
+ go func() {
+ _, err := io.Copy(b, a)
+ a2bErr <- err
+ }()
+
+ b2aErr := make(chan error, 1)
+ go func() {
+ _, err := io.Copy(a, b)
+ b2aErr <- err
+ }()
+
+ // In the first two cases, we call CloseWrite (not Close) on the
+ // corresponding destination connection, so that the other copy goroutine
+ // can continue reading from it (until it hits EOF). In the hard exit
+ // case, we want to exit ASAP, so we close both ends of both connections.
+ var err error
+ select {
+ case err = <-a2bErr:
+ b.CloseWrite()
+ <-b2aErr
+ case err = <-b2aErr:
+ a.CloseWrite()
+ <-a2bErr
+ case <-hardExit:
+ a.Close()
+ b.Close()
+ <-a2bErr
+ <-b2aErr
+ }
+
+ return err
+}
+
+func proxy(client *tls.Conn) {
+ logf := func(format string, a ...interface{}) {
+ log.Printf("client %s: %s\n", client.RemoteAddr(), fmt.Sprintf(format, a...))
+ }
+
+ logf("connected")
+ defer logf("disconnected")
+ defer client.Close()
+
+ err := handshake(client)
+ if err != nil {
+ logf("handshake error: %s", err)
+ return
+ }
+
+ hostname, err := parseHostname(client.ConnectionState().ServerName)
+ if err != nil && !errors.Is(err, errHostnameEmpty) {
+ logf("rejected: %s", err)
+ return
+ }
+
+ sinkAddr, err := lookupSink(hostname)
+ if err != nil {
+ logf("rejected: %s", err)
+ return
+ }
+
+ sink, err := net.Dial(sinkAddr.Network(), sinkAddr.String())
+ if err != nil {
+ logf("dial error: %s", err)
+ return
+ }
+ defer sink.Close()
+
+ err = splice(client, sink)
+ if err != nil {
+ logf("splice error: %s", err)
+ return
+ }
+}
+
+func accept(listener net.Listener) {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ for {
+ conn, err := listener.Accept()
+ if err != nil {
+ if !errors.Is(err, net.ErrClosed) {
+ log.Fatalf("source %s error: %s\n", listener.Addr(), err)
+ }
+ return
+ }
+
+ wg.Add(1)
+ go func() {
+ proxy(conn.(*tls.Conn))
+ wg.Done()
+ }()
+ }
+}
+
+func manageExit() {
+ softSigs := make(chan os.Signal, 1)
+ signal.Notify(softSigs, unix.SIGINT)
+
+ hardSigs := make(chan os.Signal, 1)
+ signal.Notify(hardSigs, unix.SIGQUIT, unix.SIGTERM)
+
+ select {
+ case <-softSigs:
+ log.Println("received SIGINT; exiting softly")
+ close(softExit)
+ <-hardSigs
+ fallthrough
+ case <-hardSigs:
+ log.Println("received SIGQUIT/SIGTERM; exiting harshly")
+ close(hardExit)
+ }
+}
+
+func listen(sources []string) ([]net.Listener, error) {
+ tlsConfig := &tls.Config{
+ GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ hostname, err := parseHostname(chi.ServerName)
+ if err != nil && !errors.Is(err, errHostnameEmpty) {
+ return nil, err
+ }
+ return lookupCert(hostname)
+ }
+ }
+
+ listeners := make([]net.Listener, 0, len(sources))
+ for _, s := range sources {
+ fields := strings.SplitN(s, ":", 2)
+ if len(fields) != 2 {
+ return nil, fmt.Errorf("invalid source: expected colon separating network type from address")
+ }
+ network := fields[0]
+ address := fields[1]
+
+ switch network {
+ case "tcp", "unix":
+ l, err := tls.Listen(network, address, tlsConfig)
+ if err != nil {
+ return nil, err
+ }
+ listeners = append(listeners, l)
+ default:
+ return nil, fmt.Errorf("invalid source: expected network type of \"tcp\" or \"unix\"")
+ }
+ }
+
+ return listeners, nil
+}
+
+func usage() {
+ format := `Usage: %s CONFIG_PATH SOURCE0 SOURCE1 ...
+ SOURCEi = tcp:HOST:PORT | unix:PATH
+tlsrp is a TLS reverse proxy. tlsrp accepts TLS-secured connections on one or
+more source sockets and tunnels the decrypted bytes to one of many specified
+sink sockets. The sink socket is chosen based on the hostname specified by the
+client using the Server Name Indication TLS extension (RFC 3546).
+`
+ // TODO: more doc
+ fmt.Fprintf(os.Stderr, format, os.Args[0])
+}
+
+func init() {
+ softExit = make(chan struct{})
+ hardExit = make(chan struct{})
+ lookupSink = make(chan lookupSinkMsg, 16)
+ lookupCert = make(chan lookupCertMsg, 16)
+}
+
+func main() {
+ flag.Usage = usage
+ flag.Parse()
+ if flag.NArgs() < 2 {
+ log.Fatalln("expected 2 or more arguments")
+ }
+
+ configPath := flag.Args()[0]
+
+ listeners, err := listen(flags.Args()[1:])
+ if err != nil {
+ log.Fatalln(err.Error())
+ }
+
+ go manageExit()
+ go manageConfig(configPath)
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ for _, l := range listeners {
+ wg.Add(1)
+ go func (l net.Listener) {
+ accept(l)
+ wg.Done()
+ }(l)
+ }
+
+ select {
+ case <-softExit:
+ case <-hardExit:
+ }
+ for _, l := range listeners {
+ l.Close()
+ }
+}