diff --git a/device/device.go b/device/device.go index 6854ed85a..6dbca4a67 100644 --- a/device/device.go +++ b/device/device.go @@ -49,8 +49,12 @@ type Device struct { staticIdentity struct { sync.RWMutex - privateKey NoisePrivateKey - publicKey NoisePublicKey + privateKey NoisePrivateKey + publicKey NoisePublicKey + mlkemPrivateKey MLKEMPrivateKey + mlkemPublicKey MLKEMPublicKey + mldsaPrivateKey MLDSAPrivateKey + mldsaPublicKey MLDSAPublicKey } peers struct { diff --git a/device/device_test.go b/device/device_test.go index 0091e2052..f6041a753 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -20,6 +20,10 @@ import ( "testing" "time" + "github.com/cloudflare/circl/kem/kyber/kyber1024" + "github.com/cloudflare/circl/sign/dilithium/mode5" + + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/tun" @@ -189,6 +193,10 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { // The device is ready. Close it when the test completes. tb.Cleanup(p.dev.Close) } + + installMLKEMKeys(tb, &pair) + installMLDSAKeys(tb, &pair) + return } @@ -405,7 +413,6 @@ func goroutineLeakCheck(t *testing.T) { if t.Failed() { return } - // Give goroutines time to exit, if they need it. for i := 0; i < 10000; i++ { if runtime.NumGoroutine() <= startGoroutines { return @@ -474,3 +481,179 @@ func TestBatchSize(t *testing.T) { t.Errorf("expected batch size %d, got %d", want, got) } } + +func installMLKEMKeys(t testing.TB, pair *testPair) { + t.Helper() + scheme := kyber1024.Scheme() + pk0, sk0, err := scheme.GenerateKeyPair() + if err != nil { t.Fatal(err) } + pk0b, _ := pk0.MarshalBinary() + sk0b, _ := sk0.MarshalBinary() + pk1, sk1, err := scheme.GenerateKeyPair() + if err != nil { t.Fatal(err) } + pk1b, _ := pk1.MarshalBinary() + sk1b, _ := sk1.MarshalBinary() + + if err := pair[0].dev.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(sk0b))); err != nil { t.Fatal(err) } + if err := pair[1].dev.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(sk1b))); err != nil { t.Fatal(err) } + var pub0, pub1 NoisePublicKey + for k := range pair[0].dev.peers.keyMap { pub0 = k; break } + for k := range pair[1].dev.peers.keyMap { pub1 = k; break } + cfgPeer0 := uapiCfg( + "public_key", hex.EncodeToString(pub0[:]), + "mlkem_public_key", hex.EncodeToString(pk1b), + ) + cfgPeer1 := uapiCfg( + "public_key", hex.EncodeToString(pub1[:]), + "mlkem_public_key", hex.EncodeToString(pk0b), + ) + if err := pair[0].dev.IpcSet(cfgPeer0); err != nil { t.Fatal(err) } + if err := pair[1].dev.IpcSet(cfgPeer1); err != nil { t.Fatal(err) } +} + +func installMLDSAKeys(t testing.TB, pair *testPair) { + t.Helper() + + dil := mode5.Scheme() + pk0, sk0, err := dil.GenerateKey() + if err != nil { t.Fatal(err) } + pk1, sk1, err := dil.GenerateKey() + if err != nil { t.Fatal(err) } + + pk0b, _ := pk0.MarshalBinary() + sk0b, _ := sk0.MarshalBinary() + pk1b, _ := pk1.MarshalBinary() + sk1b, _ := sk1.MarshalBinary() + + copy(pair[0].dev.staticIdentity.mldsaPrivateKey[:], sk0b) + copy(pair[1].dev.staticIdentity.mldsaPrivateKey[:], sk1b) + + var peer0, peer1 *Peer + for _, p := range pair[0].dev.peers.keyMap { peer0 = p; break } + for _, p := range pair[1].dev.peers.keyMap { peer1 = p; break } + if peer0 == nil || peer1 == nil { + t.Fatal("não foi possível localizar peers para configurar MLDSA") + } + + copy(peer0.handshake.remoteMLDSAStatic[:], pk1b) + copy(peer1.handshake.remoteMLDSAStatic[:], pk0b) +} + +func TestMLKEMKeyGeneration(t *testing.T) { + pub, priv, err := GenerateQuantumKeyPair() + if err != nil { t.Fatal(err) } + + scheme := kyber1024.Scheme() + if len(pub) != scheme.PublicKeySize() { + t.Fatalf("pub size mismatch: got %d, want %d", len(pub), scheme.PublicKeySize()) + } + if len(priv) != scheme.PrivateKeySize() { + t.Fatalf("priv size mismatch: got %d, want %d", len(priv), scheme.PrivateKeySize()) + } + + if _, err := scheme.UnmarshalBinaryPublicKey(pub); err != nil { t.Fatal(err) } + if _, err := scheme.UnmarshalBinaryPrivateKey(priv); err != nil { t.Fatal(err) } +} + + +func TestMLKEMEncapDecap(t *testing.T) { + scheme := kyber1024.Scheme() + pk, sk, err := scheme.GenerateKeyPair() + if err != nil { t.Fatal(err) } + + ct, ssEnc, err := scheme.Encapsulate(pk) + if err != nil { t.Fatal(err) } + ssDec, err := scheme.Decapsulate(sk, ct) + if err != nil { t.Fatal(err) } + + if !bytes.Equal(ssEnc, ssDec) { + t.Fatal("ML-KEM shared secrets differ") + } +} + +func TestNoiseHandshakeWithMLKEM(t *testing.T) { + skA, _ := newPrivateKey() + skB, _ := newPrivateKey() + + tunA := tuntest.NewChannelTUN() + tunB := tuntest.NewChannelTUN() + + devA := NewDevice(tunA.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + devB := NewDevice(tunB.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + + defer devA.Close() + defer devB.Close() + + if err := devA.SetPrivateKey(skA); err != nil { t.Fatal(err) } + if err := devB.SetPrivateKey(skB); err != nil { t.Fatal(err) } + + peerB, err := devA.NewPeer(skB.publicKey()) + if err != nil { t.Fatal(err) } + peerA, err := devB.NewPeer(skA.publicKey()) + if err != nil { t.Fatal(err) } + + scheme := kyber1024.Scheme() + pkA, skAkem, _ := scheme.GenerateKeyPair() + pkB, skBkem, _ := scheme.GenerateKeyPair() + pkAb, _ := pkA.MarshalBinary() + pkBb, _ := pkB.MarshalBinary() + skAb, _ := skAkem.MarshalBinary() + skBb, _ := skBkem.MarshalBinary() + + if err := devA.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skAb))); err != nil { t.Fatal(err) } + if err := devB.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skBb))); err != nil { t.Fatal(err) } + + if err := devA.IpcSet(uapiCfg("public_key", hex.EncodeToString(peerB.handshake.remoteStatic[:]), + "mlkem_public_key", hex.EncodeToString(pkBb))); err != nil { t.Fatal(err) } + if err := devB.IpcSet(uapiCfg("public_key", hex.EncodeToString(peerA.handshake.remoteStatic[:]), + "mlkem_public_key", hex.EncodeToString(pkAb))); err != nil { t.Fatal(err) } + + dil := mode5.Scheme() + pkAS, skAS, _ := dil.GenerateKey() + pkBS, skBS, _ := dil.GenerateKey() + pkASb, _ := pkAS.MarshalBinary() + pkBSb, _ := pkBS.MarshalBinary() + skASb, _ := skAS.MarshalBinary() + skBSb, _ := skBS.MarshalBinary() + + copy(devA.staticIdentity.mldsaPrivateKey[:], skASb) + copy(devB.staticIdentity.mldsaPrivateKey[:], skBSb) + + copy(peerB.handshake.remoteMLDSAStatic[:], pkBSb) + copy(peerA.handshake.remoteMLDSAStatic[:], pkASb) + + peerA.Start() + peerB.Start() + + msg1, err := devA.CreateMessageInitiation(peerB) + if err != nil { t.Fatal(err) } + if p := devB.ConsumeMessageInitiation(msg1); p == nil { + t.Fatal("handshake failed at initiation (ML-KEM)") + } + + msg2, err := devB.CreateMessageResponse(peerA) + if err != nil { t.Fatal(err) } + if p := devA.ConsumeMessageResponse(msg2); p == nil { + t.Fatal("handshake failed at response (ML-KEM)") + } + + if err := peerA.BeginSymmetricSession(); err != nil { t.Fatal(err) } + if err := peerB.BeginSymmetricSession(); err != nil { t.Fatal(err) } + + keyA := peerA.keypairs.next.Load() + keyB := peerB.keypairs.current + + msg := []byte("pqc wireguard ok") + var nonce [12]byte + + out := keyA.send.Seal(nil, nonce[:], msg, nil) + plain, err := keyB.receive.Open(nil, nonce[:], out, nil) + if err != nil { t.Fatal(err) } + if !bytes.Equal(plain, msg) { t.Fatal("A->B decrypt mismatch") } + + out = keyB.send.Seal(nil, nonce[:], msg, nil) + plain, err = keyA.receive.Open(nil, nonce[:], out, nil) + if err != nil { t.Fatal(err) } + if !bytes.Equal(plain, msg) { t.Fatal("B->A decrypt mismatch") } + +} diff --git a/device/mldsa_test.go b/device/mldsa_test.go new file mode 100644 index 000000000..89e47d43d --- /dev/null +++ b/device/mldsa_test.go @@ -0,0 +1,169 @@ +package device + +import ( + "bytes" + "testing" + "encoding/hex" + + "github.com/cloudflare/circl/sign/dilithium/mode5" + "github.com/cloudflare/circl/kem/kyber/kyber1024" +) + +func TestGenerateMLDSAKeyPair(t *testing.T) { + pub, priv, err := GenerateMLDSAKeyPair() + if err != nil { + t.Fatalf("erro gerando MLDSA: %v", err) + } + if len(pub) != MLDSAPublicKeySize || len(priv) != MLDSAPrivateKeySize { + t.Fatalf("tamanhos inválidos: pub=%d priv=%d", len(pub), len(priv)) + } + s := mode5.Scheme() + if _, err := s.UnmarshalBinaryPublicKey(pub); err != nil { + t.Fatalf("publicKey inválida: %v", err) + } + if _, err := s.UnmarshalBinaryPrivateKey(priv); err != nil { + t.Fatalf("privateKey inválida: %v", err) + } +} + +func TestMLDSASignVerify(t *testing.T) { + s := mode5.Scheme() + pk, sk, err := s.GenerateKey() + if err != nil { + t.Fatalf("erro gerando par mldsa: %v", err) + } + msg := []byte("wireguard + mldsa test") + sig := s.Sign(sk, msg, nil) + + if !s.Verify(pk, msg, sig, nil) { + t.Fatalf("assinatura MLDSA não verificou") + } + if s.Verify(pk, append(msg, 0x01), sig, nil) { + t.Fatalf("assinatura deveria falhar em msg alterada") + } +} + +func mustCopy(dst []byte, src []byte) { + if len(dst) != len(src) { panic("tam inválido") } + copy(dst, src) +} + +func TestHybridHandshakeWithMLDSASignature(t *testing.T) { + dev1 := randDevice(t) + dev2 := randDevice(t) + defer dev1.Close() + defer dev2.Close() + + kyb := kyber1024.Scheme() + pkK1, skK1, _ := kyb.GenerateKeyPair() + pkK2, skK2, _ := kyb.GenerateKeyPair() + pubK1, _ := pkK1.MarshalBinary() + privK1, _ := skK1.MarshalBinary() + pubK2, _ := pkK2.MarshalBinary() + privK2, _ := skK2.MarshalBinary() + + mldsa := mode5.Scheme() + pkS1, skS1, _ := mldsa.GenerateKey() + pkS2, skS2, _ := mldsa.GenerateKey() + pubS1, _ := pkS1.MarshalBinary() + privS1, _ := skS1.MarshalBinary() + pubS2, _ := pkS2.MarshalBinary() + privS2, _ := skS2.MarshalBinary() + + mustCopy(dev1.staticIdentity.mlkemPrivateKey[:], privK1) + mustCopy(dev2.staticIdentity.mlkemPrivateKey[:], privK2) + mustCopy(dev1.staticIdentity.mldsaPrivateKey[:], privS1) + mustCopy(dev2.staticIdentity.mldsaPrivateKey[:], privS2) + + peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) + if err != nil { t.Fatal(err) } + peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) + if err != nil { t.Fatal(err) } + + mustCopy(peer1.handshake.remoteMLKEMStatic[:], pubK1) + mustCopy(peer2.handshake.remoteMLKEMStatic[:], pubK2) + mustCopy(peer1.handshake.remoteMLDSAStatic[:], pubS1) + mustCopy(peer2.handshake.remoteMLDSAStatic[:], pubS2) + + peer1.Start() + peer2.Start() + + init, err := dev1.CreateMessageInitiation(peer2) + if err != nil { + t.Fatalf("CreateMessageInitiation falhou: %v", err) + } + if p := dev2.ConsumeMessageInitiation(init); p == nil { + t.Fatalf("ConsumeMessageInitiation falhou (assinatura/MLKEM?)") + } + + resp, err := dev2.CreateMessageResponse(peer1) + if err != nil { + t.Fatalf("CreateMessageResponse falhou: %v", err) + } + if p := dev1.ConsumeMessageResponse(resp); p == nil { + t.Fatalf("ConsumeMessageResponse falhou") + } + + if err := peer1.BeginSymmetricSession(); err != nil { + t.Fatalf("peer1.BeginSymmetricSession: %v", err) + } + if err := peer2.BeginSymmetricSession(); err != nil { + t.Fatalf("peer2.BeginSymmetricSession: %v", err) + } + + key1 := peer1.keypairs.next.Load() + key2 := peer2.keypairs.current + plain := []byte("ok mldsa+mlkem+noise") + var nonce [12]byte + c := key1.send.Seal(nil, nonce[:], plain, nil) + out, err := key2.receive.Open(nil, nonce[:], c, nil) + if err != nil || !bytes.Equal(out, plain) { + t.Fatalf("falha cifrar/decifrar: %v", err) + } +} + +func TestHybridHandshake_MLDSAInvalidSignature(t *testing.T) { + dev1 := randDevice(t) + dev2 := randDevice(t) + defer dev1.Close(); defer dev2.Close() + + kyb := kyber1024.Scheme() + pkK1, skK1, _ := kyb.GenerateKeyPair() + pkK2, skK2, _ := kyb.GenerateKeyPair() + pubK1, _ := pkK1.MarshalBinary() + privK1, _ := skK1.MarshalBinary() + pubK2, _ := pkK2.MarshalBinary() + privK2, _ := skK2.MarshalBinary() + mustCopy(dev1.staticIdentity.mlkemPrivateKey[:], privK1) + mustCopy(dev2.staticIdentity.mlkemPrivateKey[:], privK2) + + s := mode5.Scheme() + pkGood, skGood, _ := s.GenerateKey() + pkWrong, _, _ := s.GenerateKey() + privGood, _ := skGood.MarshalBinary() + pubGood, _ := pkGood.MarshalBinary() + pubWrong, _ := pkWrong.MarshalBinary() + mustCopy(dev1.staticIdentity.mldsaPrivateKey[:], privGood) + + peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) + peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) + peer1.Start(); peer2.Start() + mustCopy(peer1.handshake.remoteMLKEMStatic[:], pubK1) + mustCopy(peer2.handshake.remoteMLKEMStatic[:], pubK2) + + mustCopy(peer1.handshake.remoteMLDSAStatic[:], pubWrong) + + init, err := dev1.CreateMessageInitiation(peer2) + if err != nil { + t.Fatalf("CreateMessageInitiation falhou: %v", err) + } + if p := dev2.ConsumeMessageInitiation(init); p != nil { + t.Fatalf("assinatura inválida deveria falhar") + } + + mustCopy(peer1.handshake.remoteMLDSAStatic[:], pubGood) + if p := dev2.ConsumeMessageInitiation(init); p == nil { + t.Fatalf("deveria aceitar com a pública correta") + } + _ = hex.EncodeToString +} diff --git a/device/mlkem_bench_test.go b/device/mlkem_bench_test.go new file mode 100644 index 000000000..465846221 --- /dev/null +++ b/device/mlkem_bench_test.go @@ -0,0 +1,454 @@ +package device + +import ( + "bytes" + "encoding/hex" + "testing" + "time" + + "golang.zx2c4.com/wireguard/tai64n" + + "github.com/cloudflare/circl/kem/kyber/kyber1024" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/tun/tuntest" +) + +func BenchmarkKyberEncapsulate(b *testing.B) { + scheme := kyber1024.Scheme() + pk, _, err := scheme.GenerateKeyPair() + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, err := scheme.Encapsulate(pk) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkKyberDecapsulate(b *testing.B) { + scheme := kyber1024.Scheme() + pk, sk, err := scheme.GenerateKeyPair() + if err != nil { + b.Fatal(err) + } + ct, _, err := scheme.Encapsulate(pk) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := scheme.Decapsulate(sk, ct) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkHandshakeWithMLKEM(b *testing.B) { + skA, _ := newPrivateKey() + skB, _ := newPrivateKey() + tunA := tuntest.NewChannelTUN() + tunB := tuntest.NewChannelTUN() + devA := NewDevice(tunA.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + devB := NewDevice(tunB.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + defer devA.Close() + defer devB.Close() + + if err := devA.SetPrivateKey(skA); err != nil { + b.Fatal(err) + } + if err := devB.SetPrivateKey(skB); err != nil { + b.Fatal(err) + } + + peerB, err := devA.NewPeer(skB.publicKey()) + if err != nil { + b.Fatal(err) + } + peerA, err := devB.NewPeer(skA.publicKey()) + if err != nil { + b.Fatal(err) + } + + scheme := kyber1024.Scheme() + pkA, skAkem, _ := scheme.GenerateKeyPair() + pkB, skBkem, _ := scheme.GenerateKeyPair() + pkAb, _ := pkA.MarshalBinary() + pkBb, _ := pkB.MarshalBinary() + skAb, _ := skAkem.MarshalBinary() + skBb, _ := skBkem.MarshalBinary() + + if err := devA.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skAb))); err != nil { + b.Fatal(err) + } + if err := devB.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skBb))); err != nil { + b.Fatal(err) + } + if err := devA.IpcSet(uapiCfg("public_key", hex.EncodeToString(peerB.handshake.remoteStatic[:]), "mlkem_public_key", hex.EncodeToString(pkBb))); err != nil { + b.Fatal(err) + } + if err := devB.IpcSet(uapiCfg("public_key", hex.EncodeToString(peerA.handshake.remoteStatic[:]), "mlkem_public_key", hex.EncodeToString(pkAb))); err != nil { + b.Fatal(err) + } + + mldsaPubA, mldsaPrivA, err := GenerateMLDSAKeyPair() + if err != nil { + b.Fatal(err) + } + mldsaPubB, mldsaPrivB, err := GenerateMLDSAKeyPair() + if err != nil { + b.Fatal(err) + } + + devA.staticIdentity.Lock() + copy(devA.staticIdentity.mldsaPrivateKey[:], mldsaPrivA) + copy(devA.staticIdentity.mldsaPublicKey[:], mldsaPubA) + devA.staticIdentity.Unlock() + + devB.staticIdentity.Lock() + copy(devB.staticIdentity.mldsaPrivateKey[:], mldsaPrivB) + copy(devB.staticIdentity.mldsaPublicKey[:], mldsaPubB) + devB.staticIdentity.Unlock() + + peerB.handshake.mutex.Lock() + copy(peerB.handshake.remoteMLDSAStatic[:], mldsaPubB) + peerB.handshake.mutex.Unlock() + + peerA.handshake.mutex.Lock() + copy(peerA.handshake.remoteMLDSAStatic[:], mldsaPubA) + peerA.handshake.mutex.Unlock() + + relaxFlood := func() { + peerA.handshake.mutex.Lock() + peerA.handshake.lastInitiationConsumption = time.Now().Add(-10 * time.Second) + peerA.handshake.lastTimestamp = tai64n.Timestamp{} + peerA.handshake.mutex.Unlock() + } + relaxFlood() + msg1, err := devA.CreateMessageInitiation(peerB) + if err != nil { + b.Fatal(err) + } + if p := devB.ConsumeMessageInitiation(msg1); p == nil { + b.Fatal("initiation fail (warmup)") + } + msg2, err := devB.CreateMessageResponse(peerA) + if err != nil { + b.Fatal(err) + } + if p := devA.ConsumeMessageResponse(msg2); p == nil { + b.Fatal("response fail (warmup)") + } + if err := peerA.BeginSymmetricSession(); err != nil { + b.Fatal(err) + } + if err := peerB.BeginSymmetricSession(); err != nil { + b.Fatal(err) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + relaxFlood() + msg1, err := devA.CreateMessageInitiation(peerB) + if err != nil { + b.Fatal(err) + } + if p := devB.ConsumeMessageInitiation(msg1); p == nil { + b.Fatal("initiation fail") + } + msg2, err := devB.CreateMessageResponse(peerA) + if err != nil { + b.Fatal(err) + } + if p := devA.ConsumeMessageResponse(msg2); p == nil { + b.Fatal("response fail") + } + if err := peerA.BeginSymmetricSession(); err != nil { + b.Fatal(err) + } + if err := peerB.BeginSymmetricSession(); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkHandshakeHybrid(b *testing.B) { + skA, _ := newPrivateKey() + skB, _ := newPrivateKey() + tunA := tuntest.NewChannelTUN() + tunB := tuntest.NewChannelTUN() + devA := NewDevice(tunA.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + devB := NewDevice(tunB.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + defer devA.Close() + defer devB.Close() + + if err := devA.IpcSet(uapiCfg("private_key", hex.EncodeToString(skA[:]))); err != nil { + b.Fatal(err) + } + if err := devB.IpcSet(uapiCfg("private_key", hex.EncodeToString(skB[:]))); err != nil { + b.Fatal(err) + } + + scheme := kyber1024.Scheme() + pkA, skAkem, _ := scheme.GenerateKeyPair() + pkB, skBkem, _ := scheme.GenerateKeyPair() + pkAb, _ := pkA.MarshalBinary() + pkBb, _ := pkB.MarshalBinary() + skAb, _ := skAkem.MarshalBinary() + skBb, _ := skBkem.MarshalBinary() + + if err := devA.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skAb))); err != nil { + b.Fatal(err) + } + if err := devB.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skBb))); err != nil { + b.Fatal(err) + } + + pkBNoise := skB.publicKey() + pkANoise := skA.publicKey() + if err := devA.IpcSet(uapiCfg("public_key", hex.EncodeToString(pkBNoise[:]), "mlkem_public_key", hex.EncodeToString(pkBb))); err != nil { + b.Fatal(err) + } + if err := devB.IpcSet(uapiCfg("public_key", hex.EncodeToString(pkANoise[:]), "mlkem_public_key", hex.EncodeToString(pkAb))); err != nil { + b.Fatal(err) + } + + peerB := devA.LookupPeer(pkBNoise) + peerA := devB.LookupPeer(pkANoise) + if peerA == nil || peerB == nil { + b.Fatal("peer lookup failed (check IpcSet order)") + } + + mldsaPubA, mldsaPrivA, err := GenerateMLDSAKeyPair() + if err != nil { + b.Fatal(err) + } + mldsaPubB, mldsaPrivB, err := GenerateMLDSAKeyPair() + if err != nil { + b.Fatal(err) + } + + devA.staticIdentity.Lock() + copy(devA.staticIdentity.mldsaPrivateKey[:], mldsaPrivA) + copy(devA.staticIdentity.mldsaPublicKey[:], mldsaPubA) + devA.staticIdentity.Unlock() + + devB.staticIdentity.Lock() + copy(devB.staticIdentity.mldsaPrivateKey[:], mldsaPrivB) + copy(devB.staticIdentity.mldsaPublicKey[:], mldsaPubB) + devB.staticIdentity.Unlock() + + peerB.handshake.mutex.Lock() + copy(peerB.handshake.remoteMLDSAStatic[:], mldsaPubB) + peerB.handshake.mutex.Unlock() + + peerA.handshake.mutex.Lock() + copy(peerA.handshake.remoteMLDSAStatic[:], mldsaPubA) + peerA.handshake.mutex.Unlock() + + relax := func() { + peerA.handshake.mutex.Lock() + peerA.handshake.lastInitiationConsumption = time.Now().Add(-10 * time.Second) + peerA.handshake.lastTimestamp = tai64n.Timestamp{} + peerA.handshake.mutex.Unlock() + } + relax() + msg1, _ := devA.CreateMessageInitiation(peerB) + devB.ConsumeMessageInitiation(msg1) + msg2, _ := devB.CreateMessageResponse(peerA) + devA.ConsumeMessageResponse(msg2) + peerA.BeginSymmetricSession() + peerB.BeginSymmetricSession() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + relax() + msg1, _ := devA.CreateMessageInitiation(peerB) + devB.ConsumeMessageInitiation(msg1) + msg2, _ := devB.CreateMessageResponse(peerA) + devA.ConsumeMessageResponse(msg2) + peerA.BeginSymmetricSession() + peerB.BeginSymmetricSession() + } +} + +func BenchmarkDataPlaneAEAD(b *testing.B) { + skA, _ := newPrivateKey() + skB, _ := newPrivateKey() + tunA := tuntest.NewChannelTUN() + tunB := tuntest.NewChannelTUN() + devA := NewDevice(tunA.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + devB := NewDevice(tunB.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + defer devA.Close() + defer devB.Close() + devA.SetPrivateKey(skA) + devB.SetPrivateKey(skB) + peerB, _ := devA.NewPeer(skB.publicKey()) + peerA, _ := devB.NewPeer(skA.publicKey()) + + scheme := kyber1024.Scheme() + pkA, skAkem, _ := scheme.GenerateKeyPair() + pkB, skBkem, _ := scheme.GenerateKeyPair() + pkAb, _ := pkA.MarshalBinary() + pkBb, _ := pkB.MarshalBinary() + skAb, _ := skAkem.MarshalBinary() + skBb, _ := skBkem.MarshalBinary() + devA.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skAb))) + devB.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skBb))) + devA.IpcSet(uapiCfg("public_key", hex.EncodeToString(peerB.handshake.remoteStatic[:]), "mlkem_public_key", hex.EncodeToString(pkBb))) + devB.IpcSet(uapiCfg("public_key", hex.EncodeToString(peerA.handshake.remoteStatic[:]), "mlkem_public_key", hex.EncodeToString(pkAb))) + mldsaPubA, mldsaPrivA, err := GenerateMLDSAKeyPair() + if err != nil { + b.Fatal(err) + } + mldsaPubB, mldsaPrivB, err := GenerateMLDSAKeyPair() + if err != nil { + b.Fatal(err) + } + + devA.staticIdentity.Lock() + copy(devA.staticIdentity.mldsaPrivateKey[:], mldsaPrivA) + copy(devA.staticIdentity.mldsaPublicKey[:], mldsaPubA) + devA.staticIdentity.Unlock() + + devB.staticIdentity.Lock() + copy(devB.staticIdentity.mldsaPrivateKey[:], mldsaPrivB) + copy(devB.staticIdentity.mldsaPublicKey[:], mldsaPubB) + devB.staticIdentity.Unlock() + + peerB.handshake.mutex.Lock() + copy(peerB.handshake.remoteMLDSAStatic[:], mldsaPubB) + peerB.handshake.mutex.Unlock() + + peerA.handshake.mutex.Lock() + copy(peerA.handshake.remoteMLDSAStatic[:], mldsaPubA) + peerA.handshake.mutex.Unlock() + msg1, _ := devA.CreateMessageInitiation(peerB) + devB.ConsumeMessageInitiation(msg1) + msg2, _ := devB.CreateMessageResponse(peerA) + devA.ConsumeMessageResponse(msg2) + peerA.BeginSymmetricSession() + peerB.BeginSymmetricSession() + + keyA := peerA.keypairs.next.Load() + keyB := peerB.keypairs.current + msg := bytes.Repeat([]byte{0x42}, 128) + var nonce [12]byte + + b.ReportAllocs() + b.SetBytes(int64(len(msg))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + out := keyA.send.Seal(nil, nonce[:], msg, nil) + _, err := keyB.receive.Open(nil, nonce[:], out, nil) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDataPlaneAEADHybrid(b *testing.B) { + skA, _ := newPrivateKey() + skB, _ := newPrivateKey() + tunA := tuntest.NewChannelTUN() + tunB := tuntest.NewChannelTUN() + devA := NewDevice(tunA.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + devB := NewDevice(tunB.TUN(), conn.NewDefaultBind(), NewLogger(LogLevelError, "")) + defer devA.Close() + defer devB.Close() + + if err := devA.IpcSet(uapiCfg("private_key", hex.EncodeToString(skA[:]))); err != nil { + b.Fatal(err) + } + if err := devB.IpcSet(uapiCfg("private_key", hex.EncodeToString(skB[:]))); err != nil { + b.Fatal(err) + } + + scheme := kyber1024.Scheme() + pkA, skAkem, _ := scheme.GenerateKeyPair() + pkB, skBkem, _ := scheme.GenerateKeyPair() + pkAb, _ := pkA.MarshalBinary() + pkBb, _ := pkB.MarshalBinary() + skAb, _ := skAkem.MarshalBinary() + skBb, _ := skBkem.MarshalBinary() + + if err := devA.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skAb))); err != nil { + b.Fatal(err) + } + if err := devB.IpcSet(uapiCfg("mlkem_private_key", hex.EncodeToString(skBb))); err != nil { + b.Fatal(err) + } + + pkBNoise := skB.publicKey() + pkANoise := skA.publicKey() + if err := devA.IpcSet(uapiCfg("public_key", hex.EncodeToString(pkBNoise[:]), "mlkem_public_key", hex.EncodeToString(pkBb))); err != nil { + b.Fatal(err) + } + if err := devB.IpcSet(uapiCfg("public_key", hex.EncodeToString(pkANoise[:]), "mlkem_public_key", hex.EncodeToString(pkAb))); err != nil { + b.Fatal(err) + } + + peerB := devA.LookupPeer(pkBNoise) + peerA := devB.LookupPeer(pkANoise) + if peerA == nil || peerB == nil { + b.Fatal("peer lookup failed (check IpcSet order)") + } + + mldsaPubA, mldsaPrivA, err := GenerateMLDSAKeyPair() + if err != nil { + b.Fatal(err) + } + mldsaPubB, mldsaPrivB, err := GenerateMLDSAKeyPair() + if err != nil { + b.Fatal(err) + } + + devA.staticIdentity.Lock() + copy(devA.staticIdentity.mldsaPrivateKey[:], mldsaPrivA) + copy(devA.staticIdentity.mldsaPublicKey[:], mldsaPubA) + devA.staticIdentity.Unlock() + + devB.staticIdentity.Lock() + copy(devB.staticIdentity.mldsaPrivateKey[:], mldsaPrivB) + copy(devB.staticIdentity.mldsaPublicKey[:], mldsaPubB) + devB.staticIdentity.Unlock() + + peerB.handshake.mutex.Lock() + copy(peerB.handshake.remoteMLDSAStatic[:], mldsaPubB) + peerB.handshake.mutex.Unlock() + + peerA.handshake.mutex.Lock() + copy(peerA.handshake.remoteMLDSAStatic[:], mldsaPubA) + peerA.handshake.mutex.Unlock() + + msg1, _ := devA.CreateMessageInitiation(peerB) + devB.ConsumeMessageInitiation(msg1) + msg2, _ := devB.CreateMessageResponse(peerA) + devA.ConsumeMessageResponse(msg2) + peerA.BeginSymmetricSession() + peerB.BeginSymmetricSession() + + keyA := peerA.keypairs.next.Load() + keyB := peerB.keypairs.current + msg := bytes.Repeat([]byte{0x42}, 128) + var nonce [12]byte + + b.ReportAllocs() + b.SetBytes(int64(len(msg))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + out := keyA.send.Seal(nil, nonce[:], msg, nil) + _, err := keyB.receive.Open(nil, nonce[:], out, nil) + if err != nil { + b.Fatal(err) + } + } +} + diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 5cf1702b6..16d381278 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -12,6 +12,8 @@ import ( "sync" "time" + "github.com/cloudflare/circl/kem/kyber/kyber1024" + "github.com/cloudflare/circl/sign/dilithium/mode5" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" @@ -61,13 +63,13 @@ const ( ) const ( - MessageInitiationSize = 148 // size of handshake initiation message - MessageResponseSize = 92 // size of response message - MessageCookieReplySize = 64 // size of cookie reply message - MessageTransportHeaderSize = 16 // size of data preceding content in transport message - MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport - MessageKeepaliveSize = MessageTransportSize // size of keepalive - MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message + MessageInitiationSize = 148 + (MLKEMCiphertextSize + poly1305.TagSize) + MLDSASignatureSize // size of handshake initiation message + MessageResponseSize = 92 // size of response message + MessageCookieReplySize = 64 // size of cookie reply message + MessageTransportHeaderSize = 16 // size of data preceding content in transport message + MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport + MessageKeepaliveSize = MessageTransportSize // size of keepalive + MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message ) const ( @@ -87,7 +89,9 @@ type MessageInitiation struct { Sender uint32 Ephemeral NoisePublicKey Static [NoisePublicKeySize + poly1305.TagSize]byte + MLKEM [MLKEMCiphertextSize + poly1305.TagSize]byte Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte + Signature MLDSASignature MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte } @@ -127,9 +131,12 @@ func (msg *MessageInitiation) unmarshal(b []byte) error { msg.Sender = binary.LittleEndian.Uint32(b[4:]) copy(msg.Ephemeral[:], b[8:]) copy(msg.Static[:], b[8+len(msg.Ephemeral):]) - copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):]) - copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):]) - copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):]) + + copy(msg.MLKEM[:], b[8+len(msg.Ephemeral)+len(msg.Static):]) + copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.MLKEM):]) + copy(msg.Signature[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.MLKEM)+len(msg.Timestamp):]) + copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.MLKEM)+len(msg.Timestamp)+len(msg.Signature):]) + copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.MLKEM)+len(msg.Timestamp)+len(msg.Signature)+len(msg.MAC1):]) return nil } @@ -143,9 +150,12 @@ func (msg *MessageInitiation) marshal(b []byte) error { binary.LittleEndian.PutUint32(b[4:], msg.Sender) copy(b[8:], msg.Ephemeral[:]) copy(b[8+len(msg.Ephemeral):], msg.Static[:]) - copy(b[8+len(msg.Ephemeral)+len(msg.Static):], msg.Timestamp[:]) - copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):], msg.MAC1[:]) - copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):], msg.MAC2[:]) + + copy(b[8+len(msg.Ephemeral)+len(msg.Static):], msg.MLKEM[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.MLKEM):], msg.Timestamp[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.MLKEM)+len(msg.Timestamp):], msg.Signature[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.MLKEM)+len(msg.Timestamp)+len(msg.Signature):], msg.MAC1[:]) + copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.MLKEM)+len(msg.Timestamp)+len(msg.Signature)+len(msg.MAC1):], msg.MAC2[:]) return nil } @@ -218,6 +228,8 @@ type Handshake struct { localIndex uint32 // used to clear hash-table remoteIndex uint32 // index for sending remoteStatic NoisePublicKey // long term key + remoteMLKEMStatic MLKEMPublicKey // long term remote ML-KEM static public key + remoteMLDSAStatic MLDSAPublicKey // long term remote ML-DSA static public key remoteEphemeral NoisePublicKey // ephemeral public key precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret lastTimestamp tai64n.Timestamp @@ -275,7 +287,6 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mutex.Lock() defer handshake.mutex.Unlock() - // create ephemeral key var err error handshake.hash = InitialHash handshake.chainKey = InitialChainKey @@ -294,37 +305,47 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) - // encrypt static key ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) if err != nil { return nil, err } + var key [chacha20poly1305.KeySize]byte - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - ss[:], - ) + var tempChainKey [blake2s.Size]byte + KDF2(&tempChainKey, &key, handshake.chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) handshake.mixHash(msg.Static[:]) - // encrypt timestamp + scheme := kyber1024.Scheme() + pk, err := scheme.UnmarshalBinaryPublicKey(handshake.remoteMLKEMStatic[:]) + if err != nil { + return nil, err + } + ciphertext, mlkemSecret, err := scheme.Encapsulate(pk) + if err != nil { + return nil, err + } + + KDF1(&key, handshake.chainKey[:], []byte("pqc-ciphertext-key")) + aead, _ = chacha20poly1305.New(key[:]) + aead.Seal(msg.MLKEM[:0], ZeroNonce[:], ciphertext, handshake.hash[:]) + handshake.mixHash(msg.MLKEM[:]) + + var combinedSecret [blake2s.Size]byte + var dummy [blake2s.Size]byte + KDF2(&combinedSecret, &dummy, ss[:], mlkemSecret) + KDF2(&handshake.chainKey, &key, handshake.chainKey[:], combinedSecret[:]) + if isZero(handshake.precomputedStaticStatic[:]) { return nil, errInvalidPublicKey } - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - handshake.precomputedStaticStatic[:], - ) + KDF2(&handshake.chainKey, &key, handshake.chainKey[:], handshake.precomputedStaticStatic[:]) timestamp := tai64n.Now() aead, _ = chacha20poly1305.New(key[:]) aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) - // assign index device.indexTable.Delete(handshake.localIndex) msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) if err != nil { @@ -334,6 +355,21 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(msg.Timestamp[:]) handshake.state = handshakeInitiationCreated + + signScheme := mode5.Scheme() + skSign, err := signScheme.UnmarshalBinaryPrivateKey(device.staticIdentity.mldsaPrivateKey[:]) + if err != nil { + return nil, err + } + + messageToSign := make([]byte, MessageInitiationSize) + if err := msg.marshal(messageToSign); err != nil { + return nil, err + } + + signature := signScheme.Sign(skSign, messageToSign[:MessageInitiationSize-blake2s.Size128*2-MLDSASignatureSize], nil) + copy(msg.Signature[:], signature) + return &msg, nil } @@ -361,7 +397,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { if err != nil { return nil } - KDF2(&chainKey, &key, chainKey[:], ss[:]) + + var tempChainKey [blake2s.Size]byte + KDF2(&tempChainKey, &key, chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) if err != nil { @@ -370,20 +408,62 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { mixHash(&hash, &hash, msg.Static[:]) // lookup peer - peer := device.LookupPeer(peerPK) if peer == nil || !peer.isRunning.Load() { return nil } + signScheme := mode5.Scheme() + pkSign, err := signScheme.UnmarshalBinaryPublicKey(peer.handshake.remoteMLDSAStatic[:]) + if err != nil { + return nil + } + + messageToCheck := make([]byte, MessageInitiationSize) + if err := msg.marshal(messageToCheck); err != nil { + return nil + } + + if !signScheme.Verify(pkSign, messageToCheck[:MessageInitiationSize-blake2s.Size128*2-MLDSASignatureSize], msg.Signature[:], nil) { + return nil + } + handshake := &peer.handshake - // verify identity + // decrypt KEM ciphertext + KDF1(&key, chainKey[:], []byte("pqc-ciphertext-key")) + aead, _ = chacha20poly1305.New(key[:]) + var ciphertext [MLKEMCiphertextSize]byte + _, err = aead.Open(ciphertext[:0], ZeroNonce[:], msg.MLKEM[:], hash[:]) + if err != nil { + return nil + } - var timestamp tai64n.Timestamp + mixHash(&hash, &hash, msg.MLKEM[:]) - handshake.mutex.RLock() + // post-quantum decapsulation + scheme := kyber1024.Scheme() + sk, err := scheme.UnmarshalBinaryPrivateKey(device.staticIdentity.mlkemPrivateKey[:]) + if err != nil { + return nil + } + + mlkemSecret, err := scheme.Decapsulate(sk, ciphertext[:]) + if err != nil { + return nil + } + // mix classic (ss) and post-quantum secret (mlkemSecret) into a single secret + var combinedSecret [blake2s.Size]byte + var dummy [blake2s.Size]byte + KDF2(&combinedSecret, &dummy, ss[:], mlkemSecret) + + // main chainKey is now updated with the combined secret + KDF2(&chainKey, &key, chainKey[:], combinedSecret[:]) + + // verify identity + var timestamp tai64n.Timestamp + handshake.mutex.RLock() if isZero(handshake.precomputedStaticStatic[:]) { handshake.mutex.RUnlock() return nil @@ -394,16 +474,17 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { chainKey[:], handshake.precomputedStaticStatic[:], ) + handshake.mutex.RUnlock() + aead, _ = chacha20poly1305.New(key[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) if err != nil { - handshake.mutex.RUnlock() return nil } mixHash(&hash, &hash, msg.Timestamp[:]) // protect against replay & flood - + handshake.mutex.RLock() replay := !timestamp.After(handshake.lastTimestamp) flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate handshake.mutex.RUnlock() @@ -417,9 +498,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { } // update handshake state - handshake.mutex.Lock() - handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender @@ -432,7 +511,6 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { handshake.lastInitiationConsumption = now } handshake.state = handshakeInitiationConsumed - handshake.mutex.Unlock() setZero(hash[:]) diff --git a/device/noise-types.go b/device/noise-types.go index 41c944e14..afa53e784 100644 --- a/device/noise-types.go +++ b/device/noise-types.go @@ -76,3 +76,26 @@ func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { func (key *NoisePresharedKey) FromHex(src string) error { return loadExactHex(key[:], src) } + +const ( + MLKEMPublicKeySize = 1568 + MLKEMPrivateKeySize = 3168 + MLKEMCiphertextSize = 1568 +) + +type ( + MLKEMPublicKey [MLKEMPublicKeySize]byte + MLKEMPrivateKey [MLKEMPrivateKeySize]byte +) + +const ( + MLDSAPublicKeySize = 2592 + MLDSAPrivateKeySize = 4864 + MLDSASignatureSize = 4595 +) + +type ( + MLDSAPublicKey [MLDSAPublicKeySize]byte + MLDSAPrivateKey [MLDSAPrivateKeySize]byte + MLDSASignature [MLDSASignatureSize]byte +) diff --git a/device/noise_test.go b/device/noise_test.go index f0928ac66..e74e74a29 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -1,8 +1,3 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. - */ - package device import ( @@ -71,6 +66,15 @@ func TestNoiseHandshake(t *testing.T) { if err != nil { t.Fatal(err) } + + pair := testPair{} + pair[0].dev = dev1 + pair[1].dev = dev2 + + installMLKEMKeys(t, &pair) + + installMLDSAKeys(t, &pair) + peer1.Start() peer2.Start() @@ -80,10 +84,6 @@ func TestNoiseHandshake(t *testing.T) { peer2.handshake.precomputedStaticStatic[:], ) - /* simulate handshake */ - - // initiation message - t.Log("exchange initiation message") msg1, err := dev1.CreateMessageInitiation(peer2) @@ -110,7 +110,6 @@ func TestNoiseHandshake(t *testing.T) { peer2.handshake.hash[:], ) - // response message t.Log("exchange response message") @@ -134,8 +133,6 @@ func TestNoiseHandshake(t *testing.T) { peer2.handshake.hash[:], ) - // key pairs - t.Log("deriving keys") err = peer1.BeginSymmetricSession() @@ -151,8 +148,6 @@ func TestNoiseHandshake(t *testing.T) { key1 := peer1.keypairs.next.Load() key2 := peer2.keypairs.current - // encrypting / decryption test - t.Log("test key pairs") func() { diff --git a/device/quantum-keys.go b/device/quantum-keys.go new file mode 100644 index 000000000..36575bada --- /dev/null +++ b/device/quantum-keys.go @@ -0,0 +1,27 @@ +package device + +import ( + "github.com/cloudflare/circl/kem/kyber/kyber1024" + "github.com/cloudflare/circl/sign/dilithium/mode5" +) + +func GenerateQuantumKeyPair() (pub []byte, priv []byte, err error) { + pk, sk, err := kyber1024.Scheme().GenerateKeyPair() + if err != nil { + return nil, nil, err + } + pub, _ = pk.MarshalBinary() + priv, _ = sk.MarshalBinary() + return pub, priv, nil +} + +func GenerateMLDSAKeyPair() (pub []byte, priv []byte, err error) { + pk, sk, err := mode5.Scheme().GenerateKey() + if err != nil { + return nil, nil, err + } + + pub, _ = pk.MarshalBinary() + priv, _ = sk.MarshalBinary() + return pub, priv, nil +} diff --git a/device/run_tests.sh b/device/run_tests.sh new file mode 100755 index 000000000..6b81ecd2c --- /dev/null +++ b/device/run_tests.sh @@ -0,0 +1,7 @@ +cd /home/arthur/Documentos/UNB/TCC/wireguard-go/device + +TESTS=$(grep -hEo '^func[[:space:]]+Test[[:alnum:]_]*' device_test.go noise_test.go mldsa_test.go \ + | awk '{print $2}' \ + | paste -sd'|' -) + +go test -v -run "^($TESTS)$" diff --git a/device/uapi.go b/device/uapi.go index cc69488b4..d3ebb9b68 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -171,7 +171,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { // Load/create the peer we are now configuring. err := device.handlePublicKeyLine(peer, value) if err != nil { - return err + return ipcErrorf(ipc.IpcErrorInvalid, "failed to load MLKEM public key: %w", err) } continue } @@ -240,6 +240,16 @@ func (device *Device) handleDeviceLine(key, value string) error { device.log.Verbosef("UAPI: Removing all peers") device.RemoveAllPeers() + case "mlkem_private_key": + var mlkemPrivateKey MLKEMPrivateKey + err := loadExactHex(mlkemPrivateKey[:], value) + if err != nil { + return err + } + device.staticIdentity.Lock() + device.staticIdentity.mlkemPrivateKey = mlkemPrivateKey + device.staticIdentity.Unlock() + default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) } @@ -397,6 +407,16 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) } + case "mlkem_public_key": + device.log.Verbosef("%v - UAPI: Updating mlkem_public_key", peer.Peer) + peer.handshake.mutex.Lock() + err := loadExactHex(peer.handshake.remoteMLKEMStatic[:], value) + if err != nil { + peer.handshake.mutex.Unlock() + return err + } + peer.handshake.mutex.Unlock() + default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) } diff --git a/go.mod b/go.mod index 2a80e0001..3b6b2678e 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module golang.zx2c4.com/wireguard go 1.23.1 require ( + github.com/cloudflare/circl v1.6.1 golang.org/x/crypto v0.37.0 golang.org/x/net v0.39.0 golang.org/x/sys v0.32.0 diff --git a/go.sum b/go.sum index 61875c160..a7bc79b1c 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= +github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=