#
# ALMA - Atacama Large Milliiter Array (c) European Southern Observatory, 2002
# Copyright by ESO (in the framework of the ALMA collaboration), All rights
# reserved
#
# This library is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
# $Revision: 1.1.2.6 $
# $Date: 2013/03/01 05:07:45 $
# $Author: tnakazat $
#
import bisect
import collections
import os
import re
import time
# import memory_profiler
import numpy
import casatools
import pipeline.infrastructure as infrastructure
from pipeline.infrastructure import casa_tools
LOG = infrastructure.get_logger(__name__)
def __coldesc(vtype, option, maxlen, ndim, comment, unit=None):
d = {'dataManagerGroup': 'StandardStMan',
'dataManagerType': 'StandardStMan',
'valueType': vtype,
'option': option,
'maxlen': maxlen,
'comment': comment}
if ndim > 0:
d['ndim'] = ndim
if unit is not None:
d['keywords'] = {'UNIT': unit}
return d
# Description for data table columns as dictionary.
# Each value is a tuple containing:
#
# (valueType,option,maxlen,ndim,comment[,unit])
#
# dataManagerGroup and dataManagerType is always 'StandardStMan'.
#
# 2018/07/31 HE : added SHIFT_RA, SHIFT_DEC for CAS-11674
# 2019/05/23 HE : added OFS_RA, OFS_DEC for PIPE-220
def __tabledescro():
TD_DESC_RO = [
# __coldesc('integer', 0, 0, -1, 'Primary key'),
__coldesc('integer', 0, 0, -1, 'Row number'),
__coldesc('integer', 0, 0, -1, 'Scan number'),
__coldesc('integer', 0, 0, -1, 'IF number'),
__coldesc('integer', 0, 0, -1, 'Number of Polarizations'),
__coldesc('integer', 0, 0, -1, 'Beam number'),
__coldesc('string', 0, 0, -1, 'Date'),
__coldesc('double', 0, 0, -1, 'Time in MJD', 'd'),
__coldesc('double', 0, 0, -1, 'Elapsed time since first scan' 'd'),
__coldesc('double', 0, 0, -1, 'Exposure time', 's'),
__coldesc('double', 0, 0, -1, 'Right Ascension', 'deg'),
__coldesc('double', 0, 0, -1, 'Declination', 'deg'),
__coldesc('double', 0, 0, -1, 'Shifted Right Ascension', 'deg'),
__coldesc('double', 0, 0, -1, 'Shifted Declination', 'deg'),
__coldesc('double', 0, 0, -1, 'Offset Right Ascension', 'deg'),
__coldesc('double', 0, 0, -1, 'Offset Declination', 'deg'),
__coldesc('double', 0, 0, -1, 'Azimuth', 'deg'),
__coldesc('double', 0, 0, -1, 'Elevation', 'deg'),
__coldesc('integer', 0, 0, -1, 'Number of channels'),
__coldesc('double', 0, 0, 1, 'Tsys', 'K'),
__coldesc('string', 0, 0, -1, 'Target name'),
__coldesc('integer', 0, 0, -1, 'Antenna index'),
__coldesc('integer', 0, 0, -1, 'Source type enum'),
__coldesc('integer', 0, 0, -1, 'Field ID')
]
name = [
'ROW', 'SCAN', 'IF', 'NPOL', 'BEAM', 'DATE',
'TIME', 'ELAPSED', 'EXPOSURE', 'RA', 'DEC',
'SHIFT_RA', 'SHIFT_DEC', 'OFS_RA', 'OFS_DEC',
'AZ', 'EL', 'NCHAN', 'TSYS', 'TARGET', 'ANTENNA',
'SRCTYPE', 'FIELD_ID'
]
return dict(zip(name, TD_DESC_RO))
def __tabledescrw():
TD_DESC_RW = [
__coldesc('double', 0, 0, 2, 'Statistics'),
__coldesc('integer', 0, 0, 2, 'Flgas'),
__coldesc('integer', 0, 0, 2, 'Permanent flags'),
__coldesc('integer', 0, 0, 1, 'Actual flag'),
__coldesc('integer', 0, 0, -1, 'Number of mask regions'),
__coldesc('integer', 0, 0, 2, 'List of mask ranges'),
__coldesc('integer', 0, 0, -1, 'Unchanged row or not'),
__coldesc('integer', 0, 0, -1, 'Position group id')
]
name = [
'STATISTICS', 'FLAG', 'FLAG_PERMANENT',
'FLAG_SUMMARY', 'NMASK', 'MASKLIST', 'NOCHANGE',
'POSGRP']
return dict(zip(name, TD_DESC_RW))
TABLEDESC_RO = __tabledescro()
TABLEDESC_RW = __tabledescrw()
[docs]def create_table(table, name, desc, memtype='plain', nrow=0):
ret = table.create(name, desc, memtype=memtype, nrow=nrow)
assert ret == True
for _colname, _coldesc in desc.items():
if 'keywords' in _coldesc:
table.putcolkeywords(_colname, _coldesc['keywords'])
# FLAG_PERMANENT Layout
WeatherFlagIndex = 0
TsysFlagIndex = 1
UserFlagIndex = 2
OnlineFlagIndex = 3
[docs]def absolute_path(name):
return os.path.abspath(os.path.expanduser(os.path.expandvars(name)))
[docs]class DataTableIndexer(object):
"""
DataTableIndexer is responsible for mapping between classical
(serial) row indices and per-MS row indices.
"""
@property
def mses(self):
return self.context.observing_run.measurement_sets
def __init__(self, context):
self.context = context
self.nrow_per_ms = []
for ms in context.observing_run.measurement_sets:
ro_table_name = os.path.join(context.observing_run.ms_datatable_name, ms.basename, 'RO')
with casa_tools.TableReader(ro_table_name) as tb:
self.nrow_per_ms.append(tb.nrows())
self.num_mses = len(self.nrow_per_ms)
[docs] def serial2perms(self, i):
"""
Return two indices. The former indicates a MS index while
the later corresponds to the row index of the datatable for
that MS.
i -- serial index
"""
base = 0
for j in range(self.num_mses):
past_base = base
base += self.nrow_per_ms[j]
if i < base:
return self.mses[j].basename, i - past_base
raise RuntimeError('Internal Consistency Error. ')
[docs] def perms2serial(self, vis, i):
"""
Return serial index.
vis -- basename of the MS
i -- per MS datatable row index
"""
j = self.mses.index(self.context.observing_run.get_ms(vis))
assert j < self.num_mses
assert i < self.nrow_per_ms[j]
base = sum(self.nrow_per_ms[:j])
return base + i
[docs] def per_ms_index_list(self, ms, index_list):
j = self.mses.index(ms)
base = sum(self.nrow_per_ms[:j])
length = self.nrow_per_ms[j]
perms_list = numpy.where(numpy.logical_and(index_list >= base,
index_list < base + length), index_list)
return perms_list - base
[docs]class DataTableImpl(object):
"""
DataTable is an object to hold meta data of scantable on memory.
row layout: [Row, Scan, IF, Pol, Beam, Date, Time, ElapsedTime,
0, 1, 2, 3, 4, 5, 6, 7,
Exptime, RA, DEC, Az, El, nchan, Tsys, TargetName,
8, 9, 10, 11, 12, 13, 14, 15,
Statistics, Flags, PermanentFlags, SummaryFlag, Nmask, MaskList, NoChange, Ant]
16, 17, 18, 19, 20, 21, 22, 23
Statistics: DataTable[ID][16] =
[LowFreqRMS, NewRMS, OldRMS, NewRMSdiff, OldRMSdiff, ExpectedRMS, ExpectedRMS]
0, 1, 2, 3, 4, 5, 6
Flags: DataTable[ID][17] =
[LowFrRMSFlag, PostFitRMSFlag, PreFitRMSFlag, PostFitRMSdiff, PreFitRMSdiff, PostFitExpRMSFlag, PreFitExpRMSFlag]
0, 1, 2, 3, 4, 5, 6
PermanentFlags: DataTable[ID][18] =
[WeatherFlag, TsysFlag, UserFlag]
0, 1, 2
Note for Flags: 1 is valid, 0 is invalid
"""
[docs] @classmethod
def get_rotable_name(cls, datatable_name):
return os.path.join(datatable_name, 'RO')
[docs] @classmethod
def get_rwtable_name(cls, datatable_name):
return os.path.join(datatable_name, 'RW')
REFKEY = 'DIRECTION_REF'
def __init__(self, name=None, readonly=None):
"""
name (optional) -- name of DataTable
"""
# unique memory table name
timestamp = ('%f' % (time.time())).replace('.', '')
self.memtable1 = 'DataTableImplRO%s.MemoryTable' % timestamp
self.memtable2 = 'DataTableImplRW%s.MemoryTable' % timestamp
self.plaintable = ''
self.cols = {}
self.tb1, self.tb2 = casatools.table(), casatools.table()
self.isopened = False
if name is None or len(name) == 0:
if readonly is None:
readonly = False
self._create(readonly=readonly)
elif not os.path.exists(name):
if readonly is None:
readonly = False
self._create(readonly=readonly)
self.plaintable = absolute_path(name)
else:
if readonly is None:
readonly = True
self.importdata2(name=name, minimal=False, readonly=readonly)
def __del__(self):
# make sure that table is closed
# LOG.debug('__del__ close CASA table...')
self.cols.clear()
self._close()
def __len__(self):
return self.nrow
@property
def nrow(self):
if self.isopened:
return self.tb1.nrows()
else:
return 0
@property
def name(self):
return self.plaintable
@property
def position_group_id(self):
key = 'POSGRP_REP'
if self.haskeyword(key):
return numpy.max(numpy.fromiter((int(x) for x in self.getkeyword(key)), dtype=numpy.int32)) + 1
else:
return 0
@property
def time_group_id_small(self):
return self.__get_time_group_id(True)
@property
def time_group_id_large(self):
return self.__get_time_group_id(False)
@property
def direction_ref(self):
if self.REFKEY in self.tb1.keywordnames():
return self.tb1.getkeyword(self.REFKEY)
else:
return None
@direction_ref.setter
def direction_ref(self, value):
# value must be string
if not isinstance(value, str):
return
# set value only if it is not yet registered to table
if self.REFKEY not in self.tb1.keywordnames():
self.tb1.putkeyword(self.REFKEY, value)
def __get_time_group_id(self, small=True):
if small:
subkey = 'SMALL'
else:
subkey = 'LARGE'
pattern = '^TIMETABLE_%s_.*' % subkey
if numpy.any(numpy.fromiter((re.match(pattern, x) is not None for x in self.keywordnames()), dtype=bool)):
group_id = 0
for key in self.tb2.keywordnames():
if re.match(pattern, key) is not None:
max_id = numpy.max(
numpy.fromiter((int(x) for x in self.getkeyword(key)), dtype=numpy.int32)) + 1
group_id = max(group_id, max_id)
return group_id
else:
return 0
[docs] def haskeyword(self, name):
return name in self.tb2.keywordnames()
[docs] def addrows(self, nrow):
self.tb1.addrows(nrow)
self.tb2.addrows(nrow)
# reset self.plaintable since memory table and corresponding
# plain table have different nrows
self.plaintable == ''
[docs] def colnames(self):
return list(self.cols.keys())
[docs] def getcol(self, name, startrow=0, nrow=-1, rowincr=1):
return self.cols[name].getcol(startrow, nrow, rowincr)
[docs] def putcol(self, name, val, startrow=0, nrow=-1, rowincr=1):
self.cols[name].putcol(val, startrow, nrow, rowincr)
[docs] def getcell(self, name, idx):
return self.cols[name].getcell(idx)
[docs] def putcell(self, name, idx, val):
"""
name -- column name
idx -- row index
val -- value to be put
"""
self.cols[name].putcell(idx, val)
[docs] def getcolslice(self, name, blc, trc, incr, startrow=0, nrow=-1, rowincr=1):
return self.cols[name].getcolslice(blc, trc, incr, startrow, nrow, rowincr)
[docs] def putcolslice(self, name, value, blc, trc, incr, startrow=0, nrow=-1, rowincr=1):
self.cols[name].putcolslice(value, blc, trc, incr, startrow, nrow, rowincr)
[docs] def getcellslice(self, name, rownr, blc, trc, incr):
return self.cols[name].getcellslice(rownr, blc, trc, incr)
[docs] def putcellslice(self, name, rownr, value, blc, trc, incr):
self.cols[name].putcellslice(rownr, value, blc, trc, incr)
[docs] def getcolkeyword(self, columnname, keyword):
if columnname in TABLEDESC_RO:
return self.tb1.getcolkeyword(columnname, keyword)
else:
return self.tb2.getcolkeyword(columnname, keyword)
[docs] def putkeyword(self, name, val):
"""
name -- keyword name
val -- keyword value
"""
if isinstance(val, str):
_val = '"{}"'.format(val)
else:
_val = str(val)
self.tb2.putkeyword(name, _val)
[docs] def getkeyword(self, name):
"""
name -- keyword name
"""
_val = self.tb2.getkeyword(name)
val = eval(_val)
return val
[docs] def keywordnames(self):
"""
return table keyword names
"""
return self.tb2.keywordnames()
[docs] def importdata(self, name, minimal=True, readonly=True):
"""
name -- name of DataTable to be imported
"""
LOG.debug('Importing DataTable from %s...' % name)
# copy input table to memory
self._copyfrom(name, minimal)
self.plaintable = absolute_path(name)
self.__init_cols(readonly=readonly)
[docs] def importdata2(self, name, minimal=True, readonly=True):
"""
name -- name of DataTable to be imported
"""
LOG.debug('Importing DataTable from %s...' % name)
# copy input table to memory
self._copyfrom2(name, minimal)
self.plaintable = absolute_path(name)
self.__init_cols(readonly=readonly)
[docs] def sync(self, minimal=True):
"""
Sync with DataTable on disk.
"""
self.importdata(name=self.plaintable, minimal=minimal)
[docs] def exportdata(self, name=None, minimal=True, overwrite=False):
"""
name -- name of exported DataTable
overwrite -- overwrite existing DataTable
"""
if name is None or len(name) == 0:
if len(self.plaintable) == 0:
raise IOError('You have to specify name of export table')
else:
name = self.plaintable
overwrite = True
LOG.debug('Exporting DataTable to %s...' % name)
# overwrite check
abspath = absolute_path(name)
basename = os.path.basename(abspath)
if not os.path.exists(abspath):
os.makedirs(abspath)
elif overwrite:
LOG.debug('Overwrite existing DataTable %s...' % name)
# os.system( 'rm -rf %s/*'%(abspath) )
else:
raise IOError('The file %s exists.' % name)
# save
if not minimal or not os.path.exists(os.path.join(abspath, 'RO')):
# LOG.trace('Exporting RO table')
if os.path.exists(self.tb1.name()):
# self.tb1 seems to be plain table, nothing to be done
pass
else:
# tb1 is memory table
tbloc = self.tb1.copy(os.path.join(abspath, 'RO'), deep=True,
valuecopy=True, returnobject=True)
tbloc.close()
# LOG.trace('Exporting RW table')
tbloc = self.tb2.copy(os.path.join(abspath, 'RW'), deep=True,
valuecopy=True, returnobject=True)
tbloc.close()
self.plaintable = abspath
# FIXME: this unused method contains bugs to fix.
[docs] def export_rwtable_exclusive(self, dirty_rows=None, cols=None):
"""
Export "on-memory" RW table to the one on disk.
To support parallel operation, the method will acquire a lock for RW table
to ensure the operation in one process doesn't overwrite the changes made by
other processes.
dirty_rows -- list of row numbers that are updated. If None, everything
including unchanged rows will be flushed. Default is None.
cols -- list of columns that are updated. If None, all rows will be flushed.
default is None.
"""
# RW table name
rwtable = self.get_rwtable_name(self.plaintable)
# list of ordinary columns
ordinary_cols = {'STATISTICS', 'FLAG', 'FLAG_PERMANENT', 'FLAG_SUMMARY', 'NMASK', 'POSGRP'}
intersects = ordinary_cols.intersection(cols)
# columns with special care
with_masklist = 'MASKLIST' in cols
with_nochange = 'NOCHANGE' in cols
# open table
with casa_tools.TableReader(rwtable, nomodify=False, lockoptions={'option': 'user'}) as tb:
# lock table
tb.lock()
LOG.info('Process {0} have acquired a lock for RW table'.format(os.getpid()))
if dirty_rows is None:
# process all rows
# FIXME: built-in range (or list) does not support .min(), .max(); should this be a numpy range?
dirty_rows = range(tb.nrows())
try:
nrow_chunk = 2000
# compute number of chunks
nrow = dirty_rows.max() - dirty_rows.min() + 1
nchunk = nrow // nrow_chunk
mod = nrow % nrow_chunk
chunks = [nrow_chunk] * nchunk + [mod]
# LOG.info('chunks={0} (nrow {1})'.format(chunks, nrow))
# for each column
for col in intersects:
start_row = dirty_rows.min()
for size_chunk in chunks:
# LOG.info('start_row {0}, size_chunk {1}'.format(start_row, size_chunk))
# read chunk
chunk_src = self.tb2.getcol(col, startrow=start_row, nrow=size_chunk)
chunk_dst = tb.getcol(col, startrow=start_row, nrow=size_chunk)
# update chunk
chunk_min = start_row
chunk_max = start_row + size_chunk - 1
target_rows = dirty_rows[numpy.logical_and(chunk_min <= dirty_rows,
dirty_rows <= chunk_max)]
for row in target_rows:
chunk_index = row - start_row
# LOG.info('row {0}, chunk_index {1}'.format(row, chunk_index))
chunk_dst[..., chunk_index] = chunk_src[..., chunk_index]
# flush chunk
tb.putcol(col, chunk_dst, startrow=start_row, nrow=size_chunk)
# increment start_row
start_row += size_chunk
# merge MASKLIST if necessary
if with_masklist is True:
src = self.cols['MASKLIST']
dst = DataTableColumnMaskList(tb)
for index, row in enumerate(dirty_rows):
data = src.getcell(index)
dst.putcell(row, data)
if with_nochange is True:
src = self.cols['NOCHANGE']
dst = DataTableColumnNoChange(tb)
for index, row in enumerate(dirty_rows):
data = src.getcell(index)
dst.putcell(row, data)
finally:
# the table lock must eventually be released
LOG.info('Process {0} is going to release a lock for RW table'.format(os.getpid()))
tb.unlock()
def _create(self, readonly=False):
self._close()
create_table(self.tb1, self.memtable1, TABLEDESC_RO, 'memory', self.nrow)
create_table(self.tb2, self.memtable2, TABLEDESC_RW, 'memory', self.nrow)
self.isopened = True
self.__init_cols(readonly=readonly)
def __init_cols(self, readonly=True):
self.cols.clear()
if readonly:
RO_COLUMN = RODataTableColumn
RW_COLUMN = RWDataTableColumn
else:
RO_COLUMN = RWDataTableColumn
RW_COLUMN = RWDataTableColumn
type_map = {'integer': int,
'double': float,
'string': str}
datatype = lambda desc: list if 'ndim' in desc and desc['ndim'] > 0 else type_map[desc['valueType']]
for k, v in TABLEDESC_RO.items():
self.cols[k] = RO_COLUMN(self.tb1, k, datatype(v))
for k, v in TABLEDESC_RW.items():
if k == 'MASKLIST':
self.cols[k] = DataTableColumnMaskList(self.tb2)
elif k == 'NOCHANGE':
self.cols[k] = DataTableColumnNoChange(self.tb2)
else:
self.cols[k] = RW_COLUMN(self.tb2, k, datatype(v))
def _close(self):
if self.isopened:
self.tb1.close()
self.tb2.close()
self.isopened = False
def _copyfrom(self, name, minimal=True):
self._close()
abspath = absolute_path(name)
if not minimal or abspath != self.plaintable:
with casa_tools.TableReader(os.path.join(name, 'RO')) as tb:
self.tb1 = tb.copy(self.memtable1, deep=True,
valuecopy=True, memorytable=True,
returnobject=True)
with casa_tools.TableReader(os.path.join(name, 'RW')) as tb:
self.tb2 = tb.copy(self.memtable2, deep=True,
valuecopy=True, memorytable=True,
returnobject=True)
self.isopened = True
def _copyfrom2(self, name, minimal=True):
self._close()
abspath = absolute_path(name)
if not minimal or abspath != self.plaintable:
with casa_tools.TableReader(os.path.join(name, 'RO')) as tb:
self.tb1 = tb.copy(self.memtable1, deep=True,
valuecopy=True, memorytable=True,
returnobject=True)
with casa_tools.TableReader(os.path.join(name, 'RW')) as tb:
self.tb2 = tb.copy(self.memtable2, deep=True,
valuecopy=True, memorytable=True,
returnobject=True)
self.isopened = True
[docs] def get_posdict(self, ant, spw, pol):
posgrp_list = self.getkeyword('POSGRP_LIST')
try:
mygrp = posgrp_list[str(ant)][str(spw)][str(pol)]
except KeyError:
raise KeyError('ant %s spw %s pol %s not in reduction group list' % (ant, spw, pol))
except Exception as e:
raise e
posgrp_rep = self.getkeyword('POSGRP_REP')
rows = self.getcol('ROW')
posgrp = self.getcol('POSGRP')
posdict = {}
for k, v in posgrp_rep.items():
if int(k) not in mygrp:
continue
key = rows[v]
posdict[key] = [[], []]
for idx in range(len(posgrp)):
grp = posgrp[idx]
if grp not in mygrp:
continue
row = rows[idx]
rep = posgrp_rep[str(grp)]
key = rows[rep]
posdict[key][0].append(row)
posdict[key][1].append(idx)
if row != key:
posdict[row] = [[-1, key], [rep]]
return posdict
[docs] def get_timegap(self, ant, spw, pol, asrow=True, ms=None, field_id=None):
timegap_s = self.getkeyword('TIMEGAP_S')
timegap_l = self.getkeyword('TIMEGAP_L')
if ms is None:
try:
mygap_s = timegap_s[ant][spw][pol]
mygap_l = timegap_l[ant][spw][pol]
except KeyError:
raise KeyError('ant %s spw %s pol %s not in reduction group list' % (ant, spw, pol))
except Exception as e:
raise e
else:
try:
mygap_s = timegap_s[ms.basename.replace('.', '_')][ant][spw][field_id]
mygap_l = timegap_l[ms.basename.replace('.', '_')][ant][spw][field_id]
except KeyError:
raise KeyError(
'ms %s field %s ant %s spw %s not in reduction group list' % (ms.basename, field_id, ant, spw))
except Exception as e:
raise e
if asrow:
rows = self.getcol('ROW')
timegap = [[], []]
for idx in mygap_s:
timegap[0].append(rows[idx])
for idx in mygap_l:
timegap[1].append(rows[idx])
else:
timegap = [mygap_s, mygap_l]
return timegap
def _update_tsys(self, context, infile, tsystable, spwmap, to_fieldid, gainfield):
"""
Transfer Tsys values in a Tsys calibration table and fill Tsys
values in DataTable.
Tsys in cal table are averaged by channels taking into account
of FLAG and linearly interpolated in time to derive values which
corresponds to TIME in DataTable.
Arguments
context: pipeline context
infile: the name of input MS
tsystable: the name of Tsys calibration table
spwmap: the list of SPW mapping
to_fieldid: FIELD_ID of data table to which Tsys is transferred
gainfield: how to find FIELD form which Tsys is extracted in cal table.
"""
start_time = time.time()
msobj = context.observing_run.get_ms(infile)
to_antids = [a.id for a in msobj.antennas]
from_fields = []
if gainfield.upper() == 'NEAREST':
LOG.info('to_fieldid={}'.format(to_fieldid))
to_field = msobj.get_fields(field_id=to_fieldid)[0]
if 'ATMOSPHERE' in to_field.intents:
# if target field has ATMOSPHERE intent, use it
from_fields = [to_fieldid]
else:
atm_fields = msobj.get_fields(intent='ATMOSPHERE')
# absolute OFF
test_prefix = '{}_OFF_'.format(to_field.clean_name)
#LOG.info('test_prefix {}, atm_fields {}'.format(test_prefix, [a.clean_name for a in atm_fields]))
nearest_id = numpy.where([a.clean_name.startswith(test_prefix) for a in atm_fields])[0]
#LOG.info('nearest_id = {}'.format(nearest_id))
if len(nearest_id) > 0:
from_fields = [atm_fields[i].id for i in nearest_id]
else:
# more generic case that requires to search nearest field by separation
rmin = casa_tools.quanta.quantity(180.0, 'deg')
origin = to_field.mdirection
nearest_id = -1
for f in atm_fields:
r = casa_tools.measures.separation(origin, f.mdirection)
#LOG.info('before test: rmin {} r {} nearest_id {}'.format(rmin['value'], r['value'], nearest_id))
# quanta.le is equivalent to <=
if casa_tools.quanta.le(r, rmin):
rmin = r
nearest_id = f.id
#LOG.info('after test: rmin {} r {} nearest_id {}'.format(rmin['value'], r['value'], nearest_id))
if nearest_id != -1:
from_fields = [nearest_id]
else:
raise RuntimeError('No nearest field for Tsys update.')
else:
from_fields = [fld.id for fld in msobj.get_fields(gainfield)]
LOG.info('from_fields = {}'.format(from_fields))
with casa_tools.TableReader(tsystable) as tb:
tsel = tb.query('FIELD_ID IN {}'.format(list(from_fields)))
spws = tsel.getcol('SPECTRAL_WINDOW_ID')
times = tsel.getcol('TIME')
#fieldids = tsel.getcol('FIELD_ID')
antids = tsel.getcol('ANTENNA1')
tsys_masked = {}
for i in range(tsel.nrows()):
tsys = tsel.getcell('FPARAM', i)
flag = tsel.getcell('FLAG', i)
tsys_masked[i] = numpy.ma.masked_array(tsys, mask=(flag == True))
tsel.close()
#LOG.info('tsys={}'.format(tsys_masked))
def map_spwchans(atm_spw, science_spw):
"""
Map the channel ID ranges of ATMCal spw that covers frequency range of a science spw
Arguments: spw object of ATMCal and science spws
"""
atm_freqs = numpy.array(atm_spw.channels.chan_freqs)
min_chan = numpy.where(abs(atm_freqs - float(science_spw.min_frequency.value)) == min(
abs(atm_freqs - float(science_spw.min_frequency.value))))[0][0]
max_chan = numpy.where(abs(atm_freqs - float(science_spw.max_frequency.value)) == min(
abs(atm_freqs - float(science_spw.max_frequency.value))))[0][-1]
start_atmchan = min(min_chan, max_chan)
end_atmchan = max(min_chan, max_chan)
# LOG.trace('calculate_average_tsys: satrt_atmchan == %d' % start_atmchan)
# LOG.trace('calculate_average_tsys: end_atmchan == %d' % end_atmchan)
if end_atmchan == start_atmchan:
end_atmchan = start_atmchan + 1
return start_atmchan, end_atmchan
_dt_antenna = self.getcol('ANTENNA')
_dt_spw = self.getcol('IF')
dt_field = self.getcol('FIELD_ID')
field_sel = numpy.where(dt_field == to_fieldid)[0]
dt_antenna = _dt_antenna[field_sel]
dt_spw = _dt_spw[field_sel]
atm_spws = set(spws)
science_spws = [x.id for x in msobj.get_spectral_windows(science_windows_only=True)]
for spw_to, spw_from in enumerate(spwmap):
# only process atm spws
if spw_from not in atm_spws:
continue
# only process science spws
if spw_to not in science_spws:
continue
atm_spw = msobj.get_spectral_window(spw_from)
science_spw = msobj.get_spectral_window(spw_to)
science_dd = msobj.get_data_description(spw=science_spw)
corr_index = [science_dd.get_polarization_id(corr) for corr in science_dd.corr_axis]
start_atmchan, end_atmchan = map_spwchans(atm_spw, science_spw)
LOG.info('Transfer Tsys from spw {} (chans: {}~{}) to {}'.format(spw_from, start_atmchan, end_atmchan, spw_to))
for ant_to in to_antids:
# select caltable row id by SPW and ANT
cal_idxs = numpy.where(numpy.logical_and(spws == spw_from, antids == ant_to))[0]
if len(cal_idxs) == 0:
continue
# atsys.shape = (nrow, npol)
atsys = numpy.asarray([tsys_masked[i].take(corr_index, axis=0)[:, start_atmchan:end_atmchan+1].mean(axis=1).data
for i in cal_idxs])
dtrows = field_sel[numpy.where(numpy.logical_and(dt_antenna == ant_to, dt_spw == spw_to))[0]]
#LOG.info('ant {} spw {} dtrows {}'.format(ant_to, spw_to, len(dtrows)))
time_sel = times.take(cal_idxs) # in sec
for dt_id in dtrows:
#LOG.info('ant {} spw {} field {}'.format(self.getcell('ANTENNA', dt_id),
# self.getcell('IF', dt_id),
# self.getcell('FIELD_ID', dt_id)))
tref = self.getcell('TIME', dt_id) * 86400 # day->sec
# LOG.trace("cal_field_ids=%s" % cal_field_idxs)
# LOG.trace('atsys = %s' % str(atsys))
if atsys.shape[0] == 1: # only one Tsys measurement selected
self.putcell('TSYS', dt_id, atsys[0, :])
else:
itsys = _interpolate(atsys, time_sel, tref)
self.putcell('TSYS', dt_id, itsys)
end_time = time.time()
LOG.info('_update_tsys: elapsed {} sec'.format(end_time - start_time))
# @memory_profiler.profile
def _update_flag(self, infile):
"""
Read MS and update online flag status of DataTable.
Arguments:
context: pipeline context instance
infile: the name of MS to transfer flag from
NOTE this method should be called before applying the other flags.
"""
LOG.info('Updating online flag for %s' % (os.path.basename(infile)))
filename = self.getkeyword('FILENAME')
assert os.path.basename(infile) == os.path.basename(filename)
# back to previous impl. with reduced memory usage
# (performance degraded)
ms_rows = self.getcol('ROW')
tmp_array = numpy.empty((4, 1,), dtype=numpy.int32)
with casa_tools.TableReader(infile) as tb:
# for dt_row in index[0]:
for dt_row, ms_row in enumerate(ms_rows):
# ms_row = rows[dt_row]
flag = tb.getcell('FLAG', ms_row)
rowflag = tb.getcell('FLAG_ROW', ms_row)
# irow += 1
npol = flag.shape[0]
#online_flag = numpy.empty((npol, 1,), dtype=numpy.int32)
online_flag = tmp_array[:npol, :]
if rowflag == True:
online_flag[:] = 0
else:
for ipol in range(npol):
online_flag[ipol, 0] = 0 if flag[ipol].all() else 1
self.putcellslice('FLAG_PERMANENT', int(dt_row), online_flag,
blc=[0, OnlineFlagIndex], trc=[npol - 1, OnlineFlagIndex],
incr=[1, 1])
[docs]class RODataTableColumn(object):
def __init__(self, table, name, dtype):
self.tb = table
self.name = name
self.caster_get = dtype
def __repr__(self):
return '%s("%s","%s")' % (self.__class__.__name__, self.name, self.caster_get)
[docs] def getcell(self, idx):
return self.tb.getcell(self.name, idx)
[docs] def getcol(self, startrow=0, nrow=-1, rowincr=1):
return self.tb.getcol(self.name, startrow, nrow, rowincr)
[docs] def getcellslice(self, rownr, blc, trc, incr):
return self.tb.getcellslice(self.name, rownr, blc, trc, incr)
[docs] def getcolslice(self, blc, trc, incr, startrow=0, nrow=-1, rowincr=1):
return self.tb.getcolslice(self.name, blc, trc, incr, startrow, nrow, rowincr)
[docs] def putcell(self, idx, val):
self.__raise()
[docs] def putcol(self, val, startrow=0, nrow=-1, rowincr=1):
self.__raise()
[docs] def putcellslice(self, rownr, value, blc, trc, incr):
self.__raise()
[docs] def putcolslice(self, value, blc, trc, incr, startrow=0, nrow=-1, rowincr=1):
self.__raise()
def __raise(self):
raise NotImplementedError('column %s is read-only' % self.name)
[docs]class RWDataTableColumn(RODataTableColumn):
def __init__(self, table, name, dtype):
super(RWDataTableColumn, self).__init__(table, name, dtype)
if dtype == list:
self.caster_put = numpy.asarray
else:
self.caster_put = dtype
[docs] def putcell(self, idx, val):
self.tb.putcell(self.name, int(idx), self.caster_put(val))
[docs] def putcol(self, val, startrow=0, nrow=-1, rowincr=1):
self.tb.putcol(self.name, numpy.asarray(val), int(startrow), int(nrow), int(rowincr))
[docs] def putcellslice(self, rownr, value, blc, trc, incr):
return self.tb.putcellslice(self.name, rownr, value, blc, trc, incr)
[docs] def putcolslice(self, value, blc, trc, incr, startrow=0, nrow=-1, rowincr=1):
return self.tb.putcolslice(self.name, value, blc, trc, incr, startrow, nrow, rowincr)
[docs]class DataTableColumnNoChange(RWDataTableColumn):
def __init__(self, table):
super(RWDataTableColumn, self).__init__(table, "NOCHANGE", int)
[docs] def putcell(self, idx, val):
if isinstance(val, bool):
v = -1
else:
v = val
self.tb.putcell(self.name, int(idx), int(v))
[docs]class DataTableColumnMaskList(RWDataTableColumn):
NoMask = numpy.zeros((1, 2), dtype=numpy.int32) - 1 # [[-1,-1]]
def __init__(self, table):
super(RWDataTableColumn, self).__init__(table, "MASKLIST", list)
[docs] def getcell(self, idx):
v = self.tb.getcell(self.name, int(idx))
if sum(v[0]) < 0:
return numpy.zeros(0, dtype=numpy.int32)
else:
return v
[docs] def getcol(self, startrow=0, nrow=-1, rowincr=1):
"""
Note: returned array has shape (nrow,nmask), in
contrast to (nmask,nrow) for return value of
tb.getcol().
"""
if nrow < 0:
nrow = self.tb.nrows()
ret = collections.defaultdict(list)
idx = 0
for i in range(startrow, nrow, rowincr):
tMASKLIST = self.getcell(i)
if len(tMASKLIST) == 1 and tMASKLIST[0][0] == 0 and \
tMASKLIST[0][1] == 0:
ret[idx] = tMASKLIST
idx += 1
return ret
[docs] def putcell(self, idx, val):
if len(val) == 0:
v = self.NoMask
else:
v = val
self.tb.putcell(self.name, int(idx), numpy.asarray(v))
[docs] def putcol(self, val, startrow=0, nrow=-1, rowincr=1):
"""
Note: input array should have shape (nrow,nmask), in
contrast to (nmask,nrow) for tb.putcol()
"""
if nrow < 0:
nrow = min(startrow + len(val) * rowincr, self.tb.nrows())
idx = 0
for i in range(startrow, nrow, rowincr):
self.putcell(i, numpy.asarray(val[idx]))
idx += 1
def _interpolate(v, t, tref):
# bisect.bisect_left(a, x)
# bisect_left returns an insertion point of x in a.
# if x matches any value in a, bisect_left returns its index.
# (bisect_right and bisect returns index next to the matched value)
idx = bisect.bisect_left(t, tref)
#LOG.info('len(t) = {}, idx = {}'.format(len(t), idx))
if idx == 0:
return v[0]
elif idx == len(t):
return v[-1]
else:
t1 = t[idx] - tref
t0 = tref - t[idx - 1]
return (v[idx] * t0 + v[idx-1] * t1) / (t[idx] - t[idx-1])
[docs]def construct_timegroup(rows, group_id_list, group_association_list):
timetable_dict = {x: [[], []] for x in group_id_list}
for (idx, group_id) in enumerate(group_association_list):
if group_id not in group_id_list:
continue
timetable_dict[group_id][0].append(rows[idx])
timetable_dict[group_id][1].append(idx)
return timetable_dict