Source code for pipeline.infrastructure.utils.imaging

"""
The imaging module contains utility functions used by the imaging tasks.

TODO These utility functions should migrate to hif.tasks.common
"""
import re

from .. import casa_tools
from .. import logging
import numpy
from typing import Union, Tuple, List, Dict, Any, Generator

from .. import utils

LOG = logging.get_logger(__name__)

__all__ = ['chan_selection_to_frequencies', 'freq_selection_to_channels', 'spw_intersect', 'update_sens_dict',
           'update_beams_dict', 'set_nested_dict', 'intersect_ranges', 'intersect_ranges_by_weight', 'merge_ranges', 'equal_to_n_digits']


def _get_cube_freq_axis(img: str) -> Tuple[float, float, str, float, int]:
    """
    Get CASA image/cube frequency axis.

    Args:
        img: CASA image/cube name

    Returns:
        Tuple of frequency axis components
        (reference frequency, delta frequency per channel, frequency unit,
         reference pixel, number of pixels of frequency axis)
    """
    iaTool = casa_tools.image

    # Get frequency axis
    iaTool.open(img)
    imInfo = iaTool.summary()
    iaTool.close()

    fIndex = imInfo['axisnames'].tolist().index('Frequency')
    refFreq = imInfo['refval'][fIndex]
    deltaFreq = imInfo['incr'][fIndex]
    freqUnit = imInfo['axisunits'][fIndex]
    refPix = imInfo['refpix'][fIndex]
    numPix = imInfo['shape'][fIndex]

    return refFreq, deltaFreq, freqUnit, refPix, numPix


