@@ -3,7 +3,10 @@ package tunnel
3
3
import (
4
4
"context"
5
5
"encoding/gob"
6
+ "fmt"
7
+ "io"
6
8
"net"
9
+ "strings"
7
10
"sync"
8
11
"sync/atomic"
9
12
"time"
@@ -52,7 +55,7 @@ type udpListener struct {
52
55
remote * settings.Remote
53
56
inbound * net.UDPConn
54
57
outboundMut sync.Mutex
55
- outbound * udpOutbound
58
+ outbound * udpChannel
56
59
sent , recv int64
57
60
}
58
61
@@ -90,13 +93,16 @@ func (u *udpListener) runInbound(ctx context.Context) error {
90
93
return u .Errorf ("read error: %w" , err )
91
94
}
92
95
//upsert ssh channel
93
- o , err := u .getOubound (ctx )
96
+ uc , err := u .getUDPChan (ctx )
94
97
if err != nil {
95
- return u .Errorf ("ssh-chan error: %w" , err )
98
+ if strings .HasSuffix (err .Error (), "EOF" ) {
99
+ continue
100
+ }
101
+ return u .Errorf ("inbound-udpchan: %w" , err )
96
102
}
97
103
//send over channel, including source address
98
104
b := buff [:n ]
99
- if err := o .encode (addr .String (), dstAddr , b ); err != nil {
105
+ if err := uc .encode (addr .String (), dstAddr , b ); err != nil {
100
106
return u .Errorf ("encode error: %w" , err )
101
107
}
102
108
//stats
@@ -108,13 +114,19 @@ func (u *udpListener) runInbound(ctx context.Context) error {
108
114
func (u * udpListener ) runOutbound (ctx context.Context ) error {
109
115
for ! isDone (ctx ) {
110
116
//upsert ssh channel
111
- o , err := u .getOubound (ctx )
117
+ uc , err := u .getUDPChan (ctx )
112
118
if err != nil {
113
- return u .Errorf ("ssh-chan error: %w" , err )
119
+ if strings .HasSuffix (err .Error (), "EOF" ) {
120
+ continue
121
+ }
122
+ return u .Errorf ("outbound-udpchan: %w" , err )
114
123
}
115
124
//receive from channel, including source address
116
125
p := udpPacket {}
117
- if err := o .decode (& p ); err != nil {
126
+ if err := uc .decode (& p ); err == io .EOF {
127
+ //outbound ssh disconnected, get new connection...
128
+ continue
129
+ } else if err != nil {
118
130
return u .Errorf ("decode error: %w" , err )
119
131
}
120
132
//write back to inbound udp
@@ -132,7 +144,7 @@ func (u *udpListener) runOutbound(ctx context.Context) error {
132
144
return nil
133
145
}
134
146
135
- func (u * udpListener ) getOubound (ctx context.Context ) (* udpOutbound , error ) {
147
+ func (u * udpListener ) getUDPChan (ctx context.Context ) (* udpChannel , error ) {
136
148
u .outboundMut .Lock ()
137
149
defer u .outboundMut .Unlock ()
138
150
//cached
@@ -142,21 +154,32 @@ func (u *udpListener) getOubound(ctx context.Context) (*udpOutbound, error) {
142
154
//not cached, bind
143
155
sshConn := u .sshTun .getSSH (ctx )
144
156
if sshConn == nil {
145
- return nil , u .Errorf ("ssh-conn nil" )
157
+ return nil , fmt .Errorf ("ssh-conn nil" )
146
158
}
147
159
//ssh request for udp packets for this proxy's remote,
148
160
//just "udp" since the remote address is sent with each packet
149
161
rwc , reqs , err := sshConn .OpenChannel ("chisel" , []byte ("udp" ))
150
162
if err != nil {
151
- return nil , u .Errorf ("ssh-chan error: %s" , err )
163
+ return nil , fmt .Errorf ("ssh-chan error: %s" , err )
152
164
}
153
165
go ssh .DiscardRequests (reqs )
166
+ //remove on disconnect
167
+ go u .unsetUDPChan (sshConn )
154
168
//ready
155
- o := & udpOutbound {
169
+ o := & udpChannel {
156
170
r : gob .NewDecoder (rwc ),
157
171
w : gob .NewEncoder (rwc ),
158
172
c : rwc ,
159
173
}
160
174
u .outbound = o
175
+ u .Debugf ("aquired channel" )
161
176
return o , nil
162
177
}
178
+
179
+ func (u * udpListener ) unsetUDPChan (sshConn ssh.Conn ) {
180
+ sshConn .Wait ()
181
+ u .Debugf ("lost channel" )
182
+ u .outboundMut .Lock ()
183
+ u .outbound = nil
184
+ u .outboundMut .Unlock ()
185
+ }
0 commit comments