#------------------------------------------------------------------------#
# Title: 	INMSAngularResponseDA.py
# Purpose:  Provide tools for reading, analyzing, and plotting INMS data.
#			
#------------------------------------------------------------------------#
#
# How to use this software: 
# Must have Python v2.7.5 or later and IPython installed and running.
# Open the IPython interpreter - QtConsole and run the following commands
# 
# cd /Users/jgrimes/Documents/Ion Gun/data or relevant user directory 
# from INMS_REU_Data_Analysis import *
# ion()
#
# All required scripts and packages have now been loaded and the directory 
# has been set to the data location to access files. To run methods from 
# the interpreter prompt type 
# ...: functionName(arg1, arg2, ..., argN) 
#------------------------------------------------------------------------#

#begin import packages
#from pylab import * #all of numpy and matplotlib
from numpy import *
from scipy import *
from datetime import datetime
from datetime import timedelta
from matplotlib.pyplot import *
from matplotlib.dates import DateFormatter
from scipy.interpolate import griddata

import os
import csv
import json
import glob

inmsDataDir = '/Users/jgrimes/Documents/Cassini/INMS Data/'
formatter = DateFormatter('%H:%M')

sfld = ['run', 'dataset', 'scan', 'mode',
        'initial_time', 'final_time']

mfld = ['ip','cs_fil1_emis_iset','cs_fil1_emis_imon', 'cs_anode1_imon',
        'cs_ionrep_imon', 'os_fil1_emis_imon', 'chamber_ion_pmon'] 

def get_INMS_data(datapath, srchstr):
    
    os.chdir(datapath)
    #make a list of appropriate files
    esFiles = glob.glob(srchstr)

    #read each file in the list
    esData = {}
    for dfile in esFiles:
        esData[dfile] = read_INMS_scan_file(dfile)

    return esData

def get_INMS_scan_data(data, id0='lens_vset', sfields='', mfields=''):
#    
    if sfields == '':
        sfields = ['run', 'dataset', 'scan', 
                    'initial_time', 'final_time']
    if mfields == '':
        mfields = ['cs_fil1_emis_imon', 'cs_anode1_imon',
                     'cs_ionrep_imon', 'chamber_ion_pmon'] 
    scanData = []    
    for k in sorted(data.keys()):
        try:
            scan = data[k]['Scan']
            oneScan = [scan[item] for item in sfields]
            monitors = data[k]['Monitors']
            oneScan.extend([monitors[item] for item in mfields])
            titl = data[k]['Data'][0]
            cnts = data[k]['Data'][1][:, titl.index('counter1')]
            d = data[k]['Data'][1][:, titl.index(id0)]
            oneScan.extend([cnts, d])
            scanData.append(oneScan)
        except TypeError:
            print(k)
            pass
    return scanData

#===============DA tools for INMS Mass Scan Data============================#

def get_MS_scan_times(jn_file):
    tBkg, tSID, tTherm, SIDgrps = [],[],[],[]
    dset = 1
    with open(jn_file) as jn:
        for line in jn:
            if 'Scan' in line:
                jln0, jln1 = line.strip('[').split(']')
                if 'Background' in jln1:
                    tBkg.append(jln0)
                    SIDgrps.append({'Bkg':tBkg})
                elif 'SID' in jln1:
                    tSID.append(jln0)
                elif 'Thermal' in jln1:
                    tTherm.append(jln0)
            elif 'SID measurement completed' in line:
                tEnd = line.strip('[').split(']')[0]
                SIDgrps.append({'SID':tSID, 'Therm':tTherm, 'End':tEnd})
                tSID, tTherm = [], []
    return SIDgrps

def get_datasets(msFiles):
    dsets = {}
    for msf in msFiles:
        _,_,d,s = msf.strip('.txt').split('_')[2:]
        try:
            dsets[d].append(s)
        except KeyError:
            dsets[d] = [s]
    return dsets
    

def scale_ms_data(data, aveP=0, avEmis=0):
    if aveP==0:
        #if no average pressure provided, calculate for this set
        aveP = array([d[8] for d in data], dtype=float).mean()

    if avEmis==0:
        #if no average emission provided, calculate for this set
        avEmis = array([d[5] for d in data], dtype=float).mean()
    
    numscans = len(data)
    press = array([d[8] for d in data], dtype=float)
    normP = aveP/press
    emis = array([d[5] for d in data], dtype=float)
    normE = avEmis/emis
    scans = array([d[-2] for d in data])
    scaled = [scans[i]*normP[i]*normE[i] for i in range(0, numscans)]
    return scaled

