import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as pltim
import scipy.cluster.vq as scikm
import scipy.misc as scimc
from mpl_toolkits.mplot3d import Axes3D
# from PIL import Image
import sys
import time
import platform
import re
from math import sqrt
from itertools import *
from random import *
from fractions import Fraction


class MacsLib():
    def startDraw(self, clear=False, type=None):
        if clear: plt.clf()
        return plt.figure(1).add_subplot(111, projection=type)


    def finishDraw(self, pltax, title=None, xaxis=None, yaxis=None, zaxis=None, xlabel='', xmin=None, xmax=None, ylabel='', ymin=None, ymax=None, zlabel='', zmin=None, zmax=None):
        if title is not None:
            pltax.set_title(title)

        if xaxis is not None: xlabel,xmin,xmax = xaxis
        if yaxis is not None: ylabel,ymin,ymax = yaxis
        if zaxis is not None: zlabel,zmin,zmax = zaxis

        pltax.set_xlim(xmin, xmax)
        pltax.set_xlabel(xlabel)

        pltax.set_ylim(ymin, ymax)
        pltax.set_ylabel(ylabel)

        if zlabel:
            pltax.set_zlim(zmin, zmax)
            pltax.set_zlabel(zlabel)

        plt.draw()
        plt.show(block=False)


    def drawImage(self, image, **kwargs):
        image[image>255] = 255
        image[image<0] = 0

        pltax = self.startDraw(clear=True)
        pltax.imshow(image, interpolation='nearest', cmap='gray')

        if image.shape[1] <= 64:
            xtick = np.arange(0, image.shape[1])
            ytick = np.arange(0, image.shape[0])
            pltax.grid()
            pltax.set_xticks(xtick+0.5)
            pltax.set_xticklabels(xtick)
            pltax.set_yticks(ytick+0.5)
            pltax.set_yticklabels(ytick)

        self.finishDraw(pltax, **kwargs)


    def drawHistogram(self, data, ymax=None, **kwargs):
        datalen = len(data)-1
        data = np.asarray(data[1:])
        left = np.arange(datalen)

        pltax = self.startDraw(clear=True)
        pltax.bar(left, data)

        tickdir = 'horizontal' if datalen < 10 else 'vertical'
        if platform.system() == 'Windows':
            pltax.set_xticks(left)
            pltax.set_xticklabels(left+1, rotation=tickdir)
            pltax.set_xlim(-0.5, datalen-0.5)
        else:
            pltax.set_xticks(left+0.4)
            pltax.set_xticklabels(left+1, rotation=tickdir)

        if ymax is None: ymax = data.max()+10

        self.finishDraw(pltax, ymax=ymax, **kwargs)


    def drawFunction(self, funcs, xmin=-10, xmax=10, ymin=None, ymax=None, **kwargs):
        pltax = self.startDraw(clear=False)

        fymin = sys.maxint if ymin is None else ymin
        fymax = -sys.maxint-1 if ymax is None else ymax

        for func in funcs:
            origfunc = func
            func = func.replace('^','**')
            func = func.replace('sin(', 'np.sin(')
            func = func.replace('cos(', 'np.cos(')
            func = func.replace('tan(', 'np.tan(')
            func = func.replace('log(', 'np.log10(')
            func = func.replace('sqrt(', 'np.sqrt(')
            func = func.replace('abs(', 'np.absolute(')
            func = func.replace('x!', 'scimc.factorial(x)')
            func = re.sub('([0-9]+)x', '\\1*x', func)
            if 'x' not in func: func = func + '+0*x'

            x = np.linspace(xmin, xmax, 100)
            y = eval(func)

            x = x[~np.isnan(y)]
            y = y[~np.isnan(y)]
            x = x[~np.isinf(y)]
            y = y[~np.isinf(y)]

            if fymin > y.min() and ymin is None: fymin = y.min()
            if fymax < y.max() and ymax is None: fymax = y.max()
            if fymin == fymax: fymin = fymin - 1; fymax = fymax + 1

            pltax.plot(x, y, label='y = '+origfunc)

        pltax.set_xlim(xmin, xmax)
        pltax.set_ylim(fymin, fymax)
        pltax.legend()

        self.finishDraw(pltax, **kwargs)


    def drawPoints(self, xpoints, ypoints=None, zpoints=None, connect=True, rescale=False, **kwargs):
        if zpoints is None:
            pltax = self.startDraw(clear=True)

            if ypoints is None:
                ypoints = xpoints
                xpoints = np.arange(0, len(ypoints))

            pltln = 'o-' if connect else 'o'
            plt.plot(xpoints, ypoints, pltln)

            if len(xpoints) == 2 or rescale:
                axisRange = max(abs(xpoints[0] - xpoints[1]), abs(ypoints[0] - ypoints[1])) * 1.1
                xmin = ((xpoints[0] + xpoints[1]) / 2.0) - (axisRange / 2.0)
                ymin = ((ypoints[0] + ypoints[1]) / 2.0) - (axisRange / 2.0)
                pltax.set_xlim(xmin, xmin + axisRange)
                pltax.set_ylim(ymin, ymin + axisRange)

        else:
            pltax = self.startDraw(clear=True, type='3d')
            pltax.scatter(xpoints, ypoints, zpoints)
            if connect:
                pltax.plot(xpoints, ypoints, zpoints, '--')

        self.finishDraw(pltax, **kwargs)


    def drawFitLine(self, xpoints, ypoints, degree, xmin=None, xmax=None, ymin=None, ymax=None, **kwargs):
        pltax = self.startDraw(clear=True)

        fxmin = xpoints.min() if xmin is None else xmin
        fxmax = xpoints.max() if xmax is None else xmax
        fymin = ypoints.min() if ymin is None else ymin
        fymax = ypoints.max() if ymax is None else ymax

        pltax.scatter(xpoints,ypoints)
        z = np.polyfit(xpoints, ypoints, degree)
        p = np.poly1d(z)
        x = np.linspace(fxmin, fxmax, 100)

        s = 'y = '
        for (xs,zs) in zip(range(z.shape[0]-1,-1,-1), np.around(z,1)):
            if ( zs != 0.0 ):
                if zs > 0: s += '+'
                if zs != 1.0: s += str(zs)
                if xs > 1:  s += 'x^'+str(xs)+' '
                elif xs > 0: s += 'x '

        pltax.plot(x, p(x), 'r-', label=s)
        pltax.legend()

        self.finishDraw(pltax, xmin=fxmin, xmax=fxmax, ymin=fymin, ymax=fymax, **kwargs)


    def drawKMeans(self, xpoints, ypoints, zpoints, numclust, **kwargs):
        colors = ['b', 'r', 'g', 'm', 'y', '#4B0082', '#00BFFF', '#FF8C00', '#8B4513']
        res = scikm.kmeans2(np.array(zip(xpoints,ypoints,zpoints)), numclust, iter=500)

        pltax = self.startDraw(clear=True, type='3d')

        (mx,my,mz,mw) = ([],[],[],0)
        for centid in range(0,numclust):
            if len(xpoints[res[1]==centid]) > 0:
                (cx,cy,cz) = (res[0][centid][0], res[0][centid][1], res[0][centid][2])
                (px,py,pz) = (xpoints[res[1]==centid], ypoints[res[1]==centid], zpoints[res[1]==centid])
                (rx,ry,rz) = (np.absolute(px-cx).max(), np.absolute(py-cy).max(), np.absolute(pz-cz).max())

                pltax.scatter(px, py, pz, color=colors[centid])
                pltax.scatter([cx], [cy], [cz], marker='+', color=colors[centid])

                # add spheres - based on mplot3d tutorial
                u = np.linspace(0, 2 * np.pi, 100)
                v = np.linspace(0, np.pi, 100)
                x = rx * 1.2 * np.outer(np.cos(u), np.sin(v)) + cx
                y = ry * 1.2 * np.outer(np.sin(u), np.sin(v)) + cy
                z = rz * 1.2 * np.outer(np.ones(np.size(u)), np.cos(v)) + cz
                pltax.plot_surface(x, y, z,  rstride=4, cstride=4, alpha=0.2, linewidths=0, color=colors[centid])

                if max(rx*1.5, ry*1.5, rz*1.5) > mw: mw = max(rx*1.5, ry*1.5, rz*1.5)
                (mx,my,mz) = (mx+[cx],my+[cy],mz+[cz])

        # if platform.system() == 'Windows':
            # pltax.set_xlim(np.mean(mx)-mw, np.mean(mx)+mw)
            # pltax.set_ylim(np.mean(my)-mw, np.mean(my)+mw)
            # pltax.set_zlim(np.mean(mz)-mw, np.mean(mz)+mw)

        self.finishDraw(pltax, **kwargs)


    def drawTSPPoints(self, points, **kwargs):
        pltax = self.startDraw(clear=True)

        xpoints = []
        ypoints = []
        for entry in points.items():
            pt = entry[1]
            xpoints.append(pt[0])
            ypoints.append(pt[1])

            lbl = entry[0]
            pltax.annotate(lbl, xy=pt, xytext=(pt[0], pt[1]+0.2))

        pltax.plot(xpoints, ypoints, 'o')

        self.finishDraw(pltax, **kwargs)


    def drawTSPPath(self, points, label='', ls='r-', **kwargs):
        pltax = self.startDraw(clear=False)

        xpoints = []
        ypoints = []
        for pt in points:
            xpoints.append(pt[0])
            ypoints.append(pt[1])

        pathLength = 0
        for i in range(len(points)):
            x1 = xpoints[i]
            y1 = ypoints[i]
            x2 = xpoints[(i+1) % len(points)]
            y2 = ypoints[(i+1) % len(points)]
            pathLength += sqrt((x1-x2)**2 + (y1-y2)**2)

            if i == len(points)-1:
                pltax.plot([x1,x2], [y1,y2], ls, label=label+', length= '+str(pathLength))
            else:
                pltax.plot([x1,x2], [y1,y2], ls)

        pltax.legend()

        self.finishDraw(pltax, **kwargs)


    def drawCTendency(self, xpoints, top=True, flabel=None, mean=None, median=None, mode=None, meanguess=None, medianguess=None, modeguess=None, **kwargs):
        pltax = self.startDraw(top==True)

        side = 1 if top else -1
        xpoints = np.array(xpoints)
        ypoints = np.zeros(xpoints.size)
        count = {}
        for i in range(xpoints.size):
            val = xpoints[i]
            count[val] = count.get(val, 0) + 1
            ypoints[i] = side*count[val]

        handle, = pltax.plot(xpoints, ypoints, 'kx', mew=2)
        for handle in pltax.get_legend_handles_labels()[0]:
            handle.set_label(None)

        ypos = ypoints.max()+2 if side > 0 else ypoints.min()-2
        m = 'v' if side > 0 else '^'
        mg = '1' if side > 0 else '2'
        if meanguess is not None:
            pltax.plot(meanguess, ypos+0.2*side, 'y'+mg, mew=1, ms=8, mec='y', fillstyle='none', label='Mean Guess: '+str(round(meanguess,3)))
        if mean is not None:
            pltax.plot(mean, ypos, 'y'+m, mew=1, ms=10, mec='y', fillstyle='none', label='Mean: '+str(round(mean,3)))
            pltax.axvline(x=mean, ymin=(0.5 if side>0 else (0.5+ypos/31.0)), ymax=((0.5+ypos/31.0) if side>0 else 0.5), color='y')
        if medianguess is not None:
            pltax.plot(medianguess, ypos+0.2*side, 'r'+mg, mew=1, ms=8, mec='r', fillstyle='none', label='Median Guess: '+str(medianguess))
        if median is not None:
            pltax.plot(median, ypos, 'r'+m, mew=1, ms=10, mec='r', fillstyle='none', label='Median: '+str(median))
            pltax.axvline(x=median, ymin=(0.5 if side>0 else (0.5+ypos/31.0)), ymax=((0.5+ypos/31.0) if side>0 else 0.5), color='r')
        if modeguess is not None:
            pltax.plot(modeguess, [ypos+0.2*side]*len(modeguess), 'b'+mg, mew=1, ms=8, mec='b', fillstyle='none', label='Mode Guess: '+str(modeguess))
        if mode is not None:
            pltax.plot(mode, [ypos]*len(mode), 'b'+m, mew=1, ms=10, mec='b', fillstyle='none', label='Mode: '+str(mode))
            for md in mode: pltax.axvline(x=md, ymin=(0.5 if side>0 else (0.5+ypos/31.0)), ymax=((0.5+ypos/31.0) if side>0 else 0.5), color='b')

        leg = pltax.legend(loc=(1 if top else 4))
        pltax.add_artist(leg)

        r = (xpoints.max()-xpoints.min())/40
        if pltax.get_ylim()[0] != -15:
            pltax.set_xlim(xpoints.min()-r, xpoints.max()+r)
            pltax.set_ylim(-15, 15)
            pltax.set_yticks([0])
            pltax.set_yticklabels([''])
        else:
            pltax.set_xlim(min(xpoints.min()-r,plt.xlim()[0]), max(xpoints.max()+r,plt.xlim()[1]))

        pltax.plot(pltax.get_xlim(), (0,0), 'k-')
        if flabel != None:
            labels = [item.get_text() for item in pltax.get_yticklabels()]
            pltax.set_yticks(list(plt.yticks()[0])+[3*side])
            pltax.set_yticklabels(labels+[flabel], rotation='vertical', size=14)

        self.finishDraw(pltax, **kwargs)

        
    def generateDeck(self, cardQuery='any'):
        if type(cardQuery) == str: cardQuery = [cardQuery]
        deck = [card for card in self.queryForCards(cardQuery)]
        deck.sort()
        return deck

        
    def cardProbability(self, deck, *cardArgs, **kwargs):
        drawCards = []
        for cardQuery in cardArgs:
            if type(cardQuery) == str: cardQuery = [cardQuery]
            cards = self.queryForCards(cardQuery)
            drawCards.append(cards)
            if kwargs.get('resolve', False):
                print cardQuery, "resolves to", len(cards), "card(s):", cards

        if kwargs.get('replace', False):
            return self.cardProbReplace(deck, drawCards)
        else:
            return self.cardProbRemove(deck, drawCards, 0)


    def cardProbReplace(self, deck, drawCards):
        deckLen = 0.0+len(deck)
        probability = 1.0

        for cards in drawCards:
            deckMatchSum = 0.0
            for card in cards:
                deckMatchSum += deck.count(card)
            probability *= deckMatchSum / deckLen

        return probability

    def cardProbRemove(self, deck, drawCards, drawDepth):
        deckLen = 0.0+len(deck)
        drawLen = len(drawCards)
        probability = 0.0

        for card in drawCards[drawDepth]:
            if card in deck:
                if deckLen > 1 and drawLen > drawDepth+1:
                    deckCopy = deck[:]
                    deckCopy.remove(card)
                    probability += (deck.count(card)/deckLen) * self.cardProbRemove(deckCopy, drawCards, drawDepth+1)
                else:
                    probability += (deck.count(card)/deckLen)
            else:
                probability += 0.0

        return probability


    def queryForCards(self, cardQuery):
        faceMap = { "any": [1,2,3,4,5,6,7,8,9,10,11,12,13],
            "e": [2,4,6,8,10], "even": [2,4,6,8,10],
            "o": [3,5,7,9], "odd": [3,5,7,9],
            "1": [1], "a": [1], "ace": [1], "2": [2], "two": [2],
            "3": [3], "three": [3], "4": [4], "four": [4],
            "5": [5], "five": [5], "6": [6], "six": [6],
            "7": [7], "seven": [7], "8": [8], "eight": [8],
            "9": [9], "nine": [9], "10": [10], "ten": [10],
            "j": [11], "jack": [11], "q": [12], "queen": [12],
            "k": [13], "king": [13], "face": [11,12,13] }

        suitMap = { "any": [1,2,3,4],
            "r": [2,3], "red": [2,3], "b": [1,4], "black": [1,4],
            "c": [1], "club": [1], "clubs": [1],
            "d": [2], "diamond": [2], "diamonds": [2],
            "h": [3], "heart": [3], "hearts": [3],
            "s": [4], "spade": [4], "spades": [4] }

        cards = set()

        for subQuery in cardQuery:
            face = "any"
            suit = "any"
            faceRange = None

            if type(subQuery) != str:
                raise ValueError("Bad card query: %s" % subQuery)

            if subQuery[0] == "<" or subQuery[0] == "=" or subQuery[0] == ">":
                if subQuery[1] == "=":
                    faceRange = subQuery[0:2]
                    subQuery = subQuery[2:]
                else:
                    faceRange = subQuery[0]
                    subQuery = subQuery[1:]

            subQueryParts = subQuery.lower().split(":")
            if len(subQueryParts) == 2:
                face,suit = subQueryParts
            elif len(subQueryParts) == 1:
                if subQueryParts[0] in faceMap:
                    face = subQueryParts[0]
                else:
                    suit = subQueryParts[0]
            else:
                raise ValueError("Bad card query: %s" % subQuery)

            if face in faceMap and suit in suitMap:
                faceVals = faceMap[face]
                suitVals = suitMap[suit]

                if faceRange == "<": faceVals = range(1, faceVals[0])
                elif faceRange == "<=": faceVals = range(1, faceVals[0]+1)
                elif faceRange == ">=": faceVals = range(faceVals[0], 14)
                elif faceRange == ">": faceVals = range(faceVals[0]+1, 14)

                cards.update(product(faceVals, suitVals))
                # print subQuery, faceVals, suitVals
            else:
                raise ValueError("Bad card query: %s" % subQuery)

        return cards


