#!/usr/bin/env python

import subprocess
import re
import os
import socket
import pkg_resources
from xml.dom.minidom import parse as parseXML
from jinja2 import Environment, PackageLoader

b_id = 0
s_id = 0

try:
    __version__ = pkg_resources.get_distribution('hiprofile').version
except Exception:
    __version__ = 'unknown'

def _get_count(node):
    """ Utility function: return the number in a 'count' element contained in
        the current node"""
    countnode = node.getElementsByTagName('count')
    return int(countnode.item(0).firstChild.data)

def _threshold_list(list, fn, threshold):
    """ Utility function: given a list and an ordering function, return the
        elements of the list that are above the threshold.

        If threshold is a percentage (ie, ends with '%'), then the items
        that fn returns above that percentage of the total. If it's an absolute
        number, then that number of items is returned"""

    if threshold.endswith('%'):
        percentage = float(threshold[0:-1]) / 100
        total = reduce(int.__add__, map(fn, list))
        list = [ x for x in list if fn(x) > (total * percentage) ]
        list.sort(key = fn, reverse = True)

    else:
        count = int(threshold)
        list.sort(key = fn, reverse = True)
        list = list[0:count]

    return list

class Connection(object):
    def __init__(self, host):
        self.host = host

    def execute(self, command, input = None):

        if self.host:
            command = ['ssh', self.host] + command

        if input:
            stdin = subprocess.PIPE
        else:
            stdin = None

        null_fd = open('/dev/null', 'w')
        proc = subprocess.Popen(command, stdin = stdin,
                stdout = subprocess.PIPE, stderr = null_fd)

        if input:
            proc.stdin.write(input)

        return proc.stdout

class SymbolInstruction(object):
    def __init__(self, addr, asm, source = None):
        self.addr = addr
        self.asm = asm
        self.percentage = 0

        if source is None:
            source = ''
        self.source = source

    def set_samples(self, samples, total):
        self.samples = samples
        self.percentage = 100 * float(samples) / total

    def colour(self):
        if not self.percentage:
            return '#ffc000'

        if self.percentage * 40 > 0xc0:
            return '#ff0000'

        return '#ff%02x00' % (0xc0 - self.percentage * 40)


# 13283  0.7260 :1031b6e0:       stwu    r1,-48(r1)
sampled_asm_re = re.compile( \
    '^\s*(?P<samples>\d+)\s+\S+\s+:\s*(?P<addr>[0-9a-f]+):\s*(?P<asm>.*)\s*$')

#               :1031b72c:       addi    r0,r9,4
unsampled_asm_re = re.compile( \
    '^\s*:\s*(?P<addr>[0-9a-f]+):\s*(?P<asm>.*)\s*$')

#               :AllocSetAlloc(MemoryContext context, Size size)
source_re = re.compile( \
    '^\s*:(?P<source>.*)$')

class SymbolReference(object):

    def __init__(self, id, name, module):
        self.id = id
        self.name = name
        self.count = 0;
        self.module = module
        self.annotations = None

        # annotation parsing buf
        self.source_buf = ''
        self.insns = []

    def module_name(self):
        return self.module.split('/')[-1]

    def filename(self):
        return 'symbol-%s.html' % self.id

    def annotate(self, line):
        match = None

        sampled_match = sampled_asm_re.match(line)
        if sampled_match:
            match = sampled_match

        unsampled_match = unsampled_asm_re.match(line)
        if unsampled_match:
            match = unsampled_match

        if match:
            insn = SymbolInstruction(match.group('addr'), match.group('asm'),
                    self.source_buf)
            if sampled_match:
                insn.set_samples(int(match.group('samples')), self.count)
            self.insns.append(insn)
            self.source_buf = ''
            return

        match = source_re.match(line)
        if match:
            self.source_buf += match.group('source') + '\n'

    @staticmethod
    def parse(report, node, module):
        id = int(node.getAttribute('idref'))
        ref = SymbolReference(id, report.symtab[id], module)

        ref.count = _get_count(node)
        return ref

