import gzip
import numpy as np
import time

from numba import jit

"""
                             7-Card Hand Evaluator                     
########################################################################
"""
'''
 Original Java version: Helmuth Melcher <helmuth@holdemresources.net>
 Ported to Python: Christoph Stich <christoph.n.stich@gmail.com>

 Hand evaluator for 7 card hands using a 512kb lookup table, 2^18 int16.
 An evaluation consists of calculating the index plus a single lookup.

 Hand format:
 The evaluation methods take a single long to represent the hand using 
 the lowest 52 bits.

 bits 0-3: 2s2h2d2c
 bits 4-7: 3s3h3d3c
 etc.


'''

CARDS = np.asarray([2**n for n in range(52)], dtype=np.int64)

RANKS = 0x1111111111111
LOOKUP_SIZE = 0x1 << 18
QUADS = 0x4444444444444
FLUSH = 111541

REDUCE = np.asarray(
    [
        248844, 344509, 578436, 1006297, 1049810, 1537951, 1627467, 2077693,
        2185318, 2557799, 2785929, 2922336, 3334005, 3411511, 3690340, 4129348,
        4333610, 4593825, 4863681, 5008913, 5489331, 5728048, 5791243, 6279771,
        6503942, 6765490, 6822030, 7296176, 7361105, 7762452, 7999853, 8324112,
        8546647, 8844882, 9137396, 9228405, 9648941, 9747526, 10147779,
        10344751, 10662433, 10823273, 11151028, 11441832, 11638513, 11998481,
        12240373, 12517381, 12589624, 13082003, 13144746, 13502680, 13824205,
        13924060, 14405355, 14620367, 14845254, 15120896, 15307881, 15601796,
        15922340, 16193284, 16422142, 16715797
    ],
    dtype='int32')

def load_rank_lookup():
    f = gzip.open(('lookup7card.dat.gz'),
                  'rb')
    lookup = np.zeros(LOOKUP_SIZE, dtype='int16')
    for i in range(len(lookup)):
        b = f.read(2)
        lookup[i] = ((b[0] & 0xFF) | (b[1] & 0xFF) << 8)
    f.close()
    return lookup


LOOKUP_TABLE = load_rank_lookup()


@jit(nopython=True)
def calculate_index(cards):
    '''
    Calculate the index for the lookup table.

    :cards: The 7-cards for which to calculate the rank (int64).
    :return: The index to lookup the correct rank.
    '''
    t = 0
    r = 0
    f = 0

    for i in range(4):
        if f != 0:
            break
        t = cards >> i & RANKS
        # following lines are a replacement for Java's popcnt version
        f = t & (t - 1)
        f &= f - 1
        f &= f - 1
        f &= f - 1
        r += t

    if f != 0:
        t |= t >> 26
        t |= t >> 13
        return FLUSH ^ (t & 0x1FFF)

    t = r & QUADS
    if t != 0:
        t = t >> 1 | ((r | r >> 1) & RANKS)
        res = (~(t | t >> 26) & 0x1FFFFFF) >> 1
    else:
        res = ((r | r >> 26) & 0x0FFFFFF)

    return res ^ REDUCE[res >> 18 & 0x3F]


@jit(nopython=True)
def evaluate7cards(cards):
	return LOOKUP_TABLE[calculate_index(cards)]



########################################################################
#                              Trials                                  #
# ######################################################################
@jit(nopython=True)
def trials():
    counter = 0
    for c1 in range(52):
        for c2 in range(c1 + 1, 52):
            for c3 in range(c2 + 1, 52):
                for c4 in range(c3 + 1, 52):
                    for c5 in range(c4 + 1, 52):
                        for c6 in range(c5 + 1, 52):
                            for c7 in range(c6 + 1, 52):
                                evaluate7cards(
                                CARDS[c1]+CARDS[c2]+CARDS[c3]+CARDS[c4]+
                                CARDS[c5]+CARDS[c6]+CARDS[c7])
                                counter += 1
    return counter
    
def timedtrials(): 
    for i in range(10):
        start = time.time()
        counter = trials()
        end = time.time()
        print(counter, 'hands counted with for loop in', str(end - start), 'seconds')
