@@ -433,20 +433,116 @@ def __init__(self, n_jobs=1, backend='multiprocessing', verbose=0,
433433 % batch_size )
434434
435435 self .pre_dispatch = pre_dispatch
436- self ._pool = None
437436 self ._temp_folder = temp_folder
438437 if isinstance (max_nbytes , _basestring ):
439438 self ._max_nbytes = 1024 * memstr_to_kbytes (max_nbytes )
440439 else :
441440 self ._max_nbytes = max_nbytes
442441 self ._mmap_mode = mmap_mode
443442 # Not starting the pool in the __init__ is a design decision, to be
444- # able to close it ASAP, and not burden the user with closing it.
443+ # able to close it ASAP, and not burden the user with closing it
444+ # unless they choose to use the context manager API with a with block.
445+ self ._pool = None
445446 self ._output = None
446447 self ._jobs = list ()
447- # A flag used to abort the dispatching of jobs in case an
448- # exception is found
449- self ._aborting = False
448+ self ._managed_pool = False
449+
450+ # This lock is used coordinate the main thread of this process with
451+ # the async callback thread of our the pool.
452+ self ._lock = threading .Lock ()
453+
454+ def __enter__ (self ):
455+ self ._managed_pool = True
456+ self ._initialize_pool ()
457+ return self
458+
459+ def __exit__ (self , exc_type , exc_value , traceback ):
460+ self ._terminate_pool ()
461+ self ._managed_pool = False
462+
463+ def _effective_n_jobs (self ):
464+ n_jobs = self .n_jobs
465+ if n_jobs == 0 :
466+ raise ValueError ('n_jobs == 0 in Parallel has no meaning' )
467+ elif mp is None or n_jobs is None :
468+ # multiprocessing is not available or disabled, fallback
469+ # to sequential mode
470+ return 1
471+ elif n_jobs < 0 :
472+ n_jobs = max (mp .cpu_count () + 1 + n_jobs , 1 )
473+ return n_jobs
474+
475+ def _initialize_pool (self ):
476+ """Build a process or thread pool and return the number of workers"""
477+ n_jobs = self ._effective_n_jobs ()
478+ # The list of exceptions that we will capture
479+ self .exceptions = [TransportableException ]
480+
481+ if n_jobs == 1 :
482+ # Sequential mode: do not use a pool instance to avoid any
483+ # useless dispatching overhead
484+ self ._pool = None
485+ elif self .backend == 'threading' :
486+ self ._pool = ThreadPool (n_jobs )
487+ elif self .backend == 'multiprocessing' :
488+ if mp .current_process ().daemon :
489+ # Daemonic processes cannot have children
490+ self ._pool = None
491+ warnings .warn (
492+ 'Multiprocessing-backed parallel loops cannot be nested,'
493+ ' setting n_jobs=1' ,
494+ stacklevel = 3 )
495+ return 1
496+ elif threading .current_thread ().name != 'MainThread' :
497+ # Prevent posix fork inside in non-main posix threads
498+ self ._pool = None
499+ warnings .warn (
500+ 'Multiprocessing backed parallel loops cannot be nested'
501+ ' below threads, setting n_jobs=1' ,
502+ stacklevel = 3 )
503+ return 1
504+ else :
505+ already_forked = int (os .environ .get (JOBLIB_SPAWNED_PROCESS , 0 ))
506+ if already_forked :
507+ raise ImportError ('[joblib] Attempting to do parallel computing '
508+ 'without protecting your import on a system that does '
509+ 'not support forking. To use parallel-computing in a '
510+ 'script, you must protect your main loop using "if '
511+ "__name__ == '__main__'"
512+ '". Please see the joblib documentation on Parallel '
513+ 'for more information'
514+ )
515+ # Set an environment variable to avoid infinite loops
516+ os .environ [JOBLIB_SPAWNED_PROCESS ] = '1'
517+
518+ # Make sure to free as much memory as possible before forking
519+ gc .collect ()
520+ poolargs = dict (
521+ max_nbytes = self ._max_nbytes ,
522+ mmap_mode = self ._mmap_mode ,
523+ temp_folder = self ._temp_folder ,
524+ verbose = max (0 , self .verbose - 50 ),
525+ context_id = 0 , # the pool is used only for one call
526+ )
527+ if self ._mp_context is not None :
528+ # Use Python 3.4+ multiprocessing context isolation
529+ poolargs ['context' ] = self ._mp_context
530+ self ._pool = MemmapingPool (n_jobs , ** poolargs )
531+
532+ # We are using multiprocessing, we also want to capture
533+ # KeyboardInterrupts
534+ self .exceptions .extend ([KeyboardInterrupt , WorkerInterrupt ])
535+ else :
536+ raise ValueError ("Unsupported backend: %s" % self .backend )
537+ return n_jobs
538+
539+ def _terminate_pool (self ):
540+ if self ._pool is not None :
541+ self ._pool .close ()
542+ self ._pool .terminate () # terminate does a join()
543+ self ._pool = None
544+ if self .backend == 'multiprocessing' :
545+ os .environ .pop (JOBLIB_SPAWNED_PROCESS , 0 )
450546
451547 def _dispatch (self , batch ):
452548 """Queue the batch for computing, with or without multiprocessing
@@ -455,6 +551,10 @@ def _dispatch(self, batch):
455551 indirectly via dispatch_one_batch.
456552
457553 """
554+ # If job.get() catches an exception, it closes the queue:
555+ if self ._aborting :
556+ return
557+
458558 if self ._pool is None :
459559 job = ImmediateComputeBatch (batch )
460560 self ._jobs .append (job )
@@ -467,10 +567,6 @@ def _dispatch(self, batch):
467567 short_format_time (time .time () - self ._start_time )
468568 ))
469569 else :
470- # If job.get() catches an exception, it closes the queue:
471- if self ._aborting :
472- return
473-
474570 dispatch_timestamp = time .time ()
475571 cb = BatchCompletionCallBack (dispatch_timestamp , len (batch ), self )
476572 job = self ._pool .apply_async (SafeFunction (batch ), callback = cb )
@@ -600,7 +696,7 @@ def print_progress(self):
600696 if (is_last_item or cursor % frequency ):
601697 return
602698 remaining_time = (elapsed_time / (index + 1 ) *
603- (self .n_dispatched_tasks - index - 1. ))
699+ (self .n_dispatched_tasks - index - 1. ))
604700 self ._print ('Done %3i out of %3i | elapsed: %s remaining: %s' ,
605701 (index + 1 ,
606702 total_tasks ,
@@ -615,127 +711,62 @@ def retrieve(self):
615711 # Wait for an async callback to dispatch new jobs
616712 time .sleep (0.01 )
617713 continue
618- # We need to be careful: the job queue can be filling up as
619- # we empty it
620- if hasattr (self , '_lock' ):
621- self ._lock .acquire ()
622- job = self ._jobs .pop (0 )
623- if hasattr (self , '_lock' ):
624- self ._lock .release ()
714+ # We need to be careful: the job list can be filling up as
715+ # we empty it and Python list are not thread-safe by default hence
716+ # the use of the lock
717+ with self ._lock :
718+ job = self ._jobs .pop (0 )
625719 try :
626720 self ._output .extend (job .get ())
627721 except tuple (self .exceptions ) as exception :
628- try :
629- self ._aborting = True
630- self ._lock .acquire ()
631- if isinstance (exception ,
632- (KeyboardInterrupt , WorkerInterrupt )):
633- # We have captured a user interruption, clean up
634- # everything
635- if hasattr (self , '_pool' ):
636- self ._pool .close ()
637- self ._pool .terminate ()
638- # We can now allow subprocesses again
639- os .environ .pop ('__JOBLIB_SPAWNED_PARALLEL__' , 0 )
640- raise exception
641- elif isinstance (exception , TransportableException ):
642- # Capture exception to add information on the local
643- # stack in addition to the distant stack
644- this_report = format_outer_frames (context = 10 ,
645- stack_start = 1 )
646- report = """Multiprocessing exception:
647- %s
648- ---------------------------------------------------------------------------
649- Sub-process traceback:
650- ---------------------------------------------------------------------------
651- %s""" % (
652- this_report ,
653- exception .message ,
654- )
655- # Convert this to a JoblibException
656- exception_type = _mk_exception (exception .etype )[0 ]
657- raise exception_type (report )
658- raise exception
659- finally :
660- self ._lock .release ()
722+ # Stop dispatching any new job in the async callback thread
723+ self ._aborting = True
724+
725+ if isinstance (exception , TransportableException ):
726+ # Capture exception to add information on the local
727+ # stack in addition to the distant stack
728+ this_report = format_outer_frames (context = 10 ,
729+ stack_start = 1 )
730+ report = """Multiprocessing exception:
731+ %s
732+ ---------------------------------------------------------------------------
733+ Sub-process traceback:
734+ ---------------------------------------------------------------------------
735+ %s""" % (this_report , exception .message )
736+ # Convert this to a JoblibException
737+ exception_type = _mk_exception (exception .etype )[0 ]
738+ exception = exception_type (report )
739+
740+ # Kill remaining running processes without waiting for
741+ # the results as we will raise the exception we got back
742+ # to the caller instead of returning any result.
743+ with self ._lock :
744+ self ._terminate_pool ()
745+ if self ._managed_pool :
746+ # In case we had to terminate a managed pool, let
747+ # us start a new one to ensure that subsequent calls
748+ # to __call__ on the same Parallel instance will get
749+ # a working pool as they expect.
750+ self ._initialize_pool ()
751+ raise exception
661752
662753 def __call__ (self , iterable ):
663754 if self ._jobs :
664755 raise ValueError ('This Parallel instance is already running' )
665- n_jobs = self .n_jobs
666- if n_jobs == 0 :
667- raise ValueError ('n_jobs == 0 in Parallel has no meaning' )
668- if n_jobs < 0 and mp is not None :
669- n_jobs = max (mp .cpu_count () + 1 + n_jobs , 1 )
670-
671- # The list of exceptions that we will capture
672- self .exceptions = [TransportableException ]
673- self ._lock = threading .Lock ()
674-
675- # Whether or not to set an environment flag to track
676- # multiple process spawning
677- set_environ_flag = False
678- if (n_jobs is None or mp is None or n_jobs == 1 ):
679- n_jobs = 1
680- self ._pool = None
681- elif self .backend == 'threading' :
682- self ._pool = ThreadPool (n_jobs )
683- elif self .backend == 'multiprocessing' :
684- if mp .current_process ().daemon :
685- # Daemonic processes cannot have children
686- n_jobs = 1
687- self ._pool = None
688- warnings .warn (
689- 'Multiprocessing-backed parallel loops cannot be nested,'
690- ' setting n_jobs=1' ,
691- stacklevel = 2 )
692- elif threading .current_thread ().name != 'MainThread' :
693- # Prevent posix fork inside in non-main posix threads
694- n_jobs = 1
695- self ._pool = None
696- warnings .warn (
697- 'Multiprocessing backed parallel loops cannot be nested'
698- ' below threads, setting n_jobs=1' ,
699- stacklevel = 2 )
700- else :
701- already_forked = int (os .environ .get ('__JOBLIB_SPAWNED_PARALLEL__' , 0 ))
702- if already_forked :
703- raise ImportError ('[joblib] Attempting to do parallel computing '
704- 'without protecting your import on a system that does '
705- 'not support forking. To use parallel-computing in a '
706- 'script, you must protect your main loop using "if '
707- "__name__ == '__main__'"
708- '". Please see the joblib documentation on Parallel '
709- 'for more information'
710- )
711-
712- # Make sure to free as much memory as possible before forking
713- gc .collect ()
714-
715- # Set an environment variable to avoid infinite loops
716- set_environ_flag = True
717- poolargs = dict (
718- max_nbytes = self ._max_nbytes ,
719- mmap_mode = self ._mmap_mode ,
720- temp_folder = self ._temp_folder ,
721- verbose = max (0 , self .verbose - 50 ),
722- context_id = 0 , # the pool is used only for one call
723- )
724- if self ._mp_context is not None :
725- # Use Python 3.4+ multiprocessing context isolation
726- poolargs ['context' ] = self ._mp_context
727- self ._pool = MemmapingPool (n_jobs , ** poolargs )
728- # We are using multiprocessing, we also want to capture
729- # KeyboardInterrupts
730- self .exceptions .extend ([KeyboardInterrupt , WorkerInterrupt ])
756+ # A flag used to abort the dispatching of jobs in case an
757+ # exception is found
758+ self ._aborting = False
759+ if not self ._managed_pool :
760+ n_jobs = self ._initialize_pool ()
731761 else :
732- raise ValueError ( "Unsupported backend: %s" % self .backend )
762+ n_jobs = self ._effective_n_jobs ( )
733763
734764 if self .batch_size == 'auto' :
735765 self ._effective_batch_size = 1
736766
737767 iterator = iter (iterable )
738768 pre_dispatch = self .pre_dispatch
769+
739770 if pre_dispatch == 'all' or n_jobs == 1 :
740771 # prevent further dispatch via multiprocessing callback thread
741772 self ._original_iterator = None
@@ -757,9 +788,6 @@ def __call__(self, iterable):
757788 self .n_completed_tasks = 0
758789 self ._smoothed_batch_duration = 0.0
759790 try :
760- if set_environ_flag :
761- # Set an environment variable to avoid infinite loops
762- os .environ [JOBLIB_SPAWNED_PROCESS ] = '1'
763791 self ._iterating = True
764792
765793 while self .dispatch_one_batch (iterator ):
@@ -774,17 +802,11 @@ def __call__(self, iterable):
774802 # Make sure that we get a last message telling us we are done
775803 elapsed_time = time .time () - self ._start_time
776804 self ._print ('Done %3i out of %3i | elapsed: %s finished' ,
777- (len (self ._output ),
778- len (self ._output ),
779- short_format_time (elapsed_time )
780- ))
781-
805+ (len (self ._output ), len (self ._output ),
806+ short_format_time (elapsed_time )))
782807 finally :
783- if n_jobs > 1 :
784- self ._pool .close ()
785- self ._pool .terminate () # terminate does a join()
786- if self .backend == 'multiprocessing' :
787- os .environ .pop (JOBLIB_SPAWNED_PROCESS , 0 )
808+ if not self ._managed_pool :
809+ self ._terminate_pool ()
788810 self ._jobs = list ()
789811 output = self ._output
790812 self ._output = None
0 commit comments