import abc
import math
import os
import string
import time
import matplotlib.pyplot as plt
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.renderer.logger as logger
import pipeline.infrastructure.displays.pointing as pointing
from pipeline.domain.datatable import DataTableImpl as DataTable
from pipeline.hsd.tasks.common.display import DPISummary, DPIDetail, SingleDishDisplayInputs, ShowPlot, LightSpeed
from pipeline.infrastructure import casa_tools
from pipeline.infrastructure.displays.pointing import MapAxesManagerBase
from pipeline.infrastructure.displays.plotstyle import casa5style_plot
from ..common import direction_utils as dirutil
LOG = infrastructure.get_logger(__name__)
RArotation = pointing.RArotation
DECrotation = pointing.DECrotation
[docs]class ClusterValidationAxesManager(MapAxesManagerBase):
def __init__(self, clusters_to_plot, nh, nv, aspect_ratio,
xformatter, yformatter, xlocator, ylocator,
xrotation, yrotation, ticksize, labelsize, titlesize ):
super(ClusterValidationAxesManager, self).__init__()
self.clusters_to_plot = clusters_to_plot
self.nh = nh
self.nv = nv
self.aspect_ratio = aspect_ratio
self.xformatter = xformatter
self.yformatter = yformatter
self.xlocator = xlocator
self.ylocator = ylocator
self.xrotation = xrotation
self.yrotation = yrotation
self.ticksize = ticksize
self.labelsize = labelsize
self.titlesize = titlesize
self._legend = None
self._axes = None
self.legend_y = 0.85
@property
def axes_legend(self):
if self._legend is None:
# self._legend = plt.axes([0.0, 0.85, 1.0, 0.15])
self._legend = plt.axes([0.0, self.legend_y, 1.0, 1.0 - self.legend_y])
self._legend.set_axis_off()
return self._legend
@property
def axes_list(self):
if self._axes is None:
self._axes = list(self.__axes_list())
return self._axes
def __axes_list(self):
for icluster in self.clusters_to_plot:
loc = self.clusters_to_plot.index(icluster)
ix = loc % self.nh
iy = int(loc // self.nh)
( x0, y0, x1, y1, tpos_x, tpos_y ) = self.__calc_axes(plt.gcf(), ix, iy)
axes = plt.axes([x0, y0, x1, y1])
# 2008/9/20 DEC Effect
axes.set_aspect(self.aspect_ratio)
#axes.set_aspect('equal')
xlabel, ylabel = self.get_axes_labels()
# fold ylabel if there are many panels
if self.nv > 3:
ylabel = ylabel.replace( '(', '\n(', 1 )
axes.set_xlabel( xlabel, size=self.labelsize, labelpad=2)
axes.set_ylabel( ylabel, size=self.labelsize, labelpad=2)
axes.xaxis.set_major_formatter(self.xformatter)
axes.yaxis.set_major_formatter(self.yformatter)
axes.xaxis.set_major_locator(self.xlocator)
axes.yaxis.set_major_locator(self.ylocator)
axes.tick_params( axis='x', pad=1, labelrotation=self.xrotation, labelsize=self.labelsize, length=self.ticksize/2 )
axes.tick_params( axis='y', pad=1, labelrotation=self.yrotation, labelsize=self.labelsize, length=self.ticksize/2 )
xlabels = axes.get_xticklabels()
ylabels = axes.get_yticklabels()
yield icluster, axes, tpos_x, tpos_y
def __calc_axes( self, fig, ix, iy ):
# unit conversion constant for points->inch
ppi = 72
# padding between panels (unit: points)
( px, py ) = ( 7, 11 )
# title vertical position
title_v = 1.7
# label extent
label_extent = 0.014
# axes size limit (unit: points)
limit = 240
# figure size (unit: points)
fx = fig.get_figwidth() * ppi
fy = fig.get_figheight() * ppi
# margins at figure edge
mx1 = fx * 0.01 # left
mx2 = fx * 0.04 # right
my1 = 0.0 # bottom
my2 = fy * 0.08 # top
# label extents (unit: points)
lx = fx * label_extent * self.labelsize
ly = fy * label_extent * self.labelsize
# panel boundary max including ticks and labels
max_x = ( fx - mx1 - mx2 - px*(self.nh-1) ) / self.nh
max_y = ( fy * self.legend_y - my1-my2 - py*(self.nv-1)) / self.nv
# limit the panel size
if max_x > limit and max_y - ly*2 > limit:
max_x = limit
max_y = limit
# extent and offset of plot area
extent_x = max_x * self.nh + px * (self.nh - 1)
extent_y = max_y * self.nv + py * (self.nv - 1)
offset_x = ( fx - extent_x ) / 2
offset_y = ( fy*self.legend_y - extent_y ) / 2
# calculate axes parameters
ax = max_x - lx
ay = max_y - title_v*self.titlesize - ly
x1 = ax / fx
if self.nh == 1:
x0 = 0.5 - x1/2.0
else:
x0 = (((max_x+px) * ix + lx + mx1 + offset_x) ) / fx
y1 = ay / fy
y0 = ((max_y+py) * (self.nv-iy-1) + ly + my1 + offset_y) / fy
# relative position of the title
if self.nh < 4:
tpos_x = 0.5 # locate title at axes center
else:
tpos_x = (ax-lx)/(2*ax) # locate title at panel center
tpos_y = 1.008 # equiv. to titlepad
return x0, y0, x1, y1, tpos_x, tpos_y
[docs]class ClusterDisplay(object):
Inputs = SingleDishDisplayInputs
def __init__(self, inputs):
self.inputs = inputs
@property
def context(self):
return self.inputs.context
def __baselined(self):
for group in self.inputs.result.outcome['baselined']:
if 'clusters' in group and 'lines' in group:
yield group
@casa5style_plot
def plot(self):
plot_list = []
stage_dir = os.path.join(self.context.report_dir,
'stage%d' % (self.inputs.result.stage_number))
start_time = time.time()
reduction_group = self.context.observing_run.ms_reduction_group
for group in self.__baselined():
group_id = group['group_id']
cluster = group['clusters']
flag_digits = group['flag_digits']
org_direction = group['org_direction']
lines = group['lines']
is_all_invalid_lines = all([l[2] == False for l in lines])
rep_member_id = group['members'][0]
rep_member = reduction_group[group_id][rep_member_id]
## now judgement to plot is done exclusively in ClusterValidationDisplay._plot()
#
# if 'cluster_score' not in cluster or is_all_invalid_lines:
# # it should be empty cluster (no detection) or false clusters (detected but
# # judged as an invalid clusters) so skip this cycle
# continue
#
# # skip the cycle for cluster with no lines validated at final stage
# flags = cluster['cluster_flag']
# final_flags = ( flags // flag_digits['final'] ) % 10
# if ( final_flags == 0 ).all():
# continue
if 'index' in group:
# having key 'index' indicates the result comes from old (Scantable-based)
# procedure
antenna = group['index'][0]
vis = None
else:
# having key 'antenna' instead of 'index' indicates the result comes from
# new (MS-based) procedure
antenna = rep_member.antenna_id
vis = rep_member.ms.name
spw = rep_member.spw_id
field = rep_member.field_id
ms = self.context.observing_run.get_ms(vis)
virtual_spw = self.context.observing_run.real2virtual_spw_id(spw, ms)
source_name = ms.fields[field].source.name.replace(' ', '_').replace('/', '_')
iteration = group['iteration']
t0 = time.time()
plot_validation = ClusterValidationDisplay(self.context, group_id, iteration, cluster,
flag_digits, vis,
virtual_spw, source_name, antenna, lines, stage_dir,
org_direction )
validation_plot = plot_validation.plot()
# if there are no validated lines, then skip all the plots
if len(validation_plot) == 0:
continue
plot_list.extend(validation_plot)
t1 = time.time()
plot_score = ClusterScoreDisplay(group_id, iteration, cluster, virtual_spw, source_name, stage_dir)
plot_list.extend(plot_score.plot())
t2 = time.time()
plot_property = ClusterPropertyDisplay(group_id, iteration, cluster, virtual_spw, source_name, stage_dir)
plot_list.extend(plot_property.plot())
t3 = time.time()
LOG.debug('PROFILE: ClusterScoreDisplay elapsed time is %s sec' % (t2-t1))
LOG.debug('PROFILE: ClusterPropertyDisplay elapsed time is %s sec' % (t3-t2))
LOG.debug('PROFILE: ClusterValidationDisplay elapsed time is %s sec' % (t1-t0))
end_time = time.time()
LOG.debug('PROFILE: plot elapsed time is %s sec'%(end_time-start_time))
return plot_list
[docs]class ClusterDisplayWorker(object, metaclass=abc.ABCMeta):
MATPLOTLIB_FIGURE_ID = 8907
def __init__(self, group_id, iteration, cluster, spw, field, stage_dir):
"""
spw is a virtual spw id
"""
self.group_id = group_id
self.iteration = iteration
self.cluster = cluster
self.spw = spw
self.field = field
self.stage_dir = stage_dir
[docs] def plot(self):
if ShowPlot:
plt.ion()
else:
plt.ioff()
plt.figure(self.MATPLOTLIB_FIGURE_ID)
if ShowPlot:
plt.ioff()
plt.cla()
plt.clf()
return list(self._plot())
def _create_plot(self, plotfile, type, x_axis, y_axis):
parameters = {}
parameters['intent'] = 'TARGET'
parameters['spw'] = self.spw # spw id should be virtual one
parameters['pol'] = 0
parameters['ant'] = 'all'
parameters['type'] = type
plot_obj = logger.Plot(plotfile,
x_axis=x_axis,
y_axis=y_axis,
field=self.field,
parameters=parameters)
return plot_obj
@abc.abstractmethod
def _plot(self):
raise NotImplementedError
[docs]class ClusterScoreDisplay(ClusterDisplayWorker):
def _plot(self):
ncluster, score = self.cluster['cluster_score']
plt.plot(ncluster, score, 'bx', markersize=10)
[xmin, xmax, ymin, ymax] = plt.axis()
plt.xlabel('Number of Clusters', fontsize=11)
plt.ylabel('Score (Lower is better)', fontsize=11)
plt.title('Score are plotted versus number of the cluster', fontsize=11)
plt.axis([0, xmax + 1, ymin, ymax])
if ShowPlot:
plt.draw()
plotfile = os.path.join(self.stage_dir,
'cluster_score_group%s_spw%s_iter%s.png' % (self.group_id, self.spw, self.iteration))
plt.savefig(plotfile, format='png', dpi=DPIDetail)
plot = self._create_plot(plotfile, 'cluster_score',
'Number of Clusters', 'Score')
yield plot
[docs]class ClusterPropertyDisplay(ClusterDisplayWorker):
def _plot(self):
lines = self.cluster['detected_lines']
properties = self.cluster['cluster_property']
scaling = self.cluster['cluster_scale']
sorted_properties = sorted(properties)
width = lines[:, 0]
center = lines[:, 1]
plt.plot(center, width, 'bs', markersize=1)
[xmin, xmax, ymin, ymax] = plt.axis()
axes = plt.gcf().gca()
cluster_id = 0
for [cx, cy, dummy, r] in sorted_properties:
radius = r * scaling
aspect = 1.0 / scaling
x_base = cx
y_base = cy * scaling
pointing.draw_beam(axes, radius, aspect, x_base, y_base, offset=0)
plt.text(x_base, y_base, str(cluster_id), fontsize=10, color='red')
cluster_id += 1
plt.xlabel('Line Center (Channel)', fontsize=11)
plt.ylabel('Line Width (Channel)', fontsize=11)
plt.axis([xmin - 1, xmax + 1, 0, ymax + 1])
plt.title('Clusters in the line Center-Width space\n\nRed Oval(s) shows each clustering region. '
'Size of the oval represents cluster radius', fontsize=11)
if ShowPlot:
plt.draw()
plotfile = os.path.join(self.stage_dir,
'cluster_property_group%s_spw%s_iter%s.png' % (self.group_id, self.spw, self.iteration))
plt.savefig(plotfile, format='png', dpi=DPISummary)
plot = self._create_plot(plotfile, 'line_property',
'Line Center', 'Line Width')
yield plot
[docs]class ClusterValidationDisplay(ClusterDisplayWorker):
Description1 = {
'detection': 'Clustering Analysis at Detection stage',
'validation': 'Clustering Analysis at Validation stage',
'smoothing': 'Clustering Analysis at Smoothing stage',
'final': 'Clustering Analysis at Final stage'
}
Description2 = {
'detection': 'Yellow Square: Single spectrum is detected in the grid\nCyan Square: More than one spectra are detected in the grid\n',
'validation': 'Validation by the rate (Number of clustering member [Nmember] v.s. Number of total spectra belong to the Grid [Nspectra])\n Blue Square: Validated: Nmember > ${valid} x Nspectra\nCyan Square: Marginally validated: Nmember > ${marginal} x Nspectra\nYellow Square: Questionable: Nmember > ${questionable} x Nspectrum\n',
'smoothing': 'Blue Square: Passed continuity check\nCyan Square: Border\nYellow Square: Questionable\n',
'final': 'Green Square: Final Grid where the line protection channels are calculated and applied to the baseline subtraction\nBlue Square: Final Grid where the calculated line protection channels are applied to the baseline subtraction\n\nIsolated Grids are eliminated.\n'
}
def __init__( self, context, group_id, iteration, cluster, flag_digits, vis, spw, field, antenna, lines, stage_dir, org_direction ):
super(ClusterValidationDisplay, self).__init__(group_id, iteration, cluster, spw, field, stage_dir)
self.context = context
self.antenna = antenna
self.lines = lines
self.flag_digits = flag_digits
self.vis = vis
self.org_direction = org_direction
def _plot(self):
plt.clf()
marks = ['gs', 'bs', 'cs', 'ys']
if 'cluster_flag' not in self.cluster:
return None
# list up iclusters of clusters to plot
clusters_to_plot = []
flags = self.cluster['cluster_flag']
final_flags = ( flags // self.flag_digits['final'] ) % 10
for icluster in range(len(final_flags)):
## (final_flags[icluster]==0).all() is no longer necessary since validation.py is revised.
# if not( self.lines[icluster][2] == False or (final_flags[icluster]==0).all() ):
if self.lines[icluster][2] == True:
clusters_to_plot.append(icluster)
num_cluster = len(clusters_to_plot)
# num_cluster = len(self.cluster['cluster_property'])
# no clusters to plot
if num_cluster == 0:
return None
num_panel_h = int(math.sqrt(num_cluster - 0.1)) + 1
num_panel_v = int((num_cluster-0.1) // num_panel_h) + 1
# num_panel_v = num_panel_h
ra0 = self.cluster['grid']['ra_min']
dec0 = self.cluster['grid']['dec_min']
scale_ra = self.cluster['grid']['grid_ra']
scale_dec = self.cluster['grid']['grid_dec']
# convert ra0/dec0 to SHIFT_RA/DEC and adjust scale_ra for Ephemeris sources
if self.org_direction is not None:
ra1, dec1 = dirutil.direction_recover( ra0, dec0, self.org_direction )
ra2, dec2 = dirutil.direction_recover( ra0+scale_ra, dec0, self.org_direction )
scale_ra = ra2 - ra1
ra0, dec0 = ra1, dec1
# 2008/9/20 DEC Effect
aspect_ratio = 1.0 / math.cos(dec0 / 180.0 * 3.141592653)
# common message for legends
scale_msg = self.__scale_msg(scale_ra, scale_dec, aspect_ratio)
# Plotting parameters
nx = len(self.cluster['cluster_flag'][0])
ny = len(self.cluster['cluster_flag'][0][0])
xmin = ra0
xmax = nx * scale_ra + xmin
ymin = dec0
ymax = ny * scale_dec + ymin
tick_size, label_size, title_size = self.__set_size( num_panel_h, num_panel_v )
# direction reference
reference_ms = self.context.observing_run.measurement_sets[0]
datatable_name = os.path.join(self.context.observing_run.ms_datatable_name, reference_ms.basename)
datatable = DataTable()
datatable.importdata(datatable_name, minimal=False, readonly=True)
direction_reference = datatable.direction_ref
del datatable
span = max(xmax - xmin, ymax - ymin)
(RAlocator, DEClocator, RAformatter, DECformatter) = pointing.XYlabel(span,
direction_reference)
axes_manager = ClusterValidationAxesManager(clusters_to_plot,
num_panel_h,
num_panel_v,
aspect_ratio,
RAformatter,
DECformatter,
RAlocator,
DEClocator,
RArotation,
DECrotation,
tick_size,
label_size,
title_size )
axes_manager.direction_reference = direction_reference
axes_db = axes_manager.axes_list
axes_list = { k: v for ( k, v, x, y ) in axes_db }
title_pos = { k: [x, y] for ( k, v, x, y ) in axes_db }
axes_legend = axes_manager.axes_legend
for (mode, data, threshold, description1, description2) in self.__stages():
plot_objects = []
for icluster in clusters_to_plot:
axes_cluster = axes_list[icluster]
axes_cluster.axis([xmax, xmin, ymin, ymax])
# calculate the optimum marker_size for axes
marker_size = self.__marker_size( axes_cluster, nx, ny )
xdata = []
ydata = []
for i in range(len(threshold)):
xdata.append([])
ydata.append([])
for ix in range(nx):
for iy in range(ny):
for i in range(len(threshold)):
if data[icluster][ix][iy] == len(threshold) - i:
xdata[i].append(xmin + (0.5 + ix) * scale_ra)
ydata[i].append(ymin + (0.5 + iy) * scale_dec)
break
# Convert Channel to Frequency and Velocity
#ichan = self.lines[icluster][0] + 0.5
(frequency, width) = self.__line_property(icluster)
# title_x = xmin + ( xmax-xmin ) * title_pos[icluster][0]
( title_x, title_y ) = title_pos[icluster]
plot_objects.append(
axes_cluster.text( title_x, title_y,
"Cluster {}\n"
r"$f_\mathrm{{center}}$={:.4f} GHz $\Delta v$={:.1f} km/s".format(icluster, frequency, width),
transform=axes_cluster.transAxes,
linespacing=1,
fontsize=title_size,
horizontalalignment='center',
verticalalignment='bottom'
)
)
if self.lines[icluster][2] == False and mode == 'final':
if num_panel_h > 2:
_tick_size = tick_size
else:
_tick_size = tick_size + 1
plot_objects.append(
axes_cluster.text(0.5 * (xmin + xmax), 0.5 * (ymin + ymax),
'INVALID CLUSTER',
horizontalalignment='center',
verticalalignment='center',
size=_tick_size)
)
else:
for i in range(len(threshold)):
plot_objects.extend(
axes_cluster.plot(xdata[i], ydata[i], marks[4 - len(threshold) + i], markersize=marker_size)
)
# Legends
plot_objects.append(
axes_legend.text( 0.5, 0.85, description1,
horizontalalignment='center',
verticalalignment='baseline', size=8 )
)
plot_objects.append(
axes_legend.text( 0.5, 0.0, description2+scale_msg,
horizontalalignment='center',
verticalalignment='baseline', size=8 )
)
if ShowPlot:
plt.draw()
plotfile = os.path.join(
self.stage_dir,
'cluster_group_%s_spw%s_iter%s_%s.png' % (self.group_id, self.spw, self.iteration, mode))
plt.savefig(plotfile, format='png', dpi=DPISummary)
for obj in plot_objects:
obj.remove()
plot = self._create_plot(plotfile, 'clustering_%s'%(mode),
'R.A.', 'Dec.')
yield plot
def __set_size( self, num_panel_h, num_panel_v ):
tick_size = 6 + (1 // num_panel_h) * 2
if num_panel_v > 3:
label_size = tick_size - 1
title_size = tick_size
elif num_panel_h > 3:
label_size = tick_size
title_size = tick_size
else:
label_size = tick_size
title_size = tick_size + 1
return tick_size, label_size, title_size
def __marker_size( self, axes, nx, ny, tile_gap=0.0 ):
axes_bbox = axes.get_position()
fig_width = axes.get_figure().get_figwidth()
fig_height = axes.get_figure().get_figheight()
ppi = 72 # constant for "Points per Inch"
axes_width = (axes_bbox.x1 - axes_bbox.x0 ) * fig_width * ppi
axes_height = (axes_bbox.y1 - axes_bbox.y0 ) * fig_height * ppi
size_h = axes_width / (nx*(1.0+tile_gap))
size_v = axes_height / (ny*(1.0+tile_gap))
marker_size = min( size_h, size_v )
return marker_size
def __stages(self):
for key in self.flag_digits.keys():
if 'cluster_flag' in self.cluster:
# Pick up target digit
_flag = self.cluster['cluster_flag']
_digit = self.flag_digits[key]
flag = ( _flag // _digit) % 10
LOG.debug('flag=%s' % flag)
threshold = self.cluster[key+'_threshold']
desc1 = self.Description1[key]
desc2 = self.Description2[key]
if key == 'validation':
template = string.Template(desc2)
valid = '%.1f' % (threshold[0])
marginal = '%.1f' % (threshold[1])
questionable = '%.1f' % (threshold[2])
desc2 = template.safe_substitute(valid=valid,
marginal=marginal,
questionable=questionable)
yield (key, flag, threshold, desc1, desc2)
def __line_property(self, icluster):
reduction_group = self.context.observing_run.ms_reduction_group[self.group_id]
field = reduction_group[0].field
source_id = field.source_id
ms = self.context.observing_run.get_ms(self.vis)
real_spw = self.context.observing_run.virtual2real_spw_id(self.spw, ms)
spectral_window = ms.get_spectral_window(real_spw)
refpix = 0
refval = spectral_window.channels.chan_freqs[0]
increment = spectral_window.channels.chan_widths[0]
with casa_tools.TableReader(os.path.join(self.vis, 'SOURCE')) as tb:
tsel = tb.query('SOURCE_ID == %s && SPECTRAL_WINDOW_ID == %s' % (source_id, real_spw))
try:
if tsel.nrows() == 0:
rest_frequency = refval
else:
if tsel.iscelldefined('REST_FREQUENCY', 0):
rest_frequency = tsel.getcell('REST_FREQUENCY', 0)[0]
else:
rest_frequency = refval
finally:
tsel.close()
# line property in channel
line_center = self.lines[icluster][0]
line_width = self.lines[icluster][1]
center_frequency = refval + (line_center - refpix) * increment
width_in_frequency = abs(line_width * increment)
center_frequency *= 1.0e-9 # Hz -> GHz
width_in_velocity = width_in_frequency / rest_frequency * LightSpeed
return center_frequency, width_in_velocity
def __scale_msg(self, scale_ra, scale_dec, aspect_ratio):
if scale_ra >= 1.0:
unit = 'degree'
scale_factor = 1.0
elif scale_ra * 60.0 >= 1.0:
unit = 'arcmin'
scale_factor = 60.0
else:
unit = 'arcsec'
scale_factor = 3600.0
ra_text = scale_ra / aspect_ratio * scale_factor
dec_text = scale_dec * scale_factor
return 'Scale of the Square (Grid): %.1f x %.1f (%s)' % (ra_text, dec_text, unit)