"""
The imaging module contains utility functions used by the imaging tasks.
TODO These utility functions should migrate to hif.tasks.common
"""
import re
from .. import casa_tools
from .. import logging
import numpy
from typing import Union, Tuple, List, Dict, Any, Generator
from .. import utils
LOG = logging.get_logger(__name__)
__all__ = ['chan_selection_to_frequencies', 'freq_selection_to_channels', 'spw_intersect', 'update_sens_dict',
           'update_beams_dict', 'set_nested_dict', 'intersect_ranges', 'intersect_ranges_by_weight', 'merge_ranges', 'equal_to_n_digits']
def _get_cube_freq_axis(img: str) -> Tuple[float, float, str, float, int]:
    """
    Get CASA image/cube frequency axis.
    Args:
        img: CASA image/cube name
    Returns:
        Tuple of frequency axis components
        (reference frequency, delta frequency per channel, frequency unit,
         reference pixel, number of pixels of frequency axis)
    """
    iaTool = casa_tools.image
    # Get frequency axis
    iaTool.open(img)
    imInfo = iaTool.summary()
    iaTool.close()
    fIndex = imInfo['axisnames'].tolist().index('Frequency')
    refFreq = imInfo['refval'][fIndex]
    deltaFreq = imInfo['incr'][fIndex]
    freqUnit = imInfo['axisunits'][fIndex]
    refPix = imInfo['refpix'][fIndex]
    numPix = imInfo['shape'][fIndex]
    return refFreq, deltaFreq, freqUnit, refPix, numPix
[docs]def chan_selection_to_frequencies(img: str, selection: str, unit: str = 'GHz') -> Union[List[float], List[str]]:
    """
    Convert channel selection to frequency tuples for a given CASA cube.
    Args:
        img:       CASA cube name
        selection: Channel selection string using CASA selection syntax
        unit:      Frequency unit
    Returns:
        List of pairs of frequency values (float) in the desired units
    """
    if selection in ('NONE', 'ALL'):
        return [selection]
    frequencies = []
    if selection != '':
        qaTool = casa_tools.quanta
        # Get frequency axis
        try:
            refFreq, deltaFreq, freqUnit, refPix, numPix = _get_cube_freq_axis(img)
        except:
            LOG.error('No frequency axis found in %s.' % (img))
            return ['NONE']
        for crange in selection.split(';'):
            c0, c1 = list(map(float, crange.split('~')))
            # Make sure c0 is the lower channel so that the +/-0.5 channel
            # adjustments below go in the right direction.
            if (c1 < c0):
                c0, c1 = c1, c0
            # Convert the channel range (c0-c1) to the corresponding frequency range
            # that spans between the outer edges of this channel range. I.e., from
            # the lower frequency edge of c0 to the upper frequency edge of c1.
            f0 = qaTool.convert({'value': refFreq + (c0 - 0.5 - refPix) * deltaFreq, 'unit': freqUnit}, unit)
            f1 = qaTool.convert({'value': refFreq + (c1 + 0.5 - refPix) * deltaFreq, 'unit': freqUnit}, unit)
            if qaTool.lt(f0, f1):
                frequencies.append((f0['value'], f1['value']))
            else:
                frequencies.append((f1['value'], f0['value']))
    else:
        frequencies = ['NONE']
    return frequencies 
[docs]def freq_selection_to_channels(img: str, selection: str) -> Union[List[int], List[str]]:
    """
    Convert frequency selection to channel tuples for a given CASA cube.
    Args:
        img:       CASA cube name
        selection: Frequency selection string using CASA syntax
    Returns:
        List of pairs of channel values (int)
    """
    if selection in ('NONE', 'ALL'):
        return [selection]
    channels = []
    if selection != '':
        qaTool = casa_tools.quanta
        # Get frequency axis
        try:
            refFreq, deltaFreq, freqUnit, refPix, numPix = _get_cube_freq_axis(img)
        except:
            LOG.error('No frequency axis found in %s.' % (img))
            return ['NONE']
        p = re.compile(r'([\d.]*)(~)([\d.]*)(\D*)')
        for frange in p.findall(selection.replace(';', '')):
            f0 = qaTool.convert('%s%s' % (frange[0], frange[3]), freqUnit)['value']
            f1 = qaTool.convert('%s%s' % (frange[2], frange[3]), freqUnit)['value']
            # It is assumed here that the frequency ranges are given from
            # the lower edge of the lowest frequency channel to the upper
            # edge of the highest frequency channel, while the reference frequency
            # is specified at the center of the reference pixel (channel). To calculate
            # the corresponding channel range, we need to add 0.5 to the lower channel,
            # and subtract 0.5 from the upper channel.
            c0 = (f0 - refFreq) / deltaFreq
            c1 = (f1 - refFreq) / deltaFreq
            # Avoid stepping outside possible channel range
            c0 = max(c0, 0)
            c0 = min(c0, numPix - 1)
            c0 = int(utils.round_half_up(c0 + 0.5))
            c0 = max(c0, 0)
            c0 = min(c0, numPix - 1)
            c1 = max(c1, 0)
            c1 = min(c1, numPix - 1)
            c1 = int(utils.round_half_up(c1 - 0.5))
            c1 = max(c1, 0)
            c1 = min(c1, numPix - 1)
            if c0 < c1:
                channels.append((c0, c1))
            else:
                channels.append((c1, c0))
    else:
        channels = ['NONE']
    return channels 
