crypto/internal/hpke: add basic implementation

Only implements the sender role, since that's all we need for
client-side ECH for now.

Change-Id: Ia7cba1bc3bad8e8dc801d98d5ea859738b1f2790
Reviewed-on: https://go-review.googlesource.com/c/go/+/585436
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Roland Shoemaker 2024-05-14 11:16:56 -07:00 committed by Gopher Robot
parent ca1d2ead5d
commit 27c302d5d5
4 changed files with 429 additions and 0 deletions

View File

@ -0,0 +1,259 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package hpke
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/rand"
"encoding/binary"
"errors"
"math/bits"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
)
// testingOnlyGenerateKey is only used during testing, to provide
// a fixed test key to use when checking the RFC 9180 vectors.
var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
type hkdfKDF struct {
hash crypto.Hash
}
func (kdf *hkdfKDF) LabeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) []byte {
labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey))
labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
labeledIKM = append(labeledIKM, suiteID...)
labeledIKM = append(labeledIKM, label...)
labeledIKM = append(labeledIKM, inputKey...)
return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
}
func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte {
labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
labeledInfo = binary.BigEndian.AppendUint16(labeledInfo, length)
labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
labeledInfo = append(labeledInfo, suiteID...)
labeledInfo = append(labeledInfo, label...)
labeledInfo = append(labeledInfo, info...)
out := make([]byte, length)
n, err := hkdf.Expand(kdf.hash.New, randomKey, labeledInfo).Read(out)
if err != nil || n != int(length) {
panic("hpke: LabeledExpand failed unexpectedly")
}
return out
}
// dhKEM implements the KEM specified in RFC 9180, Section 4.1.
type dhKEM struct {
dh ecdh.Curve
kdf hkdfKDF
suiteID []byte
nSecret uint16
}
var SupportedKEMs = map[uint16]struct {
curve ecdh.Curve
hash crypto.Hash
nSecret uint16
}{
// RFC 9180 Section 7.1
0x0020: {ecdh.X25519(), crypto.SHA256, 32},
}
func newDHKem(kemID uint16) (*dhKEM, error) {
suite, ok := SupportedKEMs[kemID]
if !ok {
return nil, errors.New("unsupported suite ID")
}
return &dhKEM{
dh: suite.curve,
kdf: hkdfKDF{suite.hash},
suiteID: binary.BigEndian.AppendUint16([]byte("KEM"), kemID),
nSecret: suite.nSecret,
}, nil
}
func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte {
eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
}
func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
var privEph *ecdh.PrivateKey
if testingOnlyGenerateKey != nil {
privEph, err = testingOnlyGenerateKey()
} else {
privEph, err = dh.dh.GenerateKey(rand.Reader)
}
if err != nil {
return nil, nil, err
}
dhVal, err := privEph.ECDH(pubRecipient)
if err != nil {
return nil, nil, err
}
encPubEph := privEph.PublicKey().Bytes()
encPubRecip := pubRecipient.Bytes()
kemContext := append(encPubEph, encPubRecip...)
return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil
}
type Sender struct {
aead cipher.AEAD
kem *dhKEM
sharedSecret []byte
suiteID []byte
key []byte
baseNonce []byte
exporterSecret []byte
seqNum uint128
}
var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewGCM(block)
}
var SupportedAEADs = map[uint16]struct {
keySize int
nonceSize int
aead func([]byte) (cipher.AEAD, error)
}{
// RFC 9180, Section 7.3
0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
}
var SupportedKDFs = map[uint16]func() *hkdfKDF{
// RFC 9180, Section 7.2
0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
}
func SetupSender(kemID, kdfID, aeadID uint16, pub crypto.PublicKey, info []byte) ([]byte, *Sender, error) {
suiteID := SuiteID(kemID, kdfID, aeadID)
kem, err := newDHKem(kemID)
if err != nil {
return nil, nil, err
}
pubRecipient, ok := pub.(*ecdh.PublicKey)
if !ok {
return nil, nil, errors.New("incorrect public key type")
}
sharedSecret, encapsulatedKey, err := kem.Encap(pubRecipient)
if err != nil {
return nil, nil, err
}
kdfInit, ok := SupportedKDFs[kdfID]
if !ok {
return nil, nil, errors.New("unsupported KDF id")
}
kdf := kdfInit()
aeadInfo, ok := SupportedAEADs[aeadID]
if !ok {
return nil, nil, errors.New("unsupported AEAD id")
}
pskIDHash := kdf.LabeledExtract(suiteID, nil, "psk_id_hash", nil)
infoHash := kdf.LabeledExtract(suiteID, nil, "info_hash", info)
ksContext := append([]byte{0}, pskIDHash...)
ksContext = append(ksContext, infoHash...)
secret := kdf.LabeledExtract(suiteID, sharedSecret, "secret", nil)
key := kdf.LabeledExpand(suiteID, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
baseNonce := kdf.LabeledExpand(suiteID, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
exporterSecret := kdf.LabeledExpand(suiteID, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
aead, err := aeadInfo.aead(key)
if err != nil {
return nil, nil, err
}
return encapsulatedKey, &Sender{
kem: kem,
aead: aead,
sharedSecret: sharedSecret,
suiteID: suiteID,
key: key,
baseNonce: baseNonce,
exporterSecret: exporterSecret,
}, nil
}
func (s *Sender) nextNonce() []byte {
nonce := s.seqNum.bytes()[16-s.aead.NonceSize():]
for i := range s.baseNonce {
nonce[i] ^= s.baseNonce[i]
}
// Message limit is, according to the RFC, 2^95+1, which
// is somewhat confusing, but we do as we're told.
if s.seqNum.bitLen() >= (s.aead.NonceSize()*8)-1 {
panic("message limit reached")
}
s.seqNum = s.seqNum.addOne()
return nonce
}
func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
return ciphertext, nil
}
func SuiteID(kemID, kdfID, aeadID uint16) []byte {
suiteID := make([]byte, 0, 4+2+2+2)
suiteID = append(suiteID, []byte("HPKE")...)
suiteID = binary.BigEndian.AppendUint16(suiteID, kemID)
suiteID = binary.BigEndian.AppendUint16(suiteID, kdfID)
suiteID = binary.BigEndian.AppendUint16(suiteID, aeadID)
return suiteID
}
func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
kemInfo, ok := SupportedKEMs[kemID]
if !ok {
return nil, errors.New("unsupported KEM id")
}
return kemInfo.curve.NewPublicKey(bytes)
}
type uint128 struct {
hi, lo uint64
}
func (u uint128) addOne() uint128 {
lo, carry := bits.Add64(u.lo, 1, 0)
return uint128{u.hi + carry, lo}
}
func (u uint128) bitLen() int {
return bits.Len64(u.hi) + bits.Len64(u.lo)
}
func (u uint128) bytes() []byte {
b := make([]byte, 16)
binary.BigEndian.PutUint64(b[0:], u.hi)
binary.BigEndian.PutUint64(b[8:], u.lo)
return b
}

View File

@ -0,0 +1,168 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package hpke
import (
"bytes"
"encoding/hex"
"encoding/json"
"os"
"strconv"
"strings"
"testing"
"crypto/ecdh"
_ "crypto/sha256"
_ "crypto/sha512"
)
func mustDecodeHex(t *testing.T, in string) []byte {
b, err := hex.DecodeString(in)
if err != nil {
t.Fatal(err)
}
return b
}
func parseVectorSetup(vector string) map[string]string {
vals := map[string]string{}
for _, l := range strings.Split(vector, "\n") {
fields := strings.Split(l, ": ")
vals[fields[0]] = fields[1]
}
return vals
}
func parseVectorEncryptions(vector string) []map[string]string {
vals := []map[string]string{}
for _, section := range strings.Split(vector, "\n\n") {
e := map[string]string{}
for _, l := range strings.Split(section, "\n") {
fields := strings.Split(l, ": ")
e[fields[0]] = fields[1]
}
vals = append(vals, e)
}
return vals
}
func TestRFC9180Vectors(t *testing.T) {
vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
if err != nil {
t.Fatal(err)
}
var vectors []struct {
Name string
Setup string
Encryptions string
}
if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
t.Fatal(err)
}
for _, vector := range vectors {
t.Run(vector.Name, func(t *testing.T) {
setup := parseVectorSetup(vector.Setup)
kemID, err := strconv.Atoi(setup["kem_id"])
if err != nil {
t.Fatal(err)
}
if _, ok := SupportedKEMs[uint16(kemID)]; !ok {
t.Skip("unsupported KEM")
}
kdfID, err := strconv.Atoi(setup["kdf_id"])
if err != nil {
t.Fatal(err)
}
if _, ok := SupportedKDFs[uint16(kdfID)]; !ok {
t.Skip("unsupported KDF")
}
aeadID, err := strconv.Atoi(setup["aead_id"])
if err != nil {
t.Fatal(err)
}
if _, ok := SupportedAEADs[uint16(aeadID)]; !ok {
t.Skip("unsupported AEAD")
}
info := mustDecodeHex(t, setup["info"])
pubKeyBytes := mustDecodeHex(t, setup["pkRm"])
pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
if err != nil {
t.Fatal(err)
}
ephemeralPrivKey := mustDecodeHex(t, setup["skEm"])
testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) {
return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey)
}
t.Cleanup(func() { testingOnlyGenerateKey = nil })
encap, context, err := SetupSender(
uint16(kemID),
uint16(kdfID),
uint16(aeadID),
pub,
info,
)
if err != nil {
t.Fatal(err)
}
expectedEncap := mustDecodeHex(t, setup["enc"])
if !bytes.Equal(encap, expectedEncap) {
t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
}
expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
if !bytes.Equal(context.sharedSecret, expectedSharedSecret) {
t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret)
}
expectedKey := mustDecodeHex(t, setup["key"])
if !bytes.Equal(context.key, expectedKey) {
t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey)
}
expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
if !bytes.Equal(context.baseNonce, expectedBaseNonce) {
t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce)
}
expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
if !bytes.Equal(context.exporterSecret, expectedExporterSecret) {
t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret)
}
for _, enc := range parseVectorEncryptions(vector.Encryptions) {
t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
seqNum, err := strconv.Atoi(enc["sequence number"])
if err != nil {
t.Fatal(err)
}
context.seqNum = uint128{lo: uint64(seqNum)}
expectedNonce := mustDecodeHex(t, enc["nonce"])
// We can't call nextNonce, because it increments the sequence number,
// so just compute it directly.
computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():]
for i := range context.baseNonce {
computedNonce[i] ^= context.baseNonce[i]
}
if !bytes.Equal(computedNonce, expectedNonce) {
t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
}
expectedCiphertext := mustDecodeHex(t, enc["ct"])
ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(ciphertext, expectedCiphertext) {
t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
}
})
}
})
}
}

File diff suppressed because one or more lines are too long

View File

@ -511,6 +511,7 @@ var depsRules = `
< golang.org/x/crypto/internal/poly1305
< golang.org/x/crypto/chacha20poly1305
< golang.org/x/crypto/hkdf
< crypto/internal/hpke
< crypto/x509/internal/macos
< crypto/x509/pkix;