#!/usr/bin/env python

import os, sys, traceback
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pylab as pylab
import numpy as np
import pylab as pl
from math import exp
from sklearn import svm
from sklearn.metrics.pairwise import euclidean_distances as skdist

import transport
reload(transport)

pl.rcParams['figure.figsize'] = (10.0, 8.0)

from cvxopt import matrix, spmatrix, solvers, printing
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.metrics.pairwise import linear_kernel

gamma=2
ga_SVM=4
discret = 200j
reg = 1
eta =10
np.random.seed(100)

###-------------------------------- DATA part ----------------------------------------------
#generate initial distribution. class1 is of type 1, class2 of type 2
N1 = 50
N2 = 50
Nini = N1 + N2
class1 = 0.6 * np.random.standard_normal((N1,2))
class2 = 1.2 * np.random.standard_normal((N2/2,2)) + np.array([2,3])
class2 = np.concatenate((class2,0.6 * np.random.standard_normal((N2/2,2)) + np.array([1,2])))

Xini = np.concatenate((class1,class2))
Y = [0] * N1 + [1] * N2

# now the modified distribution. We choose a new sampling for both 1 and 2
N1bis = 100
N2bis = 100
Nfin = N1bis + N2bis
class1p = 0.8 * np.random.standard_normal((N1bis,2))+ np.array([-1.8,0.8])
class2p = 1.2 * np.random.standard_normal((N2bis/2,2)) + np.array([2,4])
class2p = np.concatenate((class2p,0.7 * np.random.standard_normal((N2bis/2,2)) + np.array([2,2])))

Xfin = np.concatenate((class1p,class2p))


### ------------------------------- Optimal Transport ---------------------------------------

K1 = linear_kernel(Xini,Xini)
K2 = linear_kernel(Xfin,Xfin)
K12 = linear_kernel(Xini,Xfin)

K1rbf = rbf_kernel(Xini,Xini,gamma)
K2rbf = rbf_kernel(Xfin,Xfin,gamma)

# compute weights for each distribution
#Wini = np.diag(np.dot(K1,np.diag(1/np.sum(K1,0))))
Wini = np.sum(K1rbf,1)
Wini = Wini/sum(Wini)
WiniScaled = Wini * 10000
#Wfin = np.diag(np.dot(K2,np.diag(1/np.sum(K2,0))))
Wfin = np.sum(K2rbf,1)
Wfin = Wfin/sum(Wfin)
WfinScaled = Wfin * 10000

## display the problem's data
pl.figure()
# plot 3D relations
t1=pl.scatter(class1[:, 0], class1[:, 1], zorder=10, s=WiniScaled[0:N1], color='w', edgecolors='k')
t2=pl.scatter(class2[:, 0], class2[:, 1], zorder=10, s=WiniScaled[N1:-1], color='r', edgecolors='k')
t3=pl.scatter(Xfin[:, 0], Xfin[:, 1],  zorder=5, s=WfinScaled, marker='+')
pl.axis('tight')
pl.title('Initial distributions',fontsize=14)
pl.legend((t1,t2,t3),('Source distribution: class 1','Source distribution: class 2','Target distribution'),loc='lower right',fontsize=14)
#pl.savefig('images/p_ini.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()



M = np.tile(np.diag(K1),(Nfin,1)) + np.tile(np.diag(K2),(Nini,1)).T - 2*K12.T



##----------------------------------------------------------------------------- LABEL ------------------------------------
transp = transport.computeTransportSinkhornLabelsLpL1(Wini,Y,Wfin,M.T,reg,eta)

transp1 = np.dot(np.diag(1/np.sum(transp,1)),transp)


## display the relations
pl.figure()
# plot 3D relations
for i in range(Nini):
	for j in range(Nfin):
		if (transp[i][j]>0.0001):
			if i<N1:
				pl.plot([Xini[i][0], Xfin[j][0]], [Xini[i][1], Xfin[j][1]], color='k',lw=3,alpha=transp1[i][j])
			else:
				pl.plot([Xini[i][0], Xfin[j][0]], [Xini[i][1], Xfin[j][1]], 'r-',lw=3,alpha=transp1[i][j])
