import os
import matplotlib.pyplot as plt
import numpy as np
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.api as api
from pipeline.infrastructure import casa_tools
LOG = infrastructure.get_logger(__name__)
def _calculate(worker, consider_flag=False):
worker.SubtractMedian(threshold=3.0, consider_flag=consider_flag)
worker.CalcStdSpectrum(consider_flag=consider_flag)
#worker.PlotSpectrum()
worker.CalcRange(threshold=3.0, detection=5.0, extension=2.0, iteration=10, consider_flag=consider_flag)
#worker.SavePlot()
mask_list = worker.masklist
return mask_list
[docs]class MaskDeviationHeuristic(api.Heuristic):
[docs] def calculate(self, vis, field_id='', antenna_id='', spw_id='', consider_flag=False):
"""
Channel mask heuristics using MaskDeviation algorithm implemented
in MaskDeviation class.
vis -- input MS filename
field_id -- target field identifier
antenna_id -- target antenna identifier
spw -- target spw identifier
consider_flag -- take into account flag in MS or not
"""
worker = MaskDeviation(vis, spw_id)
worker.ReadData(field=field_id, antenna=antenna_id)
mask_list = _calculate(worker, consider_flag=consider_flag)
del worker
return mask_list
[docs]def VarPlot(infile):
# infile is asap format
s = MaskDeviation(infile)
s.ReadData()
s.SubtractMedian(threshold=3.0)
s.CalcStdSpectrum()
s.PlotSpectrum()
s.CalcRange(threshold=3.0, detection=5.0, extension=2.0, iteration=10)
s.SavePlot()
[docs]class MaskDeviation(object):
"""
The class is used to detect channels having large variation or deviation. If there's any
emission lines or atmospheric absorption/emission on some channels, their values largely
change according to the positional and environmental changes. Emission lines and atmospheric
features often degrade the quality of the baseline subtraction. Therefore, channels with
large valiation should be masked before baseline fitting order determination and baseline
subtraction.
"""
def __init__(self, infile, spw=None):
self.infile = infile.rstrip('/')
self.spw = spw
LOG.debug('MaskDeviation.__init__: infile %s spw %s'%(os.path.basename(self.infile), self.spw))
self.masklist = []
[docs] def ReadData(self, vis='', field='', antenna='', colname=None):
"""
Reads data from input MS.
"""
if vis != '':
self.infile=vis
if vis == '':
vis = self.infile
spwsel = '' if self.spw is None else str(self.spw)
mssel = {'field': str(field),
'spw': str(spwsel),
'scanintent': 'OBSERVE_TARGET#ON_SOURCE*'}
LOG.debug('vis="%s"'%(vis))
LOG.debug('mssel=%s'%(mssel))
if colname is None:
with casa_tools.TableReader(vis) as mytb:
colnames = mytb.colnames()
if 'CORRECTED_DATA' in colnames:
colname = 'corrected_data'
elif 'FLOAT_DATA' in colnames:
colname = 'float_data'
elif 'DATA' in colnames:
colname = 'data'
else:
raise RuntimeError('{} doesn\'t have any data column (CORRECTED, FLOAT, DATA)'.format(os.path.basename(vis)))
with casa_tools.MSReader(vis) as myms:
mssel['baseline'] = '%s&&&'%(antenna)
myms.msselect(mssel)
r = myms.getdata([colname, 'flag'])
npol, nchan, nrow = r['flag'].shape
self.nrow = npol * nrow
self.nchan = nchan
self.data= np.real(r[colname.lower()]).transpose((2, 0, 1)).reshape((nrow * npol, nchan))
self.flag = r['flag'].transpose((2, 0, 1)).reshape((nrow * npol, nchan))
LOG.debug('MaskDeviation.ReadDataFromMS: %s %s'%(self.nrow, self.nchan))
return r
[docs] def CalcStdSpectrum(self, consider_flag=False):
"""
meanSP, maxSP, minSP, ymax, ymin: used only for plotting and should be
commented out when implemented in the pipeline
"""
if hasattr(self, 'flag') and consider_flag == True:
with_flag = True
else:
with_flag = False
if with_flag:
work_data = np.ma.masked_array(self.data, self.flag)
else:
work_data = self.data
self.stdSP = work_data.std(axis=0)
self.meanSP = work_data.mean(axis=0)
self.maxSP = work_data.max(axis=0)
self.minSP = work_data.min(axis=0)
self.ymax = self.maxSP.max()
self.ymin = self.minSP.min()
LOG.trace('std %s\nmean %s\n max %s\n min %s\n ymax %s ymin %s' %
(self.stdSP, self.meanSP, self.maxSP, self.minSP, self.ymax, self.ymin))
[docs] def CalcRange(self, threshold=3.0, detection=5.0, extension=2.0, iteration=10, consider_flag=False):
"""
Find regions which value is greater than threshold.
'threshold' is used for median calculation
'detection' is used to detect mask region
'extension' is used to extend the mask region
Used data:
self.stdSp: 1D spectrum with self.nchan channels calculated in CalcStdSpectrum
Each channel records standard deviation of the channel in all original spectra
"""
if hasattr(self.stdSP, 'mask') and consider_flag == True:
with_flag = True
stdSP = self.stdSP.data
else:
with_flag = False
stdSP = self.stdSP
# mask: True => valid, False => invalid
mask = (stdSP>-99999)
if with_flag:
mask = np.logical_and(mask, self.stdSP.mask == False)
Nmask0 = 0
for i in range(iteration):
median = np.median(stdSP[np.where(mask == True)])
std = stdSP[np.where(mask == True)].std()
mask = stdSP<(median+threshold*std)
#mask = (self.stdSP<(median+threshold*std)) * (self.stdSP>(median-threshold*std))
if with_flag:
mask = np.logical_and(mask, self.stdSP.mask == False)
Nmask = mask.sum()
LOG.trace('MaskDeviation.CalcRange: %s %s %s %s'%(median, std, Nmask, self.nchan))
if Nmask == Nmask0: break
else: Nmask0 = Nmask
# TODO
mask = stdSP<(median+detection*std)
LOG.trace('MaskDeviation.CalcRange: before ExtendMask %s'%(mask))
mask = self.ExtendMask(mask, median+extension*std)
LOG.trace('MaskDeviation.CalcRange: after ExtendMask %s'%(mask))
self.mask = np.arange(self.nchan)[np.where(mask == False)]
LOG.trace('MaskDeviation.CalcRange: self.mask=%s'%(self.mask))
RL = (mask*1)[1:]-(mask*1)[:-1]
LOG.trace('MaskDeviation.CalcRange: RL=%s'%(RL))
L = np.arange(self.nchan)[np.where(RL == -1)] + 1
R = np.arange(self.nchan)[np.where(RL == 1)]
if len(self.mask) > 0 and self.mask[0] == 0: L = np.insert(L, 0, 0)
if len(self.mask) > 0 and self.mask[-1] == self.nchan-1: R = np.insert(R, len(R), self.nchan - 1)
self.masklist = []
for i in range(len(L)):
self.masklist.append([L[i], R[i]])
self.PlotRange(L, R)
if len(self.mask) > 0:
LOG.trace('MaskDeviation.CalcRange: %s %s %s %s %s'%(self.masklist, L, R, self.mask[0], self.mask[-1]))
else:
LOG.trace('MaskDeviation.CalcRange: %s %s %s'%(self.masklist, L, R))
del mask, RL
[docs] def ExtendMask(self, mask, threshold):
"""
Extend the mask region as long as Standard Deviation value is higher than the given threshold
"""
LOG.trace('MaskDeviation.ExtendMask: threshold = %s'%(threshold))
for i in range(len(mask)-1):
if mask[i] == False and self.stdSP[i+1]>threshold: mask[i+1] = False
for i in range(len(mask)-1, 1, -1):
if mask[i] == False and self.stdSP[i-1]>threshold: mask[i-1] = False
return mask
[docs] def PlotSpectrum(self):
"""
plot max, min, mean, and standard deviation of the spectra
"""
color = ['r', 'm', 'b', 'g', 'k']
label = ['max', 'mean', 'min', 'STD', 'MASK']
plt.clf()
plt.plot(self.maxSP, color=color[0])
plt.plot(self.meanSP, color=color[1])
plt.plot(self.minSP, color=color[2])
plt.plot(self.stdSP, color=color[3])
plt.xlim(-10, self.nchan + 9)
posx = (self.nchan + 20) * 0.8 - 10
deltax = (self.nchan + 20) * 0.05
posy = (self.ymax - self.ymin) * 0.95 + self.ymin
deltay = (self.ymax - self.ymin) * 0.06
for i in range(len(label)):
plt.text(posx, posy - i * deltay, label[i], color=color[i])
plt.title(self.infile)
[docs] def PlotRange(self, L, R):
"""
Plot masked range
"""
if len(L)>0:
plt.vlines(L, self.ymin, self.ymax)
plt.vlines(R, self.ymin, self.ymax)
Y = [(self.ymax-self.ymin)*0.8+self.ymin for x in range(len(L))]
plt.hlines(Y, L, R)
[docs] def SavePlot(self):
"""
Save the plot in PNG format
"""
plt.savefig(self.infile + '.png', format='png')