Source code for pipeline.hsd.tasks.common.utils

import collections
import contextlib
import functools
import os
import sys
import time

import numpy

import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.logging as logging
import pipeline.infrastructure.mpihelpers as mpihelpers
from pipeline.domain.datatable import OnlineFlagIndex, DataTableIndexer
from pipeline.infrastructure import casa_tools
from . import compress

_LOG = infrastructure.get_logger(__name__)


[docs]class OnDemandStringParseLogger(object): PRIORITY_MAP = {'warn': 'warning'} def __init__(self, logger): self.logger = logger self._func_list = []
[docs] @staticmethod def parse(msg_template, *args, **kwargs): if len(args) == 0 and len(kwargs) == 0: return msg_template else: return msg_template.format(*args, **kwargs)
def _post(self, priority, msg_template, *args, **kwargs): key_for_level = self.PRIORITY_MAP.get(priority, priority) if self.logger.isEnabledFor(logging.LOGGING_LEVELS[key_for_level]): getattr(self.logger, priority)(OnDemandStringParseLogger.parse(msg_template, *args, **kwargs))
[docs] def critical(self, msg_template, *args, **kwargs): self._post('critical', msg_template, *args, **kwargs)
[docs] def error(self, msg_template, *args, **kwargs): self._post('error', msg_template, *args, **kwargs)
[docs] def warn(self, msg_template, *args, **kwargs): self._post('warning', msg_template, *args, **kwargs)
[docs] def info(self, msg_template, *args, **kwargs): self._post('info', msg_template, *args, **kwargs)
[docs] def debug(self, msg_template, *args, **kwargs): self._post('debug', msg_template, *args, **kwargs)
[docs] def todo(self, msg_template, *args, **kwargs): self._post('todo', msg_template, *args, **kwargs)
[docs] def trace(self, msg_template, *args, **kwargs): self._post('trace', msg_template, *args, **kwargs)
LOG = OnDemandStringParseLogger(_LOG)
[docs]def profiler(func): @functools.wraps(func) def wrapper(*args, **kw): start = time.time() # LOG.info('#TIMING# Begin {} at {}', func.__name__, start) result = func(*args, **kw) end = time.time() # LOG.info('#TIMING# End {} at {}', func.__name__, end) LOG.info('#PROFILE# %s: elapsed %s sec' % (func.__name__, end - start)) return result return wrapper
[docs]def require_virtual_spw_id_handling(observing_run): """ Judge if spw ids vary across EBs. Return True if ids vary. observing_run -- domain.ObservingRun instance """ return numpy.any([spw.id != observing_run.real2virtual_spw_id(spw.id, ms) for ms in observing_run.measurement_sets for spw in ms.get_spectral_windows(science_windows_only=True)])
[docs]def is_nro(context): mses = context.observing_run.measurement_sets return numpy.all([ms.antenna_array.name == 'NRO' for ms in mses])
[docs]def asdm_name(scantable_object): """ Return ASDM name that target scantable belongs to. Assumptions are: - scantable is generated from MS - MS is generated from ASDM - MS name is <uid>.ms """ return asdm_name_from_ms(scantable_object.ms)
[docs]def asdm_name_from_ms(ms_domain): """ Return ASDM name that target ms originates from. Assumptions are: - MS is generated from ASDM - MS name is <uid>.ms """ ms_basename = ms_domain.basename index_for_suffix = ms_basename.rfind('.') asdm = ms_basename[:index_for_suffix] if index_for_suffix > 0 else ms_basename return asdm
[docs]def get_parent_ms_idx(context, msname): """ Returns index of corresponding ms in context The method maps both work_data and original MS to a proper index The return value is -1 if no match found. """ mslist = context.observing_run.measurement_sets idx_found = -1 for idx in range(len(mslist)): msobj = mslist[idx] search_list = [msobj.name, msobj.basename] if hasattr(msobj, "work_data"): search_list += [msobj.work_data, os.path.basename(msobj.work_data)] if msname in search_list: idx_found = idx break return idx_found
[docs]def get_parent_ms_name(context, msname): """ Returns name of corresponding parent ms in context The method maps both work_data and original MS to a proper index The return value is "" if no match found. """ idx = get_parent_ms_idx(context, msname) return context.observing_run.measurement_sets[idx].name if idx >= 0 else ""
#### # ProgressTimer # # Show the progress bar on the console if LogLevel is lower than or equal to 2. # ####
[docs]class ProgressTimer(object): def __init__(self, length=80, maxCount=80, LogLevel='info'): """ Constructor: length: length of the progress bar (default 80 characters) """ self.currentLevel = 0 self.maxCount = maxCount self.curCount = 0 self.scale = float(length)/float(maxCount) if isinstance(LogLevel, str): self.LogLevel = logging.LOGGING_LEVELS[LogLevel] if LogLevel in logging.LOGGING_LEVELS else logging.INFO else: # should be integer self.LogLevel = LogLevel if self.LogLevel >= logging.INFO: print('\n|{} 100% {}|'.format('=' * ((length - 8) // 2), '=' * ((length - 8) // 2))) def __del__(self): if self.LogLevel >= logging.INFO: print('\n')
[docs] def count(self, increment=1): if self.LogLevel >= logging.INFO: self.curCount += increment newLevel = int(self.curCount * self.scale) if newLevel != self.currentLevel: print('\b{}'.format('*' * (newLevel - self.currentLevel))) sys.stdout.flush() self.currentLevel = newLevel
# parse edge parameter to tuple
[docs]def parseEdge(edge): if isinstance(edge, int) or isinstance(edge, float): EdgeL = edge EdgeR = edge elif len(edge) == 0: EdgeL = 0 EdgeR = 0 elif len(edge) == 1: EdgeL = edge[0] EdgeR = edge[0] else: (EdgeL, EdgeR) = edge[:2] return EdgeL, EdgeR
[docs]def mjd_to_datestring(t, unit='sec'): """ MJD ---> date string t: MJD unit: sec or day """ if unit in ['sec', 's']: mjd = t elif unit in ['day', 'd']: mjd = t * 86400.0 else: mjd = 0.0 import time import datetime mjdzero = datetime.datetime(1858, 11, 17, 0, 0, 0) zt = time.gmtime(0.0) timezero = datetime.datetime(zt.tm_year, zt.tm_mon, zt.tm_mday, zt.tm_hour, zt.tm_min, zt.tm_sec) dtd = timezero-mjdzero dtsec = mjd-(float(dtd.days)*86400.0+float(dtd.seconds)+float(dtd.microseconds)*1.0e-6) mjdstr = time.asctime(time.gmtime(dtsec))+' UTC' return mjdstr
[docs]def to_list(s): if s is None: return None elif isinstance(s, list) or isinstance(s, numpy.ndarray): return s elif isinstance(s, str): if s.startswith('['): if s.lstrip('[')[0].isdigit(): return eval(s) else: # maybe string list return eval(s.replace('[', '[\'').replace(']', '\']').replace(',', '\',\'')) else: try: return [float(s)] except: return [s] else: return [s]
[docs]def to_bool(s): if s is None: return None elif isinstance(s, bool): return s elif isinstance(s, str): if s.upper() == 'FALSE' or s == 'F': return False elif s.upper() == 'TRUE' or s == 'T': return True else: return s else: return bool(s)
[docs]def to_numeric(s): if s is None: return None elif isinstance(s, str): try: return float(s) except: return s else: return s
[docs]def get_mask_from_flagtra(flagtra): """Convert FLAGTRA (unsigned char) to a mask array (1=valid, 0=flagged)""" return (numpy.asarray(flagtra) == 0).astype(int)
[docs]def iterate_group_member(group_desc, member_id_list): for mid in member_id_list: member = group_desc[mid] yield member.ms, member.field_id, member.antenna_id, member.spw_id
[docs]def get_index_list_for_ms(datatable, vis_list, antennaid_list, fieldid_list, spwid_list, srctype=None): return numpy.fromiter(_get_index_list_for_ms(datatable, vis_list, antennaid_list, fieldid_list, spwid_list, srctype), dtype=numpy.int64)
def _get_index_list_for_ms(datatable, vis_list, antennaid_list, fieldid_list, spwid_list, srctype=None): # use time_table instead of data selection #online_flag = datatable.getcolslice('FLAG_PERMANENT', [0, OnlineFlagIndex], [-1, OnlineFlagIndex], 1)[0] #LOG.info('online_flag=%s'%(online_flag)) for (_vis, _field, _ant, _spw) in zip(vis_list, fieldid_list, antennaid_list, spwid_list): try: time_table = datatable.get_timetable(_ant, _spw, None, os.path.basename(_vis), _field) except RuntimeError as e: # data could be missing. just skip. LOG.warn('Exception reported from datatable.get_timetable:') LOG.warn(str(e)) continue # time table separated by large time gap the_table = time_table[1] for group in the_table: for row in group[1]: permanent_flag = datatable.getcell('FLAG_PERMANENT', row) online_flag = permanent_flag[:, OnlineFlagIndex] if any(online_flag == 1): yield row
[docs]def get_index_list_for_ms2(datatable, group_desc, member_list, srctype=None): # use time_table instead of data selection #online_flag = datatable.getcolslice('FLAG_PERMANENT', [0, OnlineFlagIndex], [-1, OnlineFlagIndex], 1)[0] #LOG.info('online_flag=%s'%(online_flag)) for (_ms, _field, _ant, _spw) in iterate_group_member(group_desc, member_list): _vis = _ms.name time_table = datatable.get_timetable(_ant, _spw, None, os.path.basename(_vis), _field) # time table separated by large time gap the_table = time_table[1] for group in the_table: for row in group[1]: permanent_flag = datatable.getcell('FLAG_PERMANENT', row) online_flag = permanent_flag[:, OnlineFlagIndex] if any(online_flag == 1): yield row
[docs]def get_index_list_for_ms3(datatable_dict, group_desc, member_list, srctype=None): # use time_table instead of data selection #online_flag = datatable.getcolslice('FLAG_PERMANENT', [0, OnlineFlagIndex], [-1, OnlineFlagIndex], 1)[0] #LOG.info('online_flag=%s'%(online_flag)) index_dict = collections.defaultdict(list) for (_ms, _field, _ant, _spw) in iterate_group_member(group_desc, member_list): print('{0} {1} {2} {3}'.format(_ms.basename, _field, _ant, _spw)) _vis = _ms.name datatable = datatable_dict[_ms.basename] time_table = datatable.get_timetable(_ant, _spw, None, os.path.basename(_vis), _field) # time table separated by large time gap the_table = time_table[1] def _g(): for group in the_table: for row in group[1]: permanent_flag = datatable.getcell('FLAG_PERMANENT', row) online_flag = permanent_flag[:, OnlineFlagIndex] if any(online_flag == 1): yield row arr = numpy.fromiter(_g(), dtype=numpy.int64) index_dict[_ms.basename].extend(arr) for vis in index_dict: index_dict[vis] = numpy.asarray(index_dict[vis]) #index_dict[vis].sort() return index_dict
[docs]def get_valid_ms_members(group_desc, msname_filter, ant_selection, field_selection, spw_selection): for member_id in range(len(group_desc)): member = group_desc[member_id] spw_id = member.spw_id field_id = member.field_id ant_id = member.antenna_id msobj = member.ms if msobj.name in [os.path.abspath(name) for name in msname_filter]: _field_selection = field_selection try: nfields = len(msobj.fields) if len(field_selection) == 0: # fine, go ahead pass elif not field_selection.isdigit(): # selection by name, bracket by "" LOG.debug('non-digit field selection') if not _field_selection.startswith('"'): _field_selection = '"{}"'.format(field_selection) else: tmp_id = int(field_selection) LOG.debug('field_id = {}'.format(tmp_id)) if tmp_id < 0 or nfields <= tmp_id: # could be selection by name consisting of digits, bracket by "" LOG.debug('field name consisting digits') if not _field_selection.startswith('"'): _field_selection = '"{}"'.format(field_selection) LOG.debug('field_selection = "{}"'.format(_field_selection)) mssel = casa_tools.ms.msseltoindex(vis=msobj.name, spw=spw_selection, field=_field_selection, baseline=ant_selection) except RuntimeError as e: LOG.trace('RuntimeError: {0}'.format(str(e))) LOG.trace('vis="{0}" field_selection: "{1}"'.format(msobj.name, _field_selection)) continue spwsel = mssel['spw'] fieldsel = mssel['field'] antsel = mssel['antenna1'] if ((len(spwsel) == 0 or spw_id in spwsel) and (len(fieldsel) == 0 or field_id in fieldsel) and (len(antsel) == 0 or ant_id in antsel)): yield member_id
[docs]def get_valid_ms_members2(group_desc, ms_filter, ant_selection, field_selection, spw_selection): for member_id in range(len(group_desc)): member = group_desc[member_id] spw_id = member.spw_id field_id = member.field_id ant_id = member.antenna_id msobj = member.ms if msobj in ms_filter: try: mssel = casa_tools.ms.msseltoindex(vis=msobj.name, spw=spw_selection, field=field_selection, baseline=ant_selection) except RuntimeError as e: LOG.trace('RuntimeError: {0}'.format(str(e))) LOG.trace('vis="{0}" field_selection: "{1}"'.format(msobj.name, field_selection)) continue spwsel = mssel['spw'] fieldsel = mssel['field'] antsel = mssel['antenna1'] if ((spwsel.size == 0 or spw_id in spwsel) and (fieldsel.size == 0 or field_id in fieldsel) and (antsel.size == 0 or ant_id in antsel)): yield member_id
def _collect_logrecords(logger): capture_handlers = [h for h in logger.handlers if h.__class__.__name__ == 'CapturingHandler'] logrecords = [] for handler in capture_handlers: logrecords.extend(handler.buffer[:]) return logrecords
[docs]@contextlib.contextmanager def TableSelector(name, query): with casa_tools.TableReader(name) as tb: tsel = tb.query(query) yield tsel tsel.close()
# dictionary that always returns key
[docs]class EchoDictionary(dict): def __getitem__(self, x): return x
[docs]def make_row_map_for_baselined_ms(ms, table_container=None): """ Make row mapping between calibrated MS and baselined MS. Return value is a dictionary whose key is row number for calibrated MS and its corresponding value is the one for baselined MS. ms: measurement set domain object returns: row mapping dictionary """ work_data = ms.work_data src_tb = None derived_tb = None if table_container is not None: src_tb = table_container.tb1 derived_tb = table_container.tb2 return make_row_map(ms, work_data, src_tb, derived_tb)
#@profiler
[docs]def make_row_map(src_ms, derived_vis, src_tb=None, derived_tb=None): """ Make row mapping between source MS and associating MS src_ms: measurement set domain object for source MS derived_vis: name of the MS that derives from source MS returns: row mapping dictionary """ ms = src_ms vis0 = ms.name vis1 = derived_vis rowmap = {} if vis0 == vis1: return EchoDictionary() # make polarization map between src MS and derived MS to_derived_polid = make_polid_map(vis0, vis1) LOG.trace('to_derived_polid=%s' % to_derived_polid) # make spw map between src MS and derived MS to_derived_spwid = make_spwid_map(vis0, vis1) LOG.trace('to_derived_spwid=%s' % to_derived_spwid) # make a map between (polid, spwid) pair and ddid for derived MS derived_ddid_map = make_ddid_map(vis1) LOG.trace('derived_ddid_map=%s' % derived_ddid_map) scans = ms.get_scans(scan_intent='TARGET') scan_numbers = [s.id for s in scans] fields = {} states = {} for scan in scans: fields[scan.id] = [f.id for f in scan.fields if 'TARGET' in f.intents] states[scan.id] = [s.id for s in scan.states if 'TARGET' in s.intents] field_values = list(fields.values()) is_unique_field_set = True for v in field_values: if v != field_values[0]: is_unique_field_set = False state_values = list(states.values()) is_unique_state_set = True for v in state_values: if v != state_values[0]: is_unique_state_set = False if is_unique_field_set and is_unique_state_set: taql = 'ANTENNA1 == ANTENNA2 && SCAN_NUMBER IN %s && FIELD_ID IN %s && STATE_ID IN %s' % (scan_numbers, field_values[0], state_values[0]) else: taql = 'ANTENNA1 == ANTENNA2 && (%s)' % (' || '.join(['(SCAN_NUMBER == %s && FIELD_ID IN %s && STATE_ID IN %s)' % (scan, fields[scan], states[scan]) for scan in scan_numbers])) LOG.trace('taql=\'%s\'' % (taql)) with casa_tools.TableReader(os.path.join(vis0, 'OBSERVATION')) as tb: nrow_obs0 = tb.nrows() with casa_tools.TableReader(os.path.join(vis0, 'PROCESSOR')) as tb: nrow_proc0 = tb.nrows() with casa_tools.TableReader(os.path.join(vis1, 'OBSERVATION')) as tb: nrow_obs1 = tb.nrows() with casa_tools.TableReader(os.path.join(vis1, 'PROCESSOR')) as tb: nrow_proc1 = tb.nrows() assert nrow_obs0 == nrow_obs1 assert nrow_proc0 == nrow_proc1 is_unique_observation_id = nrow_obs0 == 1 is_unique_processor_id = nrow_proc0 == 1 if src_tb is None: with casa_tools.TableReader(vis0) as tb: tsel = tb.query(taql) try: if is_unique_observation_id: observation_id_list0 = None observation_id_set = {0} else: observation_id_list0 = tsel.getcol('OBSERVATION_ID') observation_id_set = set(observation_id_list0) if is_unique_processor_id: processor_id_list0 = None processor_id_set = {0} else: processor_id_list0 = tsel.getcol('PROCESSOR_ID') processor_id_set = set(processor_id_list0) scan_number_list0 = tsel.getcol('SCAN_NUMBER') field_id_list0 = tsel.getcol('FIELD_ID') antenna1_list0 = tsel.getcol('ANTENNA1') state_id_list0 = tsel.getcol('STATE_ID') data_desc_id_list0 = tsel.getcol('DATA_DESC_ID') time_list0 = tsel.getcol('TIME') rownumber_list0 = tsel.rownumbers() finally: tsel.close() else: tsel = src_tb.query(taql) try: if is_unique_observation_id: observation_id_list0 = None observation_id_set = {0} else: observation_id_list0 = tsel.getcol('OBSERVATION_ID') observation_id_set = set(observation_id_list0) if is_unique_processor_id: processor_id_list0 = None processor_id_set = {0} else: processor_id_list0 = tsel.getcol('PROCESSOR_ID') processor_id_set = set(processor_id_list0) scan_number_list0 = tsel.getcol('SCAN_NUMBER') field_id_list0 = tsel.getcol('FIELD_ID') antenna1_list0 = tsel.getcol('ANTENNA1') state_id_list0 = tsel.getcol('STATE_ID') data_desc_id_list0 = tsel.getcol('DATA_DESC_ID') time_list0 = tsel.getcol('TIME') rownumber_list0 = tsel.rownumbers() finally: tsel.close() if derived_tb is None: with casa_tools.TableReader(vis1) as tb: tsel = tb.query(taql) try: if is_unique_observation_id: observation_id_list1 = None else: observation_id_list1 = tsel.getcol('OBSERVATION_ID') if is_unique_processor_id: processor_id_list1 = None else: processor_id_list1 = tsel.getcol('PROCESSOR_ID') scan_number_list1 = tsel.getcol('SCAN_NUMBER') field_id_list1 = tsel.getcol('FIELD_ID') antenna1_list1 = tsel.getcol('ANTENNA1') state_id_list1 = tsel.getcol('STATE_ID') data_desc_id_list1 = tsel.getcol('DATA_DESC_ID') time_list1 = tsel.getcol('TIME') rownumber_list1 = tsel.rownumbers() finally: tsel.close() else: tsel = derived_tb.query(taql) try: if is_unique_observation_id: observation_id_list1 = None else: observation_id_list1 = tsel.getcol('OBSERVATION_ID') if is_unique_processor_id: processor_id_list1 = None else: processor_id_list1 = tsel.getcol('PROCESSOR_ID') scan_number_list1 = tsel.getcol('SCAN_NUMBER') field_id_list1 = tsel.getcol('FIELD_ID') antenna1_list1 = tsel.getcol('ANTENNA1') state_id_list1 = tsel.getcol('STATE_ID') data_desc_id_list1 = tsel.getcol('DATA_DESC_ID') time_list1 = tsel.getcol('TIME') rownumber_list1 = tsel.rownumbers() finally: tsel.close() for processor_id in processor_id_set: LOG.trace('PROCESSOR_ID %s' % processor_id) for observation_id in observation_id_set: LOG.trace('OBSERVATION_ID %s' % observation_id) for scan_number in scan_numbers: LOG.trace('SCAN_NUMBER %s' % scan_number) if scan_number not in states: LOG.trace('No target states in SCAN %s' % scan_number) continue for field_id in fields[scan_number]: LOG.trace('FIELD_ID %s' % field_id) for antenna in ms.antennas: antenna_id = antenna.id LOG.trace('ANTENNA_ID %s' % antenna_id) for spw in ms.get_spectral_windows(science_windows_only=True): data_desc = ms.get_data_description(spw=spw) data_desc_id = data_desc.id pol_id = data_desc.pol_id spw_id = spw.id LOG.trace('START PROCESSOR %s SCAN %s DATA_DESC_ID %s ANTENNA %s FIELD %s' % (processor_id, scan_number, data_desc_id, antenna_id, field_id)) derived_pol_id = to_derived_polid[pol_id] derived_spw_id = to_derived_spwid[spw_id] derived_dd_id = derived_ddid_map[(derived_pol_id, derived_spw_id)] LOG.trace('SRC DATA_DESC_ID %s (SPW %s)' % (data_desc_id, spw_id)) LOG.trace('DERIVED DATA_DESC_ID %s (SPW %s)' % (derived_dd_id, derived_spw_id)) tmask0 = numpy.logical_and( data_desc_id_list0 == data_desc_id, numpy.logical_and(antenna1_list0 == antenna_id, numpy.logical_and(field_id_list0 == field_id, scan_number_list0 == scan_number))) if not is_unique_processor_id: numpy.logical_and(tmask0, processor_id_list0 == processor_id, out=tmask0) if not is_unique_observation_id: numpy.logical_and(tmask0, observation_id_list0 == observation_id, out=tmask0) tmask1 = numpy.logical_and( data_desc_id_list1 == derived_dd_id, numpy.logical_and(antenna1_list1 == antenna_id, numpy.logical_and(field_id_list1 == field_id, scan_number_list1 == scan_number))) if not is_unique_processor_id: numpy.logical_and(tmask1, processor_id_list1 == processor_id, out=tmask1) if not is_unique_observation_id: numpy.logical_and(tmask1, observation_id_list1 == observation_id, out=tmask1) if numpy.all(tmask0 == False) and numpy.all(tmask1 == False): # no corresponding data (probably due to PROCESSOR_ID for non-science windows) LOG.trace('SKIP PROCESSOR %s SCAN %s DATA_DESC_ID %s ANTENNA %s FIELD %s' % (processor_id, scan_number, data_desc_id, antenna_id, field_id)) continue tstate0 = state_id_list0[tmask0] tstate1 = state_id_list1[tmask1] ttime0 = time_list0[tmask0] ttime1 = time_list1[tmask1] trow0 = rownumber_list0[tmask0] trow1 = rownumber_list1[tmask1] sort_index0 = numpy.lexsort((tstate0, ttime0)) sort_index1 = numpy.lexsort((tstate1, ttime1)) LOG.trace('scan %s' % (scan_number) + ' actual %s' % (list(set(tstate0))) + ' expected %s' % (states[scan_number])) assert numpy.all(ttime0[sort_index0] == ttime1[sort_index1]) assert numpy.all(tstate0[sort_index0] == tstate1[sort_index1]) # assert set(tstate0) == set(states[scan_number]) assert set(tstate0).issubset(set(states[scan_number])) for (i0, i1) in zip(sort_index0, sort_index1): r0 = trow0[i0] r1 = trow1[i1] rowmap[r0] = r1 LOG.trace('END PROCESSOR %s SCAN %s DATA_DESC_ID %s ANTENNA %s FIELD %s' % (processor_id, scan_number, data_desc_id, antenna_id, field_id)) return rowmap
[docs]class SpwSimpleView(object): def __init__(self, spwid, name): self.id = spwid self.name = name
[docs]class SpwDetailedView(object): def __init__(self, spwid, name, num_channels, ref_frequency, min_frequency, max_frequency): self.id = spwid self.name = name self.num_channels = num_channels self.ref_frequency = ref_frequency self.min_frequency = min_frequency self.max_frequency = max_frequency
[docs]def get_spw_names(vis): with casa_tools.TableReader(os.path.join(vis, 'SPECTRAL_WINDOW')) as tb: gen = (SpwSimpleView(i, tb.getcell('NAME', i)) for i in range(tb.nrows())) spws = list(gen) return spws
[docs]def get_spw_properties(vis): with casa_tools.TableReader(os.path.join(vis, 'SPECTRAL_WINDOW')) as tb: spws = [] for irow in range(tb.nrows()): name = tb.getcell('NAME', irow) nchan = tb.getcell('NUM_CHAN', irow) ref_freq = tb.getcell('REF_FREQUENCY', irow) chan_freq = tb.getcell('CHAN_FREQ', irow) chan_width = tb.getcell('CHAN_WIDTH', irow) min_freq = chan_freq.min() - abs(chan_width[0]) / 2 max_freq = chan_freq.max() + abs(chan_width[0]) / 2 spws.append(SpwDetailedView(irow, name, nchan, ref_freq, min_freq, max_freq)) return spws
# @profiler def __read_table(reader, method, vis): if reader is None: result = method(vis) else: with reader(vis) as readerobj: result = method(readerobj) return result def _read_table(reader, table, vis): rows = __read_table(reader, table._read_table, vis) return rows # @profiler
[docs]def make_spwid_map(srcvis, dstvis): # src_spws = __read_table(casa_tools.MSMDReader, # tablereader.SpectralWindowTable.get_spectral_windows, # srcvis) # dst_spws = __read_table(casa_tools.MSMDReader, # tablereader.SpectralWindowTable.get_spectral_windows, # dstvis) src_spws = __read_table(None, get_spw_properties, srcvis) dst_spws = __read_table(None, get_spw_properties, dstvis) for spw in src_spws: LOG.trace('SRC SPWID %s NAME %s' % (spw.id, spw.name)) for spw in dst_spws: LOG.trace('DST SPWID %s NAME %s' % (spw.id, spw.name)) map_byname = collections.defaultdict(list) for src_spw in src_spws: for dst_spw in dst_spws: if src_spw.name == dst_spw.name: map_byname[src_spw].append(dst_spw) spwid_map = {} for src, dst in map_byname.items(): LOG.trace('map_byname src spw %s: dst spws %s' % (src.id, [spw.id for spw in dst])) if len(dst) == 0: continue elif len(dst) == 1: # mapping by name worked spwid_map[src.id] = dst[0].id else: # need more investigation for spw in dst: if (src.num_channels == spw.num_channels and src.ref_frequency == spw.ref_frequency and src.min_frequency == spw.min_frequency and src.max_frequency == spw.max_frequency): if src.id in spwid_map: raise RuntimeError('Failed to create spw map for MSs \'%s\' and \'%s\'' % (srcvis, dstvis)) spwid_map[src.id] = spw.id return spwid_map
# @profiler
[docs]def make_polid_map(srcvis, dstvis): src_rows = _read_polarization_table(srcvis) dst_rows = _read_polarization_table(dstvis) for (src_polid, src_numpol, src_poltype, _, _) in src_rows: LOG.trace('SRC: POLID %s NPOL %s POLTYPE %s' % (src_polid, src_numpol, src_poltype)) for (dst_polid, dst_numpol, dst_poltype, _, _) in dst_rows: LOG.trace('DST: POLID %s NPOL %s POLTYPE %s' % (dst_polid, dst_numpol, dst_poltype)) polid_map = {} for (src_polid, src_numpol, src_poltype, _, _) in src_rows: for (dst_polid, dst_numpol, dst_poltype, _, _) in dst_rows: if src_numpol == dst_numpol and numpy.all(src_poltype == dst_poltype): polid_map[src_polid] = dst_polid LOG.trace('polid_map = %s' % polid_map) return polid_map
# @profiler
[docs]def make_ddid_map(vis): with casa_tools.TableReader(os.path.join(vis, 'DATA_DESCRIPTION')) as tb: pol_ids = tb.getcol('POLARIZATION_ID') spw_ids = tb.getcol('SPECTRAL_WINDOW_ID') num_ddids = tb.nrows() ddid_map = {} for ddid in range(num_ddids): ddid_map[(pol_ids[ddid], spw_ids[ddid])] = ddid return ddid_map
[docs]def get_datacolumn_name(vis): colname_candidates = ['CORRECTED_DATA', 'FLOAT_DATA', 'DATA'] with casa_tools.TableReader(vis) as tb: colnames = tb.colnames() colname = None for name in colname_candidates: if name in colnames: colname = name break assert colname is not None return colname
# helper functions for parallel execution
[docs]def create_serial_job(task_cls, task_args, context): inputs = task_cls.Inputs(context, **task_args) task = task_cls(inputs) job = mpihelpers.SyncTask(task) LOG.debug('Serial Job: %s' % task) return job
[docs]def create_parallel_job(task_cls, task_args, context): context_path = os.path.join(context.output_dir, context.name + '.context') if not os.path.exists(context_path): context.save(context_path) task = mpihelpers.Tier0PipelineTask(task_cls, task_args, context_path) job = mpihelpers.AsyncTask(task) LOG.debug('Parallel Job: %s' % task) return job
def _read_polarization_table(vis): """ Read the POLARIZATION table of the given measurement set. This function used to be part of tablereader, which has since moved from direct table reading to using the MSMD tool. """ LOG.debug('Analysing POLARIZATION table') polarization_table = os.path.join(vis, 'POLARIZATION') with casa_tools.TableReader(polarization_table) as table: num_corrs = table.getcol('NUM_CORR') vcorr_types = table.getvarcol('CORR_TYPE') vcorr_products = table.getvarcol('CORR_PRODUCT') flag_rows = table.getcol('FLAG_ROW') rowids = [] corr_types = [] corr_products = [] for i in range(table.nrows()): rowids.append(i) corr_types.append(vcorr_types['r%s' % (i + 1)]) corr_products.append(vcorr_products['r%s' % (i + 1)]) rows = list(zip(rowids, num_corrs, corr_types, corr_products, flag_rows)) return rows
[docs]def get_restfrequency(vis, spwid, source_id): source_table = os.path.join(vis, 'SOURCE') with casa_tools.TableReader(source_table) as tb: tsel = tb.query('SOURCE_ID == {} && SPECTRAL_WINDOW_ID == {}'.format(source_id, spwid)) try: if tsel.nrows() == 0: return None else: if tsel.iscelldefined('REST_FREQUENCY', 0): return tsel.getcell('REST_FREQUENCY', 0)[0] else: return None finally: tsel.close()
[docs]class RGAccumulator(object): def __init__(self): self.field = [] self.antenna = [] self.spw = [] self.pols = [] self.grid_table = [] self.channelmap_range = []
[docs] def append(self, field_id, antenna_id, spw_id, pol_ids=None, grid_table=None, channelmap_range=None): self.field.append(field_id) self.antenna.append(antenna_id) self.spw.append(spw_id) self.pols.append(pol_ids) if isinstance(grid_table, compress.CompressedObj) or grid_table is None: self.grid_table.append(grid_table) else: self.grid_table.append(compress.CompressedObj(grid_table)) self.channelmap_range.append(channelmap_range)
# def extend(self, field_id_list, antenna_id_list, spw_id_list): # self.field.extend(field_id_list) # self.antenna.extend(antenna_id_list) # self.spw.extend(spw_id_list) #
[docs] def get_field_id_list(self): return self.field
[docs] def get_antenna_id_list(self): return self.antenna
[docs] def get_spw_id_list(self): return self.spw
[docs] def get_pol_ids_list(self): return self.pols
[docs] def get_grid_table_list(self): return self.grid_table
[docs] def get_channelmap_range_list(self): return self.channelmap_range
[docs] def iterate_id(self): assert len(self.field) == len(self.antenna) assert len(self.field) == len(self.spw) assert len(self.field) == len(self.pols) for v in zip(self.field, self.antenna, self.spw): yield v
[docs] def iterate_all(self): assert len(self.field) == len(self.antenna) assert len(self.field) == len(self.spw) assert len(self.field) == len(self.pols) assert len(self.field) == len(self.grid_table) assert len(self.field) == len(self.channelmap_range) for f, a, s, g, c in zip(self.field, self.antenna, self.spw, self.grid_table, self.channelmap_range): _g = g.decompress() yield f, a, s, _g, c del _g
[docs] def get_process_list(self, withpol=False): field_id_list = self.get_field_id_list() antenna_id_list = self.get_antenna_id_list() spw_id_list = self.get_spw_id_list() assert len(field_id_list) == len(antenna_id_list) assert len(field_id_list) == len(spw_id_list) if withpol == True: pol_ids_list = self.get_pol_ids_list() assert len(field_id_list) == len(pol_ids_list) return field_id_list, antenna_id_list, spw_id_list, pol_ids_list else: return field_id_list, antenna_id_list, spw_id_list
[docs]def sort_fields(context): mses = context.observing_run.measurement_sets sorted_names = [] sorted_fields = [] for ms in mses: fields = ms.get_fields(intent='TARGET') for f in fields: if f.name not in sorted_names: sorted_fields.append(f) sorted_names.append(f.name) return sorted_fields
[docs]def get_brightness_unit(vis, defaultunit='Jy/beam'): with casa_tools.TableReader(vis) as tb: colnames = tb.colnames() target_columns = ['CORRECTED_DATA', 'FLOAT_DATA', 'DATA'] bunit = defaultunit for col in target_columns: if col in colnames: keys = tb.getcolkeywords(col) if 'UNIT' in keys: _bunit = keys['UNIT'] if len(_bunit) > 0: # should be K or Jy # update bunit only when UNIT is K if _bunit == 'K': bunit = 'K' break return bunit