#!/usr/bin/env python

import os, sys, traceback
import matplotlib.pylab as pylab
import numpy as np
import pylab as pl
from scipy.io import loadmat
from scipy.stats.mstats import zscore
from scipy.spatial.distance import cdist

import random

from sklearn.metrics.pairwise import rbf_kernel
from sklearn.metrics.pairwise import linear_kernel


pl.rcParams['figure.figsize'] = (10.0, 8.0)

import transport
reload(transport)

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

def split(Y,nPerClass):
    idx1 = []
    idx2 = []
    for c in range(1,max(Y)+1):
        idx = indices(Y, lambda x: x==c)
        random.shuffle(idx)
        idx1 = idx1 +idx[0:min(nPerClass,len(idx))]
        idx2 = idx2 +idx[min(nPerClass,len(idx)):-1]
    return idx1,idx2
###-------------------------------- DATA part ----------------------------------------------

# Four domains: { Caltech10, amazon, webcam, dslr }
possible_DTS = ['Caltech10', 'amazon', 'webcam', 'dslr']

tests = []
data = {}

for ds in possible_DTS:
    possible_data = loadmat('data/'+ds+'_SURF_L10.mat')
    feat = possible_data['fts'].astype(float) / np.tile(np.sum(possible_data['fts'],1),(np.shape(possible_data['fts'])[1],1)).T
    feat = zscore(feat,0)
    labels = possible_data['labels']
    data[ds]=[feat,labels]
    for dt in possible_DTS:
        if ds!=dt:
            if ds=='dslr':
                tests.append([ds,dt,8])
            else:
                tests.append([ds,dt,20])

# set the number of rounds for each test
round = 10
eta = 1
reg = 0.5
#-------------------- Main testing loop --------------------------------------

for test in tests:
    src = test[0]
    tgt = test[1]
    nPerClass = test[2]
    print 'testing '+src+' -> '+tgt
    #--------------------II. prepare data--------------------------------------
    dataS = data[src][0]
    LabelS = data[src][1]

    Xt = data[tgt][0]
    Yt = data[tgt][1]

    #--------------------III. run experiments----------------------------------

    result1 = []
    
    for iter in range(round):
        id1,id2 = split(LabelS, nPerClass)
        Xr = dataS[id1,:]
        Yr = LabelS[id1]
        
        # --------- transport reg
        W1 = np.array([1./len(Yr)]*len(Yr))
        W2 = np.array([1./len(Yt)]*len(Yt))
        
        # --------- transport lineaire
        distances = cdist(Xr,Xt,metric='euclidean')
        #distances = distances/np.median(distances)        
        
        # compute transport 
        # ----------  Exact Optimal transport         
        #transp1 = transport.computeTransportLP(W1,W2,distances)
        # ----------  Sinkhorn Optimal transport [Cuturi13]         
        #transp1 =transport.computeTransportSinkhorn(W1,W2,distances,reg)
        # ----------  LpL1 Optimal transport [Courty14]         
        transp1 =transport.computeTransportSinkhornLabelsLpL1(W1,Yr,W2,distances,reg,eta)
        
        
        transp1 = np.dot(np.diag(1/np.sum(transp1,1)),transp1) 
        Xrinterpolated1 = np.dot(transp1,Xt)
        dist = cdist(Xrinterpolated1,Xt,metric='sqeuclidean')
        minIDX = np.argmin(dist,axis=0)
        # ------------ Accuracy evaluation
        prediction = Yr[minIDX]
        result1.append(100*float(sum ( prediction == Yt)) /len(Yt))

    print '   +--Adaptation result - mean OA = '+str(np.mean(result1))+' ('+str(np.std(result1))+')'