global macs
macs = MacsLib()

def loadData(fileName, delimiter=',', comments='#', unpack=True, **kwargs):
    return np.loadtxt(fileName, delimiter=delimiter, comments=comments, unpack=unpack, **kwargs)

def loadImage(fileName, **kwargs):
    image = pltim.imread(fileName, **kwargs)

    readImage = image.copy()
    for r in range(readImage.shape[0]):
        image[r] = readImage[r]#readImage.shape[0]-r-1]

    if isinstance(image[0][0], np.ndarray):
        readImage = image.copy()
        image = np.empty((readImage.shape[0], readImage.shape[1]), dtype='uint8')
        for r in range(readImage.shape[0]):
            for c in range(readImage.shape[1]):
                image[r][c] = readImage[r][c][0]

    return image

def drawImage(*args, **kwargs): return macs.drawImage(*args, **kwargs)
def drawHistogram(*args, **kwargs): return macs.drawHistogram(*args, **kwargs)
def drawFunction(*args, **kwargs): return macs.drawFunction(args, **kwargs)
def drawPoints(*args, **kwargs): return macs.drawPoints(*args, **kwargs)
def drawFitLine(*args, **kwargs): return macs.drawFitLine(*args, **kwargs)
def drawKMeans(*args, **kwargs): return macs.drawKMeans(*args, **kwargs)
def drawTSPPoints(*args, **kwargs): return macs.drawTSPPoints(*args, **kwargs)
def drawTSPPath(*args, **kwargs): return macs.drawTSPPath(*args, **kwargs)
def drawCTendency(*args, **kwargs): return macs.drawCTendency(*args, **kwargs)
def generateDeck(*args, **kwargs): return macs.generateDeck(*args, **kwargs)
def cardProbability(*args, **kwargs): return macs.cardProbability(*args, **kwargs)
