commit edd853eceb61e994d9726533072ed45fd8aad88c
parent d1c7a12e46b97d25cd4a67f17c21ef597ec82235
Author: Robert Russell <robertrussell.72001@gmail.com>
Date: Tue, 16 Jul 2024 14:47:35 -0700
Restructure hostnames and make listeners print their address
The latter is promised in the man page.
Diffstat:
| M | tlsrp.go | | | 152 | +++++++++++++++++++++++++++++++++++++++++++++++-------------------------------- |
1 file changed, 91 insertions(+), 61 deletions(-)
diff --git a/tlsrp.go b/tlsrp.go
@@ -19,15 +19,10 @@ import (
"time"
)
-// TODO: FS-based config
-// foo.rr3.xyz
-// | _cert
-// | _key
-// | _unix OR _tcp
-// Leading wildcards:
-// _.rr3.xyz
-// Explicit non-wildcard preferred.
-// Just "_" means default for clients with no SNI support.
+// TODO: Scrutinize. In particular, compare with sltls.
+// TODO: Add support for more than just alternation in patterns, like
+// arbitrary regex.
+// XXX: We currently don't length check hostnames or the labels within.
// We only enforce a timeout on the handshake. After the handshake is complete,
// the sink is responsible for timing-out clients.
@@ -54,7 +49,7 @@ func lookupSink(hostname hostname) (net.Addr, error) {
}
sink, ok := <-reply
if !ok {
- return nil, fmt.Errorf("no sink for hostname %q", hostname)
+ return nil, fmt.Errorf("no sink for hostname %s", hostname)
}
return sink, nil
}
@@ -72,71 +67,98 @@ func lookupCert(hostname hostname) (*tls.Certificate, error) {
}
cert, ok := <-reply
if !ok {
- return nil, fmt.Errorf("no certificate for hostname %q", hostname)
+ return nil, fmt.Errorf("no certificate for hostname %s", hostname)
}
return cert, nil
}
-var errHostnameEmpty = errors.New("empty hostname")
-
-type hostname string
+type label string
-// XXX: We currently don't length check hostnames or the labels within.
-func parseHostname(s string) (hostname, error) {
+func parseLabel(s string) (label, error) {
if len(s) == 0 {
- return "", errHostnameEmpty
+ return "", fmt.Errorf("empty label")
}
-
+
buf := make([]byte, 0, len(s))
- labels := strings.Split(s, ".")
- for _, label := range labels {
- if len(label) == 0 {
- return "", fmt.Errorf("illegal hostname: empty label")
+ for i, r := range s {
+ first := i == 0
+ last := i == len(s) - 1
+
+ switch {
+ case 'A' <= r && r <= 'Z':
+ r += 'a' - 'A'
+ case 'a' <= r && r <= 'z':
+ // Ok
+ case '0' <= r && r <= '9':
+ // Ok
+ case r == '-' && (!first && !last):
+ // Ok
+ case r == '-' && first:
+ return "", fmt.Errorf("hyphen at start of label")
+ case r == '-' && last:
+ return "", fmt.Errorf("hyphen at end of label")
+ default:
+ return "", fmt.Errorf("illegal rune in label: %q", r)
}
- if len(buf) > 0 {
- buf = append(buf, '.')
+ buf = append(buf, byte(r))
+ }
+
+ return label(buf), nil
+}
+
+type hostname []label
+
+func (hostname hostname) String() string {
+ // Ughh, Go can't convert between hostname and []string,
+ // so we can't use strings.Join.
+
+ var sb strings.Builder
+ for i, label := range hostname {
+ if i > 0 {
+ sb.WriteByte('.')
}
+ sb.WriteString(string(label))
+ }
- for i, r := range label {
- first := i == 0
- last := i == len(label) - 1
- switch {
- case 'A' <= r && r <= 'Z':
- r += 'a' - 'A'
- case 'a' <= r && r <= 'z':
- // Ok
- case '0' <= r && r <= '9':
- // Ok
- case r == '-' && (!first && !last):
- // Ok
- case r == '-' && first:
- return "", fmt.Errorf("illegal hostname: hyphen at start of label")
- case r == '-' && last:
- return "", fmt.Errorf("illegal hostname: hyphen at end of label")
- default:
- return "", fmt.Errorf("illegal hostname: illegal rune: %q", r)
- }
- buf = append(buf, byte(r))
+ return sb.String()
+}
+
+func (hostname0 hostname) equal(hostname1 hostname) bool {
+ return slices.Equal(hostname0, hostname1)
+}
+
+func parseHostname(s string) (hostname, error) {
+ if len(s) == 0 {
+ return nil, nil
+ }
+
+ labelStrs := strings.Split(s, ".")
+ labels := make([]label, 0, len(labelStrs))
+
+ for _, labelStr := range labelStrs {
+ label, err := parseLabel(labelStr)
+ if err != nil {
+ return nil, fmt.Errorf("illegal hostname: %w", err)
}
+ labels = append(labels, label)
}
- return hostname(buf), nil
+ return hostname(labels), nil
}
-// TODO: Add support for more than just alternation in patterns.
type pattern []hostname
-func (pattern pattern) matches(hostname hostname) bool {
- return slices.Contains(pattern, hostname)
+func (pat pattern) matches(hostname hostname) bool {
+ return slices.ContainsFunc(pat, hostname.equal)
}
-func parsePattern(ss []string) (pattern, error) {
- pat := make(pattern, 0, len(ss))
+func parsePattern(hostnameStrs []string) (pattern, error) {
+ pat := make(pattern, 0, len(hostnameStrs))
- for _, s := range ss {
- hostname, err := parseHostname(s)
+ for _, hostnameStr := range hostnameStrs {
+ hostname, err := parseHostname(hostnameStr)
if err != nil {
return nil, err
}
@@ -274,7 +296,7 @@ func manageConfig(cfgPath string) {
}
case msg := <-lookupSinkChan:
- if msg.hostname != "" {
+ if msg.hostname != nil {
for _, sink := range cfg.sinks {
if sink.pattern.matches(msg.hostname) {
msg.reply <- sink
@@ -287,7 +309,7 @@ func manageConfig(cfgPath string) {
close(msg.reply)
case msg := <-lookupCertChan:
- if msg.hostname != "" {
+ if msg.hostname != nil {
for _, cert := range cfg.certs {
if cert.pattern.matches(msg.hostname) {
msg.reply <- cert.cert
@@ -359,7 +381,7 @@ func splice(a, b conn) error {
}
func proxy(client *tls.Conn) {
- logf := func(format string, a ...interface{}) {
+ logf := func(format string, a ...any) {
log.Printf("client %s: %s\n", client.RemoteAddr(), fmt.Sprintf(format, a...))
}
@@ -369,12 +391,12 @@ func proxy(client *tls.Conn) {
err := handshake(client)
if err != nil {
- logf("handshake error: %s", err)
+ logf("%s", err)
return
}
hostname, err := parseHostname(client.ConnectionState().ServerName)
- if err != nil && !errors.Is(err, errHostnameEmpty) {
+ if err != nil {
logf("rejected: %s", err)
return
}
@@ -387,14 +409,14 @@ func proxy(client *tls.Conn) {
sink, err := net.Dial(sinkAddr.Network(), sinkAddr.String())
if err != nil {
- logf("dial error: %s", err)
+ logf("%s", err)
return
}
defer sink.Close()
err = splice(client, sink.(conn))
if err != nil {
- logf("splice error: %s", err)
+ logf("%s", err)
return
}
}
@@ -403,11 +425,19 @@ func accept(listener net.Listener) {
var wg sync.WaitGroup
defer wg.Wait()
+ logf := func(format string, a ...any) {
+ log.Printf("source %s: %s\n", listener.Addr(), fmt.Sprintf(format, a...))
+ }
+
+ logf("accepting")
+ defer logf("closing")
+
for {
conn, err := listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
- log.Fatalf("source %s error: %s\n", listener.Addr(), err)
+ logf("%s", err)
+ os.Exit(1)
}
return
}
@@ -449,7 +479,7 @@ func listen(sources []string) ([]net.Listener, error) {
tlsConfig := &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
hostname, err := parseHostname(chi.ServerName)
- if err != nil && !errors.Is(err, errHostnameEmpty) {
+ if err != nil {
return nil, err
}
return lookupCert(hostname)