tlsrp

TLS reverse proxy
git clone git://git.rr3.xyz/tlsrp
Log | Files | Refs | README | LICENSE

commit d257c068fa028e68e51aff4e61eeeb5978b51270
parent 0766dcb9d85f4636416d4c249da3997d5879504f
Author: Robert Russell <robertrussell.72001@gmail.com>
Date:   Sun, 14 Jul 2024 17:05:53 -0700

Make it build

Diffstat:
Mtlsrp.go | 77+++++++++++++++++++++++++++++++++++++++--------------------------------------
1 file changed, 39 insertions(+), 38 deletions(-)

diff --git a/tlsrp.go b/tlsrp.go @@ -10,6 +10,7 @@ import ( "flag" "fmt" "golang.org/x/sys/unix" + "io" "log" "net" "os" @@ -73,7 +74,7 @@ var errHostnameEmpty = errors.New("empty hostname") type hostname string // XXX: We currently don't length check hostnames or the labels within. -func parseHostname(s string) (Hostname, error) { +func parseHostname(s string) (hostname, error) { if len(s) == 0 { return "", errHostnameEmpty } @@ -99,15 +100,15 @@ func parseHostname(s string) (Hostname, error) { // Ok case '0' <= r && r <= '9': // Ok - case '-' && !first: + case r == '-' && !first: // Ok - case '-' && first: + case r == '-' && first: return "", fmt.Errorf("illegal hostname: hyphen at start of label") default: return "", fmt.Errorf("illegal hostname: illegal rune: %q", r) } - buf = append(buf, r) - first := false + buf = append(buf, byte(r)) + first = false } if buf[len(buf)-1] == '-' { @@ -115,7 +116,7 @@ func parseHostname(s string) (Hostname, error) { } } - return string(buf), nil + return hostname(buf), nil } // TODO: Add support for more than just alternation in patterns. @@ -125,11 +126,11 @@ func (pattern pattern) matches(hostname hostname) bool { return slices.Contains(pattern, hostname) } -func parsePattern(ss []string) (pattern, error) { +func parsePattern(ss [][]byte) (pattern, error) { pat := make(pattern, 0, len(ss)) for _, s := range ss { - hostname, err := parseHostname(s) + hostname, err := parseHostname(string(s)) if err != nil { return nil, err } @@ -143,7 +144,7 @@ func parsePattern(ss []string) (pattern, error) { // These are really more like "sink specifications" and "cert specifications". type sink struct { - pattern Pattern + pattern pattern network string address string } @@ -157,15 +158,15 @@ func (sink sink) String() string { } type cert struct { - pattern Pattern + pattern pattern crtPath string keyPath string cert *tls.Certificate } type config struct { - sinks []Sink - certs []Cert + sinks []sink + certs []cert } func loadConfig(path string) (config, error) { @@ -182,7 +183,7 @@ func loadConfig(path string) (config, error) { var cfg config - lines := bytes.Split(data, '\n') + lines := bytes.Split(data, []byte{'\n'}) for _, line := range lines { if len(line) == 0 { continue @@ -190,10 +191,10 @@ func loadConfig(path string) (config, error) { fields := bytes.Fields(line) if len(fields) < 3 { - return &config{}, fmt.Errorf("illegal config: line with fewer than 3 fields") + return config{}, fmt.Errorf("illegal config: line with fewer than 3 fields") } - switch fields[0] { + switch string(fields[0]) { case "sink": pat, err := parsePattern(fields[3:]) if err != nil { @@ -202,14 +203,14 @@ func loadConfig(path string) (config, error) { sink := sink{ pattern: pat, - network: fields[1], - address: fields[2], + network: string(fields[1]), + address: string(fields[2]), } - config.sinks = append(config.sinks, sink) + cfg.sinks = append(cfg.sinks, sink) case "cert": - crtPath := fields[1] - keyPath := fields[2] + crtPath := string(fields[1]) + keyPath := string(fields[2]) c, err := tls.LoadX509KeyPair(crtPath, keyPath) if err != nil { return config{}, err @@ -222,11 +223,11 @@ func loadConfig(path string) (config, error) { cert := cert{ pattern: pat, - crtPath: fields[1], - keyPath: fields[2], - cert: c, + crtPath: crtPath, + keyPath: keyPath, + cert: &c, } - config.certs = append(config.certs, cert) + cfg.certs = append(cfg.certs, cert) default: return config{}, fmt.Errorf("illegal config: expected \"sink\" or \"cert\" as first field") @@ -257,7 +258,7 @@ func manageConfig(configPath string) { for i := range certs { c, err := tls.LoadX509KeyPair(certs[i].crtPath, certs[i].keyPath) if err == nil { - certs[i].cert = c + certs[i].cert = &c } else { log.Printf("failed to reload certificates: %s\n", err) } @@ -303,18 +304,16 @@ func manageConfig(configPath string) { } func handshake(conn *tls.Conn) error { + ctx, cancel := context.WithTimeout(context.Background(), handshakeTimeout) + defer cancel() + handshakeErr := make(chan error, 1) - go func() { - ctx := context.WithTimeout(context.Background(), handshakeTimeout) - handshakeErr <- conn.HandshakeContext(ctx) - }() + go func() { handshakeErr <- conn.HandshakeContext(ctx) }() var err error select { case err = <-handshakeErr: case <-hardExit: - conn.Close() - <-handshakeErr } return err @@ -389,7 +388,7 @@ func proxy(client *tls.Conn) { } defer sink.Close() - err = splice(client, sink) + err = splice(client, sink.(*tls.Conn)) if err != nil { logf("splice error: %s", err) return @@ -429,7 +428,8 @@ func manageExit() { log.Println("received SIGINT; exiting softly") close(softExit) <-hardSigs - fallthrough + log.Println("received SIGQUIT/SIGTERM; exiting harshly") + close(hardExit) case <-hardSigs: log.Println("received SIGQUIT/SIGTERM; exiting harshly") close(hardExit) @@ -444,7 +444,7 @@ func listen(sources []string) ([]net.Listener, error) { return nil, err } return lookupCert(hostname) - } + }, } listeners := make([]net.Listener, 0, len(sources)) @@ -486,20 +486,21 @@ client using the Server Name Indication TLS extension (RFC 3546). func init() { softExit = make(chan struct{}) hardExit = make(chan struct{}) - lookupSink = make(chan lookupSinkMsg, 16) - lookupCert = make(chan lookupCertMsg, 16) + + lookupSinkChan = make(chan lookupSinkMsg, 16) + lookupCertChan = make(chan lookupCertMsg, 16) } func main() { flag.Usage = usage flag.Parse() - if flag.NArgs() < 2 { + if flag.NArg() < 2 { log.Fatalln("expected 2 or more arguments") } configPath := flag.Args()[0] - listeners, err := listen(flags.Args()[1:]) + listeners, err := listen(flag.Args()[1:]) if err != nil { log.Fatalln(err.Error()) }