Skip to content

Commit 9dafcb1

Browse files
committed
Merge pull request scikit-learn#5016 from lesteve/update-joblib-to-0.9.0b3
[MRG+1] Update joblib to 0.9.0b3
2 parents 0bcbf92 + 52ed08c commit 9dafcb1

File tree

2 files changed

+153
-131
lines changed

2 files changed

+153
-131
lines changed

sklearn/externals/joblib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
117117
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
118118
#
119-
__version__ = '0.9.0b2'
119+
__version__ = '0.9.0b3'
120120

121121

122122
from .memory import Memory, MemorizedResult

sklearn/externals/joblib/parallel.py

Lines changed: 152 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -433,20 +433,116 @@ def __init__(self, n_jobs=1, backend='multiprocessing', verbose=0,
433433
% batch_size)
434434

435435
self.pre_dispatch = pre_dispatch
436-
self._pool = None
437436
self._temp_folder = temp_folder
438437
if isinstance(max_nbytes, _basestring):
439438
self._max_nbytes = 1024 * memstr_to_kbytes(max_nbytes)
440439
else:
441440
self._max_nbytes = max_nbytes
442441
self._mmap_mode = mmap_mode
443442
# Not starting the pool in the __init__ is a design decision, to be
444-
# able to close it ASAP, and not burden the user with closing it.
443+
# able to close it ASAP, and not burden the user with closing it
444+
# unless they choose to use the context manager API with a with block.
445+
self._pool = None
445446
self._output = None
446447
self._jobs = list()
447-
# A flag used to abort the dispatching of jobs in case an
448-
# exception is found
449-
self._aborting = False
448+
self._managed_pool = False
449+
450+
# This lock is used coordinate the main thread of this process with
451+
# the async callback thread of our the pool.
452+
self._lock = threading.Lock()
453+
454+
def __enter__(self):
455+
self._managed_pool = True
456+
self._initialize_pool()
457+
return self
458+
459+
def __exit__(self, exc_type, exc_value, traceback):
460+
self._terminate_pool()
461+
self._managed_pool = False
462+
463+
def _effective_n_jobs(self):
464+
n_jobs = self.n_jobs
465+
if n_jobs == 0:
466+
raise ValueError('n_jobs == 0 in Parallel has no meaning')
467+
elif mp is None or n_jobs is None:
468+
# multiprocessing is not available or disabled, fallback
469+
# to sequential mode
470+
return 1
471+
elif n_jobs < 0:
472+
n_jobs = max(mp.cpu_count() + 1 + n_jobs, 1)
473+
return n_jobs
474+
475+
def _initialize_pool(self):
476+
"""Build a process or thread pool and return the number of workers"""
477+
n_jobs = self._effective_n_jobs()
478+
# The list of exceptions that we will capture
479+
self.exceptions = [TransportableException]
480+
481+
if n_jobs == 1:
482+
# Sequential mode: do not use a pool instance to avoid any
483+
# useless dispatching overhead
484+
self._pool = None
485+
elif self.backend == 'threading':
486+
self._pool = ThreadPool(n_jobs)
487+
elif self.backend == 'multiprocessing':
488+
if mp.current_process().daemon:
489+
# Daemonic processes cannot have children
490+
self._pool = None
491+
warnings.warn(
492+
'Multiprocessing-backed parallel loops cannot be nested,'
493+
' setting n_jobs=1',
494+
stacklevel=3)
495+
return 1
496+
elif threading.current_thread().name != 'MainThread':
497+
# Prevent posix fork inside in non-main posix threads
498+
self._pool = None
499+
warnings.warn(
500+
'Multiprocessing backed parallel loops cannot be nested'
501+
' below threads, setting n_jobs=1',
502+
stacklevel=3)
503+
return 1
504+
else:
505+
already_forked = int(os.environ.get(JOBLIB_SPAWNED_PROCESS, 0))
506+
if already_forked:
507+
raise ImportError('[joblib] Attempting to do parallel computing '
508+
'without protecting your import on a system that does '
509+
'not support forking. To use parallel-computing in a '
510+
'script, you must protect your main loop using "if '
511+
"__name__ == '__main__'"
512+
'". Please see the joblib documentation on Parallel '
513+
'for more information'
514+
)
515+
# Set an environment variable to avoid infinite loops
516+
os.environ[JOBLIB_SPAWNED_PROCESS] = '1'
517+
518+
# Make sure to free as much memory as possible before forking
519+
gc.collect()
520+
poolargs = dict(
521+
max_nbytes=self._max_nbytes,
522+
mmap_mode=self._mmap_mode,
523+
temp_folder=self._temp_folder,
524+
verbose=max(0, self.verbose - 50),
525+
context_id=0, # the pool is used only for one call
526+
)
527+
if self._mp_context is not None:
528+
# Use Python 3.4+ multiprocessing context isolation
529+
poolargs['context'] = self._mp_context
530+
self._pool = MemmapingPool(n_jobs, **poolargs)
531+
532+
# We are using multiprocessing, we also want to capture
533+
# KeyboardInterrupts
534+
self.exceptions.extend([KeyboardInterrupt, WorkerInterrupt])
535+
else:
536+
raise ValueError("Unsupported backend: %s" % self.backend)
537+
return n_jobs
538+
539+
def _terminate_pool(self):
540+
if self._pool is not None:
541+
self._pool.close()
542+
self._pool.terminate() # terminate does a join()
543+
self._pool = None
544+
if self.backend == 'multiprocessing':
545+
os.environ.pop(JOBLIB_SPAWNED_PROCESS, 0)
450546

