Skip to content

Commit 33d4f05

Browse files
authored
[sparksql] Improve session reuse and fix corner cases (#2851)
- Improve session handling - Fix failing corner cases - Add checks for different session states - Cancel statement improvements - Fix failing UTs
1 parent cba12fb commit 33d4f05

File tree

3 files changed

+97
-29
lines changed

3 files changed

+97
-29
lines changed

apps/spark/src/spark/livy_client.py

+3
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ def close(self, uuid):
153153
def get_batches(self):
154154
return self._root.get('batches')
155155

156+
def cancel_statement(self, session, statement_id):
157+
return self._root.post('sessions/%s/statements/%s/cancel' % (session, statement_id))
158+
156159
def submit_batch(self, properties):
157160
properties['proxyUser'] = self.user
158161
return self._root.post('batches', data=json.dumps(properties), contenttype=_JSON_CONTENT_TYPE)

desktop/libs/notebook/src/notebook/connectors/spark_shell.py

+92-29
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,26 @@ def _get_session_key(self):
118118
}
119119

120120

121+
def _check_session(self, session):
122+
'''
123+
Check if the session is actually present and its state is healthy.
124+
'''
125+
api = self.get_api()
126+
try:
127+
session_present = api.get_session(session['id'])
128+
except Exception as e:
129+
session_present = None
130+
131+
if session_present and session_present['state'] not in ('dead', 'shutting_down', 'error', 'killed'):
132+
return session_present
133+
134+
121135
def create_session(self, lang='scala', properties=None):
122136
api = self.get_api()
123137
session_key = self._get_session_key()
124138

125139
if SESSIONS.get(session_key):
126-
# Checking if the session is actually present to avoid stale value
127-
session_present = api.get_session(SESSIONS[session_key]['id'])
140+
session_present = self._check_session(SESSIONS[session_key])
128141
if session_present:
129142
return SESSIONS[session_key]
130143

@@ -161,15 +174,18 @@ def execute(self, notebook, snippet):
161174
api = self.get_api()
162175
session = _get_snippet_session(notebook, snippet)
163176

164-
response = self._execute(api, session, snippet['statement'])
177+
response = self._execute(api, session, snippet.get('type'), snippet['statement'])
165178
return response
166179

167180

168-
def _execute(self, api, session, statement):
181+
def _execute(self, api, session, snippet_type, statement):
169182
session_key = self._get_session_key()
170183

171-
if session['id'] is None and SESSIONS.get(session_key) is not None:
172-
session = SESSIONS[session_key]
184+
if not session or not self._check_session(session):
185+
if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
186+
session = SESSIONS[session_key]
187+
else:
188+
session = self.create_session(snippet_type)
173189

174190
try:
175191
response = api.submit_statement(session['id'], statement)
@@ -191,6 +207,8 @@ def check_status(self, notebook, snippet):
191207
session = _get_snippet_session(notebook, snippet)
192208
cell = snippet['result']['handle']['id']
193209