def plot_MS_bar(M, D, sttl, ttl, yl='Counts', save=False, inline=True):
    figure()
    bar(M-.5, D)
    suptitle(sttl, fontsize=16)
    gca().set_title(ttl, fontsize=12)
    xlabel('Mass per Charge')
    ylabel(yl)
    xlim(.5, 25.5)
    ylim(0, 1.05*max(D))
    draw()
    savefig(sttl+'.png')

def plot_crosstalk(dsets, mscans):
    sum_sets(dsets, mscans)
    for s in dsets:
        plot_MS_bar(s['masses'], s['sum'], s['sttl'], s['ttl'])


def get_ds_ranges(dst):
    for k in dst.keys():
        sc = sorted(dst[k])
        i, f = int(sc[0][1:]), int(sc[-1][1:])
        dst[k] = {'srange':(i,f)}


def sum_sets(dsets, mscans):
    # sum the mass scans to create one composite mass scan for the dataset
    # extract all items relevant to scale and compare mass scan data
    for s in sorted(dsets.keys()):
        i, f = dsets[s]['srange'] 
        scanSum = sum(array([ms[-2] for ms in mscans[i-1: f]]),0)
        h2cts = array([ms[-2][1] for ms in mscans[i-1:f]])
        wcts = array([ms[-2][14] for ms in mscans[i-1:f]])
        csEmis = array([ms[8] for ms in mscans[i-1:f]])
        osEmis = array([ms[11] for ms in mscans[i-1:f]])
        chPress = array([ms[12] for ms in mscans[i-1:f]])
        run, mode, ip = mscans[f-1][0:7:3]
        masses = mscans[f-1][-1]
        numscans = f-(i-1)
        
        sttl = '{} - R{}'
        if 'bkg' in s:
            sttl =  '{} Background - R{}'
        elif mode=='OSNB':
            if float(mscans[f-1][7]) >= 1:
                sttl = '{} - CS On - R{}'
            else:
                sttl = '{} - CS Off - R{}'
        sttl = sttl.format(mode, run)
        if 'v_comp' in dsets[s].keys():
            sttl = sttl + ' {}kms'.format(dsets[s]['v_comp'])

        ttl = 'IP={}ms, {} scan summation, H2counts={:.0f}'
        ttl = ttl.format(ip, numscans, scanSum[1])

        dsets[s]['sum']=scanSum
        dsets[s]['h2cts']=h2cts
        dsets[s]['wcts'] = wcts
        dsets[s]['masses'] = masses
        dsets[s]['cemis'] = csEmis
        dsets[s]['oemis'] = osEmis
        dsets[s]['chPress'] = chPress
        dsets[s]['sttl']=sttl 
        dsets[s]['ttl']=ttl
        

def compute_stats(dsets):
    pfmat = '{:<28}: sum={:12.0f}, std={:12.2f}, var={:12.2f}'
    hfmat = '{:<28}: {:>12}{:>12}{:>12}'
    print(hfmat.format('Dataset','H2 Sum','Avg. H2','Std. Dev.', 'Variance'))
    for dk in sorted(dsets.keys()):
        if 'h2cts' in dsets[dk].keys():
            ds = dsets[dk]
            ds['h2sum'] = ds['h2cts'].sum()
            ds['mean'] = ds['h2cts'].mean()
            ds['std'] = ds['h2cts'].std()
            ds['var'] = ds['h2cts'].var()
            pp = pfmat.format(ds['sttl'],ds['h2sum'],ds['mean'],ds['std'],ds['var']) 
            print(pp)


