commit a31b76827ed7ed10637b7fe730277a3425071dfc
parent d257c068fa028e68e51aff4e61eeeb5978b51270
Author: Robert Russell <robertrussell.72001@gmail.com>
Date: Sun, 14 Jul 2024 21:51:45 -0700
Add go mod files
Diffstat:
| A | go.mod | | | 5 | +++++ |
| A | go.sum | | | 2 | ++ |
| M | tlsrp.go | | | 103 | ++++++++++++++++++++++++++++++++++++++----------------------------------------- |
3 files changed, 57 insertions(+), 53 deletions(-)
diff --git a/go.mod b/go.mod
@@ -0,0 +1,5 @@
+module rr3.xyz/tlsrp
+
+go 1.22.5
+
+require golang.org/x/sys v0.22.0 // indirect
diff --git a/go.sum b/go.sum
@@ -0,0 +1,2 @@
+golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
+golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
diff --git a/tlsrp.go b/tlsrp.go
@@ -3,7 +3,6 @@ package main
// TODO: check log.Printf newlines
import (
- "bytes"
"context"
"crypto/tls"
"errors"
@@ -91,8 +90,9 @@ func parseHostname(s string) (hostname, error) {
buf = append(buf, '.')
}
- first := true
- for _, r := range label {
+ for i, r := range label {
+ first := i == 0
+ last := i == len(label) - 1
switch {
case 'A' <= r && r <= 'Z':
r += 'a' - 'A'
@@ -100,19 +100,16 @@ func parseHostname(s string) (hostname, error) {
// Ok
case '0' <= r && r <= '9':
// Ok
- case r == '-' && !first:
+ case r == '-' && (!first && !last):
// Ok
case r == '-' && first:
return "", fmt.Errorf("illegal hostname: hyphen at start of label")
+ case r == '-' && last:
+ return "", fmt.Errorf("illegal hostname: hyphen at end of label")
default:
return "", fmt.Errorf("illegal hostname: illegal rune: %q", r)
}
buf = append(buf, byte(r))
- first = false
- }
-
- if buf[len(buf)-1] == '-' {
- return "", fmt.Errorf("illegal hostname: hyphen at end of label")
}
}
@@ -126,11 +123,11 @@ func (pattern pattern) matches(hostname hostname) bool {
return slices.Contains(pattern, hostname)
}
-func parsePattern(ss [][]byte) (pattern, error) {
+func parsePattern(ss []string) (pattern, error) {
pat := make(pattern, 0, len(ss))
for _, s := range ss {
- hostname, err := parseHostname(string(s))
+ hostname, err := parseHostname(s)
if err != nil {
return nil, err
}
@@ -140,9 +137,6 @@ func parsePattern(ss [][]byte) (pattern, error) {
return pat, nil
}
-// TODO: Come up with better names for the types "sink" and "cert".
-// These are really more like "sink specifications" and "cert specifications".
-
type sink struct {
pattern pattern
network string
@@ -183,18 +177,17 @@ func loadConfig(path string) (config, error) {
var cfg config
- lines := bytes.Split(data, []byte{'\n'})
+ lines := strings.Split(string(data), "\n")
for _, line := range lines {
- if len(line) == 0 {
- continue
+ fields := strings.Fields(line)
+ if len(fields) == 0 {
+ continue // Empty line
}
-
- fields := bytes.Fields(line)
if len(fields) < 3 {
return config{}, fmt.Errorf("illegal config: line with fewer than 3 fields")
}
- switch string(fields[0]) {
+ switch fields[0] {
case "sink":
pat, err := parsePattern(fields[3:])
if err != nil {
@@ -203,15 +196,15 @@ func loadConfig(path string) (config, error) {
sink := sink{
pattern: pat,
- network: string(fields[1]),
- address: string(fields[2]),
+ network: fields[1],
+ address: fields[2],
}
cfg.sinks = append(cfg.sinks, sink)
case "cert":
- crtPath := string(fields[1])
- keyPath := string(fields[2])
- c, err := tls.LoadX509KeyPair(crtPath, keyPath)
+ crtPath := fields[1]
+ keyPath := fields[2]
+ tlsCert, err := tls.LoadX509KeyPair(crtPath, keyPath)
if err != nil {
return config{}, err
}
@@ -225,7 +218,7 @@ func loadConfig(path string) (config, error) {
pattern: pat,
crtPath: crtPath,
keyPath: keyPath,
- cert: &c,
+ cert: &tlsCert,
}
cfg.certs = append(cfg.certs, cert)
@@ -237,14 +230,14 @@ func loadConfig(path string) (config, error) {
return cfg, nil
}
-func manageConfig(configPath string) {
- cfg, err := loadConfig(configPath)
+func manageConfig(cfgPath string) {
+ cfg, err := loadConfig(cfgPath)
if err != nil {
log.Printf("failed to load initial configuration: %s\n", err)
+ // Proceed with cfg == config{}, i.e., with no sinks and no certs,
+ // causing every client to be rejected.
}
- // Proceed with cfg == config{}, i.e., with no sinks and no certs, causing
- // every client to be rejected.
-
+
sigusr := make(chan os.Signal, 2)
signal.Notify(sigusr, unix.SIGUSR1, unix.SIGUSR2)
@@ -256,17 +249,19 @@ func manageConfig(configPath string) {
log.Println("received SIGUSR1; reloading certificates")
certs := cfg.certs
for i := range certs {
- c, err := tls.LoadX509KeyPair(certs[i].crtPath, certs[i].keyPath)
+ crtPath := certs[i].crtPath
+ keyPath := certs[i].keyPath
+ tlsCert, err := tls.LoadX509KeyPair(crtPath, keyPath)
if err == nil {
- certs[i].cert = &c
+ certs[i].cert = &tlsCert
} else {
- log.Printf("failed to reload certificates: %s\n", err)
+ log.Printf("failed to reload certificate (%s, %s): %s\n", crtPath, keyPath, err)
}
}
case unix.SIGUSR2:
log.Println("received SIGUSR2; reloading configuration")
- newCfg, err := loadConfig(configPath)
+ newCfg, err := loadConfig(cfgPath)
if err == nil {
cfg = newCfg
} else {
@@ -417,22 +412,24 @@ func accept(listener net.Listener) {
}
func manageExit() {
- softSigs := make(chan os.Signal, 1)
- signal.Notify(softSigs, unix.SIGINT)
-
- hardSigs := make(chan os.Signal, 1)
- signal.Notify(hardSigs, unix.SIGQUIT, unix.SIGTERM)
-
- select {
- case <-softSigs:
- log.Println("received SIGINT; exiting softly")
- close(softExit)
- <-hardSigs
- log.Println("received SIGQUIT/SIGTERM; exiting harshly")
- close(hardExit)
- case <-hardSigs:
- log.Println("received SIGQUIT/SIGTERM; exiting harshly")
- close(hardExit)
+ sigs := make(chan os.Signal, 3)
+ signal.Notify(sigs, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
+
+ softExiting := false
+ for sig := range sigs {
+ switch sig {
+ case unix.SIGINT:
+ if !softExiting {
+ log.Println("received SIGINT; exiting softly")
+ close(softExit)
+ softExiting = true
+ }
+ // Keep going in case we receive hard exit signal.
+ case unix.SIGQUIT, unix.SIGTERM:
+ log.Println("received SIGQUIT/SIGTERM; exiting harshly")
+ close(hardExit)
+ return
+ }
}
}
@@ -498,7 +495,7 @@ func main() {
log.Fatalln("expected 2 or more arguments")
}
- configPath := flag.Args()[0]
+ cfgPath := flag.Args()[0]
listeners, err := listen(flag.Args()[1:])
if err != nil {
@@ -506,7 +503,7 @@ func main() {
}
go manageExit()
- go manageConfig(configPath)
+ go manageConfig(cfgPath)
var wg sync.WaitGroup
defer wg.Wait()