import os
import re
import string
import textwrap
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from matplotlib.colors import ColorConverter, Colormap, Normalize
from matplotlib.patches import Rectangle
from numpy import ma
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.renderer.logger as logger
import pipeline.infrastructure.utils as utils
LOG = infrastructure.get_logger(__name__)
_valid_chars = "_.%s%s" % (string.ascii_letters, string.digits)
flag_color = {'outlier': 'red',
'high outlier': 'orange',
'low outlier': 'yellow',
'too many flags': 'lightblue',
'too many entirely flagged': 'darkblue',
'nmedian': 'darkred',
'max abs': 'pink',
'min abs': 'darkcyan',
'bad quadrant': 'yellow',
'bad antenna': 'red'}
def _char_replacer(s):
"""
A small utility function that echoes the argument or returns '_' if the
argument is in a list of forbidden characters.
"""
if s not in _valid_chars:
return '_'
return s
[docs]def sanitize(text):
filename = ''.join(_char_replacer(c) for c in text)
return filename
[docs]class ImageDisplay(object):
@staticmethod
def _findchunks(times):
"""
Return a list of arrays, each containing the indices of a chunk
of data i.e. a sequence of equally spaced measurements separated
from other chunks by larger time gaps.
Keyword arguments:
times -- Numeric array of times at which the measurements
were taken.
"""
difference = times[1:] - times[:-1]
median_diff = np.median(difference)
chunks = []
chunk = [0]
for i in np.arange(len(difference)):
if difference[i] < 1.5 * median_diff:
chunk.append(i+1)
else:
chunks.append(np.array(chunk))
chunk = [i+1]
chunks.append(np.array(chunk))
return chunks
@staticmethod
def _get_plot_filename(result, prefix=''):
fileparts = {
'prefix': prefix,
'datatype': result.datatype,
'x': result.axes[0].name,
'y': result.axes[1].name,
'file': '' if result.filename is None else 'File_%s' % os.path.basename(result.filename),
'intent': '' if result.intent == '' else 'Intent_%s' % result.intent.replace(',', '_'),
'fieldname': '' if result.fieldname == '' else 'Field_%s' % result.fieldname.replace(',', '_'),
'fieldid': '' if result.field_id is None else 'ID_%s' % str(result.field_id).replace(',', '_'),
'pol': '' if result.pol is None else 'Pol_%s' % result.pol.replace(',', '_'),
}
if result.spw == '':
fileparts['spw'] = ''
else:
# format spws for filename sorting
spws = ['%0.2d' % int(spw) for spw in str(result.spw).split(',')]
fileparts['spw'] = 'SpW_%s' % '_'.join(spws)
if result.ant is None or result.ant == '':
fileparts['ant'] = ''
else:
fileparts['ant'] = 'Ant_%s' % utils.find_ranges(result.ant)
if result.time is None or result.time == '':
fileparts['time'] = ''
else:
# represent time sensibly relative to day start
t = result.time - 86400.0 * np.floor(result.time/86400.0)
h = int(np.floor(t/3600.0))
t -= h * 3600.0
m = int(np.floor(t/60.0))
t -= m * 60.0
s = int(np.floor(t))
fileparts['time'] = '%sh%sm%ss' % (h, m, s)
png = "{prefix}_{datatype}_{y}_vs_{x}_{file}_{intent}_{fieldname}_" \
"{fieldid}_{spw}_{pol}_{ant}_{time}.png".format(**fileparts)
png = sanitize(png)
# Maximum filename size for Lustre filesystems is 255 bytes.
# ImageDisplayMosaics can exceed this limit due to including the IDs
# of all antennas. Truncate filename while keeping it unique
# by replacing with hash.
if len(png) > 251: # 255 - '.png'
new_png = '{!s}.png'.format(hash(png))
LOG.info('Renaming plot to avoid exceeding filesystem limit on filename length.\n'
'Old: {!s}\nNew: {!s}'.format(png, new_png))
png = new_png
return png
[docs] def plot(self, context, results, reportdir, prefix='', change='Flagging', dpi=None):
if not results:
return []
# Create a plot for each flagging view in the result.
plots = []
for description in sorted(results.descriptions()):
# Derive output filename.
plotfile = self._get_plot_filename(results.first(description), prefix)
plotfile = os.path.join(reportdir, plotfile)
# Create a plot object for the current flagging view, and store in
# list of plots.
plot = logger.Plot(
plotfile,
x_axis=results.first(description).axes[0].name, y_axis=results.first(description).axes[1].name,
field=results.first(description).fieldname,
parameters={'vis': os.path.basename(results.vis),
'intent': results.first(description).intent,
'spw': results.first(description).spw,
'pol': results.first(description).pol,
'ant': results.first(description).ant,
'type': results.first(description).datatype,
'file': os.path.basename(results.first(description).filename)})
plots.append(plot)
# If the plot figure already exists on disk, then skip to next one.
if os.path.exists(plotfile):
LOG.trace('Not overwriting existing image at %s' % plotfile)
continue
# Otherwise create the plot figure.
self._create_plot_file(context, results, description, change, plotfile, dpi=dpi)
return plots
def _create_plot_file(self, context, results, description, change, plotfile, dpi=None):
# Retrieve metadata from context and result.
stagenumber = context.stage
ms = context.observing_run.get_ms(name=results.vis)
antennas = ms.antennas
flagcmds = results.flagcmds()
# Depending on whether flagging occurred, create a 2 or 3-panel figure,
# and plot the flagging view data panels.
if len(flagcmds) > 0:
nsubplots = 3
fig, axs = plt.subplots(1, nsubplots, constrained_layout=True, gridspec_kw={'width_ratios': [3, 3, 2]})
self._plot_panel(fig, axs[0], nsubplots, 1, results.first(description), 'Before %s' % change)
self._plot_panel(fig, axs[1], nsubplots, 2, results.last(description), 'After')
else:
nsubplots = 2
fig, axs = plt.subplots(1, nsubplots, constrained_layout=True, gridspec_kw={'width_ratios': [3, 1]})
self._plot_panel(fig, axs[0], nsubplots, 1, results.first(description), '')
# Reduce the padding of the constrained layout.
fig.set_constrained_layout_pads(w_pad=0.02, h_pad=0.02)
# # Plot the legend panel.
self._plot_legend_panel(axs[-1], antennas, flagcmds)
# Set figure title.
figtitle = 'Stage %s - %s' % (stagenumber, description)
fig.suptitle("\n".join(textwrap.wrap(figtitle, 100)), size='small')
# Save the figure to file.
plt.savefig(plotfile, dpi=dpi)
plt.close(fig)
def _plot_legend_panel(self, ax, antennas, flagcmds):
"""
Plot the antenna and flagging legend information into a panel.
Keyword arguments:
ax -- Matplotlib Axes object for current panel.
antennas -- List of antennas.
flagcmds -- List of flagging commands.
"""
# Do not show axes.
ax.axis('off')
# Plot the antenna legend.
xoff = 0.
yoff = 1.03
xoffstart = xoff
yoff = self.plottext(ax, xoffstart, yoff, 'Antenna key:', 40, mult=0.8)
yoffstart = yoff
for idx, antenna in enumerate(antennas):
yoff = self.plottext(ax, xoff, yoff, '%s:%s' % (antenna.id, antenna.name), 40, mult=0.7)
# Go to next column after every 22 antennas.
if (idx + 1) % 22 == 0:
yoff = yoffstart
xoff += 0.4
# Key for masked data.
yoff = 0.30
xlen = 0.20 # length of colour block
ylen = 0.02 # height of colour block
strlen = 20 # max length of string for flag reason
rectyoff = -0.003 # y-off for colour block, to align with text
# Always show "no data" and "cannot calculate" in the legend.
yoff = self.plottext(ax, xoffstart, yoff, 'Key for masked data:', 45, mult=0.8)
ax.add_patch(Rectangle((xoffstart, yoff+rectyoff), xlen, ylen, facecolor='indigo', edgecolor='indigo',
transform=ax.transAxes))
yoff = self.plottext(ax, xoffstart + 0.25, yoff, 'no data', strlen, mult=0.8)
ax.add_patch(Rectangle((xoffstart, yoff+rectyoff), xlen, ylen, facecolor='violet', edgecolor='violet',
transform=ax.transAxes))
yoff = self.plottext(ax, xoffstart + 0.25, yoff, 'cannot calculate', strlen, mult=0.8)
# Add key for data flagged during this stage.
if len(flagcmds) > 0:
rulesplotted = set()
for flagcmd in flagcmds:
if flagcmd.rulename == 'ignore':
continue
if (flagcmd.rulename, flagcmd.ruleaxis, flag_color[flagcmd.rulename]) not in rulesplotted:
color = flag_color[flagcmd.rulename]
ax.add_patch(Rectangle((xoffstart, yoff+rectyoff), xlen, ylen, facecolor=color, edgecolor=color,
transform=ax.transAxes))
if flagcmd.ruleaxis is not None:
yoff = self.plottext(ax, xoffstart + 0.25, yoff,
'%s axis - %s' % (flagcmd.ruleaxis, flagcmd.rulename),
strlen, mult=0.8)
else:
yoff = self.plottext(ax, xoffstart + 0.25, yoff, flagcmd.rulename, strlen, mult=0.8)
rulesplotted.update([(flagcmd.rulename, flagcmd.ruleaxis, color)])
def _plot_panel(self, fig, ax, nplots, plotnumber, image, subtitle):
"""
Plot the 2d data into one panel.
Keyword arguments:
fig -- Matplotlib figure object.
ax -- Matplotlib Axes object for current panel.
nplots -- The number of sub-plots on the page.
plotnumber -- The index of this sub-plot.
image -- The 2d data.
subtitle -- The title to be given to this subplot.
"""
cc = ColorConverter()
sentinels = {}
flag = image.flag
data = image.data
flag_reason_plane = image.flag_reason_plane
flag_reason_key = image.flag_reason_key
xtitle = image.axes[0].name
xdata = image.axes[0].data
xunits = image.axes[0].units
ytitle = image.axes[1].name
ydata = image.axes[1].data
# yunits = image.axes[1].units
dataunits = image.units
datatype = image.datatype
# set sentinels at points with no data/violet. These should be
# overwritten by other flag colours in a moment.
data[flag != 0] = 2.0
sentinels[2.0] = cc.to_rgb('violet')
# set points to their flag reason
data[flag_reason_plane > 0] = flag_reason_plane[flag_reason_plane > 0] + 10.0
# sentinels to mark flagging.
sentinel_set = set(np.ravel(flag_reason_plane))
sentinel_set.discard(0)
sentinelvalues = np.array(list(sentinel_set), np.float) + 10.0
for sentinelvalue in sentinelvalues:
sentinels[sentinelvalue] = cc.to_rgb(
flag_color[flag_reason_key[int(sentinelvalue)-10]])
# plot points with no data indigo.
nodata = image.nodata
data[nodata != 0] = 5.0
sentinels[5.0] = cc.to_rgb('indigo')
# set my own colormap and normalise to plot sentinels
cmap = _SentinelMap(plt.cm.gray, sentinels=sentinels)
norm = _SentinelNorm(sentinels=list(sentinels.keys()))
# calculate vmin, vmax without the sentinels. Leaving norm to do
# this is not sufficient; the standard Normalize gets called
# by something in matplotlib and initialises vmin and vmax incorrectly.
sentinel_mask = np.zeros(np.shape(data), np.bool)
for sentinel in sentinels:
sentinel_mask += (data == sentinel)
actual_data = data[np.logical_not(sentinel_mask)]
# watch out for nans which mess up vmin, vmax
actual_data = actual_data[np.logical_not(np.isnan(actual_data))]
if len(actual_data):
vmin = actual_data.min()
vmax = actual_data.max()
else:
vmin = vmax = 0.0
# make antenna x antenna plots square
aspect = 'auto'
cb_aspect = 50
shrink = 0.8
fraction = 0.15
pad = 0
if ('ANTENNA' in xtitle.upper()) and ('ANTENNA' in ytitle.upper()):
aspect = 'equal'
shrink = 0.4
fraction = 0.1
# look out for yaxis values that would trip up matplotlib
if isinstance(ydata[0], str):
if re.match('\d+&\d+', ydata[0]):
# baseline - replace & by . and convert to float
ydata_numeric = []
for b in ydata:
ydata_numeric.append(float(b.replace('&', '.')))
# highest baseline number is am.am where 'am' is the
# largest antenna id. If this 34, for example, then
# highest axis value will be 34.34 - must be changed
# to 34.99 otherwise scale will not look right
# (think, next baseline would be 35.00).
am = int(ydata_numeric[-1])
ydata_numeric[-1] = am + 0.99
ydata_numeric = np.array(ydata_numeric)
major_formatter = ticker.FormatStrFormatter('%05.2f')
ax.yaxis.set_major_formatter(major_formatter)
else:
# any other string just replace by index
ydata_numeric = np.arange(len(ydata))
else:
ydata_numeric = ydata
# only plot y tick labels on first panel to avoid collision
# between y tick labels for second panel with greyscale for
# first
if plotnumber > 1:
ax.yaxis.set_major_formatter(ticker.NullFormatter())
if 'ANTENNA' in xtitle.upper():
if ydata_numeric[0] == ydata_numeric[-1]:
# sometimes causes empty plots if min==max
extent = [0, len(xdata)-1, ydata_numeric[0], ydata_numeric[-1] + 1]
else:
extent = [0, len(xdata)-1, ydata_numeric[0], ydata_numeric[-1]]
else:
if ydata_numeric[0] == ydata_numeric[-1]:
# sometimes causes empty plots if min==max
extent = [xdata[0], xdata[-1], ydata_numeric[0], ydata_numeric[-1]+1]
else:
extent = [xdata[0], xdata[-1], ydata_numeric[0], ydata_numeric[-1]]
# If plotting by antenna, then extend limits of the axis to ensure that
# the tick marks align correctly with the center of the antenna pixels.
if 'ANTENNA' in xtitle.upper():
extent[0] -= 0.5
extent[1] += 0.5
if 'ANTENNA' in ytitle.upper():
extent[2] -= 0.5
extent[3] += 0.5
# Plot the image array; transpose data to get [x,y] into [row,column]
# expected by matplotlib
img = ax.imshow(np.transpose(data), cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, interpolation='nearest',
origin='lower', aspect=aspect, extent=extent)
# Set y-axis title, only add this to the first panel.
if plotnumber == 1:
ax.set_ylabel(ytitle, size='medium')
# Set x-axis title, add units to title if available.
xlabel = xtitle
if xunits:
xlabel = '%s [%s]' % (xlabel, xunits)
ax.set_xlabel(xlabel, size='medium')
# Create the color-bar.
# plot wedge, make tick numbers smaller, label with units
if vmin == vmax:
cb = fig.colorbar(img, ax=ax, shrink=shrink, fraction=fraction, pad=pad, aspect=cb_aspect, ticks=[-1, 0, 1])
else:
cb = fig.colorbar(img, ax=ax, shrink=shrink, fraction=fraction, pad=pad, aspect=cb_aspect)
cb.formatter.set_scientific(True)
cb.formatter.set_powerlimits((-2, 2))
cb.ax.yaxis.set_offset_position('left')
cb.update_ticks()
# Set size of y-tick labels on the color-bar.
for label in cb.ax.get_yticklabels():
label.set_fontsize('small')
# Set a label for color-bar for the right-most panel, adding units if available.
data_label = datatype if dataunits is None else '%s (%s)' % (datatype, dataunits)
if nplots == 2 or plotnumber == 2:
cb.set_label(data_label, fontsize='medium')
# Rotate x tick labels to avoid them clashing
ax.tick_params(axis='x', rotation=35)
# If plotting with antenna on the x-axis, then modify the tick mark
# layout.
if 'ANTENNA' in xtitle.upper():
# Offset the plot title to allow space for labels above upper
# x-axis.
ax.set_title(subtitle, fontsize='medium', y=1.06)
# Set x-ticks explicitly for each antenna ID.
xticks = np.arange(0, len(xdata), 1)
# Set size of x-labels based on number of antennas, with minimum
# label size of 5.
xlabel_size = max(np.ceil(10 - len(xdata) // 9), 5)
# Add labels for even-indices in antenna array.
ax.set_xticks(xticks[::2])
ax.set_xticklabels([str(x) for x in xdata[::2]], rotation=90)
ax.xaxis.set_minor_locator(ticker.FixedLocator(xdata[1::2]))
# Display ticks outside the plot for both axes and both sides;
# further rotation tick labels.
ax.tick_params(axis='both', which='both', direction='out')
# Set x-label size.
for label in ax.get_xticklabels():
label.set_fontsize(xlabel_size)
# Add labels for odd-indices in antenna array.
if len(xdata) > 1:
axt = ax.twiny()
axt.set_xlim(ax.get_xlim()) # copy over limits.
axt.set_xticks(xticks[1::2])
axt.set_xticklabels([str(x) for x in xdata[1::2]], rotation=90)
axt.xaxis.set_minor_locator(ticker.FixedLocator(xdata[::2]))
# Display ticks outside the plot for both axes and both sides;
# further rotation tick labels.
axt.tick_params(axis='both', which='both', direction='out')
# Set x-label size.
for label in axt.get_xticklabels():
label.set_fontsize(xlabel_size)
else:
# Set plot title.
ax.set_title(subtitle, fontsize='medium')
# For x-axis, enable minor ticks
ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())
# If plotting by time on y-axis, then disable automatic tickmarks,
# and add labels manually for chunks of continguous time.
if ytitle.upper() == 'TIME':
ax.yaxis.set_major_locator(plt.NullLocator())
if plotnumber == 1:
# identify chunks of contiguous time
chunks = self._findchunks(ydata)
base_time = 86400.0 * np.floor(ydata[0]/86400.0)
tim_plot = ydata - base_time
for chunk in chunks:
t = tim_plot[chunk[0]]
h = int(np.floor(t/3600.0))
t -= h * 3600.0
m = int(np.floor(t/60.0))
t -= m * 60.0
s = int(np.floor(t))
tstring = '%sh%sm%ss' % (h, m, s)
ax.text(ax.axis()[0]-0.25, ydata[chunk[0]], tstring, fontsize=8, ha='right', va='bottom',
clip_on=False)
# If plotting by baseline on y-axis, add minor tick marks (should mark
# start of each new antenna)
if 'BASELINE' in ytitle.upper():
ax.yaxis.set_minor_locator(ticker.AutoMinorLocator())
[docs] @staticmethod
def plottext(ax, xoff, yoff, text, maxchars, ny_subplot=1, mult=1):
"""
Utility method to plot text and put line breaks in to keep the
text within a given limit.
Keyword arguments:
ax -- Matplotlib Axes object for current panel.
xoff -- world x coord where text is to start.
yoff -- world y coord where text is to start.
text -- Text to print.
maxchars -- Maximum number of characters before a newline is
inserted.
ny_subplot -- Number of sub-plots along the y-axis of the page.
mult -- Factor by which the text fontsize is to be multiplied.
"""
words = text.rsplit()
words_in_line = 0
line = ''
for i in range(len(words)):
temp = line + words[i] + ' '
words_in_line += 1
if len(temp) > maxchars:
if words_in_line == 1:
while len(temp) > 0:
ax.text(xoff, yoff, temp[:maxchars], va='center',
fontsize=mult*8,
transform=ax.transAxes, clip_on=False)
temp = temp[min(len(temp), maxchars):]
yoff -= 0.03 * ny_subplot * mult
words_in_line = 0
else:
ax.text(xoff, yoff, line, va='center', fontsize=mult*8,
transform=ax.transAxes, clip_on=False)
yoff -= 0.03 * ny_subplot * mult
line = words[i] + ' '
words_in_line = 1
else:
line = temp
if len(line) > 0:
ax.text(xoff, yoff, line, va='center', fontsize=mult*8,
transform=ax.transAxes, clip_on=False)
yoff -= 0.02 * ny_subplot * mult
yoff -= 0.02 * ny_subplot * mult
return yoff
class _SentinelMap(Colormap):
"""Utility class for plotting sentinel pixels in colours."""
def __init__(self, cmap, sentinels={}):
"""
Constructor.
Keyword arguments:
"""
self.name = 'SentinelMap'
cmap._init()
self.cmap = cmap
self._lut = cmap._lut
self.N = cmap.N
self.sentinels = sentinels
self._isinit = True
def __call__(self, scaledData, alpha=1.0, bytes=False):
"""Utility method."""
rgba = self.cmap(scaledData, alpha, bytes)
if bytes:
mult = 255
else:
mult = 1
for sentinel, rgb in self.sentinels.items():
r, g, b = rgb
if np.ndim(rgba) == 3:
rgba[:, :, 0][scaledData == sentinel] = r * mult
rgba[:, :, 1][scaledData == sentinel] = g * mult
rgba[:, :, 2][scaledData == sentinel] = b * mult
if alpha is not None:
rgba[:, :, 3] = alpha * mult
elif np.ndim(rgba) == 2:
rgba[:, 0][scaledData == sentinel] = r * mult
rgba[:, 1][scaledData == sentinel] = g * mult
rgba[:, 2][scaledData == sentinel] = b * mult
if alpha is not None:
rgba[:, 3] = alpha * mult
return rgba
class _SentinelNorm(Normalize):
"""Normalise but leave sentinel values unchanged."""
def __init__(self, vmin=None, vmax=None, clip=True, sentinels=[]):
self.vmin = vmin
self.vmax = vmax
self.clip = clip
self.sentinels = sentinels
def __call__(self, value, clip=None):
# remove sentinels, keeping a mask of where they were.
sentinel_mask = np.zeros(np.shape(value), np.bool)
for sentinel in self.sentinels:
sentinel_mask += (value == sentinel)
sentinel_values = value[sentinel_mask]
actual_data = value[np.logical_not(sentinel_mask)]
if len(actual_data):
value[sentinel_mask] = actual_data.min()
value = ma.asarray(value)
value = Normalize.__call__(self, value, clip)
# restore sentinels
value[sentinel_mask] = sentinel_values
return value