[docs]def spw_intersect(spw_range: List[float], line_regions: List[List[float]]) -> List[List[float]]:
    """
    This utility function takes a frequency range (as numbers with arbitrary
    but common units) and computes the intersection with a list of frequency
    ranges defining the regions of spectral lines. It returns the remaining
    ranges excluding the line frequency ranges.
    Args:
        spw_range:    List of two numbers defining the spw frequency range
        line_regions: List of lists of pairs of numbers defining frequency ranges
                      to be excluded
    Returns:
        List of lists of pairs of numbers defining the remaining frequency ranges
    """
    spw_sel_intervals = []
    for line_region in line_regions:
        if (line_region[0] <= spw_range[0]) and (line_region[1] >= spw_range[1]):
            spw_sel_intervals = []
            spw_range = []
            break
        elif (line_region[0] <= spw_range[0]) and (line_region[1] >= spw_range[0]):
            spw_range = [line_region[1], spw_range[1]]
        elif (line_region[0] >= spw_range[0]) and (line_region[1] < spw_range[1]):
            spw_sel_intervals.append([spw_range[0], line_region[0]])
            spw_range = [line_region[1], spw_range[1]]
        elif line_region[0] >= spw_range[1]:
            spw_sel_intervals.append(spw_range)
            spw_range = []
            break
        elif (line_region[0] >= spw_range[0]) and (line_region[1] >= spw_range[1]):
            spw_sel_intervals.append([spw_range[0], line_region[0]])
            spw_range = []
            break
    if spw_range != []:
        spw_sel_intervals.append(spw_range)
    return spw_sel_intervals 
[docs]def update_sens_dict(dct: Dict, udct: Dict) -> None:
    """
    Update a sensitivity dictionary. All generic solutions
    tried so far did not do the job. So this method assumes
    an explicit dictionary structure of
    ['<MS name>']['<field name']['<intent>'][<spw>]: {<sensitivity result>}.
    Args:
        dct:  Sensitivities dictionary
        udct: Sensitivities update dictionary
    Returns:
        None. The main dictionary is modified in place.
    """
    for msname in udct:
        # Exclude special primary keys that are not MS names
        if msname not in ['recalc', 'robust', 'uvtaper']:
            if msname not in dct:
                dct[msname] = {}
            for field in udct[msname]:
                if field not in dct[msname]:
                    dct[msname][field] = {}
                for intent in udct[msname][field]:
                    if intent not in dct[msname][field]:
                        dct[msname][field][intent] = {}
                    for spw in udct[msname][field][intent]:
                        if spw not in dct[msname][field][intent]:
                            dct[msname][field][intent][spw] = {}
                        dct[msname][field][intent][spw] = udct[msname][field][intent][spw] 
[docs]def update_beams_dict(dct: Dict, udct: Dict) -> None:
    """
    Update a beams dictionary. All generic solutions
    tried so far did not do the job. So this method assumes
    an explicit dictionary structure of
    ['<field name']['<intent>'][<spwids>]: {<beam>}.
    Args:
        dct:  Beams dictionary
        udct: Beams update dictionary
    Returns:
        None. The main dictionary is modified in place.
    """
    for field in udct:
        # Exclude special primary keys that are not MS names
        if field not in ['recalc', 'robust', 'uvtaper']:
            if field not in dct:
                dct[field] = {}
            for intent in udct[field]:
                if intent not in dct[field]:
                    dct[field][intent] = {}
                for spwids in udct[field][intent]:
                    if spwids not in dct[field][intent]:
                        dct[field][intent][spwids] = {}
                    dct[field][intent][spwids] = udct[field][intent][spwids] 