t1=pl.scatter(class1[:, 0], class1[:, 1], zorder=10, s=WiniScaled[0:N1], color='w', edgecolors='k')
t2=pl.scatter(class2[:, 0], class2[:, 1], zorder=10, s=WiniScaled[N1:-1], color='r', edgecolors='k')
t3=pl.scatter(Xfin[:, 0], Xfin[:, 1],  zorder=5, s=WfinScaled, marker='+')
pl.axis('tight')
#pl.legend((t1,t2,t3),('Source distribution: class 1','Source distribution: class 2','Target distribution'),loc='lower right',fontsize=14)
#pl.title('Sinkhorn Label Regularized Transport. Gamma='+str(reg),fontsize=14)
#pl.savefig('images/p_sinkhorn_label.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()

## display the relations source -> target
pl.figure()
#, lw=WiniScaled[l]
Xini_interpolated = np.dot(transp1,Xfin)
for l in range(Nini):
	pl.arrow(Xini[l, 0], Xini[l, 1],Xini_interpolated[l,0]-Xini[l, 0], Xini_interpolated[l,1]-Xini[l, 1],fc="grey", length_includes_head=True, head_width=.1)

# plot 2D relations
t1=pl.scatter(class1[:, 0], class1[:, 1], zorder=10, s=WiniScaled[0:N1], color='w', edgecolors='k')
t2=pl.scatter(class2[:, 0], class2[:, 1], zorder=10, s=WiniScaled[N1:-1], color='r', edgecolors='k')
t3=pl.scatter(Xfin[:, 0], Xfin[:, 1],  zorder=5, s=WfinScaled, marker='+')
pl.axis('tight')
#pl.legend((t1,t2,t3),('Source distribution: class 1','Source distribution: class 2','Target distribution'),loc='lower right',fontsize=14)
#pl.title('Sinkhorn Label Regularized Transport. Transported Source',fontsize=14)
#pl.savefig('images/p_sinkhorn_label_arrow.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()


# save transport matrix
fig = pl.figure()
implot = pl.imshow(transp, interpolation="nearest", aspect='auto',vmin=0,vmax=0.0002)
implot.set_cmap('Blues')
fig.colorbar(implot)
#pl.savefig('images/p_sinkhorn_label_transp.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()
##----------------------------------------------------------------------------- SINKHORN ------------------------------------
transp_not = transport.computeTransportSinkhorn(Wini,Wfin,M.T,reg)
#transp_not = transport.computeTransportLP(Wini,Wfin,M.T)

transp2 = np.dot(np.diag(1/np.sum(transp_not,1)),transp_not)

## display the relations
pl.figure()
# plot 3D relations
for i in range(Nini):
	for j in range(Nfin):
		if (transp_not[i][j]>0.0001):
			if i<N1:
				pl.plot([Xini[i][0], Xfin[j][0]], [Xini[i][1], Xfin[j][1]], color='k',lw=3,alpha=transp2[i][j])
			else:
				pl.plot([Xini[i][0], Xfin[j][0]], [Xini[i][1], Xfin[j][1]], 'r-',lw=3,alpha=transp2[i][j])
t1=pl.scatter(class1[:, 0], class1[:, 1], zorder=10, s=WiniScaled[0:N1], color='w', edgecolors='k')
t2=pl.scatter(class2[:, 0], class2[:, 1], zorder=10, s=WiniScaled[N1:-1], color='r', edgecolors='k')
t3=pl.scatter(Xfin[:, 0], Xfin[:, 1],  zorder=5, s=WfinScaled, marker='+')
pl.axis('tight')
#pl.legend((t1,t2,t3),('Source distribution: class 1','Source distribution: class 2','Target distribution'),loc='lower right',fontsize=14)
#pl.title('Sinkhorn Transport. Gamma='+str(reg),fontsize=14)
#pl.savefig('images/p_Sinkhorn.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()

## display the relations source -> target
pl.figure()
#, lw=WiniScaled[l]
Xini_interpolated = np.dot(transp2,Xfin)
for l in range(Nini):
	pl.arrow(Xini[l, 0], Xini[l, 1],Xini_interpolated[l,0]-Xini[l, 0], Xini_interpolated[l,1]-Xini[l, 1],fc="grey", length_includes_head=True, head_width=.1)

