commit 5f3c7bb1cdf0cd76d4716d158aa205bc5a3418e2
parent edd853eceb61e994d9726533072ed45fd8aad88c
Author: Robert Russell <robertrussell.72001@gmail.com>
Date: Tue, 16 Jul 2024 15:00:09 -0700
Move hostname stuff to separate file
Diffstat:
| A | hostname.go | | | 103 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
| M | tlsrp.go | | | 262 | ++++++++++++++++++++++++++----------------------------------------------------- |
2 files changed, 187 insertions(+), 178 deletions(-)
diff --git a/hostname.go b/hostname.go
@@ -0,0 +1,103 @@
+package main
+
+import (
+ "fmt"
+ "slices"
+ "strings"
+)
+
+type label string
+
+func parseLabel(s string) (label, error) {
+ if len(s) == 0 {
+ return "", fmt.Errorf("empty label")
+ }
+
+ buf := make([]byte, 0, len(s))
+
+ for i, r := range s {
+ first := i == 0
+ last := i == len(s) - 1
+
+ switch {
+ case 'A' <= r && r <= 'Z':
+ r += 'a' - 'A'
+ case 'a' <= r && r <= 'z':
+ // Ok
+ case '0' <= r && r <= '9':
+ // Ok
+ case r == '-' && (!first && !last):
+ // Ok
+ case r == '-' && first:
+ return "", fmt.Errorf("hyphen at start of label")
+ case r == '-' && last:
+ return "", fmt.Errorf("hyphen at end of label")
+ default:
+ return "", fmt.Errorf("illegal rune in label: %q", r)
+ }
+
+ buf = append(buf, byte(r))
+ }
+
+ return label(buf), nil
+}
+
+type hostname []label
+
+func (hostname hostname) String() string {
+ // Ughh, Go can't convert between hostname and []string,
+ // so we can't use strings.Join.
+
+ var sb strings.Builder
+ for i, label := range hostname {
+ if i > 0 {
+ sb.WriteByte('.')
+ }
+ sb.WriteString(string(label))
+ }
+
+ return sb.String()
+}
+
+func (hostname0 hostname) equal(hostname1 hostname) bool {
+ return slices.Equal(hostname0, hostname1)
+}
+
+func parseHostname(s string) (hostname, error) {
+ if len(s) == 0 {
+ return nil, nil
+ }
+
+ labelStrs := strings.Split(s, ".")
+ labels := make([]label, 0, len(labelStrs))
+
+ for _, labelStr := range labelStrs {
+ label, err := parseLabel(labelStr)
+ if err != nil {
+ return nil, fmt.Errorf("illegal hostname: %w", err)
+ }
+ labels = append(labels, label)
+ }
+
+ return hostname(labels), nil
+}
+
+type pattern []hostname
+
+func (pat pattern) matches(hostname hostname) bool {
+ return slices.ContainsFunc(pat, hostname.equal)
+}
+
+func parsePattern(hostnameStrs []string) (pattern, error) {
+ pat := make(pattern, 0, len(hostnameStrs))
+
+ for _, hostnameStr := range hostnameStrs {
+ hostname, err := parseHostname(hostnameStr)
+ if err != nil {
+ return nil, err
+ }
+ pat = append(pat, hostname)
+ }
+
+ return pat, nil
+}
diff --git a/tlsrp.go b/tlsrp.go
@@ -13,7 +13,6 @@ import (
"net"
"os"
"os/signal"
- "slices"
"strings"
"sync"
"time"
@@ -37,10 +36,18 @@ var softExit chan struct{}
var hardExit chan struct{}
var lookupSinkChan chan lookupSinkMsg
+var lookupCertChan chan lookupCertMsg
+
type lookupSinkMsg struct {
hostname hostname
reply chan<- net.Addr
}
+
+type lookupCertMsg struct {
+ hostname hostname
+ reply chan<- *tls.Certificate
+}
+
func lookupSink(hostname hostname) (net.Addr, error) {
reply := make(chan net.Addr, 1)
lookupSinkChan <- lookupSinkMsg{
@@ -54,11 +61,6 @@ func lookupSink(hostname hostname) (net.Addr, error) {
return sink, nil
}
-var lookupCertChan chan lookupCertMsg
-type lookupCertMsg struct {
- hostname hostname
- reply chan<- *tls.Certificate
-}
func lookupCert(hostname hostname) (*tls.Certificate, error) {
reply := make(chan *tls.Certificate, 1)
lookupCertChan <- lookupCertMsg{
@@ -72,102 +74,6 @@ func lookupCert(hostname hostname) (*tls.Certificate, error) {
return cert, nil
}
-type label string
-
-func parseLabel(s string) (label, error) {
- if len(s) == 0 {
- return "", fmt.Errorf("empty label")
- }
-
- buf := make([]byte, 0, len(s))
-
- for i, r := range s {
- first := i == 0
- last := i == len(s) - 1
-
- switch {
- case 'A' <= r && r <= 'Z':
- r += 'a' - 'A'
- case 'a' <= r && r <= 'z':
- // Ok
- case '0' <= r && r <= '9':
- // Ok
- case r == '-' && (!first && !last):
- // Ok
- case r == '-' && first:
- return "", fmt.Errorf("hyphen at start of label")
- case r == '-' && last:
- return "", fmt.Errorf("hyphen at end of label")
- default:
- return "", fmt.Errorf("illegal rune in label: %q", r)
- }
-
- buf = append(buf, byte(r))
- }
-
- return label(buf), nil
-}
-
-type hostname []label
-
-func (hostname hostname) String() string {
- // Ughh, Go can't convert between hostname and []string,
- // so we can't use strings.Join.
-
- var sb strings.Builder
- for i, label := range hostname {
- if i > 0 {
- sb.WriteByte('.')
- }
- sb.WriteString(string(label))
- }
-
- return sb.String()
-}
-
-func (hostname0 hostname) equal(hostname1 hostname) bool {
- return slices.Equal(hostname0, hostname1)
-}
-
-func parseHostname(s string) (hostname, error) {
- if len(s) == 0 {
- return nil, nil
- }
-
- labelStrs := strings.Split(s, ".")
- labels := make([]label, 0, len(labelStrs))
-
- for _, labelStr := range labelStrs {
- label, err := parseLabel(labelStr)
- if err != nil {
- return nil, fmt.Errorf("illegal hostname: %w", err)
- }
- labels = append(labels, label)
- }
-
- return hostname(labels), nil
-}
-
-type pattern []hostname
-
-func (pat pattern) matches(hostname hostname) bool {
- return slices.ContainsFunc(pat, hostname.equal)
-}
-
-func parsePattern(hostnameStrs []string) (pattern, error) {
- pat := make(pattern, 0, len(hostnameStrs))
-
- for _, hostnameStr := range hostnameStrs {
- hostname, err := parseHostname(hostnameStr)
- if err != nil {
- return nil, err
- }
- pat = append(pat, hostname)
- }
-
- return pat, nil
-}
-
type sink struct {
pattern pattern
network string
@@ -273,57 +179,6 @@ func loadConfig(path string) (config, error) {
return cfg, nil
}
-func manageConfig(cfgPath string) {
- cfg, err := loadConfig(cfgPath)
- if err != nil {
- log.Printf("failed to load initial configuration: %s\n", err)
- // Proceed with cfg == config{}, i.e., with no sinks and no certs,
- // causing every client to be rejected.
- }
-
- sighup := make(chan os.Signal, 1)
- signal.Notify(sighup, unix.SIGHUP)
-
- for {
- select {
- case <-sighup:
- log.Println("received SIGHUP; reloading configuration")
- newCfg, err := loadConfig(cfgPath)
- if err == nil {
- cfg = newCfg
- } else {
- log.Printf("failed to reload configuration: %s\n", err)
- }
-
- case msg := <-lookupSinkChan:
- if msg.hostname != nil {
- for _, sink := range cfg.sinks {
- if sink.pattern.matches(msg.hostname) {
- msg.reply <- sink
- break
- }
- }
- } else if len(cfg.sinks) > 0 {
- msg.reply <- cfg.sinks[0]
- }
- close(msg.reply)
-
- case msg := <-lookupCertChan:
- if msg.hostname != nil {
- for _, cert := range cfg.certs {
- if cert.pattern.matches(msg.hostname) {
- msg.reply <- cert.cert
- break
- }
- }
- } else if len(cfg.certs) > 0 {
- msg.reply <- cfg.certs[0].cert
- }
- close(msg.reply)
- }
- }
-}
-
func handshake(conn *tls.Conn) error {
ctx, cancel := context.WithTimeout(context.Background(), handshakeTimeout)
defer cancel()
@@ -450,31 +305,6 @@ func accept(listener net.Listener) {
}
}
-func manageExit() {
- sigs := make(chan os.Signal, 3)
- signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
-
- softExiting := false
- for sig := range sigs {
- switch sig {
- case unix.SIGINT, unix.SIGQUIT:
- if !softExiting {
- log.Println("received SIGINT/SIGQUIT; exiting softly")
- close(softExit)
- softExiting = true
- } else {
- log.Println("received another SIGINT/SIGQUIT; exiting harshly")
- close(hardExit)
- return
- }
- case unix.SIGTERM:
- log.Println("received SIGTERM; exiting harshly")
- close(hardExit)
- return
- }
- }
-}
-
func listen(sources []string) ([]net.Listener, error) {
tlsConfig := &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
@@ -510,6 +340,82 @@ func listen(sources []string) ([]net.Listener, error) {
return listeners, nil
}
+func manageConfig(cfgPath string) {
+ cfg, err := loadConfig(cfgPath)
+ if err != nil {
+ log.Printf("failed to load initial configuration: %s\n", err)
+ // Proceed with cfg == config{}, i.e., with no sinks and no certs,
+ // causing every client to be rejected.
+ }
+
+ sighup := make(chan os.Signal, 1)
+ signal.Notify(sighup, unix.SIGHUP)
+
+ for {
+ select {
+ case <-sighup:
+ log.Println("received SIGHUP; reloading configuration")
+ newCfg, err := loadConfig(cfgPath)
+ if err == nil {
+ cfg = newCfg
+ } else {
+ log.Printf("failed to reload configuration: %s\n", err)
+ }
+
+ case msg := <-lookupSinkChan:
+ if msg.hostname != nil {
+ for _, sink := range cfg.sinks {
+ if sink.pattern.matches(msg.hostname) {
+ msg.reply <- sink
+ break
+ }
+ }
+ } else if len(cfg.sinks) > 0 {
+ msg.reply <- cfg.sinks[0]
+ }
+ close(msg.reply)
+
+ case msg := <-lookupCertChan:
+ if msg.hostname != nil {
+ for _, cert := range cfg.certs {
+ if cert.pattern.matches(msg.hostname) {
+ msg.reply <- cert.cert
+ break
+ }
+ }
+ } else if len(cfg.certs) > 0 {
+ msg.reply <- cfg.certs[0].cert
+ }
+ close(msg.reply)
+ }
+ }
+}
+
+func manageExit() {
+ sigs := make(chan os.Signal, 3)
+ signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
+
+ softExiting := false
+ for sig := range sigs {
+ switch sig {
+ case unix.SIGINT, unix.SIGQUIT:
+ if !softExiting {
+ log.Println("received SIGINT/SIGQUIT; exiting softly")
+ close(softExit)
+ softExiting = true
+ } else {
+ log.Println("received another SIGINT/SIGQUIT; exiting harshly")
+ close(hardExit)
+ return
+ }
+ case unix.SIGTERM:
+ log.Println("received SIGTERM; exiting harshly")
+ close(hardExit)
+ return
+ }
+ }
+}
+
func init() {
softExit = make(chan struct{})
hardExit = make(chan struct{})