Skip to content

Commit e24dd2f

Browse files
committed
lnwire: let DNSAddress implement RecordProducer
In preparation for using this type as a TLV record, we let it implement the RecordProducer interface.
1 parent e9a4f22 commit e24dd2f

File tree

3 files changed

+204
-0
lines changed

3 files changed

+204
-0
lines changed

lnwire/dns_addr.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package lnwire
22

33
import (
4+
"bytes"
45
"errors"
56
"fmt"
7+
"io"
68
"net"
79
"strconv"
10+
11+
"github.com/lightningnetwork/lnd/tlv"
812
)
913

1014
var (
@@ -86,3 +90,70 @@ func ValidateDNSAddr(hostname string, port uint16) error {
8690

8791
return nil
8892
}
93+
94+
// Record returns a TLV record that can be used to encode/decode the DNSAddress.
95+
//
96+
// NOTE: this is part of the tlv.RecordProducer interface.
97+
func (d *DNSAddress) Record() tlv.Record {
98+
sizeFunc := func() uint64 {
99+
// Hostname length + 2 bytes for port.
100+
return uint64(len(d.Hostname) + 2)
101+
}
102+
103+
return tlv.MakeDynamicRecord(
104+
0, d, sizeFunc, dnsAddressEncoder, dnsAddressDecoder,
105+
)
106+
}
107+
108+
// dnsAddressEncoder is a TLV encoder for DNSAddress.
109+
func dnsAddressEncoder(w io.Writer, val any, _ *[8]byte) error {
110+
if v, ok := val.(*DNSAddress); ok {
111+
var buf bytes.Buffer
112+
113+
// Write the hostname as raw bytes (no length prefix for TLV).
114+
if _, err := buf.WriteString(v.Hostname); err != nil {
115+
return err
116+
}
117+
118+
// Write the port as 2 bytes.
119+
err := WriteUint16(&buf, v.Port)
120+
if err != nil {
121+
return err
122+
}
123+
124+
_, err = w.Write(buf.Bytes())
125+
126+
return err
127+
}
128+
129+
return tlv.NewTypeForEncodingErr(val, "DNSAddress")
130+
}
131+
132+
// dnsAddressDecoder is a TLV decoder for DNSAddress.
133+
func dnsAddressDecoder(r io.Reader, val any, _ *[8]byte,
134+
l uint64) error {
135+
136+
if v, ok := val.(*DNSAddress); ok {
137+
if l < 2 {
138+
return fmt.Errorf("DNS address must be at least 2 " +
139+
"bytes")
140+
}
141+
142+
// Read hostname (all bytes except last 2).
143+
hostnameLen := l - 2
144+
hostnameBytes := make([]byte, hostnameLen)
145+
if _, err := io.ReadFull(r, hostnameBytes); err != nil {
146+
return err
147+
}
148+
v.Hostname = string(hostnameBytes)
149+
150+
// Read port (last 2 bytes).
151+
if err := ReadElement(r, &v.Port); err != nil {
152+
return err
153+
}
154+
155+
return ValidateDNSAddr(v.Hostname, v.Port)
156+
}
157+
158+
return tlv.NewTypeForDecodingErr(val, "DNSAddress", l, 0)
159+
}

lnwire/dns_addr_test.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package lnwire
22

33
import (
4+
"bytes"
45
"fmt"
56
"strings"
67
"testing"
78

9+
"github.com/lightningnetwork/lnd/tlv"
810
"github.com/stretchr/testify/require"
11+
"pgregory.net/rapid"
912
)
1013

1114
// TestValidateDNSAddr tests hostname and port validation per BOLT #7.
@@ -85,3 +88,112 @@ func TestValidateDNSAddr(t *testing.T) {
8588
})
8689
}
8790
}
91+
92+
// TestDNSAddressTLVEncoding tests the TLV encoding and decoding of DNSAddress
93+
// structs using the ExtraOpaqueData interface.
94+
func TestDNSAddressTLVEncoding(t *testing.T) {
95+
t.Parallel()
96+
97+
testDNSAddr := DNSAddress{
98+
Hostname: "lightning.example.com",
99+
Port: 9000,
100+
}
101+
102+
var extraData ExtraOpaqueData
103+
require.NoError(t, extraData.PackRecords(&testDNSAddr))
104+
105+
var decodedDNSAddr DNSAddress
106+
tlvs, err := extraData.ExtractRecords(&decodedDNSAddr)
107+
require.NoError(t, err)
108+
109+
require.Contains(t, tlvs, tlv.Type(0))
110+
require.Equal(t, testDNSAddr, decodedDNSAddr)
111+
}
112+
113+
// TestDNSAddressRecord tests the TLV Record interface of DNSAddress
114+
// by directly encoding and decoding using the Record method.
115+
func TestDNSAddressRecord(t *testing.T) {
116+
t.Parallel()
117+
118+
testDNSAddr := DNSAddress{
119+
Hostname: "lightning.example.com",
120+
Port: 9000,
121+
}
122+
123+
var buf bytes.Buffer
124+
record := testDNSAddr.Record()
125+
require.NoError(t, record.Encode(&buf))
126+
127+
var decodedDNSAddr DNSAddress
128+
decodedRecord := decodedDNSAddr.Record()
129+
require.NoError(t, decodedRecord.Decode(&buf, uint64(buf.Len())))
130+
131+
require.Equal(t, testDNSAddr, decodedDNSAddr)
132+
}
133+
134+
// TestDNSAddressInvalidDecoding tests error cases during TLV decoding.
135+
func TestDNSAddressInvalidDecoding(t *testing.T) {
136+
t.Parallel()
137+
138+
testCases := []struct {
139+
name string
140+
data []byte
141+
errMsg string
142+
}{
143+
{
144+
name: "too short (only 1 byte)",
145+
data: []byte{0x61},
146+
errMsg: "DNS address must be at least 2 bytes",
147+
},
148+
{
149+
name: "empty data",
150+
data: []byte{},
151+
errMsg: "DNS address must be at least 2 bytes",
152+
},
153+
}
154+
155+
for _, tc := range testCases {
156+
t.Run(tc.name, func(t *testing.T) {
157+
var dnsAddr DNSAddress
158+
record := dnsAddr.Record()
159+
160+
buf := bytes.NewReader(tc.data)
161+
err := record.Decode(buf, uint64(len(tc.data)))
162+
require.Error(t, err)
163+
require.ErrorContains(t, err, tc.errMsg)
164+
})
165+
}
166+
}
167+
168+
// TestDNSAddressProperty uses property-based testing to verify that DNSAddress
169+
// TLV encoding and decoding is correct for random DNSAddress values.
170+
func TestDNSAddressProperty(t *testing.T) {
171+
t.Parallel()
172+
173+
scenario := func(t *rapid.T) {
174+
// Generate a random valid hostname.
175+
hostname := genValidHostname(t)
176+
177+
// Generate a random port (excluding 0 which is invalid).
178+
port := rapid.Uint16Range(1, 65535).Draw(t, "port")
179+
180+
dnsAddr := DNSAddress{
181+
Hostname: hostname,
182+
Port: port,
183+
}
184+
185+
var buf bytes.Buffer
186+
record := dnsAddr.Record()
187+
err := record.Encode(&buf)
188+
require.NoError(t, err)
189+
190+
var decodedDNSAddr DNSAddress
191+
decodedRecord := decodedDNSAddr.Record()
192+
err = decodedRecord.Decode(&buf, uint64(buf.Len()))
193+
require.NoError(t, err)
194+
195+
require.Equal(t, dnsAddr, decodedDNSAddr)
196+
}
197+
198+
rapid.Check(t, scenario)
199+
}

lnwire/test_message.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,3 +1867,24 @@ func (c *Error) RandTestMessage(t *rapid.T) Message {
18671867

18681868
return msg
18691869
}
1870+
1871+
// genValidHostname generates a random valid hostname according to BOLT #7
1872+
// rules.
1873+
func genValidHostname(t *rapid.T) string {
1874+
// Valid characters: a-z, A-Z, 0-9, -, .
1875+
validChars := "abcdefghijklmnopqrstuvwxyzABCDE" +
1876+
"FGHIJKLMNOPQRSTUVWXYZ0123456789-."
1877+
1878+
// Generate hostname length between 1 and 255 characters.
1879+
length := rapid.IntRange(1, 255).Draw(t, "hostname_length")
1880+
1881+
hostname := make([]byte, length)
1882+
for i := 0; i < length; i++ {
1883+
charIndex := rapid.IntRange(0, len(validChars)-1).Draw(
1884+
t, fmt.Sprintf("char_%d", i),
1885+
)
1886+
hostname[i] = validChars[charIndex]
1887+
}
1888+
1889+
return string(hostname)
1890+
}

0 commit comments

Comments
 (0)