diff --git a/src/crypto/internal/hpke/hpke.go b/src/crypto/internal/hpke/hpke.go index d451bff250..de601c4bfe 100644 --- a/src/crypto/internal/hpke/hpke.go +++ b/src/crypto/internal/hpke/hpke.go @@ -26,31 +26,23 @@ type hkdfKDF struct { hash crypto.Hash } -func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) []byte { +func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) ([]byte, error) { labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey)) labeledIKM = append(labeledIKM, []byte("HPKE-v1")...) labeledIKM = append(labeledIKM, sid...) labeledIKM = append(labeledIKM, label...) labeledIKM = append(labeledIKM, inputKey...) - prk, err := hkdf.Extract(kdf.hash.New, labeledIKM, salt) - if err != nil { - panic(err) - } - return prk + return hkdf.Extract(kdf.hash.New, labeledIKM, salt) } -func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte { +func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) { labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info)) labeledInfo = byteorder.BEAppendUint16(labeledInfo, length) labeledInfo = append(labeledInfo, []byte("HPKE-v1")...) labeledInfo = append(labeledInfo, suiteID...) labeledInfo = append(labeledInfo, label...) labeledInfo = append(labeledInfo, info...) - key, err := hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length)) - if err != nil { - panic(err) - } - return key + return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length)) } // dhKEM implements the KEM specified in RFC 9180, Section 4.1. @@ -88,8 +80,11 @@ func newDHKem(kemID uint16) (*dhKEM, error) { }, nil } -func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte { - eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey) +func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) ([]byte, error) { + eaePRK, err := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey) + if err != nil { + return nil, err + } return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret) } @@ -111,8 +106,11 @@ func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encap encPubRecip := pubRecipient.Bytes() kemContext := append(encPubEph, encPubRecip...) - - return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil + sharedSecret, err = dh.ExtractAndExpand(dhVal, kemContext) + if err != nil { + return nil, nil, err + } + return sharedSecret, encPubEph, nil } func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) { @@ -125,8 +123,7 @@ func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, return nil, err } kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...) - - return dh.ExtractAndExpand(dhVal, kemContext), nil + return dh.ExtractAndExpand(dhVal, kemContext) } type context struct { @@ -201,16 +198,33 @@ func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) ( return nil, errors.New("unsupported AEAD id") } - pskIDHash := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil) - infoHash := kdf.LabeledExtract(sid, nil, "info_hash", info) + pskIDHash, err := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil) + if err != nil { + return nil, err + } + infoHash, err := kdf.LabeledExtract(sid, nil, "info_hash", info) + if err != nil { + return nil, err + } ksContext := append([]byte{0}, pskIDHash...) ksContext = append(ksContext, infoHash...) - secret := kdf.LabeledExtract(sid, sharedSecret, "secret", nil) - - key := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) - baseNonce := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) - exporterSecret := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) + secret, err := kdf.LabeledExtract(sid, sharedSecret, "secret", nil) + if err != nil { + return nil, err + } + key, err := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) + if err != nil { + return nil, err + } + baseNonce, err := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) + if err != nil { + return nil, err + } + exporterSecret, err := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) + if err != nil { + return nil, err + } aead, err := aeadInfo.aead(key) if err != nil {