Source code for pipeline.hifv.tasks.syspower.syspower

import ast
import os
import re
import shutil
from copy import deepcopy
from math import factorial

import numpy as np

import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.vdp as vdp
from pipeline.infrastructure import casa_tools
from pipeline.infrastructure import task_registry

LOG = infrastructure.get_logger(__name__)


# old
#  scipy.signal.savgol_filter(x, window_length, polyorder, deriv=0, delta=1.0, axis=-1, mode='interp', cval=0.0)

# http://scipy-cookbook.readthedocs.io/items/SavitzkyGolay.html
[docs]def savitzky_golay(y, window_size, order, deriv=0, rate=1): r"""Smooth (and optionally differentiate) data with a Savitzky-Golay filter. The Savitzky-Golay filter removes high frequency noise from data. It has the advantage of preserving the original shape and features of the signal better than other types of filtering approaches, such as moving averages techniques. Parameters ---------- y : array_like, shape (N,) the values of the time history of the signal. window_size : int the length of the window. Must be an odd integer number. order : int the order of the polynomial used in the filtering. Must be less then `window_size` - 1. deriv: int the order of the derivative to compute (default = 0 means only smoothing) Returns ------- ys : ndarray, shape (N) the smoothed signal (or it's n-th derivative). Notes ----- The Savitzky-Golay is a type of low-pass filter, particularly suited for smoothing noisy data. The main idea behind this approach is to make for each point a least-square fit with a polynomial of high order over a odd-sized window centered at the point. Examples -------- t = np.linspace(-4, 4, 500) y = np.exp( -t**2 ) + np.random.normal(0, 0.05, t.shape) ysg = savitzky_golay(y, window_size=31, order=4) import matplotlib.pyplot as plt plt.plot(t, y, label='Noisy signal') plt.plot(t, np.exp(-t**2), 'k', lw=1.5, label='Original signal') plt.plot(t, ysg, 'r', label='Filtered signal') plt.legend() plt.show() References ---------- .. [1] A. Savitzky, M. J. E. Golay, Smoothing and Differentiation of Data by Simplified Least Squares Procedures. Analytical Chemistry, 1964, 36 (8), pp 1627-1639. .. [2] Numerical Recipes 3rd Edition: The Art of Scientific Computing W.H. Press, S.A. Teukolsky, W.T. Vetterling, B.P. Flannery Cambridge University Press ISBN-13: 9780521880688 """ try: window_size = np.abs(np.int(window_size)) order = np.abs(np.int(order)) except ValueError: raise ValueError("window_size and order have to be of type int") if window_size % 2 != 1 or window_size < 1: raise TypeError("window_size size must be a positive odd number") if window_size < order + 2: raise TypeError("window_size is too small for the polynomials order") half_window = (window_size - 1) // 2 # precompute coefficients b = np.mat([[k**i for i in range(order + 1)] for k in range(-half_window, half_window+1)]) m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv) # pad the signal at the extremes with # values taken from the signal itself firstvals = y[0] - np.abs(y[1:half_window+1][::-1] - y[0]) lastvals = y[-1] + np.abs(y[-half_window-1:-1][::-1] - y[-1]) y = np.concatenate((firstvals, y, lastvals)) return np.convolve( m[::-1], y, mode='valid')
[docs]class SyspowerResults(basetask.Results): def __init__(self, gaintable=None, spowerdict=None, dat_common=None, clip_sp_template=None, template_table=None): if gaintable is None: gaintable = '' if spowerdict is None: spowerdict = {} if dat_common is None: dat_common = np.array([]) if clip_sp_template is None: clip_sp_template = [] if template_table is None: template_table = '' super(SyspowerResults, self).__init__() self.pipeline_casa_task = 'Syspower' self.gaintable = gaintable self.spowerdict = spowerdict self.dat_common = dat_common self.clip_sp_template = clip_sp_template self.template_table = template_table
[docs] def merge_with_context(self, context): """ See :method:`~pipeline.infrastructure.api.Results.merge_with_context` """ return
def __repr__(self): # return 'SyspowerResults:\n\t{0}'.format( # '\n\t'.join([ms.name for ms in self.mses])) return 'SyspowerResults:'
[docs]class SyspowerInputs(vdp.StandardInputs): antexclude = vdp.VisDependentProperty(default='') usemedian = vdp.VisDependentProperty(default=False) @vdp.VisDependentProperty def clip_sp_template(self): return [0.7, 1.2] def __init__(self, context, vis=None, clip_sp_template=None, antexclude=None, usemedian=None, templatevalue=None): self.context = context self.vis = vis self.clip_sp_template = clip_sp_template self.antexclude = antexclude self.usemedian = usemedian self.templatevalue = templatevalue
[docs]@task_registry.set_equivalent_casa_task('hifv_syspower') @task_registry.set_casa_commands_comment('Sys power fix compression') class Syspower(basetask.StandardTaskTemplate): Inputs = SyspowerInputs
[docs] def prepare(self): m = self.inputs.context.observing_run.get_ms(self.inputs.vis) # flag normalized p_diff outside this range clip_sp_template = self.inputs.clip_sp_template if isinstance(self.inputs.clip_sp_template, str): clip_sp_template = ast.literal_eval(self.inputs.clip_sp_template) try: rq_table = self.inputs.context.results[4].read()[0].rq_result[0].final[0].gaintable except Exception as ex: rq_table = self.inputs.context.results[4].read()[0].rq_result.final[0].gaintable LOG.debug(ex) template_table = 'pdiff.tbl' fields = m.get_fields(intent='AMPLITUDE') field = fields[0] flux_field = field.id flux_times = field.time antenna_ids = np.array([a.id for a in m.antennas]) antenna_names = [a.name for a in m.antennas] spws = [spw.id for spw in m.get_spectral_windows(science_windows_only=True)] LOG.info("Using flux field: {0} (ID: {1})".format(field.name, flux_field)) # get switched power from MS with casa_tools.TableReader(self.inputs.vis + '/SYSPOWER') as tb: stb = tb.query('SPECTRAL_WINDOW_ID > '+str(min(spws)-1)) # VLASS specific? sp_time = stb.getcol('TIME') sp_ant = stb.getcol('ANTENNA_ID') sp_spw = stb.getcol('SPECTRAL_WINDOW_ID') p_diff = stb.getcol('SWITCHED_DIFF') rq = stb.getcol('REQUANTIZER_GAIN') stb.done() # setup arrays sorted_time = np.unique(sp_time) dat_raw = np.zeros((len(antenna_ids), len(spws), 2, len(sorted_time))) dat_rq = np.zeros((len(antenna_ids), len(spws), 2, len(sorted_time))) dat_flux = np.zeros((len(antenna_ids), len(spws), 2)) dat_scaled = np.zeros((len(antenna_ids), len(spws), 2, len(sorted_time))) dat_filtered = np.zeros((len(antenna_ids), len(spws), 2, len(sorted_time))) dat_common = np.ma.zeros((len(antenna_ids), 2, 2, len(sorted_time))) dat_online_flags = np.zeros((len(antenna_ids), len(sorted_time)), dtype='bool') dat_sum = np.zeros((len(antenna_ids), len(spws), 2, len(sorted_time))) dat_sum_flux = np.zeros((len(antenna_ids), len(spws), 2, len(sorted_time))) # Obtain online flagging commands from flagdata result flagresult = self.inputs.context.results[2] result = flagresult.read() result = result[0] onlineflagslist = result._flagcmds # get online flags from .flagonline.txt flag_file_name = self.inputs.vis.replace('.ms', '.flagonline.txt') if os.path.isfile(flag_file_name): with open(flag_file_name, 'r') as flag_file: for line in flag_file: try: r = re.search("antenna='ea(\d*)&&\*' timerange='(.*)' reason", line) except Exception as e: r = False if r: this_ant = 'ea' + r.groups()[0] start_time = r.groups()[1].split('~')[0] end_time = r.groups()[1].split('~')[1] start_time_sec = casa_tools.quanta.convert(casa_tools.quanta.quantity(start_time), 's')['value'] end_time_sec = casa_tools.quanta.convert(casa_tools.quanta.quantity(end_time), 's')['value'] indices_to_flag = np.where((sorted_time >= start_time_sec) & (sorted_time <= end_time_sec))[0] dat_online_flags[antenna_names.index(this_ant), indices_to_flag] = True # remove requantizer changes from p_diff pdrq = p_diff / (rq ** 2) # read tables into arrays for i, this_ant in enumerate(antenna_ids): LOG.info('reading antenna {0}'.format(this_ant)) for j, this_spw in enumerate(spws): hits = np.where((sp_ant == this_ant) & (sp_spw == this_spw))[0] times = sp_time[hits] hits2 = np.where(np.in1d(sorted_time, times))[0] flux_hits = np.where((times >= np.min(flux_times)) & (times <= np.max(flux_times)))[0] for pol in [0, 1]: LOG.debug(str(i) + ' ' + str(j) + ' ' + str(pol) + ' ' + str(hits2)) dat_raw[i, j, pol, hits2] = p_diff[pol, hits] dat_flux[i, j, pol] = np.median(pdrq[pol, hits][flux_hits]) dat_rq[i, j, pol, hits2] = rq[pol, hits] dat_scaled[i, j, pol, hits2] = pdrq[pol, hits] / dat_flux[i, j, pol] dat_filtered[i, j, pol, hits2] = deepcopy(dat_scaled[i, j, pol, hits2]) # common baseband template for i, this_ant in enumerate(antenna_ids): LOG.info('Creating template for antenna {0}'.format(antenna_names[this_ant])) for bband in [0, 1]: common_indices = list(range(0, 8)) if bband == 0 else list(range(8, 16)) for pol in [0, 1]: LOG.info(' processing baseband {0}, polarization {1}'.format(bband, pol)) # create initial template sp_data = dat_filtered[i, common_indices, pol, :] sp_data = np.ma.array(sp_data) sp_data.mask = np.ma.getmaskarray(sp_data) sp_data.mask = dat_online_flags[i] sp_data, flag_percent = self.flag_with_medfilt(sp_data, sp_data, flag_median=True, k=9, threshold=8, do_shift=True) LOG.info(' total flagged data: {0:.2f}% in first pass'.format(flag_percent)) sp_data, flag_percent = self.flag_with_medfilt(sp_data, sp_data, flag_rms=True, k=5, threshold=8, do_shift=True) LOG.info(' total flagged data: {0:.2f}% in second pass'.format(flag_percent)) sp_template = np.ma.median(sp_data, axis=0) # flag residuals and recalculate template sp_data, flag_percent = self.flag_with_medfilt(sp_data, sp_template, flag_median=True, k=11, threshold=7, do_shift=False) LOG.info(' total flagged data: {0:.2f}% in third pass'.format(flag_percent)) sp_data, flag_percent = self.flag_with_medfilt(sp_data, sp_template, flag_rms=True, k=5, threshold=7, do_shift=False) LOG.info(' total flagged data: {0:.2f}% in fourth pass'.format(flag_percent)) sp_median_data = np.ma.median(sp_data, axis=0) sp_median_mask = deepcopy(sp_median_data.mask) # scipy.signal.savgol_filter(x, window_length, polyorder, deriv=0, delta=1.0, axis=-1, mode='interp', cval=0.0) # OLD # savitzky_golay(y, window_size, order, deriv=0, rate=1): NEW # sp_template = savgol_filter(self.interp_with_medfilt(sp_median_data), 7, 3) sp_template = savitzky_golay(self.interp_with_medfilt(sp_median_data), 7, 3) sp_template = np.ma.array(sp_template) sp_template.mask = np.ma.getmaskarray(sp_template) sp_template.mask = sp_median_mask LOG.info(' restored {0:.2f}% template flags after interpolation'.format( 100.0 * np.sum(sp_median_mask) / sp_median_mask.size)) # repeat after square root if isinstance(sp_data.mask, bool): sp_data.mask = np.ma.getmaskarray(sp_data) sp_data.mask[sp_data < 0] = True sp_data = sp_data ** .5 sp_template = sp_template ** .5 sp_data.mask[sp_data != sp_data] = True sp_data, flag_percent = self.flag_with_medfilt(sp_data, sp_template, flag_rms=True, flag_median=True, k=5, threshold=6, do_shift=False) LOG.info(' total flagged data: {0:.2f}% in fifth pass'.format(flag_percent)) sp_median_data = np.ma.median(sp_data, axis=0) # sp_median_mask = deepcopy(sp_median_data.mask) sp_template = savitzky_golay(self.interp_with_medfilt(sp_median_data), 7, 3) dat_common[i, bband, pol, :] = sp_template spowerdict = {} spowerdict['spower_raw'] = dat_raw spowerdict['spower_flux_levels'] = dat_flux spowerdict['spower_rq'] = dat_rq spowerdict['spower_scaled'] = dat_scaled spowerdict['spower_filtered'] = dat_filtered spowerdict['spower_common'] = np.ma.filled(dat_common, 0) spowerdict['spower_online_flags'] = dat_online_flags spowerdict['spower_sum'] = dat_sum spowerdict['spower_sum_flux'] = dat_sum_flux # flag template using clip values final_template = np.ma.array(dat_common) final_template.mask = np.ma.getmaskarray(final_template) final_template.mask[final_template < clip_sp_template[0]] = True final_template.mask[final_template > clip_sp_template[1]] = True antids = list(antenna_ids) if self.inputs.usemedian and self.inputs.antexclude != '': for i, this_ant in enumerate(antenna_ids): antindex = antids.index(i) antname = antenna_names[antindex] if antname in self.inputs.antexclude: LOG.info("Antenna " + antname + " to be excluded.") final_template.mask[i, :, :, :] = np.ma.masked # Change mask values to True for that antenna median_final_template = np.ma.median(final_template, axis=0) for i, this_ant in enumerate(antenna_ids): antindex = antids.index(i) antname = antenna_names[antindex] if antname in self.inputs.antexclude: if self.inputs.usemedian: LOG.info("Using median value in template for antenna " + antname + ".") final_template.data[i, :, :, :] = median_final_template.data final_template.mask[i, :, :, :] = median_final_template.mask else: LOG.info("Using value of 1.0 in template for antenna " + antname + ".") final_template.data[i, :, :, :] = 1.0 final_template.mask[i, :, :, :] = np.ma.nomask with casa_tools.TableReader(rq_table, nomodify=False) as tb: rq_time = tb.getcol('TIME') rq_spw = tb.getcol('SPECTRAL_WINDOW_ID') rq_par = tb.getcol('FPARAM') rq_ant = tb.getcol('ANTENNA1') rq_flag = tb.getcol('FLAG') LOG.info('Starting RQ table') spw_offset = 2 # Hardwired for VLASS for i, this_ant in enumerate(antenna_ids): LOG.info(' writing RQ table for antenna {0}'.format(this_ant)) for j, this_spw in enumerate(range(len(spws))): hits = np.where((rq_ant == i) & (rq_spw == j + spw_offset))[0] # hits = np.where((rq_ant == i) & (rq_spw == j))[0] bband = 0 if (j < 8) else 1 hits2 = np.where(np.in1d(sorted_time, rq_time[hits]))[0] for pol in [0, 1]: try: rq_par[2 * pol, 0, hits] *= final_template[i, bband, pol, hits2].data rq_flag[2 * pol, 0, hits] = np.logical_or(rq_flag[2 * pol, 0, hits], final_template[i, bband, pol, hits2].mask) if j in [0, 8]: message = ' {2}% of solutions flagged in baseband {0}, polarization {1}' LOG.info(message.format(bband, pol, 100. * np.sum(rq_flag[2 * pol, 0, hits]) / rq_flag[2 * pol, 0, hits].size)) except: LOG.warn('Error preparing final RQ table') raise # SystemExit('shape mismatch writing final RQ table') try: tb.putcol('FPARAM', rq_par) tb.putcol('FLAG', rq_flag) except Exception as ex: LOG.warn('Error writing final RQ table - switched power will not be applied' + str(ex)) # create new table to plot pdiff template_table if os.path.isdir(template_table): shutil.rmtree(template_table) shutil.copytree(rq_table, template_table) with casa_tools.TableReader(template_table, nomodify=False) as tb: for i, this_ant in enumerate(antenna_ids): for j, this_spw in enumerate(range(len(spws))): hits = np.where((rq_ant == i) & (rq_spw == j + spw_offset))[0] bband = 0 if (j < 8) else 1 hits2 = np.where(np.in1d(sorted_time, rq_time[hits]))[0] for pol in [0, 1]: try: rq_par[2 * pol, 0, hits] = final_template[i, bband, pol, hits2].data rq_flag[2 * pol, 0, hits] = final_template[i, bband, pol, hits2].mask except Exception as ex: LOG.error('Shape mismatch writing final template table') tb.putcol('FPARAM', rq_par) tb.putcol('FLAG', rq_flag) return SyspowerResults(gaintable=rq_table, spowerdict=spowerdict, dat_common=dat_common, clip_sp_template=clip_sp_template, template_table=template_table)
[docs] def analyse(self, results): return results
# function for smoothing and statistical flagging # adapted from https://gist.github.com/bhawkins/3535131
[docs] def medfilt(self, x, k, threshold=6.0, flag_rms=False, flag_median=False, flag_only=False, fill_gaps=False): k2 = (k - 1) // 2 y = np.ma.zeros((len(x), k)) y.mask = np.ma.resize(x.mask, (len(x), k)) y[:, k2] = x for i in range(k2): j = k2 - i y[j:, i] = x[:-j] y[:j, i] = x[0] y.mask[:j, i] = True y[:-j, -(i + 1)] = x[j:] y[-j:, -(i + 1)] = x[-1] y.mask[-j:, -(i + 1)] = True medians = np.ma.median(y, axis=1) if isinstance(medians.mask, np.bool_): medians.mask = np.ma.getmaskarray(medians) if np.ma.all(medians.mask): return medians if fill_gaps: x[x.mask] = medians[x.mask] return x if flag_median: rms = np.ma.std(y, axis=1) dev = np.ma.median(rms[rms != 0]) medians.mask[abs(x - medians) > (dev * threshold)] = True medians.mask[rms == 0] = True medians.mask[rms != rms] = True if flag_rms: rms = np.ma.std(y, axis=1) dev = np.ma.median(rms[rms != 0]) medians.mask[rms > (dev * threshold)] = True medians.mask[rms == 0] = True medians.mask[rms != rms] = True if not flag_only: return medians else: x.mask = np.logical_or(x.mask, medians.mask) return x
# combine SPWs and flag based on moving window statistics
[docs] def flag_with_medfilt(self, x, temp, k=21, threshold=6, do_shift=False, **kwargs): if do_shift: resid = x.ravel() - np.roll(x.ravel(), -1) else: resid = (x - temp[np.newaxis, :]).ravel() new_flags = self.medfilt(resid, k, threshold=threshold, flag_only=True, **kwargs) x.mask = np.reshape(new_flags.mask, newshape=x.shape) flag_percent = 100.0 * np.sum(x.mask) / x.size x.mask[x == 0] = True return x, flag_percent
# use median filter to interpolate flagged values
[docs] def interp_with_medfilt(self, x, k=21, threshold=99, max_interp=10): x.mask = np.ma.getmaskarray(x) this_interp = 0 while np.any(x.mask == True): flag_percent = 100.0 * np.sum(x.mask) / x.size message = ' will attempt to interpolate {0:.2f}% of data in iteration {1}'.format(flag_percent, this_interp + 1) if this_interp == 0: LOG.info(message) else: LOG.debug(message) x = self.medfilt(x, k, threshold, fill_gaps=True) this_interp += 1 if this_interp > max_interp: break flag_percent2 = 100.0 * np.sum(x.mask) / x.size LOG.info(' finished interpolation with {0:.2f}% of data flagged'.format(flag_percent2)) x.mask[x == 0] = True return x