#!/usr/bin/env python

import glob
import os
import re
import sys
import shutil
import importlib
import imp

from math import*
from collections import Counter
from numpy  import *

from efidir.config import *
from efidir.sws import *
from efidir.utils import listimage, envi

nbr_thumbs = nbr_imagette

thumbs=[]

for thumb_i in range(0, nbr_thumbs):
    thumbs.append([])
    for master_i in range(0, nbr_images):
        thumbs[thumb_i].append({"dir" : main_dir+"/"+"thumb_disp/thumb"+str(thumb_i)+"_master_"+str(master_i)})

        # GET DX
        fileNames = []
        for name in glob.glob(thumbs[thumb_i][master_i]["dir"]+"/dx/crop/*.hdr"):
            fileNames.append(name)

        fileNames.sort()

        DXs = []
        i=0
        for name in fileNames:
            if(i==master_i):
                DXs.append(0.0)
                i=i+1
            hdr = envi.readHDR(name)
            DX = float(hdr["EFIDIR_average_band_0"][1:][:-1])
            DXs.append(DX)
            i=i+1
        if(i==master_i):
            DXs.append(0.0)
            

        thumbs[thumb_i][master_i]["fileNamesDX"] = fileNames
        thumbs[thumb_i][master_i]["DX"] = DXs

        # GET DY
        fileNames = []
        for name in glob.glob(thumbs[thumb_i][master_i]["dir"]+"/dy/crop/*.hdr"):
            fileNames.append(name)

        fileNames.sort()

        DYs = []
        i=0
        for name in fileNames:
            if(i==master_i):
                DYs.append(0.0)
                i=i+1
            hdr = envi.readHDR(name)
            DY = float(hdr["EFIDIR_average_band_0"][1:][:-1])
            DYs.append(DY)
            i=i+1
        if(i==master_i):
            DYs.append(0.0)

        thumbs[thumb_i][master_i]["fileNamesDY"] = fileNames
        thumbs[thumb_i][master_i]["DY"] = DYs

#print(thumbs[0])

candidate = []
f=open("matrices.m", "w")
fp=open("matrices.py", "w")

MDXvalidThumb = []
MDYvalidThumb = []
MvalidThumb = []
for master_i in range(0, nbr_images):
    MDXvalidThumb.append([0]*nbr_thumbs)
    MDYvalidThumb.append([0]*nbr_thumbs)
    MvalidThumb.append([0]*nbr_thumbs)