210+
session = self._handle_session_health_check(session)
211+
194212
try:
195213
response = api.fetch_data(session['id'], cell)
196214
return {
@@ -209,6 +227,8 @@ def fetch_result(self, notebook, snippet, rows, start_over):
209227
session = _get_snippet_session(notebook, snippet)
210228
cell = snippet['result']['handle']['id']
211229

230+
session = self._handle_session_health_check(session)
231+
212232
response = self._fetch_result(api, session, cell, start_over)
213233
return response
214234

@@ -279,16 +299,43 @@ def _fetch_result(self, api, session, cell, start_over):
279299
def cancel(self, notebook, snippet):
280300
api = self.get_api()
281301
session = _get_snippet_session(notebook, snippet)
282-
response = api.cancel(session['id'])
302+
303+
session = self._handle_session_health_check(session)
304+
305+
try:
306+
response = api.cancel(session['id'])
307+
except Exception as e:
308+
message = force_unicode(str(e)).lower()
309+
LOG.debug(message)
283310

284311
return {'status': 0}
285312

286313

287314
def get_log(self, notebook, snippet, startFrom=0, size=None):
315+
response = {'status': 0}
288316
api = self.get_api()
289317
session = _get_snippet_session(notebook, snippet)
290318

291-
return api.get_log(session['id'], startFrom=startFrom, size=size)
319+
session = self._handle_session_health_check(session)
320+
try:
321+
response = api.get_log(session['id'], startFrom=startFrom, size=size)
322+
except RestException as e:
323+
message = force_unicode(str(e)).lower()
324+
LOG.debug(message)
325+
326+
return response
327+
328+
329+
def _handle_session_health_check(self, session):
330+
session_key = self._get_session_key()
331+
332+
if not session or not self._check_session(session):
333+
if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
334+
session = SESSIONS[session_key]
335+
else:
336+
raise PopupException(_("Session expired. Please create new session and try again."))
337+
338+
return session
292339

293340

294341
def close_statement(self, notebook, snippet): # Individual statements cannot be closed
@@ -327,9 +374,9 @@ def get_jobs(self, notebook, snippet, logs):
327374

328375
def autocomplete(self, snippet, database=None, table=None, column=None, nested=None, operation=None):
329376
response = {}
330-
331377
# As booting a new SQL session is slow and we don't send the id of the current one in /autocomplete
332378
# we could implement this by introducing an API cache per user similarly to SqlAlchemy.
379+
333380
api = self.get_api()
334381
session_key = self._get_session_key()
335382

@@ -338,14 +385,17 @@ def autocomplete(self, snippet, database=None, table=None, column=None, nested=N
338385
if SESSIONS.get(session_key):
339386
self._close_unused_sessions()
340387

341-
session = SESSIONS[session_key] if SESSIONS.get(session_key) else self.create_session(snippet.get('type'))
388+
if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
389+
session = SESSIONS[session_key]
390+
else:
391+
session = self.create_session(snippet.get('type'))
342392

343393
if database is None:
344-
response['databases'] = self._show_databases(api, session)
394+
response['databases'] = self._show_databases(api, session, snippet.get('type'))
345395
elif table is None:
346-
response['tables_meta'] = self._show_tables(api, session, database)
396+
response['tables_meta'] = self._show_tables(api, session, snippet.get('type'), database)
347397
elif column is None:
348-
columns = self._get_columns(api, session, database, table)
398+
columns = self._get_columns(api, session, snippet.get('type'), database, table)
349399
response['columns'] = [col['name'] for col in columns]
350400
response['extended_columns'] = [{
351401
'comment': col.get('comment'),
@@ -360,52 +410,62 @@ def autocomplete(self, snippet, database=None, table=None, column=None, nested=N
360410

361411
def _close_unused_sessions(self):
362412
'''
363-
Closes all unsused Livy sessions for a particular user to free up session resources.
413+
Closes all unused Livy sessions for a particular user to free up session resources.
364414
'''
365415
api = self.get_api()
366416
session_key = self._get_session_key()
367417

368-
all_sessions = api.get_sessions()
369-
for session in all_sessions['sessions']:
370-
if session['owner'] == self.user.username and session['id'] != SESSIONS[session_key]['id']:
371-
self.close_session(session)
418+
all_session = {}
419+
try:
420+
all_sessions = api.get_sessions()
421+
except Exception as e:
422+
message = force_unicode(str(e)).lower()
423+
LOG.debug(message)
424+
425+
if all_sessions:
426+
for session in all_sessions['sessions']:
427+
if session['owner'] == self.user.username and session['id'] != SESSIONS[session_key]['id'] and \
428+
session['state'] in ('idle', 'shutting_down', 'error', 'dead', 'killed'):
429+
self.close_session(session)
372430

373431

374432
def _check_status_and_fetch_result(self, api, session, execute_resp):
375433
check_status = api.fetch_data(session['id'], execute_resp['id'])
376434

377-
while check_status['state'] in ['running', 'waiting']:
435+
count = 0
436+
while check_status['state'] in ['running', 'waiting'] and count < 120:
378437
check_status = api.fetch_data(session['id'], execute_resp['id'])
438+
count += 1
379439
time.sleep(1)
380440

381441
if check_status['state'] == 'available':
382442
return self._fetch_result(api, session, execute_resp['id'], start_over=True)
383443

384444

385-
def _show_databases(self, api, session):
386-
show_db_execute = self._execute(api, session, 'SHOW DATABASES')
445+
def _show_databases(self, api, session, snippet_type):
446+
show_db_execute = self._execute(api, session, snippet_type, 'SHOW DATABASES')
387447
db_list = self._check_status_and_fetch_result(api, session, show_db_execute)
388448

389449
if db_list:
390450
return [db[0] for db in db_list['data']]
391451

392452

393-
def _show_tables(self, api, session, database):
394-
use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
453+
def _show_tables(self, api, session, snippet_type, database):
454+
use_db_execute = self._execute(api, session, snippet_type, 'USE %(database)s' % {'database': database})
395455
use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
396456

397-
show_tables_execute = self._execute(api, session, 'SHOW TABLES')
457+
show_tables_execute = self._execute(api, session, snippet_type, 'SHOW TABLES')
398458
tables_list = self._check_status_and_fetch_result(api, session, show_tables_execute)
399459

400460
if tables_list:
401461
return [table[1] for table in tables_list['data']]
402462

403463

404-
def _get_columns(self, api, session, database, table):
405-
use_db_execute = self._execute(api, session, 'USE %(database)s' % {'database': database})
464+
def _get_columns(self, api, session, snippet_type, database, table):
465+
use_db_execute = self._execute(api, session, snippet_type, 'USE %(database)s' % {'database': database})
406466
use_db_resp = self._check_status_and_fetch_result(api, session, use_db_execute)
407467

408-
describe_tables_execute = self._execute(api, session, 'DESCRIBE %(table)s' % {'table': table})
468+
describe_tables_execute = self._execute(api, session, snippet_type, 'DESCRIBE %(table)s' % {'table': table})
409469
columns_list = self._check_status_and_fetch_result(api, session, describe_tables_execute)
410470

411471
if columns_list:
@@ -425,11 +485,14 @@ def get_sample_data(self, snippet, database=None, table=None, column=None, is_as
425485
if SESSIONS.get(session_key):
426486
self._close_unused_sessions()
427487

428-
session = SESSIONS[session_key] if SESSIONS.get(session_key) else self.create_session(snippet.get('type'))
488+
if SESSIONS.get(session_key) and self._check_session(SESSIONS[session_key]):
489+
session = SESSIONS[session_key]
490+
else:
491+
session = self.create_session(snippet.get('type'))
429492

430493
statement = self._get_select_query(database, table, column, operation)
431494

432-
sample_execute = self._execute(api, session, statement)
495+
sample_execute = self._execute(api, session, snippet.get('type'), statement)
433496
sample_result = self._check_status_and_fetch_result(api, session, sample_execute)
434497

435498
response = {

desktop/libs/notebook/src/notebook/connectors/spark_shell_tests.py

+2
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def test_execute(self):
169169
return_value={'id': 'test_id'}
170170
)
171171
)
172+
self.api._check_session = Mock(return_value={'id': '1'})
172173

173174
response = self.api.execute(notebook, snippet)
174175
assert_equal(response['id'], 'test_id')
@@ -197,6 +198,7 @@ def test_check_status(self):
197198
return_value={'state': 'test_state'}
198199
)
199200
)
201+
self.api._handle_session_health_check = Mock(return_value={'id': '1'})
200202

201203
response = self.api.check_status(notebook, snippet)
202204
assert_equal(response['status'], 'test_state')

0 commit comments

Comments
 (0)