Source code for pipeline.hifa.tasks.wvrgcalflag.display

import collections
import functools
import operator
import os

import cachetools
import matplotlib
import matplotlib.pyplot as plt
import numpy

import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.renderer.logger as logger
import pipeline.infrastructure.utils as utils
from pipeline.h.tasks.common.displays import common
from pipeline.hifa.tasks.common.displays import phaseoffset

LOG = infrastructure.get_logger(__name__)


[docs]class WVRScoreFinder(object): def __init__(self, delegate): self._delegate = delegate self._cache = cachetools.LRUCache(maxsize=1000)
[docs] @cachetools.cachedmethod(operator.attrgetter('_cache'), key=functools.partial(cachetools.keys.hashkey, 'get_score')) def get_score(self, spw, antenna): spw_id = spw.id antenna_id = antenna.id spw_viewlist = [viewlist for viewlist in self._delegate.view.values() if viewlist[0].spw == spw_id] if len(spw_viewlist) <= 0: return 0.0 LOG.todo('Is the QA score the first or last viewlist?') spw_imageresult = spw_viewlist[0][-1] antenna_axis = [axis for axis in spw_imageresult.axes if axis.name == 'Antenna'][0] antenna_idx = list(antenna_axis.data).index(antenna_id) scores_for_antenna = spw_imageresult.data[antenna_idx] return scores_for_antenna[0]
[docs]class WVRPhaseVsBaselineChart(object):
[docs] class WvrChartHelper(object): def __init__(self, antennas): self._antennas = antennas
[docs] def get_antennas(self): return self._antennas[:]
[docs] def label_antenna(self, axes): plt.title('All Antennas', size=10)
@property def antenna_filename_component(self): return ''
def _load_caltables(self): if self._caltables_loaded: return # Get phases before and after data_before = common.CaltableWrapper.from_caltable(self._table_before) data_after = common.CaltableWrapper.from_caltable(self._table_after) # some sanity checks, as unequal caltables have bit me before # TODO with- and without wvr plots have different times for X16b, # causing these assertions to fail. We need to understand why. # assert utils.areEqual(data_before.time, data_after.time), 'Time columns are not equal' # assert utils.areEqual(data_before.antenna, data_after.antenna), 'Antenna columns are not equal' # assert utils.areEqual(data_before.spw, data_after.spw), 'Spw columns are not equal' # assert utils.areEqual(data_before.scan, data_after.scan), 'Scan columns are not equal' self._data_before = data_before self._data_after = data_after self._caltables_loaded = True def _get_plot_intents(self): return set(self.result.dataresult.inputs['qa_intent'].split(',')) def _get_plot_scans(self): plot_intents = self._get_plot_intents() return [scan for scan in self.ms.scans if not plot_intents.isdisjoint(scan.intents)]
[docs] def get_symbol_and_colour(self, pol, state='BEFORE'): """ Get the plot symbol and colour for this polarization and bandtype. """ d = {'BEFORE': {'L': ('-', 'orange', 0.3), 'R': ('--', 'sandybrown', 0.3), 'X': ('^', 'lightslategray', 0.3), 'Y': ('o', 'lightslategray', 0.3), 'XX': ('^', 'lightslategray', 0.3), 'YY': ('o', 'lightslategray', 0.3)}, 'AFTER': {'L': ('-', 'green', 0.6), 'R': ('-', 'red', 0.6), 'X': ('^', 'green', 0.6), 'Y': ('o', 'red', 0.6), 'XX': ('^', 'green', 0.6), 'YY': ('o', 'red', 0.6)}} return d.get(state, {}).get(pol, ('x', 'grey'))
def __init__(self, context, result): self.context = context self.result = result self.ms = context.observing_run.get_ms(result.inputs['vis']) self._caltables_loaded = False nowvr_gaintables = {c.gaintable for c in result.dataresult.nowvr_result.pool} assert len(nowvr_gaintables) is 1, ('Unexpected number of pre-WVR phase-up' 'gaintables: %s' % nowvr_gaintables) nowvr_gaintable = nowvr_gaintables.pop() wvr_gaintable = result.dataresult.qa_wvr.gaintable_wvr LOG.debug('Gaintables for WVR plots:\n' 'No WVR: %s\tWith WVR: %s' % (nowvr_gaintable, wvr_gaintable)) self._table_before = nowvr_gaintable self._table_after = wvr_gaintable self._score_retriever = WVRScoreFinder(result.viewresult) self._wrappers = [] refant_name = result.dataresult.nowvr_result.inputs['refant'].split(',')[0] self._refant = self.ms.get_antenna(refant_name)[0]
[docs] def get_data_object(self, data, corr_id): delegate = common.PhaseVsBaselineData(data, self.ms, corr_id, self._refant.id) return common.XYData(delegate, 'distance_to_refant', 'median_offset')
[docs] def plot(self): self._load_caltables() data_before = self._data_before data_after = self._data_after # get the windows this was tested on from the caltable. spw_ids = set(data_before.spw).intersection(set(data_after.spw)) spws = {spw for spw in self.ms.spectral_windows if spw.id in spw_ids} plot_scans = self._get_plot_scans() # phase offsets are plotted per corr, spw and scan. We cannot index # the phase arrays with multiple corr/spw/scans as the unwrapped # phases would be for the whole data, not for the corr/spw/scan # combination we want to plot LOG.debug('Finding maximum phase offset over all scans/spws/corrs/antennas') for scan in plot_scans: # scan may not have all spws, so just process those present for spw in scan.spws.intersection(spws): # find the data description for this scan. Just one dd # expected. dds = [dd for dd in scan.data_descriptions if dd.spw.id == spw.id] if len(dds) is not 1: LOG.info('Bypassing plot generation for %s scan %s spw ' '%s. Expected 1 matching data description but ' 'got %s.', self.ms.basename, scan.id, spw.id, len(dds)) continue dd = dds[0] # we expect the number and identity of the caltable # correlations for this scan to match those in the MS, so we # can enumerate over the correlations in the MS scan. for corr_id, _ in enumerate(dd.polarizations): for antenna in self.ms.antennas: # we don't want the phase RMS for the reference antenna as it # doesn't make any sense, plus it often 'spikes' the scale with # an extreme value if antenna.id is self._refant.id: continue try: caltable = data_before.filename selection_before = data_before.filter(scan=[scan.id], antenna=[antenna.id], spw=[spw.id]) baseline_data_before = self.get_data_object(selection_before, corr_id) caltable = data_after.filename selection_after = data_after.filter(scan=[scan.id], antenna=[antenna.id], spw=[spw.id]) baseline_data_after = self.get_data_object(selection_after, corr_id) except (ValueError, KeyError): # We can't construct data objects for completely # flagged selections LOG.debug('Could not evaluate data for %s ' 'antenna %s spw %s scan %s. Data ' 'completely flagged?' % (caltable, antenna.name, spw.id, scan.id)) continue wrapper = common.DataRatio(baseline_data_before, baseline_data_after) self._wrappers.append(wrapper) offsets = [w.before.y for w in self._wrappers] offsets.extend([w.after.y for w in self._wrappers]) # offsets could contain None where data was flagged, but that's ok as # max ignores it. self._max_phase_offset = numpy.ma.max(offsets) LOG.trace('Maximum phase offset for %s = %s' % (self.ms.basename, self._max_phase_offset)) # Extract ratios, excluding any that are set to None. ratios = [w.y for w in self._wrappers if w.y is not None] # Convert to masked array that masks out invalid numbers (NaN, Inf). ratios = numpy.ma.masked_invalid(ratios) # Determine whether an alternate refant might have been used by # assessing if the minimum ratio was 0. self._alt_refant_used = (ratios.min() == 0.0) # If a minimum ratio of 0 was found, raise warning about alternate # refant, and update ratios to also mask out values of 0. if self._alt_refant_used: LOG.warning('Phase ratio of 0 suggests that alternate refant was ' 'used during gaincal. This plot might be misleading!') ratios = numpy.ma.masked_equal(ratios, 0) self._max_ratio = ratios.max() self._min_ratio = ratios.min() self._median_ratio = numpy.ma.median(ratios) LOG.trace('Maximum phase ratio for %s = %s' % (self.ms.basename, self._max_ratio)) LOG.trace('Minimum phase ratio for %s = %s' % (self.ms.basename, self._min_ratio)) distances = [w.x for w in self._wrappers] self._max_distance = numpy.ma.max(distances) LOG.trace('Maximum distance for %s = %s' % (self.ms.basename, self._max_distance)) plots = [] for spw in spws: # plot scans individually as plotting multiple scans on one plot # creates an unintelligible mess. for scan in plot_scans: # if spw.id == 17 and scan.id == 3: plots.append(self.get_plot_wrapper(spw, [scan, ], self.ms.antennas)) return [p for p in plots if p is not None]
[docs] def create_plot(self, spw, scans, helper): data_before = self._data_before data_after = self._data_after # check the spw given by the spw argument is present in both caltables assert spw.id in data_before.spw, 'Spw %s not in %s' % (spw, self._table_before) assert spw.id in data_after.spw, 'Spw %s not in %s' % (spw, self._table_after) # get the scan intents from the list of scans scan_intents = set() for scan in scans: scan_intents.update(scan.intents) scan_intents.remove('WVR') scan_intents = ','.join(scan_intents) # get the fields from the list of scans scan_fields = set() for scan in scans: scan_fields.update([field.name for field in scan.fields]) scan_fields = ','.join(scan_fields) # create the figure: 2 rows x 1 column, sharing the X axis (baseline # length) fig, ((ax1, ax2)) = plt.subplots(2, 1, sharex=True) plt.subplots_adjust(hspace=0.0, bottom=0.16) ax1.set_yscale('log') ax1.tick_params(labelsize=8, left=True, right=False, top=False, bottom=False) ax2.tick_params(labelsize=8, left=True, right=False, top=False, bottom=True, labelright=False) trans1 = matplotlib.transforms.blended_transform_factory(ax1.transAxes, ax1.transData) ax1.axhspan(self._min_ratio, 1, facecolor='k', linewidth=0.0, alpha=0.04) ax1.text(0.012, numpy.sqrt(self._min_ratio), 'No Improvement', transform=trans1, color='k', ha='left', va='center', size=8, alpha=0.4) ax1.axhline(y=self._median_ratio, color='k', ls='dotted', alpha=0.4) ax1.text(0.012, self._median_ratio, 'Median', transform=trans1, color='k', ha='left', va='baseline', size=8, alpha=0.4) # create bottom plot: phase offset vs baseline legend = [] plots = [] for scan in scans: # get the polarisations for the scan corr_axes = [tuple(dd.polarizations) for dd in scan.data_descriptions if dd.spw.id == spw.id] # the scan may not necessarily contain the spw, as in multi-receiver EBs if not corr_axes: continue # discard WVR and other strange data descriptions corr_axes = {x for x in corr_axes if x not in [(), ('I',)]} assert len(corr_axes) is 1, ('Data descriptions have different ' 'corr axes for scan %s. Got %s' '' % (scan.id, corr_axes)) # go from set(('XX', 'YY')) to the ('XX', 'YY') corr_axes = corr_axes.pop() for corr_idx, corr_axis in enumerate(corr_axes): wrappers = [w for w in self._wrappers if scan.id in w.scans and spw.id in w.spws and corr_idx in w.corr] unflagged_wrappers = [w for w in wrappers if w.y is not None] # upper plot: ratio improvement x = [float(w.x) for w in unflagged_wrappers] y = [w.y for w in unflagged_wrappers] (symbol, color, alpha) = self.get_symbol_and_colour(corr_axis, state='AFTER') _, = ax1.plot(x, y, symbol, color=color, alpha=alpha) # lower plot: abs(median offset from median phase) x = [float(w.x) for w in unflagged_wrappers] y = [w.before.y for w in unflagged_wrappers] (symbol, color, alpha) = self.get_symbol_and_colour(corr_axis, state='BEFORE') p, = ax2.plot(x, y, symbol, color=color, alpha=alpha) plots.append(p) legend.append('%s %s' % (corr_axis, 'before')) y = [w.after.y for w in unflagged_wrappers] (symbol, color, alpha) = self.get_symbol_and_colour(corr_axis, state='AFTER') p, = ax2.plot(x, y, symbol, color=color, alpha=alpha) plots.append(p) legend.append('%s %s' % (corr_axis, 'after')) ax1.set_xlim(0, self._max_distance) ax1.set_ylim(self._min_ratio, self._max_ratio) ax2.set_ylim(0, self._max_phase_offset) ax1.set_ylabel('ratio', size=10) # CAS-7955: hif_timegaincal weblog: add refant next to phase(uvdist) plot x_axis_title = 'Distance to Reference Antenna {!s} (m)'.format(self._refant.name) ax2.set_xlabel(x_axis_title, size=10) ax2.set_ylabel('degrees', size=10) rax = plt.axes([0, 0, 1, 0.1], frame_on=False) rax.legend(plots, legend, prop={'size': 10}, numpoints=1, loc='lower center', frameon=False, ncol=len(legend)) spw_msg = 'SPW %s Correlation%s' % (spw.id, utils.commafy(corr_axes, quotes=False, multi_prefix='s')) plt.text(0.0, 1.026, spw_msg, color='k', transform=ax1.transAxes, size=9) plt.text(0.5, 1.15, '%s (%s)' % (scan_fields, scan_intents), color='k', transform=ax1.transAxes, ha='center', size=10) plt.text(0.5, 1.026, 'All Antennas', color='k', transform=ax1.transAxes, ha='center', size=9) scan_ids = [str(s.id) for s in scans] max_scans_for_msg = 8 # # print 'Scans 4, 8, 12 ... 146' if there are too many scans to # # print if len(scans) > max_scans_for_msg: start = ','.join(scan_ids[0:max_scans_for_msg - 1]) end = scan_ids[-1] scan_txt = 's %s ... %s' % (start, end) else: scan_txt = utils.commafy(scan_ids, multi_prefix='s', quotes=False, separator=',') plt.text(1.0, 1.026, 'Scan%s' % scan_txt, color='k', ha='right', transform=ax1.transAxes, size=9) plt.text(0.012, 0.97, 'Median Absolute Deviation from Median Phase', color='k', transform=ax2.transAxes, ha='left', va='top', size=9) plt.text(0.012, 0.97, 'Phase RMS without WVR / Phase RMS with WVR', color='k', transform=ax1.transAxes, ha='left', va='top', size=9) if self._alt_refant_used: plt.text(0.012, 0.89, 'Warning! Use of alternate refant detected; x-axis values may be unreliable', color='r', transform=ax1.transAxes, ha='left', va='top', size=9) # We need to draw the canvas, otherwise the labels won't be positioned and # won't have values yet. fig.canvas.draw() # omit the last y axis tick label from the lower plot ax2.set_yticklabels([t.get_text() for t in ax2.get_yticklabels()[0:-1]]) figfile = self.get_figfile(spw, scans, helper.antenna_filename_component) plt.savefig(figfile) plt.close()
[docs] def get_figfile(self, spw, scans, antennas): vis = os.path.basename(self.result.vis) scan_ids = '_'.join(['%0.2d' % scan.id for scan in scans]) return os.path.join(self.context.report_dir, 'stage%s' % self.result.stage_number, '%s.phase_vs_baseline.spw%0.2d.scan%s.png' % (vis, spw.id, scan_ids))
[docs] def get_plot_wrapper(self, spw, scans, antenna): figfile = self.get_figfile(spw, scans, antenna) scan_ids = ','.join([str(scan.id) for scan in scans]) wrapper = logger.Plot(figfile, x_axis='baseline length', y_axis='phase offset', parameters={'vis': os.path.basename(self.result.vis), 'spw': spw.id, 'scan': scan_ids}) if not os.path.exists(figfile): LOG.trace('WVR phase vs baseline plot for spw %s scan %s not found. Creating new ' 'plot: %s' % (spw.id, scan_ids, figfile)) helper = WVRPhaseVsBaselineChart.WvrChartHelper(antenna) try: self.create_plot(spw, scans, helper) except Exception as ex: LOG.error('Could not create WVR phase vs baseline plot for' ' spw %s scan %s' % (spw.id, scan_ids)) LOG.exception(ex) # close figure just in case state is transferred between calls plt.clf() return None # the plot may not be created if all data for that antenna are flagged if os.path.exists(figfile): return wrapper return None
[docs]class WVRPhaseOffsetPlotHelper(phaseoffset.PhaseOffsetPlotHelper): def __init__(self, context, result, plot_per_antenna=True): calapp = result.pool[0] rootdir = os.path.join(context.report_dir, 'stage%s' % result.stage_number) prefix = '%s.phase_offset' % os.path.basename(calapp.vis) nowvr_gaintables = {c.gaintable for c in result.nowvr_result.pool} assert len(nowvr_gaintables) is 1, ('Unexpected number of pre-WVR phase-up' 'gaintables: %s' % nowvr_gaintables) nowvr_gaintable = nowvr_gaintables.pop() wvr_gaintable = result.qa_wvr.gaintable_wvr LOG.debug('Gaintables for WVR plots:\n' 'No WVR: %s\tWith WVR: %s' % (nowvr_gaintable, wvr_gaintable)) caltable_map = collections.OrderedDict() caltable_map['BEFORE'] = nowvr_gaintable caltable_map['AFTER'] = wvr_gaintable super(WVRPhaseOffsetPlotHelper, self).__init__(rootdir, prefix, caltable_map, plot_per_antenna=plot_per_antenna)
[docs]class WVRPhaseOffsetPlot(phaseoffset.PhaseOffsetPlot): def __init__(self, context, result): vis = os.path.basename(result.dataresult.pool[0].vis) ms = context.observing_run.get_ms(vis) plothelper = WVRPhaseOffsetPlotHelper(context, result.dataresult) scan_intent = result.dataresult.inputs['qa_intent'] score_retriever = WVRScoreFinder(result.viewresult) super(WVRPhaseOffsetPlot, self).__init__( context, ms, plothelper, scan_intent=scan_intent, score_retriever=score_retriever)
[docs]class WVRPhaseOffsetSummaryPlotHelper(WVRPhaseOffsetPlotHelper): def __init__(self, context, result): super(WVRPhaseOffsetSummaryPlotHelper, self).__init__(context, result, plot_per_antenna=False)
[docs]class WVRPhaseOffsetSummaryPlot(phaseoffset.PhaseOffsetPlot): def __init__(self, context, result): vis = os.path.basename(result.dataresult.pool[0].vis) ms = context.observing_run.get_ms(vis) plothelper = WVRPhaseOffsetSummaryPlotHelper(context, result.dataresult) scan_intent = result.dataresult.inputs['qa_intent'] score_retriever = WVRScoreFinder(result.viewresult) super(WVRPhaseOffsetSummaryPlot, self).__init__( context, ms, plothelper, scan_intent=scan_intent, score_retriever=score_retriever)