# ipv4map.py - no copyright 2011 0x2620.org - public domain

from __future__ import division
import Image
import ImageDraw
import math
import os
import ox
import pygeoip

geoip = pygeoip.GeoIP('../dat/GeoLiteCity.dat', pygeoip.MEMORY_CACHE)
path = '../../../../oxjs.org/source/Ox.Geo/png/flags/'
projection = {
    'U': [[0, 2, 3, 1], 'DUUC'], # Square 0 1 Segment U = 0 3 = 0 1 E F = D C
    'D': [[0, 1, 3, 2], 'UDDA'], #        2 3             1 2   3 2 D C   U U
    'C': [[3, 2, 0, 1], 'ACCU'], #                              4 7 8 B
    'A': [[3, 1, 0, 2], 'CAAD']  #                              5 6 9 A
}

def get_country_code_by_ip(ip):
    replace = {'A1': 'NTHH', 'A2': 'NTHH', 'AN': 'ANHH', 'AP': 'FM', 'O1': 'NTHH'}
    data = geoip.record_by_addr(ip) if not ip.startswith('0.') else None
    country_code = data['country_code'] if data and 'country_code' in data else 'NTHH'
    return replace[country_code] if country_code in replace else country_code

def get_flag_image(county_code, size):
    flag_size = str(int(pow(4, math.ceil(math.log(size, 2) / 2))))
    flag_image = Image.open(path + flag_size + '/' + county_code + '.png')
    if size != flag_size:
        flag_image = flag_image.resize((size, size), Image.ANTIALIAS)
    return flag_image

def get_ip_by_n(n):
    ip = []
    for i in range(4):
        ip.append(str(int(n / pow(256, 3 - i)) % 256))
    return '.'.join(ip)

def get_n_by_ip(ip):
    n = 0
    for i, v in enumerate(ip.split('.')):
        n += int(v) * pow(256, 3 - i)
    return n

def get_tile_by_ip(ip, z):
    n = get_n_by_ip(ip)
    last_n = n + pow(4, 16 - z) - 1
    last_ip = get_ip_by_n(last_n)
    return '../png/tiles/%s/%s-%s.png' % (ip.split('.')[0], ip, last_ip)

def get_xy_by_ip(ip, z, w='U'):
    x, y = 0, 0
    n = int(get_n_by_ip(ip) / pow(4, 8 - z))
    for i in range(8 + z):
        p2 = pow(2, 7 + z - i)
        p4 = pow(4, 7 + z - i)
        q = int(n / p4)
        xy = projection[w][0][q]
        x += xy % 2 * p2
        y += int(xy / 2) * p2
        n -= q * p4
        w = projection[w][1][q]
    return [x, y]

def render_flags():
    # renders an image of 256 flags
    image = Image.new('RGB', (1024, 1024))
    font = '../ttf/DejaVuSansMonoBold.ttf'
    countries = ox.file.read_json(path.replace('png/flags/', 'json/Ox.Geo.json'))
    countries = filter(lambda x: len(x['code']) == 2 and not x['code'] in ['UK', 'XK'], countries)
    for i, country in enumerate(sorted(countries, key=lambda x: x['code'])):
        flag = get_flag_image(country['code'], 64)
        x, y = i % 16 * 64, int(i / 16) * 64
        image.paste(flag, (x, y, x + 64, y + 64))
        ox.image.drawText(image, (x + 5, y + 5), country['code'], font, 8, (64, 64, 64))
        ox.image.drawText(image, (x + 4, y + 4), country['code'], font, 8, (192, 192, 192))
    image.save('../png/flags.png')

def render_map():
    # renders images of the full map
    if not os.path.exists('../png/ipv4map16.png'):
        mapfile = '../png/ipv4map16384.png'
        mapimage = Image.new('RGB', (16384, 16384))
        for a in range(0, 256, 16):
            print '%d.0.0.0' % a
            image = Image.open('../png/map/%d.0.0.0-%d.255.255.255.png' % (a, a + 15))
            x, y = map(lambda x: int(x / 4096) * 4096, get_xy_by_ip('%d.0.0.0' % a, 6))
            mapimage.paste(image.resize((4096, 4096), Image.ANTIALIAS), (x, y))
        mapimage.save(mapfile)
        for s in [4096, 1024, 16]:
            mapimage.resize((s, s), Image.ANTIALIAS).save(mapfile.replace('16384', str(s)))