[docs]def set_nested_dict(dct: Dict, keys: Tuple[Any], value: Any) -> None:
    """
    Set a hierarchy of dictionaries with given keys and value
    for the lowest level key.
    >>> d = {}
    >>> set_nested_dict(d, ('key1', 'key2', 'key3'), 1)
    >>> print(d)
    {'key1': {'key2': {'key3': 1}}}
    Args:
        dct:   Any dictionary
        keys : List of keys to build hierarchy
        value: Value for lowest level key
    Returns:
        None. The dictionary is modified in place.
    """
    for key in keys[:-1]:
        dct = dct.setdefault(key, {})
    dct[keys[-1]] = value 
[docs]def intersect_ranges(ranges: List[Tuple[Union[float, int]]]) -> Tuple[Union[float, int]]:
    """
    Compute intersection of ranges.
    Args:
        ranges: List of tuples defining (frequency) intervals
    Returns:
        intersect_range: Tuple of two numbers defining the intersection
    """
    if len(ranges) == 0:
        return ()
    elif len(ranges) == 1:
        return ranges[0]
    else:
        intersect_range = ranges[0]
        for myrange in ranges[1:]:
            i0 = max(intersect_range[0], myrange[0])
            i1 = min(intersect_range[1], myrange[1])
            if i0 <= i1:
                intersect_range = (i0, i1)
            else:
                return ()
        return intersect_range 
[docs]def intersect_ranges_by_weight(ranges: List[Tuple[Union[float, int]]], delta: float, threshold: float) -> Tuple[float]:
    """
    Compute intersection of ranges through weight arrays and a threshold.
    Args:
        ranges:    List of tuples defining frequency intervals
        delta:     Frequency step to be used for the intersection
        threshold: Threshold to be used for the intersection
    Returns:
        intersect_range: Tuple of two numbers defining the intersection
    """
    if len(ranges) == 0:
        return ()
    elif len(ranges) == 1:
        return ranges[0]
    else:
        min_v = min(numpy.array(ranges).flatten())
        max_v = max(numpy.array(ranges).flatten())
    max_range = numpy.arange(min_v, max_v+delta, delta)
    range_weights = numpy.zeros(max_range.shape, 'd')
    for myrange in ranges:
        range_weights += numpy.where((max_range >= myrange[0]) & (max_range <= myrange[1]), 1.0, 0.0)
    range_weights /= len(ranges)
    valid_indices = numpy.where(range_weights >= threshold)[0]
    if valid_indices.shape != (0,):
        return (max_range[valid_indices[0]], max_range[valid_indices[-1]])
    else:
        return () 
[docs]def merge_ranges(ranges: List[Tuple[Union[float, int]]]) -> Generator[List[Tuple[float]], None, None]:
    """
    Merge overlapping and adjacent ranges and yield the merged ranges
    in order. The argument must be an iterable of pairs (start, stop).
    Args:
        ranges: List of tuples of two numbers defining ranges
    Returns:
        Generator yielding tuples of merged ranges
    >>> list(merge_ranges([(5,7), (3,5), (-1,3)]))
    [(-1, 7)]
    >>> list(merge_ranges([(5,6), (3,4), (1,2)]))
    [(1, 2), (3, 4), (5, 6)]
    >>> list(merge_ranges([]))
    []
    (c) Gareth Rees 02/2013
    """
    ranges = iter(sorted(ranges))
    current_start, current_stop = next(ranges)
    for start, stop in ranges:
        if start > current_stop:
            # Gap between segments: output current segment and start a new one.
            yield current_start, current_stop
            current_start, current_stop = start, stop
        else:
            # Segments adjacent or overlapping: merge.
            current_stop = max(current_stop, stop)
    yield current_start, current_stop 
[docs]def equal_to_n_digits(x: float, y: float, numdigits: int = 7) -> bool:
    """
    Approximate equality check up to a given number of digits.
    Args:
        x: First floating point number
        y: Second floating point number
        numdigits: Number of digits to check
    Returns:
        Boolean
    """
    try:
        numpy.testing.assert_approx_equal(x, y, numdigits)
        return True
    except:
        return False