import abc
import collections
import itertools
import os
import tempfile
from pipeline.infrastructure import basetask
from pipeline.infrastructure import exceptions
from . import mpihelpers
from . import utils
from . import vdp
__all__ = [
'as_list',
'group_into_sessions',
'parallel_inputs_impl',
'ParallelTemplate',
'remap_spw_int',
'remap_spw_str',
'VDPTaskFactory',
'VisResultTuple'
]
# VisResultTuple is a data structure used by VDPTaskFactor to group
# inputs and results.
VisResultTuple = collections.namedtuple('VisResultTuple', 'vis inputs result')
[docs]def as_list(o):
return o if isinstance(o, list) else [o]
[docs]def group_into_sessions(context, all_results):
"""
Return results grouped into lists by session.
:param context: pipeline context
:type context: Context
:param all_results: result to be grouped
:type all_results: list
:return: dict of sessions to results for that session
:rtype: dict {session name: [result, result, ...]
"""
session_map = {ms.basename: ms.session
for ms in context.observing_run.measurement_sets}
ms_start_times = {ms.basename: utils.get_epoch_as_datetime(ms.start_time)
for ms in context.observing_run.measurement_sets}
def get_session(r):
basename = os.path.basename(r[0])
return session_map.get(basename, 'Shared')
def get_start_time(r):
basename = os.path.basename(r[0])
return ms_start_times.get(basename, None)
results_by_session = sorted(all_results, key=get_session)
return {session_id: sorted(results_for_session, key=get_start_time)
for session_id, results_for_session in itertools.groupby(results_by_session, get_session)}
def group_vislist_into_sessions(context, vislist):
"""
Group the specified list of measurement sets 'vislist' by their session,
and return a dictionary that maps a session name to the corresponding list
of measurement sets.
:param context: pipeline context
:type context: :class:`~pipeline.infrastructure.launcher.Context`
:param vislist: list of vis to be grouped
:type vislist: list
:return: dictionary of sessions to vislist for that session
:rtype: dict {session name: [vis, vis, ...]}
"""
sessions = collections.defaultdict(list)
for vis in vislist:
sessions[getattr(context.observing_run.get_ms(vis), 'session', 'Shared')].append(vis)
return sessions
def get_vislist_for_session(context, session):
"""
Return list of measurement sets for specified session name.
:param context: pipeline context
:type context: :class:`~pipeline.infrastructure.launcher.Context`
:param session: name of session for which to retrieve list of vis
:type session: str
:return: list of vis for session
:rtype: list [vis, vis, ...]
"""
return [ms.name for ms in context.observing_run.get_measurement_sets() if ms.session == session]
[docs]class VDPTaskFactory(object):
"""
VDPTaskFactory is a class that implements the Factory design
pattern, returning tasks that execute on an MPI client or locally
as appropriate.
The correctness of this task is dependent on the correct mapping of
Inputs arguments to measurement set, hence it is dependent on
Inputs objects that sub-class VDP StandardInputs.
"""
def __init__(self, inputs, executor, task):
"""
Create a new VDPTaskFactory.
:param inputs: inputs for the task
:type inputs: class that extends vdp.StandardInputs
:param executor: pipeline task executor
:type executor: basetask.Executor
:param task: task to execute
"""
self.__inputs = inputs
self.__context = inputs.context
self.__executor = executor
self.__task = task
self.__context_path = None
def __enter__(self):
# If there's a possibility that we'll submit MPI jobs, save the context
# to disk ready for import by the MPI servers.
if mpihelpers.mpiclient:
# Use the tempfile module to generate a unique temporary filename,
# which we use as the output path for our pickled context
tmpfile = tempfile.NamedTemporaryFile(suffix='.context',
dir=self.__context.output_dir,
delete=True)
self.__context_path = tmpfile.name
tmpfile.close()
self.__context.save(self.__context_path)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.__context_path and os.path.exists(self.__context_path):
os.unlink(self.__context_path)
[docs] def get_task(self, vis):
"""
Create and return a SyncTask or AsyncTask for the job.
:param vis: measurement set to create a job for
:type vis: str
:return: task object ready for execution
:rtype: a tuple of (task arguments, (SyncTask|AsyncTask)
"""
# get the task arguments for the targetted MS
task_args = self.__get_task_args(vis)
# Validate the task arguments against the signature expected by
# the task. This is necessary as self.__inputs could be a
# different type from the Inputs instance associated with the
# task, and hence and hence have a different signature. For
# instance, the session-aware Inputs classes accept a
# 'parallel' argument, which the non-session tasks do not.
#
# To make things complicated, ModeInputs are a special case.
# Here we shouldn't inspect the constructor signature of the
# task as it doesn't reflect that of the active task, which is
# the one that we would want to filter on. So, if the task's
# Inputs are ModeInputs, dereference the task's Inputs class.
if issubclass(self.__task.Inputs, vdp.ModeInputs):
active_mode_task = self.__task.Inputs._modes[self.__inputs.mode]
task_inputs_cls = active_mode_task.Inputs
else:
task_inputs_cls = self.__task.Inputs
valid_args = validate_args(task_inputs_cls, task_args)
is_mpi_ready = mpihelpers.is_mpi_ready()
is_tier0_job = is_mpi_ready
parallel_wanted = mpihelpers.parse_mpi_input_parameter(self.__inputs.parallel)
if is_tier0_job and parallel_wanted:
executable = mpihelpers.Tier0PipelineTask(self.__task, valid_args, self.__context_path)
return valid_args, mpihelpers.AsyncTask(executable)
else:
inputs = vdp.InputsContainer(self.__task, self.__context, **valid_args)
task = self.__task(inputs)
return valid_args, mpihelpers.SyncTask(task, self.__executor)
def __get_task_args(self, vis):
inputs = self.__inputs
original_vis = inputs.vis
try:
inputs.vis = vis
task_args = inputs.as_dict()
# support for single-dish tasks
if 'infiles' in task_args:
task_args['infiles'] = task_args['vis']
finally:
inputs.vis = original_vis
return task_args
def remove_unexpected_args(fn, fn_args):
# get the argument names for the function
code = fn.__code__
arg_count = code.co_argcount
arg_names = code.co_varnames[:arg_count]
# identify arguments that are not expected by the function
unexpected = [k for k in fn_args if k not in arg_names]
# return the fn args purged of any unexpected items
x = {k: v for k, v in fn_args.items() if k not in unexpected}
# LOG.info('Arg names: {!s}'.format(arg_names))
# LOG.info('Unexpected: {!s}'.format(unexpected))
# LOG.info('Valid: {!s}'.format(x))
return x
def validate_args(inputs_cls, task_args):
inputs_constructor_fn = getattr(inputs_cls, '__init__')
valid_args = remove_unexpected_args(inputs_constructor_fn, task_args)
return valid_args
def get_spwmap(source_ms, target_ms):
"""
Get a map of spectral windows IDs that map from a source spw ID in
the source MS to its equivalent spw in the target MS.
:param source_ms: the MS to map spws from
:type source_ms: domain.MeasurementSet
:param target_ms: the MS to map spws to
:type target_ms: domain.MeasurementSet
:param spws: the spw argument to convert
:return: dict of integer spw IDs
"""
# spw names are not guaranteed to be unique. They seem to be unique
# across science intents, but they could be repeated for other
# scans (pointing, sideband, etc.) in non-science spectral windows.
# if not eliminated, these non-science window names collide with
# the science windows and you end up having science windows map to
# non-science windows and vice versa. Not what we want! This set
# will be used to filter for the spectral windows we want to
# consider.
science_intents = {'AMPLITUDE', 'BANDPASS', 'PHASE', 'TARGET', 'CHECK'}
# map spw id to spw name for source MS - just for science intents
id_to_name = {spw.id: spw.name
for spw in source_ms.spectral_windows
if not science_intents.isdisjoint(spw.intents)}
# map spw name to spw id for target MS - just for science intents
name_to_id = {spw.name: spw.id
for spw in target_ms.spectral_windows
if not science_intents.isdisjoint(spw.intents)}
# note the get(v, k) here. This says that for non-science windows,
# which have been filtered from the maps, use the original spw ID -
# hence non-science spws are not remapped.
return {k: name_to_id.get(v, k)
for k, v in id_to_name.items()
if v in name_to_id}
[docs]def remap_spw_int(source_ms, target_ms, spws):
"""
Map integer spw arguments from one MS to their equivalent spw in
the target ms.
:param source_ms: the MS to map spws from
:type source_ms: domain.MeasurementSet
:param target_ms: the MS to map spws to
:type target_ms: domain.MeasurementSet
:param spws: the spw argument to convert
:return: a list of remapped integer spw IDs
:rtype: list
"""
int_spw_map = get_spwmap(source_ms, target_ms)
return [int_spw_map[spw_id] for spw_id in spws]
[docs]def remap_spw_str(source_ms, target_ms, spws):
"""
Remap a string spw argument, e.g., '16,18,20,22', from one MS to
the equivalent map in the target ms.
:param source_ms: the MS to map spws from
:type source_ms: domain.MeasurementSet
:param target_ms: the MS to map spws to
:type target_ms: domain.MeasurementSet
:param spws: the spw argument to convert
:return: a list of remapped integer spw IDs
:rtype: str
"""
spw_ints = [int(i) for i in spws.split(',')]
l = remap_spw_int(source_ms, target_ms, spw_ints)
return ','.join([str(i) for i in l])
[docs]class ParallelTemplate(basetask.StandardTaskTemplate):
is_multi_vis_task = True
@abc.abstractproperty
def Task(self):
"""
A reference to the :class:`Task` class containing the implementation
for this pipeline stage.
"""
raise NotImplementedError
def __init__(self, inputs):
super(ParallelTemplate, self).__init__(inputs)
[docs] def get_result_for_exception(self, vis, result):
raise NotImplementedError
[docs] def prepare(self):
inputs = self.inputs
# this will hold the tuples of ms, jobs and results
assessed = []
vis_list = as_list(inputs.vis)
with VDPTaskFactory(inputs, self._executor, self.Task) as factory:
task_queue = [(vis, factory.get_task(vis)) for vis in vis_list]
# Jobs must complete within the scope of the VDPTaskFactory as the
# context copies used by the MPI clients are removed on __exit__.
for (vis, (task_args, task)) in task_queue:
try:
worker_result = task.get_result()
except exceptions.PipelineException as e:
assessed.append((vis, task_args, e))
else:
assessed.append((vis, task_args, worker_result))
return assessed
[docs] def analyse(self, assessed):
# all results will be added to this object
final_result = basetask.ResultsList()
context = self.inputs.context
session_groups = group_into_sessions(context, assessed)
for session_id, session_results in session_groups.items():
for vis, task_args, vis_result in session_results:
if isinstance(vis_result, Exception):
fake_result = self.get_result_for_exception(vis, vis_result)
fake_result.inputs = task_args
final_result.append(fake_result)
else:
if isinstance(vis_result, collections.Iterable):
final_result.extend(vis_result)
else:
final_result.append(vis_result)
return final_result