diff --git a/src/crypto/x509/cert_pool.go b/src/crypto/x509/cert_pool.go index 873ffeee1d..ae43c84424 100644 --- a/src/crypto/x509/cert_pool.go +++ b/src/crypto/x509/cert_pool.go @@ -249,3 +249,16 @@ func (s *CertPool) Subjects() [][]byte { } return res } + +// Equal reports whether s and other are equal. +func (s *CertPool) Equal(other *CertPool) bool { + if s.systemPool != other.systemPool || len(s.haveSum) != len(other.haveSum) { + return false + } + for h := range s.haveSum { + if !other.haveSum[h] { + return false + } + } + return true +} diff --git a/src/crypto/x509/cert_pool_test.go b/src/crypto/x509/cert_pool_test.go new file mode 100644 index 0000000000..d1ec9aaefd --- /dev/null +++ b/src/crypto/x509/cert_pool_test.go @@ -0,0 +1,58 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x509 + +import "testing" + +func TestCertPoolEqual(t *testing.T) { + a, b := NewCertPool(), NewCertPool() + if !a.Equal(b) { + t.Error("two empty pools not equal") + } + + tc := &Certificate{Raw: []byte{1, 2, 3}, RawSubject: []byte{2}} + a.AddCert(tc) + if a.Equal(b) { + t.Error("empty pool equals non-empty pool") + } + + b.AddCert(tc) + if !a.Equal(b) { + t.Error("two non-empty pools not equal") + } + + otherTC := &Certificate{Raw: []byte{9, 8, 7}, RawSubject: []byte{8}} + a.AddCert(otherTC) + if a.Equal(b) { + t.Error("non-equal pools equal") + } + + systemA, err := SystemCertPool() + if err != nil { + t.Fatalf("unable to load system cert pool: %s", err) + } + systemB, err := SystemCertPool() + if err != nil { + t.Fatalf("unable to load system cert pool: %s", err) + } + if !systemA.Equal(systemB) { + t.Error("two empty system pools not equal") + } + + systemA.AddCert(tc) + if systemA.Equal(systemB) { + t.Error("empty system pool equals non-empty system pool") + } + + systemB.AddCert(tc) + if !systemA.Equal(systemB) { + t.Error("two non-empty system pools not equal") + } + + systemA.AddCert(otherTC) + if systemA.Equal(systemB) { + t.Error("non-equal system pools equal") + } +}