import os
import numpy
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.vdp as vdp
from pipeline.domain import DataTable
from pipeline.infrastructure import casa_tools
from .. import common
LOG = infrastructure.get_logger(__name__)
[docs]class WeightMSResults(common.SingleDishResults):
def __init__(self, task=None, success=None, outcome=None):
super(WeightMSResults, self).__init__(task, success, outcome)
[docs] def merge_with_context(self, context):
super(WeightMSResults, self).merge_with_context(context)
def _outcome_name(self):
# return [image.imagename for image in self.outcome]
return self.outcome
[docs]class WeightMS(basetask.StandardTaskTemplate):
Inputs = WeightMSInputs
Rule = {'WeightDistance': 'Gauss',
'Clipping': 'MinMaxReject',
'WeightRMS': True,
'WeightTsysExpTime': False}
[docs] def prepare(self, datatable_dict=None):
# for each data
outfile = self.inputs.outfiles
is_full_resolution = (self.inputs.spwtype.upper() in ["TDM", "FDM"])
LOG.info('Setting weight for %s Antenna %s Spw %s Field %s' % \
(os.path.basename(outfile), self.inputs.ms.antennas[self.inputs.antenna].name,
self.inputs.spwid, self.inputs.ms.fields[self.inputs.fieldid].name))
# make row mapping table between scantable and ms
row_map = self._make_row_map()
# set weight
if is_full_resolution:
minmaxclip = (WeightMS.Rule['Clipping'].upper() == 'MINMAXREJECT')
weight_rms = WeightMS.Rule['WeightRMS']
weight_tintsys = WeightMS.Rule['WeightTsysExpTime']
else:
minmaxclip = False
weight_rms = False
weight_tintsys = True
datatable = None
if datatable_dict is not None:
datatable = datatable_dict[self.inputs.ms.basename]
self._set_weight(row_map, minmaxclip=minmaxclip, weight_rms=weight_rms,
weight_tintsys=weight_tintsys, try_fallback=is_full_resolution,
default_datatable=datatable)
result = WeightMSResults(task=self.__class__,
success=True,
outcome=outfile)
result.task = self.__class__
return result
[docs] def analyse(self, result):
return result
def _make_row_map(self):
"""
Returns a dictionary which maps input and output MS row IDs
The dictionary has
key: input MS row IDs
value: output MS row IDs
It has row IDs of all intents of selected field, spw, and antenna.
"""
infile = self.inputs.infiles
outfile = self.inputs.outfiles
spwid = self.inputs.spwid
antid = self.inputs.antenna
fieldid = self.inputs.fieldid
in_rows = []
with casa_tools.TableReader(os.path.join(infile, 'DATA_DESCRIPTION')) as tb:
spwids = tb.getcol('SPECTRAL_WINDOW_ID')
data_desc_id = numpy.where(spwids == spwid)[0][0]
with casa_tools.TableReader(infile) as tb:
tsel = tb.query('DATA_DESC_ID==%d && FIELD_ID==%d && ANTENNA1==%d && ANTENNA2==%d' %
(data_desc_id, fieldid, antid, antid),
sortlist='TIME')
if tsel.nrows() > 0:
in_rows = tsel.rownumbers()
tsel.close()
with casa_tools.TableReader(os.path.join(outfile, 'DATA_DESCRIPTION')) as tb:
spwids = tb.getcol('SPECTRAL_WINDOW_ID')
data_desc_id = numpy.where(spwids == spwid)[0][0]
with casa_tools.TableReader(outfile) as tb:
tsel = tb.query('DATA_DESC_ID==%s && FIELD_ID==%d && ANTENNA1==%d && ANTENNA2==%d' %
(data_desc_id, fieldid, antid, antid),
sortlist='TIME')
out_rows = tsel.rownumbers()
tsel.close()
row_map = {}
# in_row: out_row
for index in range(len(in_rows)):
row_map[in_rows[index]] = out_rows[index]
return row_map
def _set_weight(self, row_map, minmaxclip, weight_rms, weight_tintsys, try_fallback=False, default_datatable=None):
inputs = self.inputs
infile = inputs.infiles
outfile = inputs.outfiles
spwid = inputs.spwid
antid = inputs.antenna
fieldid = self.inputs.fieldid
context = inputs.context
datatable = None
if default_datatable is not None:
filename = default_datatable.getkeyword('FILENAME')
if os.path.basename(filename) == os.path.basename(infile):
datatable = default_datatable
if datatable is None:
datatable_name = os.path.join(context.observing_run.ms_datatable_name, os.path.basename(infile))
datatable = DataTable(name=datatable_name, readonly=True)
# get corresponding datatable rows (only IDs of target scans will be retruned)
index_list = common.get_index_list_for_ms(datatable, [infile], [antid], [fieldid], [spwid])
in_rows = datatable.getcol('ROW').take(index_list)
# row map filtered by target scans (key: target input
target_row_map = {}
for idx in in_rows:
target_row_map[idx] = row_map.get(idx, -1)
weight = {}
# for row in in_rows:
# weight[row] = 1.0
# set weight (key: input MS row ID, value: weight)
# TODO: proper handling of pols
if weight_rms:
stats = datatable.getcol('STATISTICS').take(index_list, axis=2)
for index in range(len(in_rows)):
row = in_rows[index]
cell_stat = stats[:, :, index]
weight[row] = numpy.ones(cell_stat.shape[0])
for ipol in range(weight[row].shape[0]):
stat = cell_stat[ipol, 1] # baselined RMS
if stat > 0.0:
weight[row][ipol] /= (stat * stat)
elif stat < 0.0 and cell_stat[ipol, 2] > 0.0:
stat = cell_stat[ipol, 2] # RMS before baseline
weight[row][ipol] /= (stat * stat)
elif try_fallback:
weight_tintsys = True
else:
weight[row][ipol] = 0.0
if weight_tintsys:
exposures = datatable.getcol('EXPOSURE').take(index_list)
tsyss = datatable.getcol('TSYS').take(index_list, axis=1)
for index in range(len(in_rows)):
row = in_rows[index]
exposure = exposures[index]
tsys = tsyss[:, index]
if row not in weight:
weight[row] = numpy.ones(tsys.shape[0])
for ipol in range(weight[row].shape[0]):
if tsys[ipol] > 0.5:
weight[row][ipol] *= (exposure / (tsys[ipol] * tsys[ipol]))
else:
weight[row][ipol] = 0.0
# put weight
with casa_tools.TableReader(outfile, nomodify=False) as tb:
# The selection, tsel, contains all intents.
# Need to match output element in selected table by rownumbers().
tsel = tb.query('ROWNUMBER() IN %s' % (list(row_map.values())), style='python')
ms_weights = tsel.getcol('WEIGHT')
rownumbers = tsel.rownumbers().tolist()
for idx in range(len(in_rows)):
in_row = in_rows[idx]
out_row = row_map[in_row]
ms_weight = weight[in_row]
# search for the place of selected MS col to put back the value.
ms_index = rownumbers.index(out_row)
ms_weights[:, ms_index] = ms_weight
tsel.putcol('WEIGHT', ms_weights)
tsel.close()
# set channel flag for min/max in each channel
minmaxclip = False
if minmaxclip:
with casa_tools.TableReader(outfile, nomodify=False) as tb:
tsel = tb.query('ROWNUMBER() IN %s' % (list(row_map.keys())),
style='python')
if 'FLOAT_DATA' in tsel.colnames():
data_column = 'FLOAT_DATA'
else:
data_column = 'DATA'
data = tsel.getcol(data_column) # (npol, nchan, nrow)
(npol, nchan, nrow) = data.shape
if nrow > 2:
argmax = data.argmax(axis=2)
argmin = data.argmin(axis=2)
print(argmax.tolist())
print(argmin.tolist())
flag = tsel.getcol('FLAG')
for ipol in range(npol):
maxrow = argmax[ipol]
minrow = argmin[ipol]
for ichan in range(nchan):
flag[ipol, ichan, maxrow[ichan]] = True
flag[ipol, ichan, minrow[ichan]] = True
tsel.putcol('FLAG', flag)
tsel.close()