commit d257c068fa028e68e51aff4e61eeeb5978b51270
parent 0766dcb9d85f4636416d4c249da3997d5879504f
Author: Robert Russell <robertrussell.72001@gmail.com>
Date: Sun, 14 Jul 2024 17:05:53 -0700
Make it build
Diffstat:
| M | tlsrp.go | | | 77 | +++++++++++++++++++++++++++++++++++++++-------------------------------------- |
1 file changed, 39 insertions(+), 38 deletions(-)
diff --git a/tlsrp.go b/tlsrp.go
@@ -10,6 +10,7 @@ import (
"flag"
"fmt"
"golang.org/x/sys/unix"
+ "io"
"log"
"net"
"os"
@@ -73,7 +74,7 @@ var errHostnameEmpty = errors.New("empty hostname")
type hostname string
// XXX: We currently don't length check hostnames or the labels within.
-func parseHostname(s string) (Hostname, error) {
+func parseHostname(s string) (hostname, error) {
if len(s) == 0 {
return "", errHostnameEmpty
}
@@ -99,15 +100,15 @@ func parseHostname(s string) (Hostname, error) {
// Ok
case '0' <= r && r <= '9':
// Ok
- case '-' && !first:
+ case r == '-' && !first:
// Ok
- case '-' && first:
+ case r == '-' && first:
return "", fmt.Errorf("illegal hostname: hyphen at start of label")
default:
return "", fmt.Errorf("illegal hostname: illegal rune: %q", r)
}
- buf = append(buf, r)
- first := false
+ buf = append(buf, byte(r))
+ first = false
}
if buf[len(buf)-1] == '-' {
@@ -115,7 +116,7 @@ func parseHostname(s string) (Hostname, error) {
}
}
- return string(buf), nil
+ return hostname(buf), nil
}
// TODO: Add support for more than just alternation in patterns.
@@ -125,11 +126,11 @@ func (pattern pattern) matches(hostname hostname) bool {
return slices.Contains(pattern, hostname)
}
-func parsePattern(ss []string) (pattern, error) {
+func parsePattern(ss [][]byte) (pattern, error) {
pat := make(pattern, 0, len(ss))
for _, s := range ss {
- hostname, err := parseHostname(s)
+ hostname, err := parseHostname(string(s))
if err != nil {
return nil, err
}
@@ -143,7 +144,7 @@ func parsePattern(ss []string) (pattern, error) {
// These are really more like "sink specifications" and "cert specifications".
type sink struct {
- pattern Pattern
+ pattern pattern
network string
address string
}
@@ -157,15 +158,15 @@ func (sink sink) String() string {
}
type cert struct {
- pattern Pattern
+ pattern pattern
crtPath string
keyPath string
cert *tls.Certificate
}
type config struct {
- sinks []Sink
- certs []Cert
+ sinks []sink
+ certs []cert
}
func loadConfig(path string) (config, error) {
@@ -182,7 +183,7 @@ func loadConfig(path string) (config, error) {
var cfg config
- lines := bytes.Split(data, '\n')
+ lines := bytes.Split(data, []byte{'\n'})
for _, line := range lines {
if len(line) == 0 {
continue
@@ -190,10 +191,10 @@ func loadConfig(path string) (config, error) {
fields := bytes.Fields(line)
if len(fields) < 3 {
- return &config{}, fmt.Errorf("illegal config: line with fewer than 3 fields")
+ return config{}, fmt.Errorf("illegal config: line with fewer than 3 fields")
}
- switch fields[0] {
+ switch string(fields[0]) {
case "sink":
pat, err := parsePattern(fields[3:])
if err != nil {
@@ -202,14 +203,14 @@ func loadConfig(path string) (config, error) {
sink := sink{
pattern: pat,
- network: fields[1],
- address: fields[2],
+ network: string(fields[1]),
+ address: string(fields[2]),
}
- config.sinks = append(config.sinks, sink)
+ cfg.sinks = append(cfg.sinks, sink)
case "cert":
- crtPath := fields[1]
- keyPath := fields[2]
+ crtPath := string(fields[1])
+ keyPath := string(fields[2])
c, err := tls.LoadX509KeyPair(crtPath, keyPath)
if err != nil {
return config{}, err
@@ -222,11 +223,11 @@ func loadConfig(path string) (config, error) {
cert := cert{
pattern: pat,
- crtPath: fields[1],
- keyPath: fields[2],
- cert: c,
+ crtPath: crtPath,
+ keyPath: keyPath,
+ cert: &c,
}
- config.certs = append(config.certs, cert)
+ cfg.certs = append(cfg.certs, cert)
default:
return config{}, fmt.Errorf("illegal config: expected \"sink\" or \"cert\" as first field")
@@ -257,7 +258,7 @@ func manageConfig(configPath string) {
for i := range certs {
c, err := tls.LoadX509KeyPair(certs[i].crtPath, certs[i].keyPath)
if err == nil {
- certs[i].cert = c
+ certs[i].cert = &c
} else {
log.Printf("failed to reload certificates: %s\n", err)
}
@@ -303,18 +304,16 @@ func manageConfig(configPath string) {
}
func handshake(conn *tls.Conn) error {
+ ctx, cancel := context.WithTimeout(context.Background(), handshakeTimeout)
+ defer cancel()
+
handshakeErr := make(chan error, 1)
- go func() {
- ctx := context.WithTimeout(context.Background(), handshakeTimeout)
- handshakeErr <- conn.HandshakeContext(ctx)
- }()
+ go func() { handshakeErr <- conn.HandshakeContext(ctx) }()
var err error
select {
case err = <-handshakeErr:
case <-hardExit:
- conn.Close()
- <-handshakeErr
}
return err
@@ -389,7 +388,7 @@ func proxy(client *tls.Conn) {
}
defer sink.Close()
- err = splice(client, sink)
+ err = splice(client, sink.(*tls.Conn))
if err != nil {
logf("splice error: %s", err)
return
@@ -429,7 +428,8 @@ func manageExit() {
log.Println("received SIGINT; exiting softly")
close(softExit)
<-hardSigs
- fallthrough
+ log.Println("received SIGQUIT/SIGTERM; exiting harshly")
+ close(hardExit)
case <-hardSigs:
log.Println("received SIGQUIT/SIGTERM; exiting harshly")
close(hardExit)
@@ -444,7 +444,7 @@ func listen(sources []string) ([]net.Listener, error) {
return nil, err
}
return lookupCert(hostname)
- }
+ },
}
listeners := make([]net.Listener, 0, len(sources))
@@ -486,20 +486,21 @@ client using the Server Name Indication TLS extension (RFC 3546).
func init() {
softExit = make(chan struct{})
hardExit = make(chan struct{})
- lookupSink = make(chan lookupSinkMsg, 16)
- lookupCert = make(chan lookupCertMsg, 16)
+
+ lookupSinkChan = make(chan lookupSinkMsg, 16)
+ lookupCertChan = make(chan lookupCertMsg, 16)
}
func main() {
flag.Usage = usage
flag.Parse()
- if flag.NArgs() < 2 {
+ if flag.NArg() < 2 {
log.Fatalln("expected 2 or more arguments")
}
configPath := flag.Args()[0]
- listeners, err := listen(flags.Args()[1:])
+ listeners, err := listen(flag.Args()[1:])
if err != nil {
log.Fatalln(err.Error())
}