def render_square(values, maxlen, w):
    # recursively renders a square region, given a list of country codes
    length = len(values)
    if length > 4:
        square_length = int(length / 4)
        square_values = map(lambda x: values[x * square_length:(x + 1) * square_length], range(4))
        values = []
        for i, v in enumerate(square_values):
            values.append(render_square(v, maxlen, projection[w][1][i]))
    equal = values[0] == values[1] == values[2] == values[3]
    if length < maxlen and equal:
        return values[0]
    else:
        order = projection[w][0]
        image_size = int(math.sqrt(length) * 16)
        flag_size = image_size if equal else int(image_size / 2)
        image = Image.new('RGB', (image_size, image_size))
        for i, value in enumerate([values[0]] if equal else values):
            flag = get_flag_image(value, flag_size) if isinstance(value, str) else value
            image.paste(flag, (order[i] % 2 * flag_size, int(order[i] / 2) * flag_size))
        if length == pow(4, 4):
            image.save('../png/_tmp/' + str(values) + '.png')
        return image            

def render_tile(ip, z):
    # recursively renders the map tile for a given ip at a given zoom level
    n = get_n_by_ip(ip)
    file = get_tile_by_ip(ip, z)
    if not os.path.exists(file):
        tiles_n = map(lambda x: n + x * pow(4, 15 - z), range(4))
        tiles_ip = map(lambda x: get_ip_by_n(x), tiles_n)
        tiles = map(lambda x: render_tile(x, z + 1), tiles_ip)
        image = Image.new('RGB', (256, 256))
        x, y = map(lambda x: int(x / 256) * 256, get_xy_by_ip(ip, z))
        for i in range(4):
            x_, y_ = map(lambda x: int(x / 128) * 128, get_xy_by_ip(tiles_ip[i], z))
            image.paste(tiles[i].resize((128, 128), Image.ANTIALIAS), (x_ - x, y_ - y))
            image.save(file)
    else:
        image = Image.open(file)
    return image

def render_tiles():
    # renders all tiles at the maximum zoom level
    w = ''
    for v in projection['U'][1]:
        w += projection[v][1]
    file256 = '../png/ipv4map16384.png'
    for a in range(256):
        print '%d.0.0.0' % a
        if a % 16 == 0:
            values = []
            file = '../png/map/%d.0.0.0-%d.255.255.255.png' % (a, a + 15)
        if not os.path.exists(file):
            for b in range(256):
                for c in range(256):
                    values.append(get_country_code_by_ip('%d.%d.%d.0' % (a, b, c)))
            if a % 16 == 15:
                image = render_square(values, len(values), w[int(a / 16)]).save(file)
        if a % 16 == 15:
            image = Image.open(file)
            for a_ in range(a - 15, a + 1):
                for b in range(256):
                    file = '../png/tiles/%d/%d.%d.0.0-%d.%d.255.255.png' % (a_, a_, b, a_, b)
                    if not os.path.exists(file):
                        x, y = map(lambda x: int(x % 16384 / 256) * 256, get_xy_by_ip('%d.%d.0.0' % (a_, b), 8))
                        dirs = os.path.split(file)[0]
                        if dirs and not os.path.exists(dirs):
                            os.makedirs(dirs)
                        image.crop((x, y, x + 256, y + 256)).save(file)

def render_projection():
    # renders an image of the map projection
    image = Image.new('RGB', (1024, 1024))
    draw = ImageDraw.Draw(image)
    font = '../ttf/DejaVuSansCondensedBold.ttf'
    for i in range(256):
        x, y = map(lambda x: int(x / 64) * 64, get_xy_by_ip('%d.0.0.0' % i, 2))
        rgb = ox.getRGB((i * 360 / 256, 1, 0.5))
        draw.rectangle((x, y, x + 64, y + 64), fill=rgb)
        ox.image.drawText(image, (x + 4, y + 4), str(i), font, 8, (0, 0, 0))
        draw.rectangle((x + 30, y + 30, x + 34, y + 33), fill=0)
        if i:
            draw.line((x_ + 32, y_ + 32, x + 32, y + 32), fill=0, width=4)
        x_, y_ = x, y
    image.save('../png/projection.png')

if __name__ == '__main__':
    render_flags()
    render_projection()
    render_tiles()
    render_tile('0.0.0.0', 0)
    render_map()

# to get a sufficiently large "international" icon, run:
# qlmanage -t -s 16384 -o . ../../../../oxjs.org/tools/geo/svg/icons/NTHH.svg