# plot 2D relations
t1=pl.scatter(class1[:, 0], class1[:, 1], zorder=10, s=WiniScaled[0:N1], color='w', edgecolors='k')
t2=pl.scatter(class2[:, 0], class2[:, 1], zorder=10, s=WiniScaled[N1:-1], color='r', edgecolors='k')
t3=pl.scatter(Xfin[:, 0], Xfin[:, 1],  zorder=5, s=WfinScaled, marker='+')
pl.axis('tight')
#pl.legend((t1,t2,t3),('Source distribution: class 1','Source distribution: class 2','Target distribution'),loc='lower right',fontsize=14)
#pl.title('Sinkhorn Transport. Transported Source',fontsize=14)
#pl.savefig('images/p_Sinkhorn_arrow.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()


# save transport matrix
fig = pl.figure()
implot = pl.imshow(transp_not, interpolation="nearest", aspect='auto',vmin=0,vmax=0.0002)
implot.set_cmap('Blues')
fig.colorbar(implot)
#pl.savefig('images/p_sinkhorn_transp.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()



##----------------------------------------------------------------------------- LP ------------------------------------
transp_not = transport.computeTransportLP(Wini,Wfin,M.T)

transp2 = np.dot(np.diag(1/np.sum(transp_not,1)),transp_not)

## display the relations
pl.figure()
# plot 3D relations
for i in range(Nini):
	for j in range(Nfin):
		if (transp_not[i][j]>0.0001):
			if i<N1:
				pl.plot([Xini[i][0], Xfin[j][0]], [Xini[i][1], Xfin[j][1]], color='k',lw=3,alpha=transp2[i][j])
			else:
				pl.plot([Xini[i][0], Xfin[j][0]], [Xini[i][1], Xfin[j][1]], 'r-',lw=3,alpha=transp2[i][j])
t1=pl.scatter(class1[:, 0], class1[:, 1], zorder=10, s=WiniScaled[0:N1], color='w', edgecolors='k')
t2=pl.scatter(class2[:, 0], class2[:, 1], zorder=10, s=WiniScaled[N1:-1], color='r', edgecolors='k')
t3=pl.scatter(Xfin[:, 0], Xfin[:, 1],  zorder=5, s=WfinScaled, marker='+')
pl.axis('tight')
#pl.legend((t1,t2,t3),('Source distribution: class 1','Source distribution: class 2','Target distribution'),loc='lower right',fontsize=14)
#pl.title('Sinkhorn Transport. Gamma='+str(reg),fontsize=14)
#pl.savefig('images/p_LP.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()

## display the relations source -> target
pl.figure()
#, lw=WiniScaled[l]
Xini_interpolated = np.dot(transp2,Xfin)
for l in range(Nini):
	pl.arrow(Xini[l, 0], Xini[l, 1],Xini_interpolated[l,0]-Xini[l, 0], Xini_interpolated[l,1]-Xini[l, 1],fc="grey", length_includes_head=True, head_width=.1)

# plot 2D relations
t1=pl.scatter(class1[:, 0], class1[:, 1], zorder=10, s=WiniScaled[0:N1], color='w', edgecolors='k')
t2=pl.scatter(class2[:, 0], class2[:, 1], zorder=10, s=WiniScaled[N1:-1], color='r', edgecolors='k')
t3=pl.scatter(Xfin[:, 0], Xfin[:, 1],  zorder=5, s=WfinScaled, marker='+')
pl.axis('tight')
#pl.legend((t1,t2,t3),('Source distribution: class 1','Source distribution: class 2','Target distribution'),loc='lower right',fontsize=14)
#pl.title('Sinkhorn Transport. Transported Source',fontsize=14)
#pl.savefig('images/p_LP_arrow.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()



# save transport matrix
fig = pl.figure()
implot = pl.imshow(transp_not, interpolation="nearest", aspect='auto')
implot.set_cmap('Blues')
fig.colorbar(implot)
#pl.savefig('images/p_LP_transp.pdf', bbox_inches="tight")
pl.axis('off')
pl.show()