class Binary(object):

    def __init__(self, name, count):
        global b_id
        self.name = name
        self.count = count
        self.references = []
        self.id = b_id = b_id + 1

    def __str__(self):
        s = '%s: %d' % (self.name, self.count)
        for ref in self.references:
            s += '\n' + str(ref)
        return s

    def shortname(self):
        return self.name.split('/')[-1]

    def filename(self):
        return 'binary-%d.html' % self.id

    def threshold(self, thresholds):
        self.references = _threshold_list(self.references,
                lambda r: r.count, thresholds['symbol'])
        self.reference_dict = dict([ (r.name, r) for r in self.references ])

    def annotate(self, report, conn, options):
        fn_re = re.compile('^[0-9a-f]+\s+<[^>]+>: /\* (\S+) total:')

        symbols = [ s for s in self.references if s.name != '(no symbols)' ]

        if not symbols:
            return

        command = [options.opannotate, '--source', '--assembly',
            '--include-file=' + self.name,
            '-i', ','.join([ s.name for s in symbols ])]

        fd = conn.execute(command)

        symbol = None

        for line in fd.readlines():
            match = fn_re.match(line)
            if match:
                if symbol:
                    symbol.annotate(line)
                symname = match.group(1)
                if symname in self.reference_dict:
                    symbol = self.reference_dict[symname]
                else:
                    symbol = None
            if symbol:
                symbol.annotate(line)

    def parse_symbol(self, report, node, module = None):
        if module is None:
            module = self.name

        ref = SymbolReference.parse(report, node, module)
        ref.percentage = 100 * float(ref.count) / self.count
        self.references.append(ref)


    @staticmethod
    def parse(report, node):
        name = node.getAttribute('name')

        binary = Binary(name, _get_count(node))

        for child_node in node.childNodes:
            if child_node.nodeType != node.ELEMENT_NODE:
                continue

            if child_node.nodeName == 'symbol':
                binary.parse_symbol(report, child_node, None)

            elif child_node.nodeName == 'module':
                module_name = child_node.getAttribute('name')
                for child_sym_node in child_node.getElementsByTagName('symbol'):
                    binary.parse_symbol(report, child_sym_node, module_name)

        return binary

class Report(object):
    def __init__(self, host, arch, cpu):
        self.host = host
        self.arch = arch
        self.cpu = cpu
        self.binaries = []
        self.symtab = []
        self.total_samples = 0

    def add_binary(self, binary):
        self.binaries.append(binary)
        self.total_samples += binary.count

    def threshold(self, thresholds):
        self.binaries = _threshold_list(self.binaries,
                lambda b: b.count, thresholds['binary'])

        for binary in self.binaries:
            binary.threshold(thresholds)

    def annotate(self, conn, options):
        for binary in self.binaries:
            binary.annotate(self, conn, options)

    @staticmethod
    def parse(doc, hostname):
        node = doc.documentElement

        cpu = '%s (%s MHz)' % (\
                   node.getAttribute('processor'),
                   node.getAttribute('mhz'))

        report = Report(hostname, node.getAttribute('cputype'), cpu)

        # parse symbol table
        symtab_node = doc.getElementsByTagName('symboltable').item(0)

        for node in symtab_node.childNodes:
            if node.nodeType != node.ELEMENT_NODE:
                continue
            report.symtab.insert(int(node.getAttribute('id')),
                                 node.getAttribute('name'))


        # parse each binary node
        for node in doc.getElementsByTagName('binary'):
             binary = Binary.parse(report, node)
             report.add_binary(binary)

        # calculate percentages
        for binary in report.binaries:
            binary.percentage = 100 * float(binary.count) / report.total_samples

        return report

    @staticmethod
    def extract(connection, options):
        fd = connection.execute([options.opreport, '--xml'])
        doc = parseXML(fd)

        if connection.host:
            hostname = connection.host
        else:
            hostname = socket.gethostname()

        return Report.parse(doc, hostname)

    def __str__(self):
        return self.machine + '\n' + '\n'.join(map(str, self.binaries))

def write_report(report, outdir):

    os.mkdir(outdir)

    # set up template engine
    env = Environment(loader = PackageLoader(__name__, 'resources'),
                      autoescape = True)
    templates = {}
    for name in ['report', 'binary', 'symbol']:
        templates[name] = env.get_template('%s.html' % name)

    # copy required files over from resources
    files = ['style.css', 'hiprofile.js', 'bar.png', 'jquery-1.3.1.min.js']
    for file in files:
        f = open(os.path.join(outdir, file), 'w')
        f.write(pkg_resources.resource_string(__name__, 'resources/' + file))
        f.close()

    reportfile = os.path.join(outdir, 'index.html')
    templates['report'].stream(version = __version__,
                               report = report).dump(reportfile)

    for binary in report.binaries:
        binaryfile = os.path.join(outdir, binary.filename())
        templates['binary'].stream(version = __version__,
                                   report = report,
                                   binary = binary) \
                                    .dump(binaryfile)

        for symbol in binary.references:
            symbolfile = os.path.join(outdir, symbol.filename())
            templates['symbol'].stream(version = __version__,
                                       report = report, binary = binary,
                                       symbol = symbol).dump(symbolfile)
