# ******************************************************************************
# ALMA - Atacama Large Millimeter Array
# Copyright (c) ATC - Astronomy Technology Center - Royal Observatory Edinburgh, 2011
# (in the framework of the ALMA collaboration).
# All rights reserved.
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
# *******************************************************************************
"""Module to plot sky images."""
# History:
# package modules
import copy
import os
import string
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.offsetbox import HPacker, TextArea, AnnotationBbox
# alma modules
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.renderer.logger as logger
from pipeline.hif.tasks.makeimages.resultobjects import MakeImagesResult
from pipeline.infrastructure import casa_tools
LOG = infrastructure.get_logger(__name__)
_valid_chars = "_.%s%s" % (string.ascii_letters, string.digits)
def _char_replacer(s):
"""A small utility function that echoes the argument or returns '_' if the
argument is in a list of forbidden characters.
"""
if s not in _valid_chars:
return '_'
return s
[docs]def sanitize(text):
filename = ''.join(_char_replacer(c) for c in text)
return filename
[docs]def plotfilename(image, reportdir):
name = '%s.sky.png' % (os.path.basename(image))
name = sanitize(name)
name = os.path.join(reportdir, name)
return name
[docs]class SkyDisplay(object):
"""Class to plot sky images."""
[docs] def plot(self, context, result, reportdir, intent=None, collapseFunction='mean', vmin=None, vmax=None, mom8_fc_peak_snr=None,
**imshow_args):
if not result:
return []
if vmin is not None and vmax is not None:
imshow_args['norm'] = plt.Normalize(vmin, vmax, clip=True)
if isinstance(context.results[-1], MakeImagesResult):
if (context.results[-1].results[0].imaging_mode in ('VLA', 'EVLA', 'JVLA') and
context.results[-1].results[0].specmode == 'cont'):
ms = context.observing_run.get_measurement_sets()[0] # only 1 ms for VLA
else:
ms = None
else:
ms = None
plotfile, coord_names, field, band = self._plot_panel(context, reportdir, result, collapseFunction=collapseFunction, ms=ms,
mom8_fc_peak_snr=mom8_fc_peak_snr, **imshow_args)
# field names may not be unique, which leads to incorrectly merged
# plots in the weblog output. As a temporary fix, change to field +
# intent - which is better but again, not guaranteed unique.
if intent:
field = '%s (%s)' % (field, intent)
with casa_tools.ImageReader(result) as image:
miscinfo = image.miscinfo()
parameters = {k: miscinfo[k] for k in ['spw', 'pol', 'field', 'type', 'iter'] if k in miscinfo}
parameters['ant'] = None
parameters['band'] = band
try:
parameters['prefix'] = miscinfo['filnam01']
except:
parameters['prefix'] = None
plot = logger.Plot(plotfile, x_axis=coord_names[0],
y_axis=coord_names[1], field=field,
parameters=parameters)
return plot
def _plot_panel(self, context, reportdir, result, collapseFunction='mean', ms=None, mom8_fc_peak_snr=None, **imshow_args):
"""Method to plot a map."""
plotfile = plotfilename(image=os.path.basename(result), reportdir=reportdir)
LOG.info('Plotting %s' % result)
with casa_tools.ImageReader(result) as image:
try:
if collapseFunction == 'center':
collapsed = image.collapse(function='mean', chans=str(image.summary()['shape'][3]//2), axes=[2, 3])
else:
# Note: in case 'max' and non-pbcor image a moment 0 map was written to disk
# in the past. With PIPE-558 this is done in hif/tasks/tclean.py tclean._calc_mom0_8()
collapsed = image.collapse(function=collapseFunction, axes=[2, 3])
except:
# All channels flagged or some other error. Make collapsed zero image.
collapsed_new = image.newimagefromimage(infile=result)
collapsed_new.set(pixelmask=True, pixels='0')
collapsed = collapsed_new.collapse(function='mean', axes=[2, 3])
collapsed_new.done()
name = image.name(strippath=True)
coordsys = collapsed.coordsys()
coord_names = coordsys.names()
coordsys.setunits(type='direction', value='arcsec arcsec')
coord_units = coordsys.units()
coord_refs = coordsys.referencevalue(format='s')
beam = collapsed.restoringbeam()
brightness_unit = collapsed.brightnessunit()
miscinfo = collapsed.miscinfo()
# don't replot if a file of the required name already exists
if os.path.exists(plotfile):
LOG.info('plotfile already exists: %s', plotfile)
return plotfile, coord_names, miscinfo.get('field'), None
# otherwise do the plot
data = collapsed.getchunk()
mask = np.invert(collapsed.getchunk(getmask=True))
shape = np.shape(data)
data = data.reshape(shape[0], shape[1])
mask = mask.reshape(shape[0], shape[1])
mdata = np.ma.array(data, mask=mask)
collapsed.done()
# get x and y axes from coordsys of image
xpix = np.arange(shape[0])
x = np.zeros(np.shape(xpix))
for pix in xpix:
world = coordsys.toworld([float(pix), 0, 0, 0])
relative = coordsys.torel(world)
x[pix] = relative['numeric'][0]
ypix = np.arange(shape[1])
y = np.zeros(np.shape(ypix))
for pix in ypix:
world = coordsys.toworld([0, float(pix), 0, 0])
relative = coordsys.torel(world)
y[pix] = relative['numeric'][1]
coordsys.done()
# remove any incomplete matplotlib plots, if left these can cause
# weird errors
plt.close('all')
f1 = plt.figure(1)
# plot data
if 'cmap' not in imshow_args:
imshow_args['cmap'] = copy.deepcopy(matplotlib.cm.jet)
imshow_args['cmap'].set_bad('k', 1.0)
plt.imshow(np.transpose(mdata), interpolation='nearest', origin='lower', aspect='equal',
extent=[x[0], x[-1], y[0], y[-1]], **imshow_args)
plt.axis('image')
lims = plt.axis()
# make ticks and labels white
ax = plt.gca()
for line in ax.xaxis.get_ticklines() + ax.yaxis.get_ticklines():
line.set_color('white')
for label in ax.xaxis.get_ticklabels() + ax.yaxis.get_ticklabels():
label.set_fontsize(0.5 * label.get_fontsize())
# colour bar
cb = plt.colorbar(shrink=0.5)
fontsize = 8
for label in cb.ax.get_yticklabels() + cb.ax.get_xticklabels():
label.set_fontsize(fontsize)
cb.set_label(brightness_unit, fontsize=fontsize)
# image reference pixel
yoff = 0.10
yoff = self.plottext(1.05, yoff, 'Reference position:', 40)
for i, k in enumerate(coord_refs['string']):
yoff = self.plottext(1.05, yoff, '%s: %s' % (coord_names[i], k), 40, mult=0.8)
# if peaksnr is available for the mom8_fc image, include it in the plot
if 'mom8_fc' in result and mom8_fc_peak_snr is not None:
yoff = 0.90
self.plottext(1.05, yoff, 'Peak SNR: {:.5f}'.format(mom8_fc_peak_snr), 40)
# plot beam
cqa = casa_tools.quanta
if 'major' in beam:
bpa = beam['positionangle']
bpa = cqa.convert(bpa, 'rad')
bpa = bpa['value']
bpa += np.pi/2.0
bmaj = beam['major']
bmaj = cqa.convert(bmaj, 'arcsec')
bmaj = bmaj['value']
bmin = beam['minor']
bmin = cqa.convert(bmin, 'arcsec')
bmin = bmin['value']
xbeam = []
ybeam = []
for i in range(37):
theta = i*10.0*np.pi / 180.0
xbeam.append(0.5 * (
bmaj*np.sin(theta)*np.cos(bpa) +
bmin*np.cos(theta)*np.sin(bpa)))
ybeam.append(0.5 * (
-bmaj*np.sin(theta)*np.sin(bpa) +
bmin*np.cos(theta)*np.cos(bpa)))
xbeam = np.array(xbeam) + lims[0] + 0.1 * (lims[1]-lims[0])
ybeam = np.array(ybeam) + lims[2] + 0.1 * (lims[3]-lims[2])
plt.plot(xbeam, ybeam, color='yellow')
# print title
plt.xlabel('%s (%s)' % (coord_names[0], coord_units[0]))
plt.ylabel('%s (%s)' % (coord_names[1], coord_units[1]))
mode_texts = {'mean': 'mean', 'max': 'peak line int. (mom8)', 'center': 'center slice'}
image_info = {'display': mode_texts[collapseFunction]}
image_info.update(miscinfo)
if 'type' in image_info:
if image_info['type'] == 'flux':
image_info['type'] = 'pb'
if image_info['type'] == 'mom0_fc':
image_info['type'] = 'Line-free Moment 0'
if image_info['type'] == 'mom8_fc':
image_info['type'] = 'Line-free Moment 8'
# VLA only, not VLASS
if ms:
band = ms.get_vla_spw2band()
band_spws = {}
for k, v in band.items():
band_spws.setdefault(v, []).append(k)
for k, v in band_spws.items():
for spw in image_info['spw'].split(','):
if int(spw) in v:
image_info['band'] = k
del image_info['spw']
break
if 'spw' not in image_info:
break
if 'band' in image_info:
label = [TextArea('%s:%s' % (key, image_info[key]), textprops=dict(color=color))
for key, color in [('type', 'k'),
('display', 'r'),
('field', 'k'),
('band', 'k'),
('pol', 'k'),
('iter', 'k')]
if image_info.get(key) is not None]
band = image_info.get('band')
else:
label = [TextArea('%s:%s' % (key, image_info[key]), textprops=dict(color=color))
for key, color in [('type', 'k'),
('display', 'r'),
('field', 'k'),
('spw', 'k'),
('pol', 'k'),
('iter', 'k')]
if image_info.get(key) is not None]
band = None
txt = HPacker(children=label, align="baseline", pad=0, sep=7)
bbox = AnnotationBbox(txt, xy=(0.1, 0.5),
xycoords='data',
frameon=True,
box_alignment=(0.5, 0.5), # alignment center, center
)
ax = plt.Axes(f1, [0.085, 0.9, 0.7, 0.1])
ax.set_frame_on(False)
ax.set_axis_off()
ax.add_artist(bbox)
f1.add_axes(ax)
# make axis fit snugly around image
plt.axis([lims[0], lims[1], lims[2], lims[3]])
# save the image
plt.savefig(plotfile)
plt.clf()
plt.close(1)
return plotfile, coord_names, miscinfo.get('field'), band
[docs] @staticmethod
def plottext(xoff, yoff, text, maxchars, ny_subplot=1, mult=1):
"""Utility method to plot text and put line breaks in to keep the
text within a given limit.
Keyword arguments:
xoff -- world x coord where text is to start.
yoff -- world y coord where text is to start.
text -- Text to print.
maxchars -- Maximum number of characters before a newline is
inserted.
ny_subplot -- Number of sub-plots along the y-axis of the page.
mult -- Factor by which the text fontsize is to be multiplied.
"""
words = text.rsplit()
words_in_line = 0
line = ''
ax = plt.gca()
for i in range(len(words)):
temp = line + words[i] + ' '
words_in_line += 1
if len(temp) > maxchars:
if words_in_line == 1:
ax.text(xoff, yoff, temp, va='center', fontsize=mult*8,
transform=ax.transAxes, clip_on=False)
yoff -= 0.03 * ny_subplot * mult
words_in_line = 0
else:
ax.text(xoff, yoff, line, va='center', fontsize=mult*8,
transform=ax.transAxes, clip_on=False)
yoff -= 0.03 * ny_subplot * mult
line = words[i] + ' '
words_in_line = 1
else:
line = temp
if len(line) > 0:
ax.text(xoff, yoff, line, va='center', fontsize=mult*8,
transform=ax.transAxes, clip_on=False)
yoff -= 0.03 * ny_subplot * mult
yoff -= 0.01 * ny_subplot * mult
return yoff