Skip to content

Commit 2da7d92

Browse files
committed
fix udp reconnects
1 parent edad853 commit 2da7d92

File tree

4 files changed

+42
-19
lines changed

4 files changed

+42
-19
lines changed

main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ func server(args []string) {
200200
log.Fatal(err)
201201
}
202202
if err := s.Wait(); err != nil {
203-
log.Fatal()
203+
log.Fatal(err)
204204
}
205205
}
206206

@@ -372,6 +372,6 @@ func client(args []string) {
372372
log.Fatal(err)
373373
}
374374
if err := c.Wait(); err != nil {
375-
log.Fatal()
375+
log.Fatal(err)
376376
}
377377
}

share/tunnel/tunnel_in_proxy_udp.go

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package tunnel
33
import (
44
"context"
55
"encoding/gob"
6+
"fmt"
7+
"io"
68
"net"
9+
"strings"
710
"sync"
811
"sync/atomic"
912
"time"
@@ -52,7 +55,7 @@ type udpListener struct {
5255
remote *settings.Remote
5356
inbound *net.UDPConn
5457
outboundMut sync.Mutex
55-
outbound *udpOutbound
58+
outbound *udpChannel
5659
sent, recv int64
5760
}
5861

@@ -90,13 +93,16 @@ func (u *udpListener) runInbound(ctx context.Context) error {
9093
return u.Errorf("read error: %w", err)
9194
}
9295
//upsert ssh channel
93-
o, err := u.getOubound(ctx)
96+
uc, err := u.getUDPChan(ctx)
9497
if err != nil {
95-
return u.Errorf("ssh-chan error: %w", err)
98+
if strings.HasSuffix(err.Error(), "EOF") {
99+
continue
100+
}
101+
return u.Errorf("inbound-udpchan: %w", err)
96102
}
97103
//send over channel, including source address
98104
b := buff[:n]
99-
if err := o.encode(addr.String(), dstAddr, b); err != nil {
105+
if err := uc.encode(addr.String(), dstAddr, b); err != nil {
100106
return u.Errorf("encode error: %w", err)
101107
}
102108
//stats
@@ -108,13 +114,19 @@ func (u *udpListener) runInbound(ctx context.Context) error {
108114
func (u *udpListener) runOutbound(ctx context.Context) error {
109115
for !isDone(ctx) {
110116
//upsert ssh channel
111-
o, err := u.getOubound(ctx)
117+
uc, err := u.getUDPChan(ctx)
112118
if err != nil {
113-
return u.Errorf("ssh-chan error: %w", err)
119+
if strings.HasSuffix(err.Error(), "EOF") {
120+
continue
121+
}
122+
return u.Errorf("outbound-udpchan: %w", err)
114123
}
115124
//receive from channel, including source address
116125
p := udpPacket{}
117-
if err := o.decode(&p); err != nil {
126+
if err := uc.decode(&p); err == io.EOF {
127+
//outbound ssh disconnected, get new connection...
128+
continue
129+
} else if err != nil {
118130
return u.Errorf("decode error: %w", err)
119131
}
120132
//write back to inbound udp
@@ -132,7 +144,7 @@ func (u *udpListener) runOutbound(ctx context.Context) error {
132144
return nil
133145
}
134146

135-
func (u *udpListener) getOubound(ctx context.Context) (*udpOutbound, error) {
147+
func (u *udpListener) getUDPChan(ctx context.Context) (*udpChannel, error) {
136148
u.outboundMut.Lock()
137149
defer u.outboundMut.Unlock()
138150
//cached
@@ -142,21 +154,32 @@ func (u *udpListener) getOubound(ctx context.Context) (*udpOutbound, error) {
142154
//not cached, bind
143155
sshConn := u.sshTun.getSSH(ctx)
144156
if sshConn == nil {
145-
return nil, u.Errorf("ssh-conn nil")
157+
return nil, fmt.Errorf("ssh-conn nil")
146158
}
147159
//ssh request for udp packets for this proxy's remote,
148160
//just "udp" since the remote address is sent with each packet
149161
rwc, reqs, err := sshConn.OpenChannel("chisel", []byte("udp"))
150162
if err != nil {
151-
return nil, u.Errorf("ssh-chan error: %s", err)
163+
return nil, fmt.Errorf("ssh-chan error: %s", err)
152164
}
153165
go ssh.DiscardRequests(reqs)
166+
//remove on disconnect
167+
go u.unsetUDPChan(sshConn)
154168
//ready
155-
o := &udpOutbound{
169+
o := &udpChannel{
156170
r: gob.NewDecoder(rwc),
157171
w: gob.NewEncoder(rwc),
158172
c: rwc,
159173
}
160174
u.outbound = o
175+
u.Debugf("aquired channel")
161176
return o, nil
162177
}
178+
179+
func (u *udpListener) unsetUDPChan(sshConn ssh.Conn) {
180+
sshConn.Wait()
181+
u.Debugf("lost channel")
182+
u.outboundMut.Lock()
183+
u.outbound = nil
184+
u.outboundMut.Unlock()
185+
}

share/tunnel/tunnel_out_ssh_udp.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
func (t *Tunnel) handleUDP(l *cio.Logger, rwc io.ReadWriteCloser) error {
1616
h := &udpHandler{
1717
Logger: l,
18-
udpOutbound: &udpOutbound{
18+
udpChannel: &udpChannel{
1919
r: gob.NewDecoder(rwc),
2020
w: gob.NewEncoder(rwc),
2121
c: rwc,
@@ -34,7 +34,7 @@ func (t *Tunnel) handleUDP(l *cio.Logger, rwc io.ReadWriteCloser) error {
3434

3535
type udpHandler struct {
3636
*cio.Logger
37-
*udpOutbound
37+
*udpChannel
3838
*udpConns
3939
sent, recv int64
4040
}
@@ -92,7 +92,7 @@ func (h *udpHandler) handleRead(p *udpPacket, conn *udpConn) {
9292
}
9393
b := buff[:n]
9494
//encode back over ssh connection
95-
err = h.udpOutbound.encode(p.Src, p.Dst, b)
95+
err = h.udpChannel.encode(p.Src, p.Dst, b)
9696
if err != nil {
9797
h.Debugf("encode error: %s", err)
9898
return

share/tunnel/udp.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ func init() {
1616
gob.Register(&udpPacket{})
1717
}
1818

19-
type udpOutbound struct {
19+
type udpChannel struct {
2020
r *gob.Decoder
2121
w *gob.Encoder
2222
c io.Closer
2323
}
2424

25-
func (o *udpOutbound) encode(src, dst string, b []byte) error {
25+
func (o *udpChannel) encode(src, dst string, b []byte) error {
2626
return o.w.Encode(udpPacket{
2727
Src: src,
2828
Dst: dst,
2929
Payload: b,
3030
})
3131
}
3232

33-
func (o *udpOutbound) decode(p *udpPacket) error {
33+
func (o *udpChannel) decode(p *udpPacket) error {
3434
return o.r.Decode(p)
3535
}
3636

0 commit comments

Comments
 (0)