import os
import pipeline.h.tasks.restoredata.restoredata as restoredata
import pipeline.infrastructure as infrastructure
import pipeline.infrastructure.basetask as basetask
import pipeline.infrastructure.vdp as vdp
from pipeline.hsd.tasks.applycal import applycal
from pipeline.infrastructure import casa_tools
from pipeline.infrastructure import task_registry
from . import ampcal
from ..importdata import importdata as importdata
LOG = infrastructure.get_logger(__name__)
[docs]class NRORestoreDataResults(restoredata.RestoreDataResults):
    def __init__(self, importdata_results=None, applycal_results=None, ampcal_results=None):
        """
        Initialise the results objects.
        """
        super(NRORestoreDataResults, self).__init__(importdata_results, applycal_results)
        self.ampcal_results = ampcal_results
[docs]    def merge_with_context(self, context):
        super(NRORestoreDataResults, self).merge_with_context(context)
        # set amplitude scaling factor to ms domain objects
        if isinstance(self.applycal_results, basetask.ResultsList):
            for result in self.applycal_results:
                self._merge_ampcal(context, result)
        else:
            self._merge_ampcal(context, self.applycal_results) 
    def _merge_ampcal(self, context, applycal_results):
        for calapp in applycal_results.applied:
            msobj = context.observing_run.get_ms(name=os.path.basename(calapp.vis))
            if not hasattr(msobj, 'k2jy_factor'):
                for _calfrom in calapp.calfrom:
                    if _calfrom.caltype == 'amp' or _calfrom.caltype == 'gaincal':
                        LOG.debug('Adding k2jy factor to {0}'.format(msobj.basename))
                        # k2jy gaincal table
                        k2jytable = _calfrom.gaintable
                        k2jy_factor = {}
                        with casa_tools.TableReader(k2jytable) as tb:
                            spws = tb.getcol('SPECTRAL_WINDOW_ID')
                            antennas = tb.getcol('ANTENNA1')
                            params = tb.getcol('CPARAM').real
                            nrow = tb.nrows()
                        for irow in range(nrow):
                            spwid = spws[irow]
                            antenna = antennas[irow]
                            param = params[:, 0, irow]
                            npol = param.shape[0]
                            antname = msobj.get_antenna(antenna)[0].name
                            dd = msobj.get_data_description(spw=int(spwid))
                            if dd is None:
                                continue
                            for ipol in range(npol):
                                polname = dd.get_polarization_label(ipol)
                                k2jy_factor[(spwid, antname, polname)] = 1.0 / (param[ipol] * param[ipol])
                        msobj.k2jy_factor = k2jy_factor
            LOG.debug('msobj.k2jy_factor = {0}'.format(getattr(msobj, 'k2jy_factor', 'N/A'))) 
[docs]@task_registry.set_equivalent_casa_task('hsdn_restoredata')
class NRORestoreData(restoredata.RestoreData):
    Inputs = NRORestoreDataInputs
[docs]    def prepare(self):
        inputs = self.inputs
        LOG.debug('prepare inputs = {0}'.format(inputs))
        # run prepare method in the parent class
        results = super(NRORestoreData, self).prepare()
        ampcal_results = self.ampcal_results
        # apply baseline table and produce baseline-subtracted MSs
        # apply final flags for baseline-subtracted MSs
        results = NRORestoreDataResults(results.importdata_results, results.applycal_results, ampcal_results)
        return results 
    def _do_importasdm(self, sessionlist, vislist):
        inputs = self.inputs
        # NROImportDataInputs operate in the scope of a single measurement set.
        # To operate in the scope of multiple MSes we must use an
        # InputsContainer.
        LOG.debug('_do_importasdm inputs = {0}'.format(inputs))
        container = vdp.InputsContainer(importdata.NROImportData, inputs.context, vis=vislist, 
                                        output_dir=None)
        importdata_task = importdata.NROImportData(container)
        return self._executor.execute(importdata_task, merge=True)
    def _do_applycal(self):
        inputs = self.inputs
        LOG.debug('_do_applycal inputs = {0}'.format(inputs))
        # Before applycal, sensitively (amplitude) correction using k2jycal task and
        # a scalefile (=reffile) given by Observatory. This is the special operation for NRO data.
        # If no scalefile exists in the working directory, skip this process.
        if os.path.exists(inputs.reffile):
            container = vdp.InputsContainer(ampcal.SDAmpCal, inputs.context, reffile=inputs.reffile)
        else:
            LOG.info('No scale factor file exists. Skip scaling.')
            container = vdp.InputsContainer(ampcal.SDAmpCal, inputs.context)
        LOG.debug('ampcal container = {0}'.format(container))
        ampcal_task = ampcal.SDAmpCal(container)
        self.ampcal_results = self._executor.execute(ampcal_task, merge=True)
        # SDApplyCalInputs operates in the scope of a single measurement set.
        # To operate in the scope of multiple MSes we must use an
        # InputsContainer.
        container = vdp.InputsContainer(applycal.SDApplycal, inputs.context)
        applycal_task = applycal.SDApplycal(container)
        LOG.debug('_do_applycal container = {0}'.format(container))
        return self._executor.execute(applycal_task, merge=True)