[docs]def chan_selection_to_frequencies(img: str, selection: str, unit: str = 'GHz') -> Union[List[float], List[str]]: """ Convert channel selection to frequency tuples for a given CASA cube. Args: img: CASA cube name selection: Channel selection string using CASA selection syntax unit: Frequency unit Returns: List of pairs of frequency values (float) in the desired units """ if selection in ('NONE', 'ALL'): return [selection] frequencies = [] if selection != '': qaTool = casa_tools.quanta # Get frequency axis try: refFreq, deltaFreq, freqUnit, refPix, numPix = _get_cube_freq_axis(img) except: LOG.error('No frequency axis found in %s.' % (img)) return ['NONE'] for crange in selection.split(';'): c0, c1 = list(map(float, crange.split('~'))) # Make sure c0 is the lower channel so that the +/-0.5 channel # adjustments below go in the right direction. if (c1 < c0): c0, c1 = c1, c0 # Convert the channel range (c0-c1) to the corresponding frequency range # that spans between the outer edges of this channel range. I.e., from # the lower frequency edge of c0 to the upper frequency edge of c1. f0 = qaTool.convert({'value': refFreq + (c0 - 0.5 - refPix) * deltaFreq, 'unit': freqUnit}, unit) f1 = qaTool.convert({'value': refFreq + (c1 + 0.5 - refPix) * deltaFreq, 'unit': freqUnit}, unit) if qaTool.lt(f0, f1): frequencies.append((f0['value'], f1['value'])) else: frequencies.append((f1['value'], f0['value'])) else: frequencies = ['NONE'] return frequencies
[docs]def freq_selection_to_channels(img: str, selection: str) -> Union[List[int], List[str]]: """ Convert frequency selection to channel tuples for a given CASA cube. Args: img: CASA cube name selection: Frequency selection string using CASA syntax Returns: List of pairs of channel values (int) """ if selection in ('NONE', 'ALL'): return [selection] channels = [] if selection != '': qaTool = casa_tools.quanta # Get frequency axis try: refFreq, deltaFreq, freqUnit, refPix, numPix = _get_cube_freq_axis(img) except: LOG.error('No frequency axis found in %s.' % (img)) return ['NONE'] p = re.compile(r'([\d.]*)(~)([\d.]*)(\D*)') for frange in p.findall(selection.replace(';', '')): f0 = qaTool.convert('%s%s' % (frange[0], frange[3]), freqUnit)['value'] f1 = qaTool.convert('%s%s' % (frange[2], frange[3]), freqUnit)['value'] # It is assumed here that the frequency ranges are given from # the lower edge of the lowest frequency channel to the upper # edge of the highest frequency channel, while the reference frequency # is specified at the center of the reference pixel (channel). To calculate # the corresponding channel range, we need to add 0.5 to the lower channel, # and subtract 0.5 from the upper channel. c0 = (f0 - refFreq) / deltaFreq c1 = (f1 - refFreq) / deltaFreq # Avoid stepping outside possible channel range c0 = max(c0, 0) c0 = min(c0, numPix - 1) c0 = int(utils.round_half_up(c0 + 0.5)) c0 = max(c0, 0) c0 = min(c0, numPix - 1) c1 = max(c1, 0) c1 = min(c1, numPix - 1) c1 = int(utils.round_half_up(c1 - 0.5)) c1 = max(c1, 0) c1 = min(c1, numPix - 1) if c0 < c1: channels.append((c0, c1)) else: channels.append((c1, c0)) else: channels = ['NONE'] return channels
[docs]def spw_intersect(spw_range: List[float], line_regions: List[List[float]]) -> List[List[float]]: """ This utility function takes a frequency range (as numbers with arbitrary but common units) and computes the intersection with a list of frequency ranges defining the regions of spectral lines. It returns the remaining ranges excluding the line frequency ranges. Args: spw_range: List of two numbers defining the spw frequency range line_regions: List of lists of pairs of numbers defining frequency ranges to be excluded Returns: List of lists of pairs of numbers defining the remaining frequency ranges """ spw_sel_intervals = [] for line_region in line_regions: if (line_region[0] <= spw_range[0]) and (line_region[1] >= spw_range[1]): spw_sel_intervals = [] spw_range = [] break elif (line_region[0] <= spw_range[0]) and (line_region[1] >= spw_range[0]): spw_range = [line_region[1], spw_range[1]] elif (line_region[0] >= spw_range[0]) and (line_region[1] < spw_range[1]): spw_sel_intervals.append([spw_range[0], line_region[0]]) spw_range = [line_region[1], spw_range[1]] elif line_region[0] >= spw_range[1]: spw_sel_intervals.append(spw_range) spw_range = [] break elif (line_region[0] >= spw_range[0]) and (line_region[1] >= spw_range[1]): spw_sel_intervals.append([spw_range[0], line_region[0]]) spw_range = [] break if spw_range != []: spw_sel_intervals.append(spw_range) return spw_sel_intervals
[docs]def update_sens_dict(dct: Dict, udct: Dict) -> None: """ Update a sensitivity dictionary. All generic solutions tried so far did not do the job. So this method assumes an explicit dictionary structure of ['<MS name>']['<field name']['<intent>'][<spw>]: {<sensitivity result>}. Args: dct: Sensitivities dictionary udct: Sensitivities update dictionary Returns: None. The main dictionary is modified in place. """ for msname in udct: # Exclude special primary keys that are not MS names if msname not in ['recalc', 'robust', 'uvtaper']: if msname not in dct: dct[msname] = {} for field in udct[msname]: if field not in dct[msname]: dct[msname][field] = {} for intent in udct[msname][field]: if intent not in dct[msname][field]: dct[msname][field][intent] = {} for spw in udct[msname][field][intent]: if spw not in dct[msname][field][intent]: dct[msname][field][intent][spw] = {} dct[msname][field][intent][spw] = udct[msname][field][intent][spw]
[docs]def update_beams_dict(dct: Dict, udct: Dict) -> None: """ Update a beams dictionary. All generic solutions tried so far did not do the job. So this method assumes an explicit dictionary structure of ['<field name']['<intent>'][<spwids>]: {<beam>}. Args: dct: Beams dictionary udct: Beams update dictionary Returns: None. The main dictionary is modified in place. """ for field in udct: # Exclude special primary keys that are not MS names if field not in ['recalc', 'robust', 'uvtaper']: if field not in dct: dct[field] = {} for intent in udct[field]: if intent not in dct[field]: dct[field][intent] = {} for spwids in udct[field][intent]: if spwids not in dct[field][intent]: dct[field][intent][spwids] = {} dct[field][intent][spwids] = udct[field][intent][spwids]
[docs]def set_nested_dict(dct: Dict, keys: Tuple[Any], value: Any) -> None: """ Set a hierarchy of dictionaries with given keys and value for the lowest level key. >>> d = {} >>> set_nested_dict(d, ('key1', 'key2', 'key3'), 1) >>> print(d) {'key1': {'key2': {'key3': 1}}} Args: dct: Any dictionary keys : List of keys to build hierarchy value: Value for lowest level key Returns: None. The dictionary is modified in place. """ for key in keys[:-1]: dct = dct.setdefault(key, {}) dct[keys[-1]] = value
[docs]def intersect_ranges(ranges: List[Tuple[Union[float, int]]]) -> Tuple[Union[float, int]]: """ Compute intersection of ranges. Args: ranges: List of tuples defining (frequency) intervals Returns: intersect_range: Tuple of two numbers defining the intersection """ if len(ranges) == 0: return () elif len(ranges) == 1: return ranges[0] else: intersect_range = ranges[0] for myrange in ranges[1:]: i0 = max(intersect_range[0], myrange[0]) i1 = min(intersect_range[1], myrange[1]) if i0 <= i1: intersect_range = (i0, i1) else: return () return intersect_range
[docs]def intersect_ranges_by_weight(ranges: List[Tuple[Union[float, int]]], delta: float, threshold: float) -> Tuple[float]: """ Compute intersection of ranges through weight arrays and a threshold. Args: ranges: List of tuples defining frequency intervals delta: Frequency step to be used for the intersection threshold: Threshold to be used for the intersection Returns: intersect_range: Tuple of two numbers defining the intersection """ if len(ranges) == 0: return () elif len(ranges) == 1: return ranges[0] else: min_v = min(numpy.array(ranges).flatten()) max_v = max(numpy.array(ranges).flatten()) max_range = numpy.arange(min_v, max_v+delta, delta) range_weights = numpy.zeros(max_range.shape, 'd') for myrange in ranges: range_weights += numpy.where((max_range >= myrange[0]) & (max_range <= myrange[1]), 1.0, 0.0) range_weights /= len(ranges) valid_indices = numpy.where(range_weights >= threshold)[0] if valid_indices.shape != (0,): return (max_range[valid_indices[0]], max_range[valid_indices[-1]]) else: return ()
[docs]def merge_ranges(ranges: List[Tuple[Union[float, int]]]) -> Generator[List[Tuple[float]], None, None]: """ Merge overlapping and adjacent ranges and yield the merged ranges in order. The argument must be an iterable of pairs (start, stop). Args: ranges: List of tuples of two numbers defining ranges Returns: Generator yielding tuples of merged ranges >>> list(merge_ranges([(5,7), (3,5), (-1,3)])) [(-1, 7)] >>> list(merge_ranges([(5,6), (3,4), (1,2)])) [(1, 2), (3, 4), (5, 6)] >>> list(merge_ranges([])) [] (c) Gareth Rees 02/2013 """ ranges = iter(sorted(ranges)) current_start, current_stop = next(ranges) for start, stop in ranges: if start > current_stop: # Gap between segments: output current segment and start a new one. yield current_start, current_stop current_start, current_stop = start, stop else: # Segments adjacent or overlapping: merge. current_stop = max(current_stop, stop) yield current_start, current_stop
[docs]def equal_to_n_digits(x: float, y: float, numdigits: int = 7) -> bool: """ Approximate equality check up to a given number of digits. Args: x: First floating point number y: Second floating point number numdigits: Number of digits to check Returns: Boolean """ try: numpy.testing.assert_approx_equal(x, y, numdigits) return True except: return False