commit 5cf5c48fd635559a2f65af47a1ca7c5e201469a0
parent 155ee4b00d9c733d2de6f02e62a7618e9f3f8beb
Author: Robert Russell <robertrussell.72001@gmail.com>
Date: Tue, 16 Jul 2024 17:22:30 -0700
Clean up
Separate configuration-related stuff into config.go. Now
tlsrp.go has only the main TLS/proxy code.
Also, fix the disturbing "sink" and "cert" naming issues by
having the lookup stuff work with the sink and cert structs
rather than net.Addr and *tls.Certificate.
Diffstat:
| A | config.go | | | 160 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
| M | hostname.go | | | 5 | +++-- |
| M | tlsrp.go | | | 213 | +++++++++++++++++-------------------------------------------------------------- |
3 files changed, 207 insertions(+), 171 deletions(-)
diff --git a/config.go b/config.go
@@ -0,0 +1,160 @@
+package main
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "io"
+ "os"
+ "strings"
+)
+
+var lookupSinkChan chan lookupSinkMsg
+var lookupCertChan chan lookupCertMsg
+
+func init() {
+ lookupSinkChan = make(chan lookupSinkMsg, 16)
+ lookupCertChan = make(chan lookupCertMsg, 16)
+}
+
+type sink struct {
+ pattern pattern
+ network string
+ address string
+}
+
+type cert struct {
+ pattern pattern
+ crtPath string
+ keyPath string
+ cert *tls.Certificate
+}
+
+type lookupSinkMsg struct {
+ hostname hostname
+ reply chan<- *sink
+}
+
+type lookupCertMsg struct {
+ hostname hostname
+ reply chan<- *cert
+}
+
+func lookupSink(hostname hostname) (*sink, error) {
+ reply := make(chan *sink, 1)
+ lookupSinkChan <- lookupSinkMsg{
+ hostname: hostname,
+ reply: reply,
+ }
+
+ sink, ok := <-reply
+ if !ok {
+ return nil, fmt.Errorf("no sink for hostname %s", hostname)
+ }
+
+ return sink, nil
+}
+
+func lookupCert(hostname hostname) (*cert, error) {
+ reply := make(chan *cert, 1)
+ lookupCertChan <- lookupCertMsg{
+ hostname: hostname,
+ reply: reply,
+ }
+
+ cert, ok := <-reply
+ if !ok {
+ return nil, fmt.Errorf("no cert for hostname %s", hostname)
+ }
+
+ return cert, nil
+}
+
+type config struct {
+ sinks []*sink
+ certs []*cert
+}
+
+func loadTLSCert(crtPath, keyPath string) (*tls.Certificate, error) {
+ tlsCert, err := tls.LoadX509KeyPair(crtPath, keyPath)
+ if err != nil {
+ return nil, err
+ }
+
+ // Parsing the leaf certificate in advance is recommended by crypto/tls to
+ // "reduce per-handshake processing".
+ tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
+ if err != nil {
+ return nil, err
+ }
+
+ return &tlsCert, nil
+}
+
+func loadConfig(path string) (*config, error) {
+ file, err := os.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer file.Close()
+
+ data, err := io.ReadAll(file)
+ if err != nil {
+ return nil, err
+ }
+
+ var cfg config
+
+ lines := strings.Split(string(data), "\n")
+ for _, line := range lines {
+ fields := strings.Fields(line)
+ if len(fields) == 0 {
+ continue // Empty line
+ }
+ if len(fields) < 3 {
+ return nil, fmt.Errorf("illegal config: line with fewer than 3 fields")
+ }
+
+ switch fields[0] {
+ case "sink":
+ pattern, err := parsePattern(fields[3:])
+ if err != nil {
+ return nil, err
+ }
+
+ sink := sink{
+ pattern: pattern,
+ network: fields[1],
+ address: fields[2],
+ }
+ cfg.sinks = append(cfg.sinks, &sink)
+
+ case "cert":
+ crtPath := fields[1]
+ keyPath := fields[2]
+
+ tlsCert, err := loadTLSCert(crtPath, keyPath)
+ if err != nil {
+ return nil, err
+ }
+
+ pattern, err := parsePattern(fields[3:])
+ if err != nil {
+ return nil, err
+ }
+
+ cert := cert{
+ pattern: pattern,
+ crtPath: crtPath,
+ keyPath: keyPath,
+ cert: tlsCert,
+ }
+ cfg.certs = append(cfg.certs, &cert)
+
+ default:
+ return nil, fmt.Errorf("illegal config: expected \"sink\" or \"cert\" as first field")
+ }
+ }
+
+ return &cfg, nil
+}
diff --git a/hostname.go b/hostname.go
@@ -8,6 +8,7 @@ import (
// TODO: Add support for more than just alternation in patterns. We should at
// least support leading wildcards, but maybe we should support full regex.
+// XXX: We currently don't length check hostnames or the labels within.
type label string
@@ -35,7 +36,7 @@ func parseLabel(labelStr string) (label, error) {
}
if buf[0] == '-' || buf[len(buf) - 1] == '-' {
- return fmt.Errorf("hyphen at start or end of label")
+ return "", fmt.Errorf("hyphen at start or end of label")
}
return label(buf), nil
@@ -44,7 +45,7 @@ func parseLabel(labelStr string) (label, error) {
type hostname []label
func (hostname hostname) String() string {
- // Ughh, Go can't convert between hostname and []string,
+ // Ughhhh, Go can't convert between hostname and []string,
// so we can't use strings.Join.
var sb strings.Builder
diff --git a/tlsrp.go b/tlsrp.go
@@ -3,7 +3,6 @@ package main
import (
"context"
"crypto/tls"
- "crypto/x509"
"errors"
"flag"
"fmt"
@@ -19,9 +18,6 @@ import (
)
// TODO: Scrutinize. In particular, compare with sltls.
-// TODO: Add support for more than just alternation in patterns, like
-// arbitrary regex.
-// XXX: We currently don't length check hostnames or the labels within.
// We only enforce a timeout on the handshake. After the handshake is complete,
// the sink is responsible for timing-out clients.
@@ -35,148 +31,28 @@ const handshakeTimeout = 10 * time.Second
var softExit chan struct{}
var hardExit chan struct{}
-var lookupSinkChan chan lookupSinkMsg
-var lookupCertChan chan lookupCertMsg
-
-type lookupSinkMsg struct {
- hostname hostname
- reply chan<- net.Addr
-}
-
-type lookupCertMsg struct {
- hostname hostname
- reply chan<- *tls.Certificate
-}
-
-func lookupSink(hostname hostname) (net.Addr, error) {
- reply := make(chan net.Addr, 1)
- lookupSinkChan <- lookupSinkMsg{
- hostname: hostname,
- reply: reply,
- }
- sink, ok := <-reply
- if !ok {
- return nil, fmt.Errorf("no sink for hostname %s", hostname)
- }
- return sink, nil
-}
-
-func lookupCert(hostname hostname) (*tls.Certificate, error) {
- reply := make(chan *tls.Certificate, 1)
- lookupCertChan <- lookupCertMsg{
- hostname: hostname,
- reply: reply,
- }
- cert, ok := <-reply
- if !ok {
- return nil, fmt.Errorf("no certificate for hostname %s", hostname)
- }
- return cert, nil
-}
-
-type sink struct {
- pattern pattern
- network string
- address string
-}
-
-func (sink sink) Network() string {
- return sink.network
-}
-
-func (sink sink) String() string {
- return sink.address
-}
-
-type cert struct {
- pattern pattern
- crtPath string
- keyPath string
- cert *tls.Certificate
-}
-
-type config struct {
- sinks []sink
- certs []cert
+func init() {
+ softExit = make(chan struct{})
+ hardExit = make(chan struct{})
}
-func loadCert(crtPath, keyPath string) (*tls.Certificate, error) {
- tlsCert, err := tls.LoadX509KeyPair(crtPath, keyPath)
- if err != nil {
- return nil, err
- }
- tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
- if err != nil {
- return nil, err
- }
- return &tlsCert, nil
+type socket interface {
+ io.ReadWriteCloser
+ CloseWrite() error
}
-func loadConfig(path string) (config, error) {
- file, err := os.Open(path)
- if err != nil {
- return config{}, err
- }
- defer file.Close()
-
- data, err := io.ReadAll(file)
- if err != nil {
- return config{}, err
- }
-
- var cfg config
-
- lines := strings.Split(string(data), "\n")
- for _, line := range lines {
- fields := strings.Fields(line)
- if len(fields) == 0 {
- continue // Empty line
- }
- if len(fields) < 3 {
- return config{}, fmt.Errorf("illegal config: line with fewer than 3 fields")
- }
-
- switch fields[0] {
- case "sink":
- pat, err := parsePattern(fields[3:])
- if err != nil {
- return config{}, err
- }
-
- sink := sink{
- pattern: pat,
- network: fields[1],
- address: fields[2],
- }
- cfg.sinks = append(cfg.sinks, sink)
+func parseSNI(sni string) (hostname, error) {
+ var hostname hostname
- case "cert":
- crtPath := fields[1]
- keyPath := fields[2]
- tlsCert, err := loadCert(crtPath, keyPath)
- if err != nil {
- return config{}, err
- }
-
- pat, err := parsePattern(fields[3:])
- if err != nil {
- return config{}, err
- }
-
- cert := cert{
- pattern: pat,
- crtPath: crtPath,
- keyPath: keyPath,
- cert: tlsCert,
- }
- cfg.certs = append(cfg.certs, cert)
-
- default:
- return config{}, fmt.Errorf("illegal config: expected \"sink\" or \"cert\" as first field")
+ if sni != "" {
+ var err error
+ hostname, err = parseHostname(sni)
+ if err != nil {
+ return nil, err
}
}
- return cfg, nil
+ return hostname, nil
}
func handshake(conn *tls.Conn) error {
@@ -195,12 +71,7 @@ func handshake(conn *tls.Conn) error {
return err
}
-type conn interface {
- io.ReadWriteCloser
- CloseWrite() error
-}
-
-func splice(a, b conn) error {
+func splice(a, b socket) error {
a2bErr := make(chan error, 1)
go func() {
_, err := io.Copy(b, a)
@@ -250,26 +121,26 @@ func proxy(client *tls.Conn) {
return
}
- hostname, err := parseHostname(client.ConnectionState().ServerName)
+ hostname, err := parseSNI(client.ConnectionState().ServerName)
if err != nil {
logf("rejected: %s", err)
return
}
- sinkAddr, err := lookupSink(hostname)
+ sink, err := lookupSink(hostname)
if err != nil {
logf("rejected: %s", err)
return
}
- sink, err := net.Dial(sinkAddr.Network(), sinkAddr.String())
+ server, err := net.Dial(sink.network, sink.address)
if err != nil {
logf("%s", err)
return
}
- defer sink.Close()
+ defer server.Close()
- err = splice(client, sink.(conn))
+ err = splice(client, server.(socket))
if err != nil {
logf("%s", err)
return
@@ -277,9 +148,6 @@ func proxy(client *tls.Conn) {
}
func accept(listener net.Listener) {
- var wg sync.WaitGroup
- defer wg.Wait()
-
logf := func(format string, a ...any) {
log.Printf("source %s: %s\n", listener.Addr(), fmt.Sprintf(format, a...))
}
@@ -287,12 +155,18 @@ func accept(listener net.Listener) {
logf("accepting")
defer logf("closing")
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
for {
conn, err := listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
logf("%s", err)
os.Exit(1)
+ // XXX: Exiting here might be an over-reaction to the error.
+ // Although, keep in mind that tlsrp should be running under
+ // some service manager that should restart tlsrp if it exits.
}
return
}
@@ -312,26 +186,33 @@ func listen(sources []string) ([]net.Listener, error) {
if err != nil {
return nil, err
}
- return lookupCert(hostname)
+
+ cert, err := lookupCert(hostname)
+ if err != nil {
+ return nil, err
+ }
+
+ return cert.cert, nil
},
}
listeners := make([]net.Listener, 0, len(sources))
- for _, s := range sources {
- fields := strings.SplitN(s, ":", 2)
+ for _, source := range sources {
+ fields := strings.SplitN(source, ":", 2)
if len(fields) != 2 {
return nil, fmt.Errorf("invalid source: expected colon separating network type from address")
}
+
network := fields[0]
address := fields[1]
switch network {
case "tcp", "unix":
- l, err := tls.Listen(network, address, tlsConfig)
+ listener, err := tls.Listen(network, address, tlsConfig)
if err != nil {
return nil, err
}
- listeners = append(listeners, l)
+ listeners = append(listeners, listener)
default:
return nil, fmt.Errorf("invalid source: expected network type of \"tcp\" or \"unix\"")
}
@@ -344,8 +225,9 @@ func manageConfig(cfgPath string) {
cfg, err := loadConfig(cfgPath)
if err != nil {
log.Printf("failed to load initial configuration: %s\n", err)
- // Proceed with cfg == config{}, i.e., with no sinks and no certs,
- // causing every client to be rejected.
+ // Proceed with an empty configuration (i.e., with no sinks and no
+ // certs), causing every client to be rejected.
+ cfg = &config{}
}
sighup := make(chan os.Signal, 1)
@@ -379,18 +261,19 @@ func manageConfig(cfgPath string) {
if msg.hostname != nil {
for _, cert := range cfg.certs {
if cert.pattern.matches(msg.hostname) {
- msg.reply <- cert.cert
+ msg.reply <- cert
break
}
}
} else if len(cfg.certs) > 0 {
- msg.reply <- cfg.certs[0].cert
+ msg.reply <- cfg.certs[0]
}
close(msg.reply)
}
}
}
+// TODO: Simplify this so all signals soft exit first, then hard exit second
func manageExit() {
sigs := make(chan os.Signal, 3)
signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
@@ -416,14 +299,6 @@ func manageExit() {
}
}
-func init() {
- softExit = make(chan struct{})
- hardExit = make(chan struct{})
-
- lookupSinkChan = make(chan lookupSinkMsg, 16)
- lookupCertChan = make(chan lookupCertMsg, 16)
-}
-
func main() {
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s CONFIG_PATH SOURCE...", os.Args[0])