from collections import Counter
from sklearn.cluster import KMeans
import numpy as np
import scipy as sp
import networkx as nx
import itertools
from operator import itemgetter
import time
import subprocess
import os
import sklearn
import cPickle, re
import datetime
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
import argparse
import fastcluster
import json
from sklearn.metrics.pairwise import cosine_similarity

###Helper Functions
def norm(a):
    return np.sqrt(np.sum(np.square(a)))


def cosine(a,b):
    return 1-np.dot(a,b)/np.sqrt(np.sum(a**2)*np.sum(b**2))

def l1(a,b):
    return abs(a-b).sum()

def l2(a,b):
    return np.sqrt(np.square(a-b).sum())


### Create a list of words to be clustered based on a model with some l2_threshold and can normalize the vectors and also repeat or no
def create_word_list(model,vocab,features,Texts,repeat=True,l2_threshold=0,normalized=True,min_count=100,min_length=0):
    data_d2v=[]
    word_d2v=[]
    words_text=[w for text in Texts for w in text]
    count=Counter(words_text)
    if repeat:
        for text in Texts:
            for w in text:
                if w in vocab and count[w]>min_count:
                    if len(w)>min_length and l2(model[w],np.zeros(features))>l2_threshold:
                        if normalized:
                            data_d2v.append(model[w]/l2(model[w],np.zeros(features)))
                        else:
                            data_d2v.append(model[w])
                        word_d2v.append(w)
    else:
        A=set(words_text)
        for w in vocab:
            if w in A and len(w)>min_length and l2(model[w],np.zeros(features))>l2_threshold and count[w]>min_count:
                if normalized:
                    data_d2v.append(model[w]/l2(model[w],np.zeros(features)))
                else:
                    data_d2v.append(model[w])
                word_d2v.append(w)

    return data_d2v, word_d2v



def calculate_depth(spcluster,words, num_points):
    cluster=[[] for w in xrange(2*num_points)]
    c=Counter()
    for i in xrange(num_points):
        cluster[i]=[i]

    for i in xrange(len(spcluster)):
        x=int(spcluster[i,0])
        y=int(spcluster[i,1])
        xval=[w for w in cluster[x]]
        yval=[w for w in cluster[y]]
        cluster[num_points+i]=xval+yval
        for w in cluster[num_points+i]:
            c[words[w]]+=1
        cluster[x][:]=[]
        cluster[y][:]=[]

    
    return c

def grapher(SP_full,sort_ids,id2word,metric,kmeans_label_ranked):
    try:
        ##Create complete graph

        # nx_graph = nx.from_scipy_sparse_matrix(SP_full,create_using=nx.Graph())


        #label nodes
        label_nodes_full={}
        for x in xrange(len(sort_ids)):
            label_nodes_full[x]=id2word[sort_ids[x]]

        # nx_graph=nx.relabel_nodes(nx_graph,label_nodes_full)

        sorter=[w[0] for w in sorted(metric.items(),key=itemgetter(1),reverse=True)]
        Max_Neighbors=10
        node_id={}
        id_node={}
        for i, w in enumerate(sorter):
            node_id[w]=i
            id_node[i]=w

        Ed=[]
        start=0
        end=250
        for z in sorter[start:end]:
            string=z
            for key, value in label_nodes_full.iteritems():
                if value==string:
                    key_s=key

            sorted_x= (-SP_full[key_s].toarray()).argsort()
            Ed+=[(node_id[string],node_id[label_nodes_full[x]]) for x in sorted_x[0][:Max_Neighbors]]
        newg=nx.Graph()
        newg.add_edges_from(Ed)
        for idd in newg.nodes():
            node=id_node[idd]
            newg.node[idd]['label']=node

        remove_list=[]

        if not nx.is_connected(newg):
            print "Graph is not connected..."
            sub_graphs=list(nx.connected_component_subgraphs(newg))
            for graph in sub_graphs:
                if len(graph.nodes())<=11:
                    remove_list+=[graph.node[i]['label'] for i in graph.nodes()]
        Ed=[]
        start=0
        end=250
        for z in sorter[start:end]:
            if z not in remove_list:
                string=z
                for key, value in label_nodes_full.iteritems():
                    if value==string:
                        key_s=key

                sorted_x= (-SP_full[key_s].toarray()).argsort()
                Ed+=[(node_id[string],node_id[label_nodes_full[x]]) for x in sorted_x[0][:Max_Neighbors]]
        newg=nx.Graph()
        newg.add_edges_from(Ed)
        for idd in newg.nodes():
            node=id_node[idd]
            newg.node[idd]['label']=node

        nx.write_gexf(newg,'graph.gexf')

        #run java to do gephi stuff
        import subprocess
        def shell_run(command):
            p=subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
            for line in p.stdout.readlines():
                print line,
            retval = p.wait()

        input_file='graph.gexf'
        output_file='gephi.json'
        shell_run('java -jar gexf2json/gexf2json.jar -i '+input_file+' -n 5000 -o '+output_file)

        #read gephi file back and fix color
        vis_colors={
            'blue':'rgb(54,9,237)',
            'orange':'rgb(237,145,9)',
            'green':'rgb(9,237,100)',
            'brightgreen': 'rgb(54,237,9)',
            'lightgreen':'rgb(191,237,9)',
            'pink': 'rgb(237,9,145)',
            'purple': 'rgb(191,9,237)',
            'red': 'rgb(226,18,18)',
            'lightblue':'rgb(18,226,226)',
            'mauve': 'rgb(31,150,210)'
        }
        ranked_color=['red','green','blue','orange','pink','purple','mauve','lightblue','lightgreen','brightgreen']

        gephi_file='gephi.json'

        with open('%s' %gephi_file, 'rb') as infile:
            gephi_in=json.load(infile)
        num_nodes=len(gephi_in['nodes'])

        node_to_id={}
        for w in xrange(len(gephi_in['nodes'])):
            node_to_id[gephi_in['nodes'][w]['label']]=w



        graph_out_file='graph.json'
        scale=1
        graph={}
        graph['edges']=gephi_in['edges']
        graph['nodes']=gephi_in['nodes']


        second_sort=[w[1] for w in sorted([(key,value) for key, value in metric.items() if key in node_to_id.keys()],key=itemgetter(1),reverse=True)]
        num_words_top=30
        thresh=np.percentile(second_sort,100*(1-num_words_top*1./num_nodes))
        alpha=np.log(0.5)/np.log(thresh)
        score={}
        top_label_avg_y=0
        top_label_num=0
        for w in graph['nodes']:
            node=w['label']

            w['color']=vis_colors[ranked_color[kmeans_label_ranked[w['label']]]]
            if ranked_color[kmeans_label_ranked[w['label']]]=='red':
                top_label_num+=1
                top_label_avg_y+=w['y']

            if metric[node]**alpha>0.5:
                size=100*metric[node]**2
            elif metric[node]>0.5:
                size=50*(metric[node]-(np.power(0.5,1/alpha)-0.5))
            elif np.sqrt(metric[node])>np.power(0.5,1/alpha):
                size=25*np.power(metric[node],1/alpha)
            else:
                size=10
            score[node]=metric[node]
            w['size']=size
            w['label']=re.sub('_',' ',w['label'])
        #flip:
        top_label_avg_y=top_label_avg_y*1./top_label_num

        if top_label_avg_y>0:
            for w in graph['nodes']:
                w['y']=w['y']*-1

        with open('static/graph.json', 'w') as outfile:
            json.dump(graph, outfile)
        return graph,"OK"
    except Exception as e:
        return '',e
