Source code for pipeline.hsd.tasks.baselineflag.baselineflag
import os
import collections
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.sessionutils as sessionutils
import pipeline.infrastructure.utils as utils
import pipeline.infrastructure.vdp as vdp
from pipeline.h.heuristics import fieldnames
from pipeline.infrastructure import casa_tasks
from pipeline.infrastructure import task_registry
from . import worker
from .flagsummary import SDBLFlagSummary
from .. import common
from ..common import utils as sdutils
LOG = infrastructure.get_logger(__name__)
[docs]class SDBLFlagInputs(vdp.StandardInputs):
"""
Inputs for single dish flagging
"""
def __to_numeric(self, val):
return sdutils.to_numeric(val)
def __to_bool(self, val):
return sdutils.to_bool(val)
def __to_int(self, val):
return int(val)
intent = vdp.VisDependentProperty(default='TARGET')
iteration = vdp.VisDependentProperty(default=5, fconvert=__to_int)
flag_tsys = vdp.VisDependentProperty(default=True, fconvert=__to_bool)
tsys_thresh = vdp.VisDependentProperty(default=3.0, fconvert=__to_numeric)
flag_weath = vdp.VisDependentProperty(default=False, fconvert=__to_bool)
weath_thresh = vdp.VisDependentProperty(default=3.0, fconvert=__to_numeric)
flag_prfre = vdp.VisDependentProperty(default=True, fconvert=__to_bool)
prfre_thresh = vdp.VisDependentProperty(default=3.0, fconvert=__to_numeric)
flag_pofre = vdp.VisDependentProperty(default=True, fconvert=__to_bool)
pofre_thresh = vdp.VisDependentProperty(default=1.3333, fconvert=__to_numeric)
flag_prfr = vdp.VisDependentProperty(default=True, fconvert=__to_bool)
prfr_thresh = vdp.VisDependentProperty(default=4.5, fconvert=__to_numeric)
flag_pofr = vdp.VisDependentProperty(default=True, fconvert=__to_bool)
pofr_thresh = vdp.VisDependentProperty(default=4.0, fconvert=__to_numeric)
flag_prfrm = vdp.VisDependentProperty(default=True, fconvert=__to_bool)
prfrm_thresh = vdp.VisDependentProperty(default=5.5, fconvert=__to_numeric)
prfrm_nmean = vdp.VisDependentProperty(default=5, fconvert=__to_int)
flag_pofrm = vdp.VisDependentProperty(default=True, fconvert=__to_bool)
pofrm_thresh = vdp.VisDependentProperty(default=5.0, fconvert=__to_numeric)
pofrm_nmean = vdp.VisDependentProperty(default=5, fconvert=__to_int)
flag_user = vdp.VisDependentProperty(default=False, fconvert=__to_bool)
user_thresh = vdp.VisDependentProperty(default=5.0, fconvert=__to_numeric)
plotflag = vdp.VisDependentProperty(default=True, fconvert=__to_bool)
@vdp.VisDependentProperty
def infiles(self):
return self.vis
@infiles.convert
def infiles(self, value):
self.vis = value
return value
@iteration.convert
def iteration(self, value):
return int(value)
edge = vdp.VisDependentProperty(default=[0, 0])
@edge.convert
def edge(self, value):
return sdutils.to_list(value)
@vdp.VisDependentProperty
def antenna(self):
return ''
@antenna.convert
def antenna(self, value):
antennas = self.ms.get_antenna(value)
# if all antennas are selected, return ''
if len(antennas) == len(self.ms.antennas):
return ''
return utils.find_ranges([a.id for a in antennas])
# return ','.join([str(a.id) for a in antennas])
@vdp.VisDependentProperty
def field(self):
# this will give something like '0542+3243,0343+242'
field_finder = fieldnames.IntentFieldnames()
intent_fields = field_finder.calculate(self.ms, self.intent)
# run the answer through a set, just in case there are duplicates
fields = set()
fields.update(utils.safe_split(intent_fields))
return ','.join(fields)
@vdp.VisDependentProperty
def spw(self):
science_spws = self.ms.get_spectral_windows(with_channels=True)
return ','.join([str(spw.id) for spw in science_spws])
@vdp.VisDependentProperty
def pol(self):
# filters polarization by self.spw
selected_spwids = [int(spwobj.id) for spwobj in self.ms.get_spectral_windows(self.spw, with_channels=True)]
pols = set()
for idx in selected_spwids:
pols.update(self.ms.get_data_description(spw=idx).corr_axis)
return ','.join(pols)
def __init__(self, context, output_dir=None,
iteration=None, edge=None, flag_tsys=None, tsys_thresh=None,
flag_weath=None, weath_thresh=None,
flag_prfre=None, prfre_thresh=None,
flag_pofre=None, pofre_thresh=None,
flag_prfr=None, prfr_thresh=None,
flag_pofr=None, pofr_thresh=None,
flag_prfrm=None, prfrm_thresh=None, prfrm_nmean=None,
flag_pofrm=None, pofrm_thresh=None, pofrm_nmean=None,
flag_user=None, user_thresh=None,
plotflag=None,
infiles=None, antenna=None, field=None,
spw=None, pol=None):
super(SDBLFlagInputs, self).__init__()
# context and vis/infiles must be set first so that properties that require
# domain objects can be function
self.context = context
self.infiles = infiles
self.output_dir = output_dir
# task specific parameters
self.iteration = iteration
self.edge = edge
self.flag_tsys = flag_tsys
self.tsys_thresh = tsys_thresh
self.flag_weath = flag_weath
self.weath_thresh = weath_thresh
self.flag_prfre = flag_prfre
self.prfre_thresh = prfre_thresh
self.flag_pofre = flag_pofre
self.pofre_thresh = pofre_thresh
self.flag_prfr = flag_prfr
self.prfr_thresh = prfr_thresh
self.flag_pofr = flag_pofr
self.pofr_thresh = pofr_thresh
self.flag_prfrm = flag_prfrm
self.prfrm_thresh = prfrm_thresh
self.prfrm_nmean = prfrm_nmean
self.flag_pofrm = flag_pofrm
self.pofrm_thresh = pofrm_thresh
self.pofrm_nmean = pofrm_nmean
self.flag_user = flag_user
self.user_thresh = user_thresh
self.plotflag = plotflag
self.antenna = antenna
self.field = field
self.spw = spw
self.pol = pol
### Default Flag rule
from . import SDFlagRule
self.FlagRuleDictionary = SDFlagRule.SDFlagRule
# MUST NOT configure FlagRuleDictionary here.
def _configureFlagRule(self):
"""A private method to convert input parameters to FlagRuleDictionary"""
d = {'TsysFlag': (self.flag_tsys, [self.tsys_thresh]),
'WeatherFlag': (self.flag_weath, [self.weath_thresh]),
'UserFlag': (self.flag_user, [self.user_thresh]),
'RmsPreFitFlag': (self.flag_prfr, [self.prfr_thresh]),
'RmsPostFitFlag': (self.flag_pofr, [self.pofr_thresh]),
'RmsExpectedPreFitFlag': (self.flag_prfre, [self.prfre_thresh]),
'RmsExpectedPostFitFlag': (self.flag_pofre, [self.pofre_thresh]),
'RunMeanPreFitFlag': (self.flag_prfrm, [self.prfrm_thresh, self.prfrm_nmean]),
'RunMeanPostFitFlag': (self.flag_pofrm, [self.pofrm_thresh, self.pofrm_nmean])}
keys = ['Threshold', 'Nmean']
for k, v in d.items():
(b, p) = v
if b == True:
self.activateFlagRule(k)
for i in range(len(p)):
self.FlagRuleDictionary[k][keys[i]] = p[i]
elif b == False:
self.deactivateFlagRule(k)
else:
raise RuntimeError("Invalid flag operation definition for %s" % k)
[docs] def activateFlagRule(self, key):
"""Activates a flag type specified by the input parameter in FlagRuleDictionary"""
if key in self.FlagRuleDictionary:
self.FlagRuleDictionary[key]['isActive'] = True
else:
raise RuntimeError('Error: %s not in predefined Flagging Rules' % key)
[docs] def deactivateFlagRule(self, key):
"""Deactivates a flag type specified by the input parameter in FlagRuleDictionary"""
if key in self.FlagRuleDictionary:
self.FlagRuleDictionary[key]['isActive'] = False
else:
raise RuntimeError('Error: %s not in predefined Flagging Rules' % key)
[docs]class SDBLFlagResults(common.SingleDishResults):
"""
The results of SDFalgData
"""
def __init__(self, task=None, success=None, outcome=None):
super(SDBLFlagResults, self).__init__(task, success, outcome)
[docs] def merge_with_context(self, context):
super(SDBLFlagResults, self).merge_with_context(context)
def _outcome_name(self):
return 'none'
# @task_registry.set_equivalent_casa_task('hsd_blflag')
# @task_registry.set_casa_commands_comment(
# 'Perform row-based flagging based on noise level and quality of spectral baseline subtraction.\n'
# 'This stage performs a pipeline calculation without running any CASA commands to be put in this file.'
# )
[docs]class SerialSDBLFlag(basetask.StandardTaskTemplate):
"""
Single dish flagging class.
"""
##################################################
# Note
# The class uses _handle_multiple_vis framework.
# Method, prepare() is called per MS. Inputs.ms
# holds "an" MS instance to be processed.
##################################################
Inputs = SDBLFlagInputs
[docs] def prepare(self):
"""
Iterates over reduction group and invoke flagdata worker function in each clip_niteration.
"""
inputs = self.inputs
context = inputs.context
# name of MS to process
cal_name = inputs.ms.name
bl_name = inputs.ms.work_data
in_ant = inputs.antenna
in_spw = inputs.spw
in_field = inputs.field
in_pol = '' if inputs.pol in ['', '*'] else inputs.pol.split(',')
clip_niteration = inputs.iteration
reduction_group = context.observing_run.ms_reduction_group
# configure FlagRuleDictionary
# this has to be done in runtime rather than in Inputs.__init__
# to accommodate later overwrite of parameters.
inputs._configureFlagRule()
flag_rule = inputs.FlagRuleDictionary
LOG.debug("Flag Rule for %s: %s" % (cal_name, flag_rule))
# sumarize flag before execution
full_intent = utils.to_CASA_intent(self.inputs.ms, self.inputs.intent)
flagdata_summary_job = casa_tasks.flagdata(vis=bl_name, mode='summary',
antenna=in_ant, field=in_field,
spw=in_spw, intent=full_intent,
spwcorr=True, fieldcnt=True,
name='before')
stats_before = self._executor.execute(flagdata_summary_job)
# collection of field, antenna, and spw ids in reduction group per MS
registry = collections.defaultdict(sdutils.RGAccumulator)
# loop over reduction group (spw and source combination)
flagResult = []
for group_id, group_desc in reduction_group.items():
LOG.debug('Processing Reduction Group %s' % group_id)
LOG.debug('Group Summary:')
for m in group_desc:
LOG.debug('\t%s: Antenna %d (%s) Spw %d Field %d (%s)' %
(os.path.basename(m.ms.name), m.antenna_id,
m.antenna_name, m.spw_id, m.field_id, m.field_name))
nchan = group_desc.nchan
if nchan == 1:
LOG.info('Skipping a group of channel averaged spw')
continue
field_sel = ''
if len(in_field) == 0:
# fine, just go ahead
field_sel = in_field
elif group_desc.field_name in [x.strip('"') for x in in_field.split(',')]:
# pre-selection of the field name
field_sel = group_desc.field_name
else:
# no field name is included in in_field, skip
LOG.info('Skip reduction group {:d}'.format(group_id))
continue
# Which group in group_desc list should be processed
member_list = list(common.get_valid_ms_members(group_desc, [cal_name], in_ant, field_sel, in_spw))
LOG.trace('group %s: member_list=%s' % (group_id, member_list))
# skip this group if valid member list is empty
if len(member_list) == 0:
LOG.info('Skip reduction group %d' % group_id)
continue
member_list.sort() # list of group_desc IDs to flag
antenna_list = [group_desc[i].antenna_id for i in member_list]
spwid_list = [group_desc[i].spw_id for i in member_list]
ms_list = [group_desc[i].ms for i in member_list]
fieldid_list = [group_desc[i].field_id for i in member_list]
temp_dd_list = [ms_list[i].get_data_description(spw=spwid_list[i])
for i in range(len(member_list))]
pols_list = [[corr for corr in ddobj.corr_axis if (in_pol == '' or corr in in_pol)]
for ddobj in temp_dd_list]
del temp_dd_list
for i in range(len(member_list)):
member = group_desc[member_list[i]]
registry[member.ms].append(field_id=member.field_id,
antenna_id=member.antenna_id,
spw_id=member.spw_id,
pol_ids=pols_list[i])
# per-MS loop
for msobj, accumulator in registry.items():
rowmap = None
if os.path.abspath(cal_name) == os.path.abspath(bl_name):
LOG.warn("%s is not yet baselined. Skipping flag by post-fit statistics for the data."
" MASKLIST will also be cleared up. You may go on flagging but the statistics"
" will contain line emission." % self.inputs.ms.basename)
else:
# row map generation is very expensive. Do as few time as possible
_ms = context.observing_run.get_ms(msobj.name)
rowmap = sdutils.make_row_map_for_baselined_ms(_ms)
antenna_list = accumulator.get_antenna_id_list()
fieldid_list = accumulator.get_field_id_list()
spwid_list = accumulator.get_spw_id_list()
pols_list = accumulator.get_pol_ids_list()
LOG.info("*"*60)
LOG.info('Members to be processed:')
for antenna_id, field_id, spw_id, pol_ids in zip(antenna_list, fieldid_list, spwid_list, pols_list):
LOG.info("\t{}:: Antenna {} ({}) Spw {} Field {} ({}) Pol '{}'".format(
msobj.basename,
antenna_id,
msobj.antennas[antenna_id].name,
spw_id,
field_id,
msobj.fields[field_id].name,
','.join(pol_ids)))
LOG.info("*"*60)
nchan = 0
# Calculate flag and update DataTable
flagging_inputs = worker.SDBLFlagWorkerInputs(
context, clip_niteration,
msobj.name, antenna_list, fieldid_list,
spwid_list, pols_list, nchan, flag_rule,
rowmap=rowmap)
flagging_task = worker.SDBLFlagWorker(flagging_inputs)
flagging_results = self._executor.execute(flagging_task, merge=False)
thresholds = flagging_results.outcome
# Summary
if not basetask.DISABLE_WEBLOG:
renderer = SDBLFlagSummary(context, msobj,
antenna_list, fieldid_list, spwid_list,
pols_list, thresholds, flag_rule)
result = self._executor.execute(renderer, merge=False)
flagResult += result
# Calculate flag fraction after operation.
flagdata_summary_job = casa_tasks.flagdata(vis=bl_name, mode='summary',
antenna=in_ant, field=in_field,
spw=in_spw, intent=full_intent,
spwcorr=True, fieldcnt=True,
name='after')
stats_after = self._executor.execute(flagdata_summary_job)
outcome = {'flagdata_summary': [stats_before, stats_after],
'summary': flagResult}
results = SDBLFlagResults(task=self.__class__,
success=True,
outcome=outcome)
return results
### Tier-0 parallelization
[docs]class HpcSDBLFlagInputs(SDBLFlagInputs):
# use common implementation for parallel inputs argument
parallel = sessionutils.parallel_inputs_impl()
def __init__(self, context, output_dir=None,
iteration=None, edge=None,
flag_tsys=None, tsys_thresh=None,
flag_weath=None, weath_thresh=None,
flag_prfre=None, prfre_thresh=None,
flag_pofre=None, pofre_thresh=None,
flag_prfr=None, prfr_thresh=None,
flag_pofr=None, pofr_thresh=None,
flag_prfrm=None, prfrm_thresh=None, prfrm_nmean=None,
flag_pofrm=None, pofrm_thresh=None, pofrm_nmean=None,
flag_user=None, user_thresh=None, plotflag=None,
infiles=None, antenna=None, field=None,
spw=None, pol=None, parallel=None):
super(HpcSDBLFlagInputs, self).__init__(
context, output_dir=output_dir,
iteration=iteration, edge=edge,
flag_tsys=flag_tsys, tsys_thresh=tsys_thresh,
flag_weath=flag_weath, weath_thresh=weath_thresh,
flag_prfre=flag_prfre, prfre_thresh=prfre_thresh,
flag_pofre=flag_pofre, pofre_thresh=pofre_thresh,
flag_prfr=flag_prfr, prfr_thresh=prfr_thresh,
flag_pofr=flag_pofr, pofr_thresh=pofr_thresh,
flag_prfrm=flag_prfrm, prfrm_thresh=prfrm_thresh, prfrm_nmean=prfrm_nmean,
flag_pofrm=flag_pofrm, pofrm_thresh=pofrm_thresh, pofrm_nmean=pofrm_nmean,
flag_user=flag_user, user_thresh=user_thresh, plotflag=plotflag,
infiles=infiles, antenna=antenna, field=field, spw=spw, pol=pol)
self.parallel = parallel
[docs]@task_registry.set_equivalent_casa_task('hsd_blflag')
@task_registry.set_casa_commands_comment(
'Perform row-based flagging based on noise level and quality of spectral baseline subtraction.\n'
'This stage performs a pipeline calculation without running any CASA commands to be put in this file.'
)
class HpcSDBLFlag(sessionutils.ParallelTemplate):
Inputs = HpcSDBLFlagInputs
Task = SerialSDBLFlag
def __init__(self, inputs):
super(HpcSDBLFlag, self).__init__(inputs)
@basetask.result_finaliser
def get_result_for_exception(self, vis, exception):
LOG.error('Error operating target flag for {!s}'.format(os.path.basename(vis)))
LOG.error('{0}({1})'.format(exception.__class__.__name__, str(exception)))
import traceback
tb = traceback.format_exc()
if tb.startswith('None'):
tb = '{0}({1})'.format(exception.__class__.__name__, str(exception))
return basetask.FailedTaskResults(self.__class__, exception, tb)