import abc
import os
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.sessionutils as sessionutils
import pipeline.infrastructure.vdp as vdp
from pipeline.domain import DataTable
from pipeline.h.heuristics import caltable as caltable_heuristic
from pipeline.hsd.heuristics import CubicSplineFitParamConfig
from pipeline.hsd.tasks.common import utils as sdutils
from pipeline.infrastructure import casa_tasks
from pipeline.infrastructure import casa_tools
from . import plotter
from .. import common
_LOG = infrastructure.get_logger(__name__)
LOG = sdutils.OnDemandStringParseLogger(_LOG)
[docs]class BaselineSubtractionResults(common.SingleDishResults):
def __init__(self, task=None, success=None, outcome=None):
super(BaselineSubtractionResults, self).__init__(task, success, outcome)
[docs] def merge_with_context(self, context):
super(BaselineSubtractionResults, self).merge_with_context(context)
def _outcome_name(self):
# outcome should be a name of blparam text file
return 'blparam: "%s" bloutput: "%s"' % (self.outcome['blparam'], self.outcome['bloutput'])
# Base class for workers
[docs]class BaselineSubtractionWorker(basetask.StandardTaskTemplate):
Inputs = BaselineSubtractionWorkerInputs
@abc.abstractproperty
def Heuristics(self):
"""
A reference to the :class:`Heuristics` class.
"""
raise NotImplementedError
is_multi_vis_task = False
def __init__(self, inputs):
super(BaselineSubtractionWorker, self).__init__(inputs)
# initialize plotter
self.datatable = DataTable(os.path.join(self.inputs.context.observing_run.ms_datatable_name,
self.inputs.ms.basename))
[docs] def prepare(self):
vis = self.inputs.vis
ms = self.inputs.ms
fit_order = self.inputs.fit_order
edge = self.inputs.edge
args = self.inputs.to_casa_args()
blparam = args['blparam']
bloutput = args['bloutput']
outfile = args['outfile']
datacolumn = args['datacolumn']
process_list = self.inputs.plan
deviationmask_list = self.inputs.deviationmask
LOG.info('deviationmask_list={}'.format(deviationmask_list))
field_id_list = self.inputs.field
antenna_id_list = self.inputs.antenna
spw_id_list = self.inputs.spw
LOG.debug('subgroup member for {vis}:\n\tfield: {field}\n\tantenna: {antenna}\n\tspw: {spw}',
vis=ms.basename,
field=field_id_list,
antenna=antenna_id_list,
spw=spw_id_list)
# initialization of blparam file
# blparam file needs to be removed before starting iteration through
# reduction group
if os.path.exists(blparam):
LOG.debug('Cleaning up blparam file for {vis}', vis=vis)
os.remove(blparam)
#datatable = DataTable(context.observing_run.ms_datatable_name)
for (field_id, antenna_id, spw_id) in process_list.iterate_id():
if (field_id, antenna_id, spw_id) in deviationmask_list:
deviationmask = deviationmask_list[(field_id, antenna_id, spw_id)]
else:
deviationmask = None
blparam_heuristic = self.Heuristics(switchpoly=self.inputs.switchpoly)
formatted_edge = list(common.parseEdge(edge))
out_blparam = blparam_heuristic(self.datatable, ms, antenna_id, field_id, spw_id,
fit_order, formatted_edge, deviationmask, blparam)
assert out_blparam == blparam
# execute sdbaseline
job = casa_tasks.sdbaseline(infile=vis, datacolumn=datacolumn, blmode='fit', dosubtract=True,
blformat='table', bloutput=bloutput,
blfunc='variable', blparam=blparam,
outfile=outfile, overwrite=True)
self._executor.execute(job)
outcome = {'infile': vis,
'blparam': blparam,
'bloutput': bloutput,
'outfile': outfile}
results = BaselineSubtractionResults(success=True, outcome=outcome)
return results
[docs] def analyse(self, results):
# plot
# initialize plot manager
plot_manager = plotter.BaselineSubtractionPlotManager(self.inputs.context, self.datatable)
outfile = results.outcome['outfile']
ms = self.inputs.ms
org_directions_dict = self.inputs.org_directions_dict
accum = self.inputs.plan
deviationmask_list = self.inputs.deviationmask
LOG.info('deviationmask_list={}'.format(deviationmask_list))
status = plot_manager.initialize(ms, outfile)
plot_list = []
for (field_id, antenna_id, spw_id, grid_table, channelmap_range) in accum.iterate_all():
LOG.info('field {0} antenna {1} spw {2}', field_id, antenna_id, spw_id)
if (field_id, antenna_id, spw_id) in deviationmask_list:
deviationmask = deviationmask_list[(field_id, antenna_id, spw_id)]
else:
deviationmask = None
if status:
fields = ms.get_fields(field_id=field_id)
source_name = fields[0].source.name
if source_name not in org_directions_dict:
raise RuntimeError("source_name {} not found in org_directions_dict (sources found are {})"
"".format(source_name, list(org_directions_dict.keys())))
org_direction = org_directions_dict[source_name]
plot_list.extend(plot_manager.plot_spectra_with_fit(field_id, antenna_id, spw_id,
org_direction,
grid_table,
deviationmask, channelmap_range))
plot_manager.finalize()
results.outcome['plot_list'] = plot_list
return results
# Worker class for cubic spline fit
[docs]class CubicSplineBaselineSubtractionWorker(BaselineSubtractionWorker):
Inputs = BaselineSubtractionWorkerInputs
Heuristics = CubicSplineFitParamConfig
# Tier-0 Parallelization
# This is abstract class since Task is not specified yet
[docs]class HpcBaselineSubtractionWorker(sessionutils.ParallelTemplate):
Inputs = HpcBaselineSubtractionWorkerInputs
def __init__(self, inputs):
super(HpcBaselineSubtractionWorker, self).__init__(inputs)
@basetask.result_finaliser
def get_result_for_exception(self, vis, exception):
LOG.error('Error operating target flag for {!s}'.format(os.path.basename(vis)))
LOG.error('{0}({1})'.format(exception.__class__.__name__, str(exception)))
import traceback
tb = traceback.format_exc()
if tb.startswith('None'):
tb = '{0}({1})'.format(exception.__class__.__name__, str(exception))
return basetask.FailedTaskResults(self.__class__, exception, tb)
[docs]class HpcCubicSplineBaselineSubtractionWorker(HpcBaselineSubtractionWorker):
Task = CubicSplineBaselineSubtractionWorker
def __init__(self, inputs):
super(HpcCubicSplineBaselineSubtractionWorker, self).__init__(inputs)
# # facade for FitParam
# class BaselineSubtractionInputs(vdp.ModeInputs):
# _modes = {'spline': CubicSplineBaselineSubtractionWorker,
# 'cspline': CubicSplineBaselineSubtractionWorker}
#
# def __init__(self, context, fitfunc, **parameters):
# super(BaselineSubtractionInputs, self).__init__(context=context, mode=fitfunc, **parameters)
#
#
# class BaselineSubtractionTask(basetask.ModeTask):
# Inputs = BaselineSubtractionInputs