@@ -19,42 +19,42 @@ var (
1919 NoSupportedAuth = fmt .Errorf ("No supported authentication mechanism" )
2020)
2121
22- // authenticate is used to handle connection authentication
23- func (s * Server ) authenticate (conn io.Writer , bufConn io.Reader ) error {
24- // Get the methods
25- methods , err := readMethods (bufConn )
26- if err != nil {
27- return fmt .Errorf ("Failed to get auth methods: %v" , err )
28- }
22+ type Authenticator interface {
23+ Authenticate (reader io.Reader , writer io.Writer ) error
24+ GetCode () uint8
25+ }
2926
30- // Determine what is supported
31- supportUserPass := s . config . Credentials != nil
27+ // NoAuthAuthenticator is used to handle the "No Authentication" mode
28+ type NoAuthAuthenticator struct {}
3229
33- // Select a usable method
34- for _ , method := range methods {
35- if method == noAuth && ! supportUserPass {
36- return noAuthMode (conn )
37- }
38- if method == userPassAuth && supportUserPass {
39- return s .userPassAuth (conn , bufConn )
40- }
41- }
30+ func (a NoAuthAuthenticator ) GetCode () uint8 {
31+ return noAuth
32+ }
4233
43- // No usable method found
44- return noAcceptableAuth (conn )
34+ func (a NoAuthAuthenticator ) Authenticate (reader io.Reader , writer io.Writer ) error {
35+ _ , err := writer .Write ([]byte {socks5Version , noAuth })
36+ return err
4537}
4638
47- // userPassAuth is used to handle username/password based
39+ // UserPassAuthenticator is used to handle username/password based
4840// authentication
49- func (s * Server ) userPassAuth (conn io.Writer , bufConn io.Reader ) error {
41+ type UserPassAuthenticator struct {
42+ Credentials CredentialStore
43+ }
44+
45+ func (a UserPassAuthenticator ) GetCode () uint8 {
46+ return userPassAuth
47+ }
48+
49+ func (a UserPassAuthenticator ) Authenticate (reader io.Reader , writer io.Writer ) error {
5050 // Tell the client to use user/pass auth
51- if _ , err := conn .Write ([]byte {socks5Version , userPassAuth }); err != nil {
51+ if _ , err := writer .Write ([]byte {socks5Version , userPassAuth }); err != nil {
5252 return err
5353 }
5454
5555 // Get the version and username length
5656 header := []byte {0 , 0 }
57- if _ , err := io .ReadAtLeast (bufConn , header , 2 ); err != nil {
57+ if _ , err := io .ReadAtLeast (reader , header , 2 ); err != nil {
5858 return err
5959 }
6060
@@ -66,44 +66,63 @@ func (s *Server) userPassAuth(conn io.Writer, bufConn io.Reader) error {
6666 // Get the user name
6767 userLen := int (header [1 ])
6868 user := make ([]byte , userLen )
69- if _ , err := io .ReadAtLeast (bufConn , user , userLen ); err != nil {
69+ if _ , err := io .ReadAtLeast (reader , user , userLen ); err != nil {
7070 return err
7171 }
7272
7373 // Get the password length
74- if _ , err := bufConn .Read (header [:1 ]); err != nil {
74+ if _ , err := reader .Read (header [:1 ]); err != nil {
7575 return err
7676 }
7777
7878 // Get the password
7979 passLen := int (header [0 ])
8080 pass := make ([]byte , passLen )
81- if _ , err := io .ReadAtLeast (bufConn , pass , passLen ); err != nil {
81+ if _ , err := io .ReadAtLeast (reader , pass , passLen ); err != nil {
8282 return err
8383 }
8484
8585 // Verify the password
86- if s . config .Credentials .Valid (string (user ), string (pass )) {
87- if _ , err := conn .Write ([]byte {userAuthVersion , authSuccess }); err != nil {
86+ if a .Credentials .Valid (string (user ), string (pass )) {
87+ if _ , err := writer .Write ([]byte {userAuthVersion , authSuccess }); err != nil {
8888 return err
8989 }
9090 } else {
91- if _ , err := conn .Write ([]byte {userAuthVersion , authFailure }); err != nil {
91+ if _ , err := writer .Write ([]byte {userAuthVersion , authFailure }); err != nil {
9292 return err
9393 }
9494 return UserAuthFailed
9595 }
9696
9797 // Done
9898 return nil
99+
99100}
100101
101- // noAuth is used to handle the "No Authentication" mode
102- func noAuthMode (conn io.Writer ) error {
103- _ , err := conn .Write ([]byte {socks5Version , noAuth })
104- return err
102+
103+
104+ // authenticate is used to handle connection authentication
105+ func (s * Server ) authenticate (conn io.Writer , bufConn io.Reader ) error {
106+ // Get the methods
107+ methods , err := readMethods (bufConn )
108+ if err != nil {
109+ return fmt .Errorf ("Failed to get auth methods: %v" , err )
110+ }
111+
112+ // Select a usable method
113+ for _ , method := range methods {
114+ cator , found := s .authMethods [method ]
115+ if found {
116+ return cator .Authenticate (bufConn , conn )
117+ }
118+ }
119+
120+ // No usable method found
121+ return noAcceptableAuth (conn )
105122}
106123
124+
125+
107126// noAcceptableAuth is used to handle when we have no eligible
108127// authentication mechanism
109128func noAcceptableAuth (conn io.Writer ) error {
0 commit comments