for thumb_i, thumb in enumerate(thumbs):
    M =[]
    MDX = []
    MDY = []
    MDXb = []
    MDYb = []
    sums = []
    for master_i, masterDic in enumerate(thumb):
        M.append([])
        MDX.append([])
        MDY.append([])
        MDXb.append([])
        MDYb.append([])
        for i in range(0, len(masterDic["DX"])):
            M[master_i].append(sqrt(pow(masterDic["DX"][i],2) + pow(masterDic["DY"][i],2)))
            MDX[master_i].append(masterDic["DX"][i])
            MDY[master_i].append(masterDic["DY"][i])

            #compute binary matrix: if displacement egual to max_displacement binary value = 1, 0 otherwise
            if (masterDic["DX"][i] in [max_deplacement, -max_deplacement]):
                MDXb[master_i].append(1)
            else:
                MDXb[master_i].append(0)
            if (masterDic["DY"][i] in [max_deplacement, -max_deplacement]):
                MDYb[master_i].append(1)
            else:
                MDYb[master_i].append(0)
        sums.append(sum(M[master_i]))
    
    #compute line sum of binary matrix
    MDXb_sum = sum(MDXb, axis=1)
    MDYb_sum = sum(MDYb, axis=1)

    #get good images
    #MDXb_hist, bin_edges = histogram(MDXb_sum,bins=range(0,len(masterDic["DX"])+10,10))
    #MDYb_hist, bin_edges = histogram(MDYb_sum,bins=range(0,len(masterDic["DX"])+10,10))
    MDXb_hist, bin_edges = histogram(MDXb_sum,bins=range(0,56+1))
    MDYb_hist, bin_edges = histogram(MDYb_sum,bins=range(0,56+1))
    
    #print(MDXb_hist)
    #print(bin_edges)

    MDX_indexSort = argsort(MDXb_sum) 
    MDY_indexSort = argsort(MDYb_sum) 

    #print(MDX_indexSort)
    #print(sort(MDXb_sum))

    MDX_thumbCluster = []
    MDY_thumbCluster = []
    first = 0;
    for i, nelem in enumerate(MDXb_hist):
        next_first = first+nelem
        MDX_thumbCluster.append([])
        for j in range(first, next_first):
            MDX_thumbCluster[i].append(MDX_indexSort[j])
            if(i<len(MDXb_hist)/2):
                MDXvalidThumb[MDX_indexSort[j]][thumb_i] = 1
                
        first = next_first;  
    first = 0;
    for i, nelem in enumerate(MDYb_hist):
        next_first = first+nelem
        MDY_thumbCluster.append([])
        for j in range(first, next_first):
            MDY_thumbCluster[i].append(MDY_indexSort[j])
            if(i<len(MDYb_hist)/2):
                MDYvalidThumb[MDY_indexSort[j]][thumb_i] = 1
        first = next_first; 
    

    #Save for python
    fp.write("M"+str(thumb_i)      +" = [[" + "],\n [".join([", ".join(['{:4}'.format(item) for item in r]) for r in M]) + "]]\n\n")

    fp.write("Mdx"+str(thumb_i)    +" = [[" + '],\n ['.join([", ".join(['{:4}'.format(item) for item in r]) for r in MDX])+"]]\n\n")
    fp.write("Mdy"+str(thumb_i)    +" = [[" + '],\n ['.join([", ".join(['{:4}'.format(item) for item in r]) for r in MDY])+"]]\n\n")

    fp.write("Mdxb"+str(thumb_i)    +" = [[" + '],\n ['.join([', '.join(['{:4}'.format(item) for item in r]) for r in MDXb])+"]]\n\n")
    fp.write("Mdyb"+str(thumb_i)    +" = [[" + '],\n ['.join([', '.join(['{:4}'.format(item) for item in r]) for r in MDYb])+"]]\n\n")

    fp.write("Mdxb_sum"+str(thumb_i)+" = ["+', '.join(['{:4}'.format(item) for item in MDXb_sum])+"]\n\n")
    fp.write("Mdyb_sum"+str(thumb_i)+" = ["+', '.join(['{:4}'.format(item) for item in MDYb_sum])+"]\n\n")

    fp.write("Mdxb_hist"+str(thumb_i)+" = ["+', '.join(['{:4}'.format(item) for item in MDXb_hist])+"]\n\n")
    fp.write("Mdyb_hist"+str(thumb_i)+" = ["+', '.join(['{:4}'.format(item) for item in MDYb_hist])+"]\n\n")

    fp.write("Mdx_thumbCluster"+str(thumb_i)+" = [["+'],\n ['.join([', '.join(['{:4}'.format(item) for item in r]) for r in MDX_thumbCluster])+"]]\n\n")
    fp.write("Mdy_thumbCluster"+str(thumb_i)+" = [["+'],\n ['.join([', '.join(['{:4}'.format(item) for item in r]) for r in MDY_thumbCluster])+"]]\n\n")
    #Save for Matlab

    #print('\n'.join([' '.join(['{:4}'.format(item) for item in row]) for row in M]))
    f.write("M"+str(thumb_i)+" = ["+';'.join([' '.join(['{:4}'.format(item) for item in r]) for r in M])+"];\n\n")

    #print('\n'.join([' '.join(['{:4}'.format(item) for item in row]) for row in MDX]))
    f.write("Mdx"+str(thumb_i)+" = ["+';'.join([' '.join(['{:4}'.format(item) for item in r]) for r in MDX])+"];\n\n")
    #print('\n'.join([' '.join(['{:4}'.format(item) for item in row]) for row in MDY]))
    f.write("Mdy"+str(thumb_i)+" = ["+';'.join([' '.join(['{:4}'.format(item) for item in r]) for r in MDY])+"];\n\n")

    #print('\n'.join([' '.join(['{:4}'.format(item) for item in row]) for row in MDXb]))
    f.write("Mdxb"+str(thumb_i)+" = ["+';'.join([' '.join(['{:4}'.format(item) for item in r]) for r in MDXb])+"];\n\n")
    #print('\n'.join([' '.join(['{:4}'.format(item) for item in row]) for row in MDYb]))
    f.write("Mdyb"+str(thumb_i)+" = ["+';'.join([' '.join(['{:4}'.format(item) for item in r]) for r in MDYb])+"];\n\n")

    #print('\n'.join([' '.join(['{:4}'.format(item) for item in row]) for row in MDXb]))
    f.write("Mdxb_sum"+str(thumb_i)+" = ["+';'.join(['{:4}'.format(item) for item in MDXb_sum])+"];\n\n")
    #print('\n'.join([' '.join(['{:4}'.format(item) for item in row]) for row in MDYb]))
    f.write("Mdyb_sum"+str(thumb_i)+" = ["+';'.join(['{:4}'.format(item) for item in MDYb_sum])+"];\n\n")

    #print('\n'.join([' '.join(['{:4}'.format(item) for item in row]) for row in MDX]))
    f.write("Mdx_thumbCluster"+str(thumb_i)+" = {["+'] ['.join([' '.join(['{:4}'.format(item) for item in r]) for r in MDX_thumbCluster])+"]};\n\n")
    #print('\n'.join([' '.join(['{:4}'.format(item) for item in row]) for row in MDY]))
    f.write("Mdy_thumbCluster"+str(thumb_i)+" = {["+'] ['.join([' '.join(['{:4}'.format(item) for item in r]) for r in MDY_thumbCluster])+"]};\n\n")

    

    #get candidate

    dicSums = dict(zip(range(0,len(sums)), sums))

    candidate.append(min(dicSums, key=dicSums.get))
    print("From thumb "+str(thumb_i)+": "+str(candidate[thumb_i]))

fp.write("Mdx_validThumb = [["+'],\n ['.join([', '.join(['{:4}'.format(item) for item in r]) for r in MDXvalidThumb])+"]]\n\n")
fp.write("Mdy_validThumb = [["+'],\n ['.join([', '.join(['{:4}'.format(item) for item in r]) for r in MDYvalidThumb])+"]]\n\n")

MvalidThumb = logical_and(MDXvalidThumb, MDYvalidThumb, dtype=integer).astype(int)
fp.write("M_validThumb = [["+'],\n ['.join([', '.join(['{:4}'.format(item) for item in r]) for r in MvalidThumb])+"]]\n\n")

MvalidImage=[]
[MvalidImage.append(reduce(lambda x, y: x and y, r, True)) for r in MvalidThumb]
fp.write("M_validImage = ["+",\n".join(['{:4}'.format(item) for item in MvalidImage])+"]\n")

fp.close()
f.close()
data = Counter(candidate)
master_image = data.most_common(1)[0][0]   # Returns the highest occurring item
print("master image detected: "+str(master_image))
