import os
from functools import reduce
import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.renderer.logger as logger
import pipeline.infrastructure.utils as utils
from pipeline.h.tasks.common.displays import common as common
LOG = infrastructure.get_logger(__name__)
[docs]class PhaseOffsetPlotHelper(object):
colour_map = {'BEFORE': {'L': ('-', 'orange', 0.6),
'R': ('--', 'sandybrown', 0.6),
'X': ('-', 'lightslategray', 0.6),
'Y': ('--', 'lightslategray', 0.6),
'XX': ('-', 'lightslategray', 0.6),
'YY': ('--', 'lightslategray', 0.6)},
'AFTER': {'L': ('-', 'green', 0.6),
'R': ('-', 'red', 0.6),
'X': ('-', 'green', 0.6),
'Y': ('-', 'red', 0.6),
'XX': ('-', 'green', 0.6),
'YY': ('-', 'red', 0.6)}}
"""
caltable_map should be a dictionary mapping the state to caltable
"""
def __init__(self, rootdir, prefix, caltable_map=None, plot_per_antenna=True):
assert set(PhaseOffsetPlotHelper.colour_map.keys()).issuperset(set(caltable_map.keys())),\
'caltables argument defines states not in colour_map'
self._rootdir = rootdir
self._prefix = prefix
self.caltable_map = caltable_map
self.plot_per_antenna = plot_per_antenna
[docs] def get_symbol_and_colour(self, pol, state):
return self.colour_map[state][pol]
[docs] def get_figfile(self, spw, antennas):
if len(antennas) is 1:
antenna = '.ant%s' % antennas[0].name
else:
antenna = ''
return os.path.join(self._rootdir,
'%s%s.spw%0.2d.png' % (self._prefix, antenna, spw.id))
[docs] def group_antennas(self, antennas):
if self.plot_per_antenna:
return [[ant] for ant in antennas]
else:
return [antennas]
[docs] def label_antenna(self, fig, antennas):
if self.plot_per_antenna:
text = '%s' % antennas[0].name
else:
text = 'All Antennas'
plt.text(0.5, 0.89, '%s' % text, color='k',
transform=fig.transFigure, ha='center', size=9)
[docs]class PhaseOffsetPlot(object):
def __init__(self, context, ms, plothelper, scan_intent=None, scan_id=None, score_retriever=None):
self._context = context
self._ms = ms
self._plothelper = plothelper
self._scans = ms.get_scans(scan_id=scan_id, scan_intent=scan_intent)
self._score_retriever = score_retriever if score_retriever else common.NullScoreFinder()
self._caltables_loaded = False
self._load_caltables(plothelper.caltable_map)
def _load_caltables(self, caltable_map):
if self._caltables_loaded:
return
data = [(state, common.CaltableWrapper.from_caltable(c)) for state, c in caltable_map.items()]
# some sanity checks, as unequal caltables have bit me before
# this doesn't work when WVR data are missing and should be interpolated over
# if len(data) > 1:
# wrapper1 = data[0][1]
# for state, wrapper2 in data[1:]:
# assert utils.areEqual(wrapper1.time, wrapper2.time), 'Time columns are not equal'
# assert utils.areEqual(wrapper1.antenna, wrapper2.antenna), 'Antenna columns are not equal'
# assert utils.areEqual(wrapper1.spw, wrapper2.spw), 'Spw columns are not equal'
# assert utils.areEqual(wrapper1.scan, wrapper2.scan), 'Scan columns are not equal'
self.data = data
self._caltables_loaded = True
[docs] def plot(self, spw_ids=None, antenna_ids=None, antenna_names=None):
# these function arguments are used for debugging the plot routines,
# so we can call plot for a particular troublesome spw/antenna.
ms = self._ms
plothelper = self._plothelper
# get the spw IDs common to all caltables..
all_spw_ids = [set(data.spw) for _, data in self.data]
common_spw_ids = reduce(lambda x, y: x.intersection(y), all_spw_ids)
# .. filter to match those specified as function arguments..
if spw_ids is not None:
common_spw_ids = common_spw_ids.intersection(set(spw_ids))
# .. before converting to domain objects
spws = [spw for spw in ms.spectral_windows
if spw.id in common_spw_ids]
# Do the same for antenna IDs, finding those common to all caltables..
all_antenna_ids = [set(data.antenna) for _, data in self.data]
common_antenna_ids = reduce(lambda x, y: x.intersection(y), all_antenna_ids)
# .. filtering to match those specified as function arguments..
if antenna_ids is not None:
common_antenna_ids = common_antenna_ids.intersection(set(antenna_ids))
if antenna_names is not None:
named_antenna_ids = {ant.id for ant in ms.antennas if ant.name in antenna_names}
common_antenna_ids = common_antenna_ids.intersection(named_antenna_ids)
# .. before converting to domain objects
antennas = [ant for ant in ms.antennas
if ant.id in common_antenna_ids]
assert len(spws) > 0, 'No common spws to plot in %s' % utils.commafy([d.filename for d in self.data])
assert len(antennas) > 0, 'No common antennas to plot in %s' % utils.commafy([d.filename for d in self.data])
plots = []
for spw in spws:
for antenna_group in plothelper.group_antennas(antennas):
plots.append(self.get_plot_wrapper(spw, self._scans, antenna_group))
return [p for p in plots if p is not None]
[docs] def create_plot(self, spw, scans, antennas):
# get the fields and scan intents from the list of scans. These are
# used in the plot title, eg. NGC123 (PHASE)
scan_fields = set()
for scan in scans:
scan_fields.update([field.name for field in scan.fields])
scan_fields = ','.join(scan_fields)
scan_intents = set()
for scan in scans:
scan_intents.update(scan.intents)
scan_intents.discard('WVR')
scan_intents = ','.join(scan_intents)
num_scans = len(scans)
autoscale_yaxis_range = [-200, 200]
fig = plt.figure()
# size subplots proportional to the scan time
integration_times = [int(s.time_on_source.total_seconds()) for s in scans]
gs = gridspec.GridSpec(1, num_scans, width_ratios=integration_times)
ax0 = fig.add_subplot(gs[0])
axes = [fig.add_subplot(gs[i], sharey=ax0) for i in range(1, num_scans)]
for axis in axes:
for label in axis.get_yticklabels():
label.set_visible(False)
axes.insert(0, ax0)
# if num_scans is 1, axes will be a scalar instead of a list
if not isinstance(axes, (tuple, list, numpy.ndarray)):
axes = [axes]
for i, axis in enumerate(axes):
axis.spines['left'].set_linestyle('dotted')
axis.spines['right'].set_visible(False)
axis.tick_params(
left=True if i == 0 else False,
right=False,
top=False,
bottom=False,
labelbottom=False,
labelright=False,
)
axis.tick_params(axis='y', labelsize=8)
axes[0].spines['left'].set_visible(True)
axes[0].spines['left'].set_linestyle('solid')
axes[-1].spines['right'].set_visible(True)
axes[0].set_ylabel('Deviation from Scan Median Phase (degrees)' % scan.id, size=10)
plt.subplots_adjust(wspace=0.0)
plothelper = self._plothelper
flag_annotate = len(antennas) is 1
for scan_idx, scan in enumerate(scans):
for antenna in antennas:
axis = axes[scan_idx]
plots = []
legends = []
for state, state_data in self.data:
try:
data = state_data.filter(scan=[scan.id],
antenna=[antenna.id],
spw=[spw.id])
except KeyError:
# scan/antenna/id not in caltable, probably flagged.
# create fake masked slices and data arrays so we can
# plot flagged annotation
class dummy(object):
pass
dummy_slice = dummy()
dummy_slice.start = 0
dummy_slice.stop = 1
dummy_time = dummy()
start_dt = utils.get_epoch_as_datetime(scan.start_time)
end_dt = utils.get_epoch_as_datetime(scan.end_time)
dummy_time.time = [matplotlib.dates.date2num(start_dt),
matplotlib.dates.date2num(end_dt)]
dummy_data = numpy.ma.MaskedArray(data=[0, 1], mask=True)
axis.plot_date(dummy_time.time, dummy_data, '.')
_, = axis.plot_date(dummy_time.time, dummy_data)
self._plot_flagged_data(dummy_time, dummy_slice, axis,
True, annotation='NO DATA')
axis.set_xlim(dummy_time.time[0], dummy_time.time[-1])
axis.set_ylim(autoscale_yaxis_range)
continue
# get the polarisations for this scan
corr_axes = [tuple(dd.polarizations) for dd in scan.data_descriptions
if dd.spw.id == spw.id]
# discard WVR and other strange data descriptions
corr_axes = {x for x in corr_axes if x not in [(), ('I',)]}
assert len(corr_axes) is 1, ('Data descriptions have different '
'corr axes for scan %s. Got %s'
'' % (scan.id, corr_axes))
# go from set(('XX', 'YY')) to the ('XX', 'YY')
corr_axes = corr_axes.pop()
for corr_idx, corr_axis in enumerate(corr_axes):
if len(data.time) is 0:
LOG.info('No data to plot for antenna %s scan %s corr %s' %
(antenna.name, scan.id, corr_axis))
continue
phase_for_corr = data.data[:, corr_idx]
rad_phase = numpy.deg2rad(phase_for_corr)
unwrapped_phase = numpy.unwrap(rad_phase)
offset_rad = unwrapped_phase - numpy.ma.median(unwrapped_phase)
# the operation above removed the mask, so add it back.
offset_rad = numpy.ma.MaskedArray(offset_rad, mask=phase_for_corr.mask)
offset_deg = numpy.rad2deg(offset_rad)
for masked_slice in numpy.ma.clump_masked(offset_deg):
self._plot_flagged_data(data, masked_slice, axis, flag_annotate)
(symbol, color, alpha) = plothelper.get_symbol_and_colour(corr_axis,
state)
axis.plot_date(data.time, offset_deg, '.',
color=color, alpha=alpha)
p, = axis.plot_date(data.time, offset_deg, symbol,
color=color, alpha=alpha)
legend_entry = '%s %s' % (corr_axis, state.lower())
if legend_entry not in legends:
legends.append(legend_entry)
plots.append(p)
# shrink the x axis range by a couple of integrations
# so that the first/last scan symbols are not clipped
# delta = numpy.mean(numpy.diff(data.time)) / 2
axis.set_xlim(data.time[0], data.time[-1])
axis.set_ylim(autoscale_yaxis_range)
axis.set_xlabel('%s' % scan.id, size=8)
# shrink the y height slightly to make room for the legend
for axis in axes:
box = axis.get_position()
axis.set_position([box.x0, box.y0 + box.height * 0.04,
box.width, box.height * 0.96])
# # sort legend and associated plots by legend text
# legends, plots = zip(*sorted(zip(legends, plots)))
axes[-1].legend(plots, legends, prop={'size':10}, numpoints=1,
loc='upper center', bbox_to_anchor=(0.5, 0.07),
frameon=False, ncol=len(legends),
bbox_transform=plt.gcf().transFigure)
spw_msg = 'SPW %s Correlation%s' % (spw.id,
utils.commafy(corr_axes, quotes=False, multi_prefix='s'))
plt.text(0.0, 1.013, spw_msg, color='k',
transform=axes[0].transAxes, size=9)
plt.text(0.5, 0.945, '%s (%s)' % (scan_fields, scan_intents),
color='k', transform=fig.transFigure, ha='center', size=10)
plothelper.label_antenna(fig, antennas)
plt.text(0.5, 0.07, 'Scan', color='k', transform=fig.transFigure,
ha='center', size=10)
scan_ids = [str(s.id) for s in scans]
max_scans_for_msg = 8
# print 'Scans 4, 8, 12 ... 146' if there are too many scans to
# print
if num_scans > max_scans_for_msg:
start = ','.join(scan_ids[0:max_scans_for_msg-1])
end = scan_ids[-1]
scan_txt = 's %s ... %s' % (start, end)
else:
scan_txt = utils.commafy(scan_ids, multi_prefix='s',
quotes=False, separator=',')
plt.text(1.0, 1.013, 'Scan%s' % scan_txt, color='k', ha='right',
transform=axes[-1].transAxes, size=9)
figfile = plothelper.get_figfile(spw, antennas)
plt.savefig(figfile)
plt.close()
[docs] def get_plot_wrapper(self, spw, scans, antennas):
plothelper = self._plothelper
antenna_names = [ant.name for ant in antennas]
figfile = plothelper.get_figfile(spw, antennas)
wrapper = logger.Plot(figfile,
x_axis='scan',
y_axis='phase',
parameters={'vis': self._ms.basename,
'spw': spw.id,
'ant': antenna_names})
if plothelper.plot_per_antenna and len(antennas) is 1:
wrapper.qa_score = self._score_retriever.get_score(spw, antennas[0])
if not os.path.exists(figfile):
LOG.trace('Phase offset plot for antenna %s spw %s not found.'
' Creating new plot: %s' % (utils.commafy(antenna_names, quotes=False), spw.id, figfile))
try:
self.create_plot(spw, scans, antennas)
except Exception as ex:
LOG.error('Could not create phase offset plot for antenna'
' %s spw %s' % (utils.commafy(antenna_names, quotes=False), spw.id))
LOG.exception(ex)
return None
# the plot may not be created if all data for that antenna are flagged
if os.path.exists(figfile):
return wrapper
return None
def _plot_flagged_data(self, data, masked_slice, axis, annotate=True,
annotation='FLAGGED'):
"""
Plot flagged data.
data -- the CaltableWrapper for the data selection
masked_slice -- the Slice object defining the flagged extent
axis -- the Axis object to be used for plotting
"""
idx_start = max(masked_slice.start-1, 0)
idx_stop = min(masked_slice.stop, len(data.time)-1)
start = data.time[idx_start]
stop = data.time[idx_stop]
width = stop - start
# the x coords of this transformation are data, and the
# y coord are axes
trans = matplotlib.transforms.blended_transform_factory(axis.transData,
axis.transAxes)
# We want x to be in data coordinates and y to
# span from 0..1 in axes coords
rect = matplotlib.patches.Rectangle((start, 0), width=width, height=1,
transform=trans, color='#EEEEEE',
alpha=0.2)
axis.add_patch(rect)
if annotate:
axis.text(start + width/2, 0.5, annotation, color='k', transform=trans,
size=9, ha='center', va='center', rotation=90)