|
1 | 1 | # Copyright 2017 Palantir Technologies, Inc. |
2 | | -from functools import partial |
| 2 | +import json |
3 | 3 | import logging |
4 | 4 | import os |
5 | 5 | import socketserver |
6 | 6 | import threading |
| 7 | +from functools import partial |
| 8 | +from hashlib import sha256 |
7 | 9 |
|
8 | 10 | from pyls_jsonrpc.dispatchers import MethodDispatcher |
9 | 11 | from pyls_jsonrpc.endpoint import Endpoint |
@@ -34,19 +36,36 @@ def setup(self): |
34 | 36 | self.delegate = self.DELEGATE_CLASS(self.rfile, self.wfile) |
35 | 37 |
|
36 | 38 | def handle(self): |
37 | | - try: |
38 | | - self.delegate.start() |
39 | | - except OSError as e: |
40 | | - if os.name == 'nt': |
41 | | - # Catch and pass on ConnectionResetError when parent process |
42 | | - # dies |
43 | | - # pylint: disable=no-member, undefined-variable |
44 | | - if isinstance(e, WindowsError) and e.winerror == 10054: |
45 | | - pass |
46 | | - |
| 39 | + self.auth(self.delegate.start) |
47 | 40 | # pylint: disable=no-member |
48 | 41 | self.SHUTDOWN_CALL() |
49 | 42 |
|
| 43 | + def auth(self, cb): |
| 44 | + token = '' |
| 45 | + if "JUPYTER_TOKEN" in os.environ: |
| 46 | + token = os.environ["JUPYTER_TOKEN"] |
| 47 | + else: |
| 48 | + log.warn('! Missing jupyter token !') |
| 49 | + |
| 50 | + data = self.rfile.readline() |
| 51 | + try: |
| 52 | + auth_req = json.loads(data.decode().split('\n')[0]) |
| 53 | + except: |
| 54 | + log.error('Error parsing authentication message') |
| 55 | + auth_error_msg = { 'msg': 'AUTH_ERROR' } |
| 56 | + self.wfile.write(json.dumps(auth_error_msg).encode()) |
| 57 | + return |
| 58 | + |
| 59 | + hashed_token = sha256(token.encode()).hexdigest() |
| 60 | + if auth_req.get('token') == hashed_token: |
| 61 | + auth_success_msg = { 'msg': 'AUTH_SUCCESS' } |
| 62 | + self.wfile.write(json.dumps(auth_success_msg).encode()) |
| 63 | + cb() |
| 64 | + else: |
| 65 | + log.info('Failed to authenticate: invalid credentials') |
| 66 | + auth_invalid_msg = { 'msg': 'AUTH_INVALID_CRED' } |
| 67 | + self.wfile.write(json.dumps(auth_invalid_msg).encode()) |
| 68 | + |
50 | 69 |
|
51 | 70 | def start_tcp_lang_server(bind_addr, port, check_parent_process, handler_class): |
52 | 71 | if not issubclass(handler_class, PythonLanguageServer): |
|
0 commit comments