import numpy
import pipeline.infrastructure.logging as logging
from pipeline.infrastructure import casa_tools
LOG = logging.get_logger(__name__)
[docs]class MSWrapper(object):
"""
MSWrapper is a wrapper around a NumPy array populated with measurement set
data for a specified scan and spectral window. The MSWrapper can be
filtered on various criteria, e.g, spw, scan, antenna, etc., to narrow the
data to a particular data selection.
The static method MSWrapper.create_from_ms should be used to instantiate
MSWrapper objects.
"""
[docs] @staticmethod
def create_from_ms(filename, scan, spw):
"""
Create a new MSWrapper for the specified scan and spw.
Reading in raw measurement set data can be a very memory-intensive
process, so data selection is deliberately limited to one scan and one
spw at a time.
:param filename: measurement set filename
:param scan: integer scan ID
:param spw: integer spw ID
:return:
"""
LOG.trace('MSWrapperFactory.from_ms(%r, %r, %r)', filename, scan, spw)
data_selection = dict(scan=str(scan), spw=str(spw))
colnames = ['antenna1', 'antenna2', 'flag', 'time', 'corrected_amplitude', 'corrected_data', 'corrected_phase']
with casa_tools.MSReader(filename) as openms:
openms.msselect(data_selection)
raw_data = openms.getdata(colnames)
axis_info = openms.getdata(['axis_info'])
num_rows = openms.nrow(selected=True)
corr_axis = axis_info['axis_info']['corr_axis']
freq_axis = axis_info['axis_info']['freq_axis']
scalar_cols = [c for c in colnames if len(raw_data[c].shape) == 1]
var_cols = [c for c in colnames if c not in scalar_cols]
# data has axis order pol->channel->time. Swap order to a more natural time->pol->channel
for c in var_cols:
raw_data[c] = raw_data[c].swapaxes(0, 2).swapaxes(1, 2)
dtypes = {c: get_dtype(raw_data, c) for c in colnames}
col_dtypes = [dtypes[c] for c in dtypes if dtypes[c] is not None]
data = numpy.ma.empty(num_rows, dtype=col_dtypes)
for c in scalar_cols:
data[c] = raw_data[c]
# convert to NumPy MaskedArray if FLAG column is present
mask = raw_data['flag']
var_cols_to_mask = [c for c in var_cols if c != 'flag']
for c in var_cols_to_mask:
data[c] = numpy.ma.MaskedArray(data=raw_data[c], dtype=raw_data[c].dtype, mask=mask)
return MSWrapper(filename, scan, spw, data, corr_axis, freq_axis)
[docs] def filter(self, antenna1=None, antenna2=None, **kwargs):
"""
Return a new MSWrapper containing rows matching the column selection
criteria.
Data for rows meeting all the column criteria will be funnelled into
the new MSWrapper return object. A boolean AND is effectively
performed: e.g., antenna1=3 AND antenna2=5 will only return rows for
one baseline.
Data can be filtered on any column listed in the wrapper.data.dtype.
"""
mask_args = dict(kwargs)
# create a mask that lets all data through for columns that are not
# specified as arguments, or just the specified values through for
# columns that are specified as arguments
def passthrough(k, column_name):
if k is None:
if column_name not in kwargs:
mask_args[column_name] = numpy.ma.unique(self[column_name])
else:
mask_args[column_name] = k
for arg, column_name in [(antenna1, 'antenna1'), (antenna2, 'antenna2')]:
passthrough(arg, column_name)
# combine masks to create final data selection mask
mask = numpy.ones(len(self))
for k, v in mask_args.items():
mask = (mask == 1) & (self._get_mask(v, k) == 1)
# find data for the selection mask
data = self[mask]
# create new object for the filtered data
return MSWrapper(self.filename, data, self.corr_axis, self.freq_axis)
[docs] def xor_filter(self, antenna1=None, antenna2=None, **kwargs):
"""
Return a new MSWrapper containing rows matching the column selection
criteria.
Data for rows meeting any column criteria will be funnelled into
the new MSWrapper return object. A boolean AND is effectively
performed: e.g., antenna1=3 AND antenna2=5 will only return rows for
one baseline.
Data can be filtered on any column listed in the wrapper.data.dtype.
DANGER! DANGER! DANGER!
NOTE! This class has only been tested for baseline selection! Using
this method for other use cases could be dangerous. Use at your own
risk!
"""
# TODO this method could probably be refactored to use numpy xor, but
# I don't have time right now...
mask_args = dict(kwargs)
# create a mask that lets all data through for columns that are not
# specified as arguments, or just the specified values through for
# columns that are specified as arguments
def passthrough(k, column_name):
if k is None:
if column_name not in kwargs:
mask_args[column_name] = numpy.unique(self[column_name])
else:
mask_args[column_name] = k
for arg, column_name in [(antenna1, 'antenna1'), (antenna2, 'antenna2')]:
passthrough(arg, column_name)
# combine masks to create final data selection mask
mask = numpy.zeros(len(self))
for k, v in mask_args.items():
mask = (mask == 1) | (self._get_mask(v, k) == 1)
# remove autocorrelations
mask = (mask == 1) & (self['antenna1'] != self['antenna2'])
# find data for the selection mask
data = self[mask]
# create new object for the filtered data
return MSWrapper(self.filename, self.scan, self.spw, data, self.corr_axis, self.freq_axis)
def __init__(self, scan, spw, filename, data, corr_axis, freq_axis):
self.scan = scan
self.spw = spw
self.filename = filename
self.data = data
self.corr_axis = corr_axis
self.freq_axis = freq_axis
def __getitem__(self, key):
return self.data[key]
def __contains__(self, key):
return key in self.data.dtype.names
def __len__(self):
return len(self.data)
def __iter__(self):
return (i for i in self.data)
def _get_mask(self, allowed, column):
try:
iter(allowed)
except TypeError:
allowed = [allowed]
mask = numpy.zeros(len(self))
for a in allowed:
if a not in self.data[column]:
raise KeyError('{} column {} value not found: {}'.format(self.filename, column, a))
mask = (mask == 1) | (self[column] == a)
return mask
[docs]def get_dtype(data, column_name):
"""
Get the numpy data type for a CASA caltable column.
:param tb: CASA table tool with caltable open.
:param column_name: name of column to process
:return: 3-tuple of column name, NumPy dtype, column shape
"""
column_data = data[column_name]
column_dtype = column_data.dtype
column_shape = column_data.shape
if len(column_shape) == 1:
return column_name, column_dtype
return column_name, column_dtype, column_shape[1:]