import collections
import contextlib
import functools
import os
import sys
import time
import numpy
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.logging as logging
import pipeline.infrastructure.mpihelpers as mpihelpers
from pipeline.domain.datatable import OnlineFlagIndex, DataTableIndexer
from pipeline.infrastructure import casa_tools
from . import compress
_LOG = infrastructure.get_logger(__name__)
[docs]class OnDemandStringParseLogger(object):
PRIORITY_MAP = {'warn': 'warning'}
def __init__(self, logger):
self.logger = logger
self._func_list = []
[docs] @staticmethod
def parse(msg_template, *args, **kwargs):
if len(args) == 0 and len(kwargs) == 0:
return msg_template
else:
return msg_template.format(*args, **kwargs)
def _post(self, priority, msg_template, *args, **kwargs):
key_for_level = self.PRIORITY_MAP.get(priority, priority)
if self.logger.isEnabledFor(logging.LOGGING_LEVELS[key_for_level]):
getattr(self.logger, priority)(OnDemandStringParseLogger.parse(msg_template, *args, **kwargs))
[docs] def critical(self, msg_template, *args, **kwargs):
self._post('critical', msg_template, *args, **kwargs)
[docs] def error(self, msg_template, *args, **kwargs):
self._post('error', msg_template, *args, **kwargs)
[docs] def warn(self, msg_template, *args, **kwargs):
self._post('warning', msg_template, *args, **kwargs)
[docs] def info(self, msg_template, *args, **kwargs):
self._post('info', msg_template, *args, **kwargs)
[docs] def debug(self, msg_template, *args, **kwargs):
self._post('debug', msg_template, *args, **kwargs)
[docs] def todo(self, msg_template, *args, **kwargs):
self._post('todo', msg_template, *args, **kwargs)
[docs] def trace(self, msg_template, *args, **kwargs):
self._post('trace', msg_template, *args, **kwargs)
LOG = OnDemandStringParseLogger(_LOG)
[docs]def profiler(func):
@functools.wraps(func)
def wrapper(*args, **kw):
start = time.time()
# LOG.info('#TIMING# Begin {} at {}', func.__name__, start)
result = func(*args, **kw)
end = time.time()
# LOG.info('#TIMING# End {} at {}', func.__name__, end)
LOG.info('#PROFILE# %s: elapsed %s sec' % (func.__name__, end - start))
return result
return wrapper
[docs]def require_virtual_spw_id_handling(observing_run):
"""
Judge if spw ids vary across EBs. Return True if ids vary.
observing_run -- domain.ObservingRun instance
"""
return numpy.any([spw.id != observing_run.real2virtual_spw_id(spw.id, ms) for ms in observing_run.measurement_sets
for spw in ms.get_spectral_windows(science_windows_only=True)])
[docs]def is_nro(context):
mses = context.observing_run.measurement_sets
return numpy.all([ms.antenna_array.name == 'NRO' for ms in mses])
[docs]def asdm_name(scantable_object):
"""
Return ASDM name that target scantable belongs to.
Assumptions are:
- scantable is generated from MS
- MS is generated from ASDM
- MS name is <uid>.ms
"""
return asdm_name_from_ms(scantable_object.ms)
[docs]def asdm_name_from_ms(ms_domain):
"""
Return ASDM name that target ms originates from.
Assumptions are:
- MS is generated from ASDM
- MS name is <uid>.ms
"""
ms_basename = ms_domain.basename
index_for_suffix = ms_basename.rfind('.')
asdm = ms_basename[:index_for_suffix] if index_for_suffix > 0 else ms_basename
return asdm
[docs]def get_parent_ms_idx(context, msname):
"""
Returns index of corresponding ms in context
The method maps both work_data and original MS to a proper index
The return value is -1 if no match found.
"""
mslist = context.observing_run.measurement_sets
idx_found = -1
for idx in range(len(mslist)):
msobj = mslist[idx]
search_list = [msobj.name, msobj.basename]
if hasattr(msobj, "work_data"):
search_list += [msobj.work_data, os.path.basename(msobj.work_data)]
if msname in search_list:
idx_found = idx
break
return idx_found
[docs]def get_parent_ms_name(context, msname):
"""
Returns name of corresponding parent ms in context
The method maps both work_data and original MS to a proper index
The return value is "" if no match found.
"""
idx = get_parent_ms_idx(context, msname)
return context.observing_run.measurement_sets[idx].name if idx >= 0 else ""
####
# ProgressTimer
#
# Show the progress bar on the console if LogLevel is lower than or equal to 2.
#
####
[docs]class ProgressTimer(object):
def __init__(self, length=80, maxCount=80, LogLevel='info'):
"""
Constructor:
length: length of the progress bar (default 80 characters)
"""
self.currentLevel = 0
self.maxCount = maxCount
self.curCount = 0
self.scale = float(length)/float(maxCount)
if isinstance(LogLevel, str):
self.LogLevel = logging.LOGGING_LEVELS[LogLevel] if LogLevel in logging.LOGGING_LEVELS else logging.INFO
else:
# should be integer
self.LogLevel = LogLevel
if self.LogLevel >= logging.INFO:
print('\n|{} 100% {}|'.format('=' * ((length - 8) // 2), '=' * ((length - 8) // 2)))
def __del__(self):
if self.LogLevel >= logging.INFO:
print('\n')
[docs] def count(self, increment=1):
if self.LogLevel >= logging.INFO:
self.curCount += increment
newLevel = int(self.curCount * self.scale)
if newLevel != self.currentLevel:
print('\b{}'.format('*' * (newLevel - self.currentLevel)))
sys.stdout.flush()
self.currentLevel = newLevel
# parse edge parameter to tuple
[docs]def parseEdge(edge):
if isinstance(edge, int) or isinstance(edge, float):
EdgeL = edge
EdgeR = edge
elif len(edge) == 0:
EdgeL = 0
EdgeR = 0
elif len(edge) == 1:
EdgeL = edge[0]
EdgeR = edge[0]
else:
(EdgeL, EdgeR) = edge[:2]
return EdgeL, EdgeR
[docs]def mjd_to_datestring(t, unit='sec'):
"""
MJD ---> date string
t: MJD
unit: sec or day
"""
if unit in ['sec', 's']:
mjd = t
elif unit in ['day', 'd']:
mjd = t * 86400.0
else:
mjd = 0.0
import time
import datetime
mjdzero = datetime.datetime(1858, 11, 17, 0, 0, 0)
zt = time.gmtime(0.0)
timezero = datetime.datetime(zt.tm_year, zt.tm_mon, zt.tm_mday, zt.tm_hour, zt.tm_min, zt.tm_sec)
dtd = timezero-mjdzero
dtsec = mjd-(float(dtd.days)*86400.0+float(dtd.seconds)+float(dtd.microseconds)*1.0e-6)
mjdstr = time.asctime(time.gmtime(dtsec))+' UTC'
return mjdstr
[docs]def to_list(s):
if s is None:
return None
elif isinstance(s, list) or isinstance(s, numpy.ndarray):
return s
elif isinstance(s, str):
if s.startswith('['):
if s.lstrip('[')[0].isdigit():
return eval(s)
else:
# maybe string list
return eval(s.replace('[', '[\'').replace(']', '\']').replace(',', '\',\''))
else:
try:
return [float(s)]
except:
return [s]
else:
return [s]
[docs]def to_bool(s):
if s is None:
return None
elif isinstance(s, bool):
return s
elif isinstance(s, str):
if s.upper() == 'FALSE' or s == 'F':
return False
elif s.upper() == 'TRUE' or s == 'T':
return True
else:
return s
else:
return bool(s)
[docs]def to_numeric(s):
if s is None:
return None
elif isinstance(s, str):
try:
return float(s)
except:
return s
else:
return s
[docs]def get_mask_from_flagtra(flagtra):
"""Convert FLAGTRA (unsigned char) to a mask array (1=valid, 0=flagged)"""
return (numpy.asarray(flagtra) == 0).astype(int)
[docs]def iterate_group_member(group_desc, member_id_list):
for mid in member_id_list:
member = group_desc[mid]
yield member.ms, member.field_id, member.antenna_id, member.spw_id
[docs]def get_index_list_for_ms(datatable, vis_list, antennaid_list, fieldid_list,
spwid_list, srctype=None):
return numpy.fromiter(_get_index_list_for_ms(datatable, vis_list, antennaid_list, fieldid_list,
spwid_list, srctype), dtype=numpy.int64)
def _get_index_list_for_ms(datatable, vis_list, antennaid_list, fieldid_list,
spwid_list, srctype=None):
# use time_table instead of data selection
#online_flag = datatable.getcolslice('FLAG_PERMANENT', [0, OnlineFlagIndex], [-1, OnlineFlagIndex], 1)[0]
#LOG.info('online_flag=%s'%(online_flag))
for (_vis, _field, _ant, _spw) in zip(vis_list, fieldid_list, antennaid_list, spwid_list):
try:
time_table = datatable.get_timetable(_ant, _spw, None, os.path.basename(_vis), _field)
except RuntimeError as e:
# data could be missing. just skip.
LOG.warn('Exception reported from datatable.get_timetable:')
LOG.warn(str(e))
continue
# time table separated by large time gap
the_table = time_table[1]
for group in the_table:
for row in group[1]:
permanent_flag = datatable.getcell('FLAG_PERMANENT', row)
online_flag = permanent_flag[:, OnlineFlagIndex]
if any(online_flag == 1):
yield row
[docs]def get_index_list_for_ms2(datatable, group_desc, member_list, srctype=None):
# use time_table instead of data selection
#online_flag = datatable.getcolslice('FLAG_PERMANENT', [0, OnlineFlagIndex], [-1, OnlineFlagIndex], 1)[0]
#LOG.info('online_flag=%s'%(online_flag))
for (_ms, _field, _ant, _spw) in iterate_group_member(group_desc, member_list):
_vis = _ms.name
time_table = datatable.get_timetable(_ant, _spw, None, os.path.basename(_vis), _field)
# time table separated by large time gap
the_table = time_table[1]
for group in the_table:
for row in group[1]:
permanent_flag = datatable.getcell('FLAG_PERMANENT', row)
online_flag = permanent_flag[:, OnlineFlagIndex]
if any(online_flag == 1):
yield row
[docs]def get_index_list_for_ms3(datatable_dict, group_desc, member_list, srctype=None):
# use time_table instead of data selection
#online_flag = datatable.getcolslice('FLAG_PERMANENT', [0, OnlineFlagIndex], [-1, OnlineFlagIndex], 1)[0]
#LOG.info('online_flag=%s'%(online_flag))
index_dict = collections.defaultdict(list)
for (_ms, _field, _ant, _spw) in iterate_group_member(group_desc, member_list):
print('{0} {1} {2} {3}'.format(_ms.basename, _field, _ant, _spw))
_vis = _ms.name
datatable = datatable_dict[_ms.basename]
time_table = datatable.get_timetable(_ant, _spw, None, os.path.basename(_vis), _field)
# time table separated by large time gap
the_table = time_table[1]
def _g():
for group in the_table:
for row in group[1]:
permanent_flag = datatable.getcell('FLAG_PERMANENT', row)
online_flag = permanent_flag[:, OnlineFlagIndex]
if any(online_flag == 1):
yield row
arr = numpy.fromiter(_g(), dtype=numpy.int64)
index_dict[_ms.basename].extend(arr)
for vis in index_dict:
index_dict[vis] = numpy.asarray(index_dict[vis])
#index_dict[vis].sort()
return index_dict
[docs]def get_valid_ms_members(group_desc, msname_filter, ant_selection, field_selection, spw_selection):
for member_id in range(len(group_desc)):
member = group_desc[member_id]
spw_id = member.spw_id
field_id = member.field_id
ant_id = member.antenna_id
msobj = member.ms
if msobj.name in [os.path.abspath(name) for name in msname_filter]:
_field_selection = field_selection
try:
nfields = len(msobj.fields)
if len(field_selection) == 0:
# fine, go ahead
pass
elif not field_selection.isdigit():
# selection by name, bracket by ""
LOG.debug('non-digit field selection')
if not _field_selection.startswith('"'):
_field_selection = '"{}"'.format(field_selection)
else:
tmp_id = int(field_selection)
LOG.debug('field_id = {}'.format(tmp_id))
if tmp_id < 0 or nfields <= tmp_id:
# could be selection by name consisting of digits, bracket by ""
LOG.debug('field name consisting digits')
if not _field_selection.startswith('"'):
_field_selection = '"{}"'.format(field_selection)
LOG.debug('field_selection = "{}"'.format(_field_selection))
mssel = casa_tools.ms.msseltoindex(vis=msobj.name, spw=spw_selection,
field=_field_selection, baseline=ant_selection)
except RuntimeError as e:
LOG.trace('RuntimeError: {0}'.format(str(e)))
LOG.trace('vis="{0}" field_selection: "{1}"'.format(msobj.name, _field_selection))
continue
spwsel = mssel['spw']
fieldsel = mssel['field']
antsel = mssel['antenna1']
if ((len(spwsel) == 0 or spw_id in spwsel) and
(len(fieldsel) == 0 or field_id in fieldsel) and
(len(antsel) == 0 or ant_id in antsel)):
yield member_id
[docs]def get_valid_ms_members2(group_desc, ms_filter, ant_selection, field_selection, spw_selection):
for member_id in range(len(group_desc)):
member = group_desc[member_id]
spw_id = member.spw_id
field_id = member.field_id
ant_id = member.antenna_id
msobj = member.ms
if msobj in ms_filter:
try:
mssel = casa_tools.ms.msseltoindex(vis=msobj.name, spw=spw_selection,
field=field_selection, baseline=ant_selection)
except RuntimeError as e:
LOG.trace('RuntimeError: {0}'.format(str(e)))
LOG.trace('vis="{0}" field_selection: "{1}"'.format(msobj.name, field_selection))
continue
spwsel = mssel['spw']
fieldsel = mssel['field']
antsel = mssel['antenna1']
if ((spwsel.size == 0 or spw_id in spwsel) and
(fieldsel.size == 0 or field_id in fieldsel) and
(antsel.size == 0 or ant_id in antsel)):
yield member_id
def _collect_logrecords(logger):
capture_handlers = [h for h in logger.handlers if h.__class__.__name__ == 'CapturingHandler']
logrecords = []
for handler in capture_handlers:
logrecords.extend(handler.buffer[:])
return logrecords
[docs]@contextlib.contextmanager
def TableSelector(name, query):
with casa_tools.TableReader(name) as tb:
tsel = tb.query(query)
yield tsel
tsel.close()
# dictionary that always returns key
[docs]class EchoDictionary(dict):
def __getitem__(self, x):
return x
[docs]def make_row_map_for_baselined_ms(ms, table_container=None):
"""
Make row mapping between calibrated MS and baselined MS.
Return value is a dictionary whose key is row number for calibrated MS and
its corresponding value is the one for baselined MS.
ms: measurement set domain object
returns: row mapping dictionary
"""
work_data = ms.work_data
src_tb = None
derived_tb = None
if table_container is not None:
src_tb = table_container.tb1
derived_tb = table_container.tb2
return make_row_map(ms, work_data, src_tb, derived_tb)
#@profiler
[docs]def make_row_map(src_ms, derived_vis, src_tb=None, derived_tb=None):
"""
Make row mapping between source MS and associating MS
src_ms: measurement set domain object for source MS
derived_vis: name of the MS that derives from source MS
returns: row mapping dictionary
"""
ms = src_ms
vis0 = ms.name
vis1 = derived_vis
rowmap = {}
if vis0 == vis1:
return EchoDictionary()
# make polarization map between src MS and derived MS
to_derived_polid = make_polid_map(vis0, vis1)
LOG.trace('to_derived_polid=%s' % to_derived_polid)
# make spw map between src MS and derived MS
to_derived_spwid = make_spwid_map(vis0, vis1)
LOG.trace('to_derived_spwid=%s' % to_derived_spwid)
# make a map between (polid, spwid) pair and ddid for derived MS
derived_ddid_map = make_ddid_map(vis1)
LOG.trace('derived_ddid_map=%s' % derived_ddid_map)
scans = ms.get_scans(scan_intent='TARGET')
scan_numbers = [s.id for s in scans]
fields = {}
states = {}
for scan in scans:
fields[scan.id] = [f.id for f in scan.fields if 'TARGET' in f.intents]
states[scan.id] = [s.id for s in scan.states if 'TARGET' in s.intents]
field_values = list(fields.values())
is_unique_field_set = True
for v in field_values:
if v != field_values[0]:
is_unique_field_set = False
state_values = list(states.values())
is_unique_state_set = True
for v in state_values:
if v != state_values[0]:
is_unique_state_set = False
if is_unique_field_set and is_unique_state_set:
taql = 'ANTENNA1 == ANTENNA2 && SCAN_NUMBER IN %s && FIELD_ID IN %s && STATE_ID IN %s' % (scan_numbers, field_values[0], state_values[0])
else:
taql = 'ANTENNA1 == ANTENNA2 && (%s)' % (' || '.join(['(SCAN_NUMBER == %s && FIELD_ID IN %s && STATE_ID IN %s)' % (scan, fields[scan], states[scan]) for scan in scan_numbers]))
LOG.trace('taql=\'%s\'' % (taql))
with casa_tools.TableReader(os.path.join(vis0, 'OBSERVATION')) as tb:
nrow_obs0 = tb.nrows()
with casa_tools.TableReader(os.path.join(vis0, 'PROCESSOR')) as tb:
nrow_proc0 = tb.nrows()
with casa_tools.TableReader(os.path.join(vis1, 'OBSERVATION')) as tb:
nrow_obs1 = tb.nrows()
with casa_tools.TableReader(os.path.join(vis1, 'PROCESSOR')) as tb:
nrow_proc1 = tb.nrows()
assert nrow_obs0 == nrow_obs1
assert nrow_proc0 == nrow_proc1
is_unique_observation_id = nrow_obs0 == 1
is_unique_processor_id = nrow_proc0 == 1
if src_tb is None:
with casa_tools.TableReader(vis0) as tb:
tsel = tb.query(taql)
try:
if is_unique_observation_id:
observation_id_list0 = None
observation_id_set = {0}
else:
observation_id_list0 = tsel.getcol('OBSERVATION_ID')
observation_id_set = set(observation_id_list0)
if is_unique_processor_id:
processor_id_list0 = None
processor_id_set = {0}
else:
processor_id_list0 = tsel.getcol('PROCESSOR_ID')
processor_id_set = set(processor_id_list0)
scan_number_list0 = tsel.getcol('SCAN_NUMBER')
field_id_list0 = tsel.getcol('FIELD_ID')
antenna1_list0 = tsel.getcol('ANTENNA1')
state_id_list0 = tsel.getcol('STATE_ID')
data_desc_id_list0 = tsel.getcol('DATA_DESC_ID')
time_list0 = tsel.getcol('TIME')
rownumber_list0 = tsel.rownumbers()
finally:
tsel.close()
else:
tsel = src_tb.query(taql)
try:
if is_unique_observation_id:
observation_id_list0 = None
observation_id_set = {0}
else:
observation_id_list0 = tsel.getcol('OBSERVATION_ID')
observation_id_set = set(observation_id_list0)
if is_unique_processor_id:
processor_id_list0 = None
processor_id_set = {0}
else:
processor_id_list0 = tsel.getcol('PROCESSOR_ID')
processor_id_set = set(processor_id_list0)
scan_number_list0 = tsel.getcol('SCAN_NUMBER')
field_id_list0 = tsel.getcol('FIELD_ID')
antenna1_list0 = tsel.getcol('ANTENNA1')
state_id_list0 = tsel.getcol('STATE_ID')
data_desc_id_list0 = tsel.getcol('DATA_DESC_ID')
time_list0 = tsel.getcol('TIME')
rownumber_list0 = tsel.rownumbers()
finally:
tsel.close()
if derived_tb is None:
with casa_tools.TableReader(vis1) as tb:
tsel = tb.query(taql)
try:
if is_unique_observation_id:
observation_id_list1 = None
else:
observation_id_list1 = tsel.getcol('OBSERVATION_ID')
if is_unique_processor_id:
processor_id_list1 = None
else:
processor_id_list1 = tsel.getcol('PROCESSOR_ID')
scan_number_list1 = tsel.getcol('SCAN_NUMBER')
field_id_list1 = tsel.getcol('FIELD_ID')
antenna1_list1 = tsel.getcol('ANTENNA1')
state_id_list1 = tsel.getcol('STATE_ID')
data_desc_id_list1 = tsel.getcol('DATA_DESC_ID')
time_list1 = tsel.getcol('TIME')
rownumber_list1 = tsel.rownumbers()
finally:
tsel.close()
else:
tsel = derived_tb.query(taql)
try:
if is_unique_observation_id:
observation_id_list1 = None
else:
observation_id_list1 = tsel.getcol('OBSERVATION_ID')
if is_unique_processor_id:
processor_id_list1 = None
else:
processor_id_list1 = tsel.getcol('PROCESSOR_ID')
scan_number_list1 = tsel.getcol('SCAN_NUMBER')
field_id_list1 = tsel.getcol('FIELD_ID')
antenna1_list1 = tsel.getcol('ANTENNA1')
state_id_list1 = tsel.getcol('STATE_ID')
data_desc_id_list1 = tsel.getcol('DATA_DESC_ID')
time_list1 = tsel.getcol('TIME')
rownumber_list1 = tsel.rownumbers()
finally:
tsel.close()
for processor_id in processor_id_set:
LOG.trace('PROCESSOR_ID %s' % processor_id)
for observation_id in observation_id_set:
LOG.trace('OBSERVATION_ID %s' % observation_id)
for scan_number in scan_numbers:
LOG.trace('SCAN_NUMBER %s' % scan_number)
if scan_number not in states:
LOG.trace('No target states in SCAN %s' % scan_number)
continue
for field_id in fields[scan_number]:
LOG.trace('FIELD_ID %s' % field_id)
for antenna in ms.antennas:
antenna_id = antenna.id
LOG.trace('ANTENNA_ID %s' % antenna_id)
for spw in ms.get_spectral_windows(science_windows_only=True):
data_desc = ms.get_data_description(spw=spw)
data_desc_id = data_desc.id
pol_id = data_desc.pol_id
spw_id = spw.id
LOG.trace('START PROCESSOR %s SCAN %s DATA_DESC_ID %s ANTENNA %s FIELD %s' %
(processor_id, scan_number, data_desc_id, antenna_id, field_id))
derived_pol_id = to_derived_polid[pol_id]
derived_spw_id = to_derived_spwid[spw_id]
derived_dd_id = derived_ddid_map[(derived_pol_id, derived_spw_id)]
LOG.trace('SRC DATA_DESC_ID %s (SPW %s)' % (data_desc_id, spw_id))
LOG.trace('DERIVED DATA_DESC_ID %s (SPW %s)' % (derived_dd_id, derived_spw_id))
tmask0 = numpy.logical_and(
data_desc_id_list0 == data_desc_id,
numpy.logical_and(antenna1_list0 == antenna_id,
numpy.logical_and(field_id_list0 == field_id,
scan_number_list0 == scan_number)))
if not is_unique_processor_id:
numpy.logical_and(tmask0, processor_id_list0 == processor_id, out=tmask0)
if not is_unique_observation_id:
numpy.logical_and(tmask0, observation_id_list0 == observation_id, out=tmask0)
tmask1 = numpy.logical_and(
data_desc_id_list1 == derived_dd_id,
numpy.logical_and(antenna1_list1 == antenna_id,
numpy.logical_and(field_id_list1 == field_id,
scan_number_list1 == scan_number)))
if not is_unique_processor_id:
numpy.logical_and(tmask1, processor_id_list1 == processor_id, out=tmask1)
if not is_unique_observation_id:
numpy.logical_and(tmask1, observation_id_list1 == observation_id, out=tmask1)
if numpy.all(tmask0 == False) and numpy.all(tmask1 == False):
# no corresponding data (probably due to PROCESSOR_ID for non-science windows)
LOG.trace('SKIP PROCESSOR %s SCAN %s DATA_DESC_ID %s ANTENNA %s FIELD %s' %
(processor_id, scan_number, data_desc_id, antenna_id, field_id))
continue
tstate0 = state_id_list0[tmask0]
tstate1 = state_id_list1[tmask1]
ttime0 = time_list0[tmask0]
ttime1 = time_list1[tmask1]
trow0 = rownumber_list0[tmask0]
trow1 = rownumber_list1[tmask1]
sort_index0 = numpy.lexsort((tstate0, ttime0))
sort_index1 = numpy.lexsort((tstate1, ttime1))
LOG.trace('scan %s' % (scan_number)
+ ' actual %s' % (list(set(tstate0)))
+ ' expected %s' % (states[scan_number]))
assert numpy.all(ttime0[sort_index0] == ttime1[sort_index1])
assert numpy.all(tstate0[sort_index0] == tstate1[sort_index1])
# assert set(tstate0) == set(states[scan_number])
assert set(tstate0).issubset(set(states[scan_number]))
for (i0, i1) in zip(sort_index0, sort_index1):
r0 = trow0[i0]
r1 = trow1[i1]
rowmap[r0] = r1
LOG.trace('END PROCESSOR %s SCAN %s DATA_DESC_ID %s ANTENNA %s FIELD %s' %
(processor_id, scan_number, data_desc_id, antenna_id, field_id))
return rowmap
[docs]class SpwSimpleView(object):
def __init__(self, spwid, name):
self.id = spwid
self.name = name
[docs]class SpwDetailedView(object):
def __init__(self, spwid, name, num_channels, ref_frequency, min_frequency, max_frequency):
self.id = spwid
self.name = name
self.num_channels = num_channels
self.ref_frequency = ref_frequency
self.min_frequency = min_frequency
self.max_frequency = max_frequency
[docs]def get_spw_names(vis):
with casa_tools.TableReader(os.path.join(vis, 'SPECTRAL_WINDOW')) as tb:
gen = (SpwSimpleView(i, tb.getcell('NAME', i)) for i in range(tb.nrows()))
spws = list(gen)
return spws
[docs]def get_spw_properties(vis):
with casa_tools.TableReader(os.path.join(vis, 'SPECTRAL_WINDOW')) as tb:
spws = []
for irow in range(tb.nrows()):
name = tb.getcell('NAME', irow)
nchan = tb.getcell('NUM_CHAN', irow)
ref_freq = tb.getcell('REF_FREQUENCY', irow)
chan_freq = tb.getcell('CHAN_FREQ', irow)
chan_width = tb.getcell('CHAN_WIDTH', irow)
min_freq = chan_freq.min() - abs(chan_width[0]) / 2
max_freq = chan_freq.max() + abs(chan_width[0]) / 2
spws.append(SpwDetailedView(irow, name, nchan, ref_freq, min_freq, max_freq))
return spws
# @profiler
def __read_table(reader, method, vis):
if reader is None:
result = method(vis)
else:
with reader(vis) as readerobj:
result = method(readerobj)
return result
def _read_table(reader, table, vis):
rows = __read_table(reader, table._read_table, vis)
return rows
# @profiler
[docs]def make_spwid_map(srcvis, dstvis):
# src_spws = __read_table(casa_tools.MSMDReader,
# tablereader.SpectralWindowTable.get_spectral_windows,
# srcvis)
# dst_spws = __read_table(casa_tools.MSMDReader,
# tablereader.SpectralWindowTable.get_spectral_windows,
# dstvis)
src_spws = __read_table(None, get_spw_properties, srcvis)
dst_spws = __read_table(None, get_spw_properties, dstvis)
for spw in src_spws:
LOG.trace('SRC SPWID %s NAME %s' % (spw.id, spw.name))
for spw in dst_spws:
LOG.trace('DST SPWID %s NAME %s' % (spw.id, spw.name))
map_byname = collections.defaultdict(list)
for src_spw in src_spws:
for dst_spw in dst_spws:
if src_spw.name == dst_spw.name:
map_byname[src_spw].append(dst_spw)
spwid_map = {}
for src, dst in map_byname.items():
LOG.trace('map_byname src spw %s: dst spws %s' % (src.id, [spw.id for spw in dst]))
if len(dst) == 0:
continue
elif len(dst) == 1:
# mapping by name worked
spwid_map[src.id] = dst[0].id
else:
# need more investigation
for spw in dst:
if (src.num_channels == spw.num_channels and
src.ref_frequency == spw.ref_frequency and
src.min_frequency == spw.min_frequency and
src.max_frequency == spw.max_frequency):
if src.id in spwid_map:
raise RuntimeError('Failed to create spw map for MSs \'%s\' and \'%s\'' % (srcvis, dstvis))
spwid_map[src.id] = spw.id
return spwid_map
# @profiler
[docs]def make_polid_map(srcvis, dstvis):
src_rows = _read_polarization_table(srcvis)
dst_rows = _read_polarization_table(dstvis)
for (src_polid, src_numpol, src_poltype, _, _) in src_rows:
LOG.trace('SRC: POLID %s NPOL %s POLTYPE %s' % (src_polid, src_numpol, src_poltype))
for (dst_polid, dst_numpol, dst_poltype, _, _) in dst_rows:
LOG.trace('DST: POLID %s NPOL %s POLTYPE %s' % (dst_polid, dst_numpol, dst_poltype))
polid_map = {}
for (src_polid, src_numpol, src_poltype, _, _) in src_rows:
for (dst_polid, dst_numpol, dst_poltype, _, _) in dst_rows:
if src_numpol == dst_numpol and numpy.all(src_poltype == dst_poltype):
polid_map[src_polid] = dst_polid
LOG.trace('polid_map = %s' % polid_map)
return polid_map
# @profiler
[docs]def make_ddid_map(vis):
with casa_tools.TableReader(os.path.join(vis, 'DATA_DESCRIPTION')) as tb:
pol_ids = tb.getcol('POLARIZATION_ID')
spw_ids = tb.getcol('SPECTRAL_WINDOW_ID')
num_ddids = tb.nrows()
ddid_map = {}
for ddid in range(num_ddids):
ddid_map[(pol_ids[ddid], spw_ids[ddid])] = ddid
return ddid_map
[docs]def get_datacolumn_name(vis):
colname_candidates = ['CORRECTED_DATA', 'FLOAT_DATA', 'DATA']
with casa_tools.TableReader(vis) as tb:
colnames = tb.colnames()
colname = None
for name in colname_candidates:
if name in colnames:
colname = name
break
assert colname is not None
return colname
# helper functions for parallel execution
[docs]def create_serial_job(task_cls, task_args, context):
inputs = task_cls.Inputs(context, **task_args)
task = task_cls(inputs)
job = mpihelpers.SyncTask(task)
LOG.debug('Serial Job: %s' % task)
return job
[docs]def create_parallel_job(task_cls, task_args, context):
context_path = os.path.join(context.output_dir, context.name + '.context')
if not os.path.exists(context_path):
context.save(context_path)
task = mpihelpers.Tier0PipelineTask(task_cls, task_args, context_path)
job = mpihelpers.AsyncTask(task)
LOG.debug('Parallel Job: %s' % task)
return job
def _read_polarization_table(vis):
"""
Read the POLARIZATION table of the given measurement set.
This function used to be part of tablereader, which has since moved from
direct table reading to using the MSMD tool.
"""
LOG.debug('Analysing POLARIZATION table')
polarization_table = os.path.join(vis, 'POLARIZATION')
with casa_tools.TableReader(polarization_table) as table:
num_corrs = table.getcol('NUM_CORR')
vcorr_types = table.getvarcol('CORR_TYPE')
vcorr_products = table.getvarcol('CORR_PRODUCT')
flag_rows = table.getcol('FLAG_ROW')
rowids = []
corr_types = []
corr_products = []
for i in range(table.nrows()):
rowids.append(i)
corr_types.append(vcorr_types['r%s' % (i + 1)])
corr_products.append(vcorr_products['r%s' % (i + 1)])
rows = list(zip(rowids, num_corrs, corr_types, corr_products, flag_rows))
return rows
[docs]def get_restfrequency(vis, spwid, source_id):
source_table = os.path.join(vis, 'SOURCE')
with casa_tools.TableReader(source_table) as tb:
tsel = tb.query('SOURCE_ID == {} && SPECTRAL_WINDOW_ID == {}'.format(source_id, spwid))
try:
if tsel.nrows() == 0:
return None
else:
if tsel.iscelldefined('REST_FREQUENCY', 0):
return tsel.getcell('REST_FREQUENCY', 0)[0]
else:
return None
finally:
tsel.close()
[docs]class RGAccumulator(object):
def __init__(self):
self.field = []
self.antenna = []
self.spw = []
self.pols = []
self.grid_table = []
self.channelmap_range = []
[docs] def append(self, field_id, antenna_id, spw_id, pol_ids=None, grid_table=None, channelmap_range=None):
self.field.append(field_id)
self.antenna.append(antenna_id)
self.spw.append(spw_id)
self.pols.append(pol_ids)
if isinstance(grid_table, compress.CompressedObj) or grid_table is None:
self.grid_table.append(grid_table)
else:
self.grid_table.append(compress.CompressedObj(grid_table))
self.channelmap_range.append(channelmap_range)
# def extend(self, field_id_list, antenna_id_list, spw_id_list):
# self.field.extend(field_id_list)
# self.antenna.extend(antenna_id_list)
# self.spw.extend(spw_id_list)
#
[docs] def get_field_id_list(self):
return self.field
[docs] def get_antenna_id_list(self):
return self.antenna
[docs] def get_spw_id_list(self):
return self.spw
[docs] def get_pol_ids_list(self):
return self.pols
[docs] def get_grid_table_list(self):
return self.grid_table
[docs] def get_channelmap_range_list(self):
return self.channelmap_range
[docs] def iterate_id(self):
assert len(self.field) == len(self.antenna)
assert len(self.field) == len(self.spw)
assert len(self.field) == len(self.pols)
for v in zip(self.field, self.antenna, self.spw):
yield v
[docs] def iterate_all(self):
assert len(self.field) == len(self.antenna)
assert len(self.field) == len(self.spw)
assert len(self.field) == len(self.pols)
assert len(self.field) == len(self.grid_table)
assert len(self.field) == len(self.channelmap_range)
for f, a, s, g, c in zip(self.field, self.antenna, self.spw, self.grid_table, self.channelmap_range):
_g = g.decompress()
yield f, a, s, _g, c
del _g
[docs] def get_process_list(self, withpol=False):
field_id_list = self.get_field_id_list()
antenna_id_list = self.get_antenna_id_list()
spw_id_list = self.get_spw_id_list()
assert len(field_id_list) == len(antenna_id_list)
assert len(field_id_list) == len(spw_id_list)
if withpol == True:
pol_ids_list = self.get_pol_ids_list()
assert len(field_id_list) == len(pol_ids_list)
return field_id_list, antenna_id_list, spw_id_list, pol_ids_list
else:
return field_id_list, antenna_id_list, spw_id_list
[docs]def sort_fields(context):
mses = context.observing_run.measurement_sets
sorted_names = []
sorted_fields = []
for ms in mses:
fields = ms.get_fields(intent='TARGET')
for f in fields:
if f.name not in sorted_names:
sorted_fields.append(f)
sorted_names.append(f.name)
return sorted_fields
[docs]def get_brightness_unit(vis, defaultunit='Jy/beam'):
with casa_tools.TableReader(vis) as tb:
colnames = tb.colnames()
target_columns = ['CORRECTED_DATA', 'FLOAT_DATA', 'DATA']
bunit = defaultunit
for col in target_columns:
if col in colnames:
keys = tb.getcolkeywords(col)
if 'UNIT' in keys:
_bunit = keys['UNIT']
if len(_bunit) > 0:
# should be K or Jy
# update bunit only when UNIT is K
if _bunit == 'K':
bunit = 'K'
break
return bunit