451547
def _dispatch(self, batch):
452548
"""Queue the batch for computing, with or without multiprocessing
@@ -455,6 +551,10 @@ def _dispatch(self, batch):
455551
indirectly via dispatch_one_batch.
456552
457553
"""
554+
# If job.get() catches an exception, it closes the queue:
555+
if self._aborting:
556+
return
557+
458558
if self._pool is None:
459559
job = ImmediateComputeBatch(batch)
460560
self._jobs.append(job)
@@ -467,10 +567,6 @@ def _dispatch(self, batch):
467567
short_format_time(time.time() - self._start_time)
468568
))
469569
else:
470-
# If job.get() catches an exception, it closes the queue:
471-
if self._aborting:
472-
return
473-
474570
dispatch_timestamp = time.time()
475571
cb = BatchCompletionCallBack(dispatch_timestamp, len(batch), self)
476572
job = self._pool.apply_async(SafeFunction(batch), callback=cb)
@@ -600,7 +696,7 @@ def print_progress(self):
600696
if (is_last_item or cursor % frequency):
601697
return
602698
remaining_time = (elapsed_time / (index + 1) *
603-
(self.n_dispatched_tasks - index - 1.))
699+
(self.n_dispatched_tasks - index - 1.))
604700
self._print('Done %3i out of %3i | elapsed: %s remaining: %s',
605701
(index + 1,
606702
total_tasks,
@@ -615,127 +711,62 @@ def retrieve(self):
615711
# Wait for an async callback to dispatch new jobs
616712
time.sleep(0.01)
617713
continue
618-
# We need to be careful: the job queue can be filling up as
619-
# we empty it
620-
if hasattr(self, '_lock'):
621-
self._lock.acquire()
622-
job = self._jobs.pop(0)
623-
if hasattr(self, '_lock'):
624-
self._lock.release()
714+
# We need to be careful: the job list can be filling up as
715+
# we empty it and Python list are not thread-safe by default hence
716+
# the use of the lock
717+
with self._lock:
718+
job = self._jobs.pop(0)
625719
try:
626720
self._output.extend(job.get())
627721
except tuple(self.exceptions) as exception:
628-
try:
629-
self._aborting = True
630-
self._lock.acquire()
631-
if isinstance(exception,
632-
(KeyboardInterrupt, WorkerInterrupt)):
633-
# We have captured a user interruption, clean up
634-
# everything
635-
if hasattr(self, '_pool'):
636-
self._pool.close()
637-
self._pool.terminate()
638-
# We can now allow subprocesses again
639-
os.environ.pop('__JOBLIB_SPAWNED_PARALLEL__', 0)
640-
raise exception
641-
elif isinstance(exception, TransportableException):
642-
# Capture exception to add information on the local
643-
# stack in addition to the distant stack
644-
this_report = format_outer_frames(context=10,
645-
stack_start=1)
646-
report = """Multiprocessing exception:
647-
%s
648-
---------------------------------------------------------------------------
649-
Sub-process traceback:
650-
---------------------------------------------------------------------------
651-
%s""" % (
652-
this_report,
653-
exception.message,
654-
)
655-
# Convert this to a JoblibException
656-
exception_type = _mk_exception(exception.etype)[0]
657-
raise exception_type(report)
658-
raise exception
659-
finally:
660-
self._lock.release()
722+
# Stop dispatching any new job in the async callback thread
723+
self._aborting = True
724+
725+
if isinstance(exception, TransportableException):
726+
# Capture exception to add information on the local
727+
# stack in addition to the distant stack
728+
this_report = format_outer_frames(context=10,
729+
stack_start=1)
730+
report = """Multiprocessing exception:
731+
%s
732+
---------------------------------------------------------------------------
733+
Sub-process traceback:
734+
---------------------------------------------------------------------------
735+
%s""" % (this_report, exception.message)
736+
# Convert this to a JoblibException
737+
exception_type = _mk_exception(exception.etype)[0]
738+
exception = exception_type(report)
739+
740+
# Kill remaining running processes without waiting for
741+
# the results as we will raise the exception we got back
742+
# to the caller instead of returning any result.
743+
with self._lock:
744+
self._terminate_pool()
745+
if self._managed_pool:
746+
# In case we had to terminate a managed pool, let
747+
# us start a new one to ensure that subsequent calls
748+
# to __call__ on the same Parallel instance will get
749+
# a working pool as they expect.
750+
self._initialize_pool()
751+
raise exception
661752

662753
def __call__(self, iterable):
663754
if self._jobs:
664755
raise ValueError('This Parallel instance is already running')
665-
n_jobs = self.n_jobs
666-
if n_jobs == 0:
667-
raise ValueError('n_jobs == 0 in Parallel has no meaning')
668-
if n_jobs < 0 and mp is not None:
669-
n_jobs = max(mp.cpu_count() + 1 + n_jobs, 1)
670-
671-
# The list of exceptions that we will capture
672-
self.exceptions = [TransportableException]
673-
self._lock = threading.Lock()
674-
675-
# Whether or not to set an environment flag to track
676-
# multiple process spawning
677-
set_environ_flag = False
678-
if (n_jobs is None or mp is None or n_jobs == 1):
679-
n_jobs = 1
680-
self._pool = None
681-
elif self.backend == 'threading':
682-
self._pool = ThreadPool(n_jobs)
683-
elif self.backend == 'multiprocessing':
684-
if mp.current_process().daemon:
685-
# Daemonic processes cannot have children
686-
n_jobs = 1
687-
self._pool = None
688-
warnings.warn(
689-
'Multiprocessing-backed parallel loops cannot be nested,'
690-
' setting n_jobs=1',
691-
stacklevel=2)
692-
elif threading.current_thread().name != 'MainThread':
693-
# Prevent posix fork inside in non-main posix threads
694-
n_jobs = 1
695-
self._pool = None
696-
warnings.warn(
697-
'Multiprocessing backed parallel loops cannot be nested'
698-
' below threads, setting n_jobs=1',
699-
stacklevel=2)
700-
else:
701-
already_forked = int(os.environ.get('__JOBLIB_SPAWNED_PARALLEL__', 0))
702-
if already_forked:
703-
raise ImportError('[joblib] Attempting to do parallel computing '
704-
'without protecting your import on a system that does '
705-
'not support forking. To use parallel-computing in a '
706-
'script, you must protect your main loop using "if '
707-
"__name__ == '__main__'"
708-
'". Please see the joblib documentation on Parallel '
709-
'for more information'
710-
)
711-
712-
# Make sure to free as much memory as possible before forking
713-
gc.collect()
714-
715-
# Set an environment variable to avoid infinite loops
716-
set_environ_flag = True
717-
poolargs = dict(
718-
max_nbytes=self._max_nbytes,
719-
mmap_mode=self._mmap_mode,
720-
temp_folder=self._temp_folder,
721-
verbose=max(0, self.verbose - 50),
722-
context_id=0, # the pool is used only for one call
723-
)
724-
if self._mp_context is not None:
725-
# Use Python 3.4+ multiprocessing context isolation
726-
poolargs['context'] = self._mp_context
727-
self._pool = MemmapingPool(n_jobs, **poolargs)
728-
# We are using multiprocessing, we also want to capture
729-
# KeyboardInterrupts
730-
self.exceptions.extend([KeyboardInterrupt, WorkerInterrupt])
756+
# A flag used to abort the dispatching of jobs in case an
757+
# exception is found
758+
self._aborting = False
759+
if not self._managed_pool:
760+
n_jobs = self._initialize_pool()
731761
else:
732-
raise ValueError("Unsupported backend: %s" % self.backend)
762+
n_jobs = self._effective_n_jobs()
733763

734764
if self.batch_size == 'auto':
735765
self._effective_batch_size = 1
736766

737767
iterator = iter(iterable)
738768
pre_dispatch = self.pre_dispatch
769+
739770
if pre_dispatch == 'all' or n_jobs == 1:
740771
# prevent further dispatch via multiprocessing callback thread
741772
self._original_iterator = None
@@ -757,9 +788,6 @@ def __call__(self, iterable):
757788
self.n_completed_tasks = 0
758789
self._smoothed_batch_duration = 0.0
759790
try:
760-
if set_environ_flag:
761-
# Set an environment variable to avoid infinite loops
762-
os.environ[JOBLIB_SPAWNED_PROCESS] = '1'
763791
self._iterating = True
764792

765793
while self.dispatch_one_batch(iterator):
@@ -774,17 +802,11 @@ def __call__(self, iterable):
774802
# Make sure that we get a last message telling us we are done
775803
elapsed_time = time.time() - self._start_time
776804
self._print('Done %3i out of %3i | elapsed: %s finished',
777-
(len(self._output),
778-
len(self._output),
779-
short_format_time(elapsed_time)
780-
))
781-
805+
(len(self._output), len(self._output),
806+
short_format_time(elapsed_time)))
782807
finally:
783-
if n_jobs > 1:
784-
self._pool.close()
785-
self._pool.terminate() # terminate does a join()
786-
if self.backend == 'multiprocessing':
787-
os.environ.pop(JOBLIB_SPAWNED_PROCESS, 0)
808+
if not self._managed_pool:
809+
self._terminate_pool()
788810
self._jobs = list()
789811
output = self._output
790812
self._output = None

0 commit comments

Comments
 (0)