def compute_h2ratios(dsets):
    #dsets must be a specific grouping of 5 data sets including, in order, 
    #CSN background, OSNB background, OSNB-filOn, CSN, and OSNB-filOff.
    if len(dsets)==5:
        skeys = sorted(dsets.keys())

        #asign the variables
        bkgc, bkgo, osnb1, csn, osnb2 = [dsets[sk] for sk in skeys]

        #subtract the background
        csnb = csn['sum'] - bkgc['sum']
        osnb1b = osnb1['sum'] - bkgo['sum']
        osnb2b = osnb2['sum'] - bkgo['sum']

        #compute the ratios for all species
        osnb1r = osnb1b/csnb
        osnb2r = osnb2b/csnb 

        #compute the uncertainty
        m1, m2, mc = osnb1['mean'], osnb2['mean'], csn['mean']
        m1r = m1/mc
        m2r = m2/mc
        sig1 = (m1/mc)*sqrt((1/m1)+(1/mc))
        sig2 = (m2/mc)*sqrt((1/m2)+(1/mc))

        #attach results to appropriate data sets
        osnb1['ratio'] = m1r
        osnb1['sigma'] = sig1

        osnb2['ratio'] = m2r
        osnb2['sigma'] = sig2

        #print results
        pfmat = '\tCS{}: \tR(H2)={:.2e}, uncertainty={:.2e}'
        print('\tH2 ratio OSNB/{}'.format(csn['sttl']))
        print(pfmat.format('On', m1r, sig1))
        print(pfmat.format('Off', m2r, sig2))

    else:
        print('Compute_h2ratios requires exactly 5 data sets.')


def compute_water_ratios(dsets):
    skeys = sorted(dsets.keys())
    mh1, mhc, mh2 = [dsets[sk]['mean'] for sk in skeys[2:]]
    mw1, mwc, mw2 = [dsets[sk]['sum'][14]/40 for sk in skeys[2:]]
    mhw1, mhw2 = mh1/mw1, mh2/mw2
    sig1 = (mh1/mw1)*sqrt((1/mh1)+(1/mw1))
    sig2 = (mh2/mw2)*sqrt((1/mh2)+(1/mw2))
    sig3 = mhc/mwc*sqrt((1/mhc)+(1/mwc))
    sig4 = mh1/mwc*sqrt((1/mh1)+(1/mhc))
    sig5 = mh2/mwc*sqrt((1/mh2)+(1/mhc))
    pfmat = '\t{}: \tR({})={:.2e}, uncertainty={:.2e}'
    print('\tH2/H2O ratio OSNB v_comp = 8.506 km/s')
    print(pfmat.format('OSNB CS-On', 'H2/H2O', mhw1, sig1))
    print(pfmat.format('OSNB CS-Off', 'H2/H2O',mhw2, sig2))
    print(pfmat.format('CS Only', 'H2/H2O', mhc/mwc, sig3))
    print(pfmat.format('OSNB CS_On', 'H2(OSNB)/H2O(CSN)', mh1/mwc, sig4))
    print(pfmat.format('OSNB CS_Off', 'H2(OSNB)/H2O(CSN)', mh2/mwc, sig5))

sfld = ['run', 'dataset', 'scan', 'mode',
        'initial_time', 'final_time']

mfld = ['ip','cs_fil1_emis_iset','cs_fil1_emis_imon', 'cs_anode1_imon',
        'cs_ionrep_imon', 'os_fil1_emis_imon', 'chamber_ion_pmon'] 


def get_crosstalk_data(dpath):
    msd = get_INMS_data(dpath, '*_MS_*')
    msc = get_INMS_scan_data(msd, id0='mass', sfields=sfld, mfields=mfld)
    dst = get_datasets(msd)
    get_ds_ranges(dst)
    sum_sets(dst, msc)
    compute_stats(dst)

    return dst, msd, msc


#sttl = 'CSN Background - R{}'
#ttlc = 'IP=1000ms, 40 scan summation, H2counts={:.0f}'.format(cts[1])

#sttl = 'OSNB Background - R{}'
#ttlc = 'IP=1000ms, 40 scan summation, H2counts={:.0f}'.format(cts[1])

#sttl = 'OSNB - CS On - R{}'
#ttlc = 'IP=1000ms, 40 scan summation, H2counts={:.0f}'.format(cts[1])

#sttl = 'CSN - R{}'
#ttlc = 'IP=1000ms, 40 scan summation, H2counts={:.0f}'.format(cts[1])

#sttl = 'OSNB - CS Off - R{}'
#ttlc = 'IP=1000ms, 40 scan summation, H2counts={:.0f}'.format(cts[1])

#sttl = 'Crosstalk Ratio, CS On - R{}'
#ttlr = 'IP=1000ms, 40 scan summation, H2 ratio = {:.3f}'.format(cts[1])

#sttl = 'Crosstalk Ratio, CS Off - R{}'
#ttlr = 'IP=1000ms, 40 scan summation, H2 ratio = {:.3f}'.format(cts[1])


#ct352 = [{dset:'bkgc', rrange}]
