@@ -892,10 +892,158 @@ def close(self):
892892 pass
893893
894894
895+ class PsijWorker (Worker ):
896+ """A worker to execute tasks using PSI/J."""
897+
898+ def __init__ (self , subtype , ** kwargs ):
899+ """
900+ Initialize PsijWorker.
901+
902+ Parameters
903+ ----------
904+ subtype : str
905+ Scheduler for PSI/J.
906+ """
907+ try :
908+ import psij
909+ except ImportError :
910+ logger .critical ("Please install psij." )
911+ raise
912+ logger .debug ("Initialize PsijWorker" )
913+ self .psij = psij
914+
915+ # Check if the provided subtype is valid
916+ valid_subtypes = ["local" , "slurm" ]
917+ if subtype not in valid_subtypes :
918+ raise ValueError (
919+ f"Invalid 'subtype' provided. Available options: { ', ' .join (valid_subtypes )} "
920+ )
921+
922+ self .subtype = subtype
923+
924+ def run_el (self , interface , rerun = False , ** kwargs ):
925+ """Run a task."""
926+ return self .exec_psij (interface , rerun = rerun )
927+
928+ def make_spec (self , cmd = None , arg = None ):
929+ """
930+ Create a PSI/J job specification.
931+
932+ Parameters
933+ ----------
934+ cmd : str, optional
935+ Executable command. Defaults to None.
936+ arg : list, optional
937+ List of arguments. Defaults to None.
938+
939+ Returns
940+ -------
941+ psij.JobSpec
942+ PSI/J job specification.
943+ """
944+ spec = self .psij .JobSpec ()
945+ spec .executable = cmd
946+ spec .arguments = arg
947+
948+ return spec
949+
950+ def make_job (self , spec , attributes ):
951+ """
952+ Create a PSI/J job.
953+
954+ Parameters
955+ ----------
956+ spec : psij.JobSpec
957+ PSI/J job specification.
958+ attributes : any
959+ Job attributes.
960+
961+ Returns
962+ -------
963+ psij.Job
964+ PSI/J job.
965+ """
966+ job = self .psij .Job ()
967+ job .spec = spec
968+ return job
969+
970+ async def exec_psij (self , runnable , rerun = False ):
971+ """
972+ Run a task (coroutine wrapper).
973+
974+ Raises
975+ ------
976+ Exception
977+ If stderr is not empty.
978+
979+ Returns
980+ -------
981+ None
982+ """
983+ import pickle
984+ from pathlib import Path
985+
986+ jex = self .psij .JobExecutor .get_instance (self .subtype )
987+ absolute_path = Path (__file__ ).parent
988+
989+ if isinstance (runnable , TaskBase ):
990+ cache_dir = runnable .cache_dir
991+ file_path = cache_dir / "runnable_function.pkl"
992+ with open (file_path , "wb" ) as file :
993+ pickle .dump (runnable ._run , file )
994+ func_path = absolute_path / "run_pickled.py"
995+ spec = self .make_spec ("python" , [func_path , file_path ])
996+ else : # it could be tuple that includes pickle files with tasks and inputs
997+ cache_dir = runnable [- 1 ].cache_dir
998+ file_path_1 = cache_dir / "taskmain.pkl"
999+ file_path_2 = cache_dir / "ind.pkl"
1000+ ind , task_main_pkl , task_orig = runnable
1001+ with open (file_path_1 , "wb" ) as file :
1002+ pickle .dump (task_main_pkl , file )
1003+ with open (file_path_2 , "wb" ) as file :
1004+ pickle .dump (ind , file )
1005+ func_path = absolute_path / "run_pickled.py"
1006+ spec = self .make_spec (
1007+ "python" ,
1008+ [
1009+ func_path ,
1010+ file_path_1 ,
1011+ file_path_2 ,
1012+ ],
1013+ )
1014+
1015+ if rerun :
1016+ spec .arguments .append ("--rerun" )
1017+
1018+ spec .stdout_path = cache_dir / "demo.stdout"
1019+ spec .stderr_path = cache_dir / "demo.stderr"
1020+
1021+ job = self .make_job (spec , None )
1022+ jex .submit (job )
1023+ job .wait ()
1024+
1025+ if spec .stderr_path .stat ().st_size > 0 :
1026+ with open (spec .stderr_path , "r" ) as stderr_file :
1027+ stderr_contents = stderr_file .read ()
1028+ raise Exception (
1029+ f"stderr_path '{ spec .stderr_path } ' is not empty. Contents:\n { stderr_contents } "
1030+ )
1031+
1032+ return
1033+
1034+ def close (self ):
1035+ """Finalize the internal pool of tasks."""
1036+ pass
1037+
1038+
8951039WORKERS = {
8961040 "serial" : SerialWorker ,
8971041 "cf" : ConcurrentFuturesWorker ,
8981042 "slurm" : SlurmWorker ,
8991043 "dask" : DaskWorker ,
9001044 "sge" : SGEWorker ,
1045+ ** {
1046+ "psij-" + subtype : lambda subtype = subtype : PsijWorker (subtype = subtype )
1047+ for subtype in ["local" , "slurm" ]
1048+ },
9011049}
0 commit comments