Source code for pipeline.infrastructure.utils.utils_test

from typing import Union, List, Dict, Tuple
from unittest.mock import Mock

import numpy as np
import pytest

from pipeline import domain
from .utils import find_ranges, dict_merge, are_equal, approx_equal, flagged_intervals, \
    get_casa_quantity, get_num_caltable_polarizations, fieldname_for_casa, fieldname_clean, \
    get_field_accessor, get_field_identifiers, get_receiver_type_for_spws

params_find_ranges = [('', ''), ([], ''), ('1:2', '1:2'), ([1, 2, 3], '1~3'),
                      (['5~12', '14', '16:17'], '5~12,14,16:17'),
                      ([1, 2, 3, 6, 7], '1~3,6~7'),
                      ([1, 2, 3, '6', '7'], '1~3,6~7')]


[docs]@pytest.mark.parametrize('data, expected', params_find_ranges) def test_find_ranges(data: Union[str, list], expected: str): """Test find_ranges() This utility function takes a string or a list of integers (e.g. spectral window lists) and returns a string containing identified ranges. E.g. [1,2,3] -> '1~3' """ assert find_ranges(data) == expected
params_dict_merge = [({}, {}, {}), ({}, 1, 1), ({'a': 1}, {}, {'a': 1}), ({'a': {'b': 1}}, {'c': 2}, {'a': {'b': 1}, 'c': 2}), ({'a': {'b': 1}}, {'a': {'b': 2}}, {'a': {'b': 2}}), ({'a': {'b': 1}}, {'a': 2}, {'a': 2}), ({'a': {'b': {'c': 1}}}, {'a': {'b': {'c': 2}}}, {'a': {'b': {'c': 2}}})]
[docs]@pytest.mark.parametrize('a, b, expected', params_dict_merge) def test_dict_merge(a: Dict, b: Dict, expected: Dict): """Test dict_merge() This utility function recursively merges dictionaries. If second argument (b) is a dictionary, then a copy of first argument (dictionary a) is created and the elements of b are merged into the new dictionary. Otherwise return argument b. In case of matching non-dictionary value keywords, content of dictionary b overwrites that of dictionary a. If the matching keyword value is a dictionary then continue merging recursively. """ assert dict_merge(a, b) == expected
params_are_equal = [([1, 2, 3], [1, 2, 3], True), ([1, 2.5, 3], [1, 2, 3], False), (np.ones(2), np.zeros(2), False), (np.ones(2), np.ones(2), True), (np.ones(2), np.ones(3), False)]
[docs]@pytest.mark.parametrize('a, b, expected', params_are_equal) def test_are_equal(a: Union[List, np.ndarray], b: Union[List, np.ndarray], expected: bool): """Test are_equal() This utility function check the equivalence of array like objects. Two arrays are equal if they have the same number of elements and elements of the same index are equal. """ assert are_equal(a, b) == expected
params_approx_eqaul = [(1.0e-2, 1.2e-2, 1e-2, True), (1.0e-2, 1.2e-2, 1e-3, False), (1.0, 2.0, 0.1, False), (1, 2, 10, True)]
[docs]@pytest.mark.parametrize('x, y, tol, expected', params_approx_eqaul) def test_approx_equal(x: float, y: float, tol: float, expected: bool): """Test approx_equal() This utility function returns True if two numbers are equal within the given tolerance. """ assert approx_equal(x, y, tol=tol) == expected
params_test_get_num_calltable_pol = [('uid___A002_Xc46ab2_X15ae_spw16_17_small.ms.hifa_' 'timegaincal.s17_7.spw0.solintinf.gacal.tbl', 1), ('uid___A002_Xc46ab2_X15ae_spw16_17_small.ms.hifa_' 'timegaincal.s17_2.spw0.solintinf.gpcal.tbl', 2)]
[docs]@pytest.mark.skip(reason="Currently no general online pipeline date storage is available for test datasets.") @pytest.mark.parametrize('caltable, expected', params_test_get_num_calltable_pol) def test_get_num_caltable_polarizations(caltable: str, expected: int): """Test get_num_caltable_polarizations() """ assert get_num_caltable_polarizations(caltable=caltable) == expected
params_flagged_intervals = [([], []), ([1, 2], [(0, 0)]), ([0, 1, 0, 1, 1], [(1, 1), (3, 4)]), ([0, 1, 0, 1, 2], [(1, 1), (3, 3)])]
[docs]@pytest.mark.parametrize('vec, expected', params_flagged_intervals) def test_flagged_intervals(vec: Union[List[int], np.ndarray], expected: List[Tuple[int]]): """Test flagged_intervals() This utility function finds islands of ones in vector provided in argument. Used to find contiguous flagged channels in a given spw. Returns a list of tuples with the start and end channels. """ assert flagged_intervals(vec=vec) == expected
params_fieldname_for_casa = [('', ''), ('helm30', 'helm30'), ('helm=30', '"helm=30"'), ('1', '"1"')]
[docs]@pytest.mark.parametrize('field, expected', params_fieldname_for_casa) def test_fieldname_for_casa(field: str, expected: str): """Test fieldname_for_casa() This utility function ensures that field string can be used as CASA argument. If field contains special characters, then return field string enclose in quotation marks, otherwise return unchanged string. """ assert fieldname_for_casa(field=field) == expected
params_fieldname_clean = [('', ''), ('helm30', 'helm30'), ('helm=30', 'helm_30'), ('1', '1')]
[docs]@pytest.mark.parametrize('field, expected', params_fieldname_clean) def test_fieldname_clean(field: str, expected: str): """Test fieldname_clean() This utility function replaces special characters in string with underscore. """ assert fieldname_clean(field=field) == expected
# Create mock Fields and MeasurementSets for testing get_field_accessor() and get_field_identifiers() # The Field name and id attributes, and MeasurementSet fields attribute and get_fields() methods are # accessed. fields = [] for i, fn in enumerate(['Mars', 'Jupiter', 'Mars']): m = Mock(spec=domain.Field, **{'id': i + 1}) m.name = fn # Mock name and name attribute interfere, set attribute explicitly fields.append(m) # get_fields() is called only once in this test, therefore set return_value. params_get_field_accessor = [ (Mock(spec=domain.MeasurementSet, **{ 'get_fields.return_value': [fields[1]] }), fields[1], 'Jupiter'), # All fields names are unique (Mock(spec=domain.MeasurementSet, **{ 'get_fields.return_value': [fields[0], fields[2]] }), fields[2], '3')] # Field name 'Mars' repeats
[docs]@pytest.mark.parametrize('ms, field, expected', params_get_field_accessor) def test_get_field_accessor(ms, field, expected): """Test get_field_accessor() This utility function returns an attribute getter. If the field specified in the argument is unique in the MeasurementSet, then the getter will access the field name (name attribute), otherwise the getter will access the field id (id attribute). """ assert get_field_accessor(ms, field)(field) == expected
# get_fields() returns all fields with the name given in argument, mock this behaviour # The method is called multiple times, therefore set side_effect. params_get_field_ids = [ (Mock(spec=domain.MeasurementSet, **{ 'fields': fields[0:2], 'get_fields.side_effect': [[f] for f in fields[0:2]] }), {1: 'Mars', 2: 'Jupiter'}), # All fields names are unique (Mock(spec=domain.MeasurementSet, **{ 'fields': fields, 'get_fields.side_effect': [[fields[0], fields[2]], [fields[1]], [fields[0], fields[2]]] }), {1: '1', 2: 'Jupiter', 3: '3'})] # Field name 'Mars' repeats
[docs]@pytest.mark.parametrize('ms, expected', params_get_field_ids) def test_get_field_identifiers(ms, expected): """Test get_field_identifiers() This utility function returns a dictionary with field ID keys and either field name or str(field ID) values. The latter happens when a field name occurs more than once. """ assert get_field_identifiers(ms=ms) == expected
params_get_receiver_type_for_spws = [ (Mock(spec=domain.MeasurementSet, **{ 'get_spectral_windows.side_effect': [None, [Mock(**{'receiver': 'fake'})]] }), [1, 2], {1: 'N/A', 2: 'fake'})]
[docs]@pytest.mark.parametrize('ms, spwids, expected', params_get_receiver_type_for_spws) def test_get_receiver_type_for_spws(ms, spwids, expected): """Test get_receiver_type_for_spws() This utility function returns a dictionary with spectral window IDs (spwids arguemnt) as keys and the associated receiver strings in the MeasurementSet as values. If spectral window ID is not found in the MeasurementSet, then the associated values is set to 'N/A'. """ assert get_receiver_type_for_spws(ms=ms, spwids=spwids) == expected
params_get_casa_quantity = [(None, {'unit': '', 'value': 0.0}), ('10klambda', {'unit': 'klambda', 'value': 10.0}), (10.0, {'unit': '', 'value': 10.0})]
[docs]@pytest.mark.parametrize('value, expected', params_get_casa_quantity) def test_get_casa_quantity(value: Union[str, float, Dict, None], expected: Dict): """Test get_casa_quantity() This utility function handles None values when calling CASA quanta.quantity() tool method. """ assert get_casa_quantity(value) == expected