Skip to content

Commit 6ac72d6

Browse files
committed
Pluggable authentication mechanisms.
1 parent df5acec commit 6ac72d6

File tree

4 files changed

+82
-42
lines changed

4 files changed

+82
-42
lines changed

auth.go

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,42 @@ var (
1919
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
2020
)
2121

22-
// authenticate is used to handle connection authentication
23-
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error {
24-
// Get the methods
25-
methods, err := readMethods(bufConn)
26-
if err != nil {
27-
return fmt.Errorf("Failed to get auth methods: %v", err)
28-
}
22+
type Authenticator interface {
23+
Authenticate(reader io.Reader, writer io.Writer) error
24+
GetCode() uint8
25+
}
2926

30-
// Determine what is supported
31-
supportUserPass := s.config.Credentials != nil
27+
// NoAuthAuthenticator is used to handle the "No Authentication" mode
28+
type NoAuthAuthenticator struct {}
3229

33-
// Select a usable method
34-
for _, method := range methods {
35-
if method == noAuth && !supportUserPass {
36-
return noAuthMode(conn)
37-
}
38-
if method == userPassAuth && supportUserPass {
39-
return s.userPassAuth(conn, bufConn)
40-
}
41-
}
30+
func (a NoAuthAuthenticator) GetCode() uint8 {
31+
return noAuth
32+
}
4233

43-
// No usable method found
44-
return noAcceptableAuth(conn)
34+
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error {
35+
_, err := writer.Write([]byte{socks5Version, noAuth})
36+
return err
4537
}
4638

47-
// userPassAuth is used to handle username/password based
39+
// UserPassAuthenticator is used to handle username/password based
4840
// authentication
49-
func (s *Server) userPassAuth(conn io.Writer, bufConn io.Reader) error {
41+
type UserPassAuthenticator struct {
42+
Credentials CredentialStore
43+
}
44+
45+
func (a UserPassAuthenticator) GetCode() uint8 {
46+
return userPassAuth
47+
}
48+
49+
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) error {
5050
// Tell the client to use user/pass auth
51-
if _, err := conn.Write([]byte{socks5Version, userPassAuth}); err != nil {
51+
if _, err := writer.Write([]byte{socks5Version, userPassAuth}); err != nil {
5252
return err
5353
}
5454

5555
// Get the version and username length
5656
header := []byte{0, 0}
57-
if _, err := io.ReadAtLeast(bufConn, header, 2); err != nil {
57+
if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
5858
return err
5959
}
6060

@@ -66,44 +66,63 @@ func (s *Server) userPassAuth(conn io.Writer, bufConn io.Reader) error {
6666
// Get the user name
6767
userLen := int(header[1])
6868
user := make([]byte, userLen)
69-
if _, err := io.ReadAtLeast(bufConn, user, userLen); err != nil {
69+
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
7070
return err
7171
}
7272

7373
// Get the password length
74-
if _, err := bufConn.Read(header[:1]); err != nil {
74+
if _, err := reader.Read(header[:1]); err != nil {
7575
return err
7676
}
7777

7878
// Get the password
7979
passLen := int(header[0])
8080
pass := make([]byte, passLen)
81-
if _, err := io.ReadAtLeast(bufConn, pass, passLen); err != nil {
81+
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
8282
return err
8383
}
8484

8585
// Verify the password
86-
if s.config.Credentials.Valid(string(user), string(pass)) {
87-
if _, err := conn.Write([]byte{userAuthVersion, authSuccess}); err != nil {
86+
if a.Credentials.Valid(string(user), string(pass)) {
87+
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
8888
return err
8989
}
9090
} else {
91-
if _, err := conn.Write([]byte{userAuthVersion, authFailure}); err != nil {
91+
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
9292
return err
9393
}
9494
return UserAuthFailed
9595
}
9696

9797
// Done
9898
return nil
99+
99100
}
100101

101-
// noAuth is used to handle the "No Authentication" mode
102-
func noAuthMode(conn io.Writer) error {
103-
_, err := conn.Write([]byte{socks5Version, noAuth})
104-
return err
102+
103+
104+
// authenticate is used to handle connection authentication
105+
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) error {
106+
// Get the methods
107+
methods, err := readMethods(bufConn)
108+
if err != nil {
109+
return fmt.Errorf("Failed to get auth methods: %v", err)
110+
}
111+
112+
// Select a usable method
113+
for _, method := range methods {
114+
cator, found := s.authMethods[method]
115+
if found {
116+
return cator.Authenticate(bufConn, conn)
117+
}
118+
}
119+
120+
// No usable method found
121+
return noAcceptableAuth(conn)
105122
}
106123

124+
125+
107126
// noAcceptableAuth is used to handle when we have no eligible
108127
// authentication mechanism
109128
func noAcceptableAuth(conn io.Writer) error {

auth_test.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ func TestNoAuth(t *testing.T) {
1010
req.Write([]byte{1, noAuth})
1111
var resp bytes.Buffer
1212

13-
s := &Server{config: &Config{}}
13+
s, _ := New(&Config{})
1414
if err := s.authenticate(&resp, req); err != nil {
1515
t.Fatalf("err: %v", err)
1616
}
@@ -30,7 +30,11 @@ func TestPasswordAuth_Valid(t *testing.T) {
3030
cred := StaticCredentials{
3131
"foo": "bar",
3232
}
33-
s := &Server{config: &Config{Credentials: cred}}
33+
34+
cator := UserPassAuthenticator{Credentials: cred}
35+
36+
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}})
37+
3438
if err := s.authenticate(&resp, req); err != nil {
3539
t.Fatalf("err: %v", err)
3640
}
@@ -50,7 +54,8 @@ func TestPasswordAuth_Invalid(t *testing.T) {
5054
cred := StaticCredentials{
5155
"foo": "bar",
5256
}
53-
s := &Server{config: &Config{Credentials: cred}}
57+
cator := UserPassAuthenticator{Credentials: cred}
58+
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}})
5459
if err := s.authenticate(&resp, req); err != UserAuthFailed {
5560
t.Fatalf("err: %v", err)
5661
}
@@ -69,7 +74,9 @@ func TestNoSupportedAuth(t *testing.T) {
6974
cred := StaticCredentials{
7075
"foo": "bar",
7176
}
72-
s := &Server{config: &Config{Credentials: cred}}
77+
cator := UserPassAuthenticator{Credentials: cred}
78+
79+
s, _ := New(&Config{AuthMethods:[]Authenticator{cator}})
7380
if err := s.authenticate(&resp, req); err != NoSupportedAuth {
7481
t.Fatalf("err: %v", err)
7582
}

socks5.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ const (
1313

1414
// Config is used to setup and configure a Server
1515
type Config struct {
16-
// If provided, username/password authentication is enabled
17-
// otherwise, non-authenticated mode is allowed
18-
Credentials CredentialStore
16+
// AuthMethods can be provided to implement custom authentication
17+
// By default, "auth-less" mode is enabled. For password-based auth use UserPassAuthenticator.
18+
AuthMethods []Authenticator
1919

2020
// Resolver can be provided to do custom name resolution.
2121
// Defaults to DNSResolver if not provided.
@@ -38,10 +38,16 @@ type Config struct {
3838
// the details of the SOCKS5 protocol
3939
type Server struct {
4040
config *Config
41+
authMethods map[uint8]Authenticator
4142
}
4243

4344
// New creates a new Server and potentially returns an error
4445
func New(conf *Config) (*Server, error) {
46+
// Ensure we have at least one authentication method enabled
47+
if conf.AuthMethods == nil || len(conf.AuthMethods) == 0 {
48+
conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}}
49+
}
50+
4551
// Ensure we have a DNS resolver
4652
if conf.Resolver == nil {
4753
conf.Resolver = DNSResolver{}
@@ -55,6 +61,13 @@ func New(conf *Config) (*Server, error) {
5561
server := &Server{
5662
config: conf,
5763
}
64+
65+
server.authMethods = make(map[uint8]Authenticator)
66+
67+
for _, a := range conf.AuthMethods {
68+
server.authMethods[a.GetCode()] = a
69+
}
70+
5871
return server, nil
5972
}
6073

socks5_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ func TestSOCKS5_Connect(t *testing.T) {
3838
creds := StaticCredentials{
3939
"foo": "bar",
4040
}
41+
cator := UserPassAuthenticator{Credentials : creds}
4142
conf := &Config{
42-
Credentials: creds,
43+
AuthMethods : []Authenticator{cator},
4344
}
4445
serv, err := New(conf)
4546
if err != nil {

0 commit comments

Comments
 (0)