diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go index ce2f647..98842b4 100644 --- a/acme/autocert/autocert.go +++ b/acme/autocert/autocert.go @@ -162,6 +162,10 @@ type Manager struct { // The error is propagated back to the caller of GetCertificate and is user-visible. // This does not affect cached certs. See HostPolicy field description for more details. func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if m.Prompt == nil { + return nil, errors.New("acme/autocert: Manager.Prompt not set") + } + name := hello.ServerName if name == "" { return nil, errors.New("acme/autocert: missing server name") diff --git a/acme/autocert/autocert_test.go b/acme/autocert/autocert_test.go index acef162..6df4cf3 100644 --- a/acme/autocert/autocert_test.go +++ b/acme/autocert/autocert_test.go @@ -159,9 +159,28 @@ func TestGetCertificate_ForceRSA(t *testing.T) { } } -// tests man.GetCertificate flow using the provided hello argument. +func TestGetCertificate_nilPrompt(t *testing.T) { + man := &Manager{} + defer man.stopRenew() + url, finish := startACMEServerStub(t, man, "example.org") + defer finish() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + man.Client = &acme.Client{ + Key: key, + DirectoryURL: url, + } + hello := &tls.ClientHelloInfo{ServerName: "example.org"} + if _, err := man.GetCertificate(hello); err == nil { + t.Error("got certificate for example.org; wanted error") + } +} + +// startACMEServerStub runs an ACME server // The domain argument is the expected domain name of a certificate request. -func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) { +func startACMEServerStub(t *testing.T, man *Manager, domain string) (url string, finish func()) { // echo token-02 | shasum -a 256 // then divide result in 2 parts separated by dot tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid" @@ -187,7 +206,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl // discovery case "/": if err := discoTmpl.Execute(w, ca.URL); err != nil { - t.Fatalf("discoTmpl: %v", err) + t.Errorf("discoTmpl: %v", err) } // client key registration case "/new-reg": @@ -197,7 +216,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl w.Header().Set("location", ca.URL+"/authz/1") w.WriteHeader(http.StatusCreated) if err := authzTmpl.Execute(w, ca.URL); err != nil { - t.Fatalf("authzTmpl: %v", err) + t.Errorf("authzTmpl: %v", err) } // accept tls-sni-02 challenge case "/challenge/2": @@ -215,14 +234,14 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl b, _ := base64.RawURLEncoding.DecodeString(req.CSR) csr, err := x509.ParseCertificateRequest(b) if err != nil { - t.Fatalf("new-cert: CSR: %v", err) + t.Errorf("new-cert: CSR: %v", err) } if csr.Subject.CommonName != domain { t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain) } der, err := dummyCert(csr.PublicKey, domain) if err != nil { - t.Fatalf("new-cert: dummyCert: %v", err) + t.Errorf("new-cert: dummyCert: %v", err) } chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL) w.Header().Set("link", chainUp) @@ -232,14 +251,51 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl case "/ca-cert": der, err := dummyCert(nil, "ca") if err != nil { - t.Fatalf("ca-cert: dummyCert: %v", err) + t.Errorf("ca-cert: dummyCert: %v", err) } w.Write(der) default: t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path) } })) - defer ca.Close() + finish = func() { + ca.Close() + + // make sure token cert was removed + cancel := make(chan struct{}) + done := make(chan struct{}) + go func() { + defer close(done) + tick := time.NewTicker(100 * time.Millisecond) + defer tick.Stop() + for { + hello := &tls.ClientHelloInfo{ServerName: tokenCertName} + if _, err := man.GetCertificate(hello); err != nil { + return + } + select { + case <-tick.C: + case <-cancel: + return + } + } + }() + select { + case <-done: + case <-time.After(5 * time.Second): + close(cancel) + t.Error("token cert was not removed") + <-done + } + } + return ca.URL, finish +} + +// tests man.GetCertificate flow using the provided hello argument. +// The domain argument is the expected domain name of a certificate request. +func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) { + url, finish := startACMEServerStub(t, man, domain) + defer finish() // use EC key to run faster on 386 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) @@ -248,7 +304,7 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl } man.Client = &acme.Client{ Key: key, - DirectoryURL: ca.URL, + DirectoryURL: url, } // simulate tls.Config.GetCertificate @@ -279,23 +335,6 @@ func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.Cl t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain) } - // make sure token cert was removed - done = make(chan struct{}) - go func() { - for { - hello := &tls.ClientHelloInfo{ServerName: tokenCertName} - if _, err := man.GetCertificate(hello); err != nil { - break - } - time.Sleep(100 * time.Millisecond) - } - close(done) - }() - select { - case <-time.After(5 * time.Second): - t.Error("token cert was not removed") - case <-done: - } } func TestAccountKeyCache(t *testing.T) {