From a074a2b9832e5e0ace0b8fbca3e4245b8c10e32d Mon Sep 17 00:00:00 2001 From: Nolan Woods Date: Wed, 27 Oct 2021 00:34:24 -0700 Subject: [PATCH] lib/http: Fix handling of ssl credentials Adds a test that makes an actual http and https request against the server --- lib/http/http.go | 42 ++++++++++++++++++++------ lib/http/http_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 101 insertions(+), 10 deletions(-) diff --git a/lib/http/http.go b/lib/http/http.go index 7bfd5922a..abf3b76d4 100644 --- a/lib/http/http.go +++ b/lib/http/http.go @@ -71,8 +71,10 @@ type Options struct { ServerReadTimeout time.Duration // Timeout for server reading data ServerWriteTimeout time.Duration // Timeout for server writing data MaxHeaderBytes int // Maximum size of request header - SslCert string // SSL PEM key (concatenation of certificate and CA certificate) - SslKey string // SSL PEM Private key + SslCert string // Path to SSL PEM key (concatenation of certificate and CA certificate) + SslKey string // Path to SSL PEM Private key + SslCertBody []byte // SSL PEM key (concatenation of certificate and CA certificate) body, ignores SslCert + SslKeyBody []byte // SSL PEM Private key body, ignores SslKey ClientCA string // Client certificate authority to verify clients with } @@ -110,7 +112,7 @@ var ( ) func useSSL(opt Options) bool { - return opt.SslKey != "" + return opt.SslKey != "" || len(opt.SslKeyBody) > 0 } // NewServer instantiates a new http server using provided listeners and options @@ -127,15 +129,31 @@ func NewServer(listeners, tlsListeners []net.Listener, opt Options) (Server, err var tlsConfig *tls.Config useSSL := useSSL(opt) - if (opt.SslCert != "") != useSSL { + if (len(opt.SslCertBody) > 0) != (len(opt.SslKeyBody) > 0) { + err := errors.New("Need both SslCertBody and SslKeyBody to use SSL") + log.Fatalf(err.Error()) + return nil, err + } + if (opt.SslCert != "") != (opt.SslKey != "") { err := errors.New("Need both -cert and -key to use SSL") log.Fatalf(err.Error()) return nil, err } if useSSL { + var cert tls.Certificate + var err error + if len(opt.SslCertBody) > 0 { + cert, err = tls.X509KeyPair(opt.SslCertBody, opt.SslKeyBody) + } else { + cert, err = tls.LoadX509KeyPair(opt.SslCert, opt.SslKey) + } + if err != nil { + log.Fatal(err) + } tlsConfig = &tls.Config{ - MinVersion: tls.VersionTLS10, // disable SSL v3.0 and earlier + MinVersion: tls.VersionTLS10, // disable SSL v3.0 and earlier + Certificates: []tls.Certificate{cert}, } } else if len(listeners) == 0 && len(tlsListeners) != 0 { return nil, errors.New("No SslKey or non-tlsListeners") @@ -211,22 +229,28 @@ func NewServer(listeners, tlsListeners []net.Listener, opt Options) (Server, err } func (s *server) Serve() { - serve := func(l net.Listener) { + serve := func(l net.Listener, tls bool) { defer s.closing.Done() - if err := s.httpServer.Serve(l); err != http.ErrServerClosed && err != nil { + var err error + if tls { + err = s.httpServer.ServeTLS(l, "", "") + } else { + err = s.httpServer.Serve(l) + } + if err != http.ErrServerClosed && err != nil { log.Fatalf(err.Error()) } } s.closing.Add(len(s.listeners)) for _, l := range s.listeners { - go serve(l) + go serve(l, false) } if s.useSSL { s.closing.Add(len(s.tlsListeners)) for _, l := range s.tlsListeners { - go serve(l) + go serve(l, true) } } } diff --git a/lib/http/http_test.go b/lib/http/http_test.go index a499d5658..1c2618be2 100644 --- a/lib/http/http_test.go +++ b/lib/http/http_test.go @@ -1,10 +1,13 @@ package http import ( + "crypto/tls" "net" "net/http" "reflect" + "strings" "testing" + "time" "golang.org/x/net/nettest" @@ -356,9 +359,28 @@ func Test_server_Shutdown(t *testing.T) { } func Test_start(t *testing.T) { + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + sslServerOptions := defaultServerOptions + sslServerOptions.SslCertBody = []byte(`-----BEGIN CERTIFICATE----- +MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d +7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B +5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr +BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 +NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l +Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc +6MF9+Yw1Yy0t +-----END CERTIFICATE-----`) + sslServerOptions.SslKeyBody = []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 +AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q +EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== +-----END EC PRIVATE KEY-----`) tests := []struct { name string opt Options + ssl bool wantErr bool }{ { @@ -366,20 +388,55 @@ func Test_start(t *testing.T) { opt: defaultServerOptions, wantErr: false, }, + { + name: "ssl", + opt: sslServerOptions, + ssl: true, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + defer func() { + err := Shutdown() + if err != nil { + t.Fatal("couldn't shutdown server") + } + }() SetOptions(tt.opt) if err := start(); (err != nil) != tt.wantErr { t.Errorf("start() error = %v, wantErr %v", err, tt.wantErr) return } s := defaultServer - if useSSL(tt.opt) { + router := s.Router() + router.Head("/", func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(201) + }) + testURL := URL() + if tt.ssl { + assert.True(t, useSSL(tt.opt)) assert.Equal(t, tt.opt.ListenAddr, s.tlsAddrs[0].String()) + assert.True(t, strings.HasPrefix(testURL, "https://")) } else { + assert.True(t, strings.HasPrefix(testURL, "http://")) assert.Equal(t, tt.opt.ListenAddr, s.addrs[0].String()) } + + // try to connect to the test server + pause := time.Millisecond + for i := 0; i < 10; i++ { + resp, err := http.Head(testURL) + if err == nil { + _ = resp.Body.Close() + return + } + // t.Logf("couldn't connect, sleeping for %v: %v", pause, err) + time.Sleep(pause) + pause *= 2 + } + t.Fatal("couldn't connect to server") + /* accessing s.httpServer.* can't be done synchronously and is a race condition assert.Equal(t, tt.opt.ServerReadTimeout, defaultServer.httpServer.ReadTimeout) assert.Equal(t, tt.opt.ServerWriteTimeout, defaultServer.httpServer.WriteTimeout) @@ -427,6 +484,16 @@ func Test_useSSL(t *testing.T) { }}, want: true, }, + { + name: "body", + args: args{opt: Options{ + SslCert: "", + SslKey: "", + SslKeyBody: []byte(`test`), + ClientCA: "", + }}, + want: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {