# -------------------------------------------------------------------------
#     This file is part of mMass - the spectrum analysis tool for MS.
#     Copyright (C) 2005-07 Martin Strohalm <mmass@biographics.cz>

#     This program is based on the library named PyPlot, originaly developped
#     by Gordon Williams and Jeff Grimmett. Thank you!

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 2 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE in the
#     main directory of the program
# -------------------------------------------------------------------------

# Function: Basic definitions fot plot objects and container.

# load libs
import wx
import numpy as num


class peaklistObject:
    """ Base class for peaklist lines and labels drawing. """

    # ----
    def __init__(self, peaks, **attr):
        self.peaks = peaks

        # get peaks
        self.points = []
        for peak in self.peaks:
            self.points.append(peak[0:2])

        # convert peaklist points to array
        self.points = num.array(self.points)
        self.cropped = self.points
        self.scaled = self.cropped

        # set default params
        self.currentCrop = (0, 0)
        self.currentScale = (1, 1)
        self.currentShift = (0, 0)
        self.attributes = {
                            'show':True,
                            'xoffset': 0,
                            'yoffset': 0,
                            'textcol': (0, 0, 0),
                            'missedcol': (255, 0, 0),
                            'matchcol': (0, 255, 0),
                            'showlabels': 1,
                            'showannots': 0,
                            'showticks': 1,
                            'width': 1,
                            'angle': 90,
                            'digits': 2,
                            'style': wx.SOLID
                            }

        # get new attributes
        for name, value in attr.items():
            self.attributes[name] = value

        # count data offset
        if self.attributes['xoffset'] != 0 or self.attributes['yoffset'] != 0:
            offset = num.array([self.attributes['xoffset'], self.attributes['yoffset']])
            self.points = self.points + offset

        # set colours
        colour = self.attributes['missedcol']
        self.attributes['missedcol'] = wx.Colour(colour[0], colour[1], colour[2])
        colour = self.attributes['matchcol']
        self.attributes['matchcol'] = wx.Colour(colour[0], colour[1], colour[2])
    # ----


    # ----
    def setParameters(self, **attr):
        """ Set new object attributes. """

        for name, value in attr.items():
            self.attributes[name] = value
    # ----


    # ----
    def getBoundingBox(self):
        """ Get peaklist bounding box. """

        # find max values
        if len(self.points) == 0 \
            or not self.attributes['show'] \
            or (not self.attributes['showlabels'] \
                and not self.attributes['showticks'] \
                ):
            return False
        else:
            minXY = num.minimum.reduce(self.points)
            maxXY = num.maximum.reduce(self.points)
            if minXY[1] > 0:
                minXY[1] = 0

            # extend X values to fit labels
            xExtend = (maxXY[0] - minXY[0]) * 0.02
            if xExtend:
                minXY[0] -= xExtend
                maxXY[0] += xExtend
            else:
                minXY[0] -= 0.5
                maxXY[0] += 0.5

            # extend Y values to fit labels
            yExtend = (maxXY[1] - minXY[1]) * 0.15
            if yExtend:
                maxXY[1] += yExtend
            else:
                maxXY[1] += 1

            return [minXY, maxXY]
    # ----


    # ----
    def cropPoints(self, minX, maxX):
        """ Remember crop. """
        self.currentCrop = (minX, maxX)
    # ----


    # ----
    def scaleAndShift(self, scale=(1, 1), shift=(0, 0)):
        """ Scale and shift peaklist points. """

        if len(self.points) != 0 and (scale is not self.currentScale or shift is not self.currentShift):
            self.scaled = scale * self.points + shift
            self.currentScale = scale
            self.currentShift = shift

            # scale crop
            minX = scale[0] * self.currentCrop[0] + shift[0]
            maxX = scale[0] * self.currentCrop[1] + shift[0]
            self.currentCrop = (minX, maxX)
    # ----


    # ----
    def filterPoints(self, filterSize=1):
        """ Filter peaklist points for printing and exporting - currently not needed. """
        return
    # ----


    # ----
    def draw(self, dc, printerScale=1):
        """ Define how to draw peaklist labels. """

        # escape if hidden object
        if not self.attributes['show']:
            return

        # set pen params
        missedPen = wx.Pen(self.attributes['missedcol'], self.attributes['width']*printerScale, self.attributes['style'])
        matchPen = wx.Pen(self.attributes['matchcol'], self.attributes['width']*printerScale, self.attributes['style'])
        dc.SetPen(missedPen)
        dc.SetTextForeground(self.attributes['textcol'])
        format = '%0.'+`self.attributes['digits']`+'f'

        # check if no intensities set
        noIntensity = True
        for peak in self.peaks:
            if peak[1]:
                noIntensity = False
                break

        # draw labels lines
        if self.attributes['showticks'] or self.attributes['showlabels']:
            for x in range(len(self.scaled)):

                # get peak
                xPos = self.scaled[x][0]
                yPos = self.scaled[x][1]

                # escape invisible labels (overcome unix drawing error)
                if not (self.currentCrop[0] < xPos < self.currentCrop[1]):
                    continue

                # scale lines if no intensities set in the peaklist
                if noIntensity:
                    yPos = self.scaled[x][1]*0.3

                # set colour according to match status
                if self.peaks[x][2]:
                    dc.SetPen(matchPen)
                else:
                    dc.SetPen(missedPen)

                # draw big label ticks
                if self.attributes['showticks']:
                    dc.DrawLine(xPos, self.currentShift[1], xPos, yPos)

                # draw small label ticks
                elif self.attributes['showlabels']:
                    dc.DrawLine(xPos, yPos-10*printerScale, xPos, yPos-20*printerScale)

                # draw peak height if centroid label type
                if self.attributes['showlabels'] and self.peaks[x][3] == 1:
                    height = self.currentScale[1] * self.peaks[x][5] + self.currentShift[1]
                    width = self.currentScale[0] * self.peaks[x][4]
                    dc.DrawLine(xPos-width/2, height, xPos+width/2, height)

        # draw labels
        if self.attributes['showlabels']:
            for x in range(len(self.scaled)):

                # get peak
                xPos = self.scaled[x][0]
                yPos = self.scaled[x][1]

                # escape invisible labels (overcome unix drawing error)
                if not (self.currentCrop[0] < xPos < self.currentCrop[1]):
                    continue

                # scale lines if no intensities set in the peaklist
                if noIntensity:
                    yPos = self.scaled[x][1]*0.3

                # set colour according to match status
                if self.peaks[x][2]:
                    dc.SetPen(matchPen)
                else:
                    dc.SetPen(missedPen)

                # get label
                label = format % self.peaks[x][0]

                # add annotations to label
                if self.attributes['showannots'] and self.peaks[x][2]:
                    label = label + ' - ' + self.peaks[x][2]

                # shift leabels
                labelSize = dc.GetTextExtent(label)

                if self.attributes['angle'] == 90:
                    xPos -= labelSize[1]*0.5
                    yPos -= 2*printerScale
                elif self.attributes['angle'] == 45:
                    xPos -= labelSize[1]*0.5
                    yPos -= labelSize[1]*0.5
                elif self.attributes['angle'] == 0:
                    xPos -= labelSize[0]*0.5
                    yPos -= labelSize[1]

                if self.attributes['showticks']:
                    yPos -= 5*printerScale
                else:
                    yPos -= 23*printerScale

                # draw labels
                dc.DrawRotatedText(label, xPos, yPos, self.attributes['angle'])
    # ----


class spectrumObject:
    """ Base class for spectrum lines drawing. """

    # ----
    def __init__(self, points, **attr):

        # convert spectrum points to array
        self.points = num.array(points)
        self.cropped = self.points
        self.scaled = self.cropped

        # set default params
        self.currentCrop = (0, 0)
        self.currentScale = (1, 1)
        self.currentShift = (0, 0)
        self.attributes = {
                            'show':True,
                            'xoffset': 0,
                            'yoffset': 0,
                            'colour': (0, 0, 0),
                            'width': 1,
                            'style': wx.SOLID,
                            'legend': ''
                            }

        # get new attributes
        for name, value in attr.items():
            self.attributes[name] = value

        # count spectrum offset
        if self.attributes['xoffset'] != 0 or self.attributes['yoffset'] != 0:
            offset = num.array([self.attributes['xoffset'], self.attributes['yoffset']])
            self.points = self.points + offset

        # set spectrum colour
        colour = self.attributes['colour']
        self.attributes['colour'] = wx.Colour(colour[0], colour[1], colour[2])
    # ----


    # ----
    def setParameters(self, **attr):
        """ Set new object attributes. """

        for name, value in attr.items():
            self.attributes[name] = value
    # ----


    # ----
    def getBoundingBox(self):
        """ Get spectrum bounding box. """

        if len(self.points) == 0 or not self.attributes['show']:
            return None
        else:
            minXY = num.minimum.reduce(self.points)
            maxXY = num.maximum.reduce(self.points)

            # check if no range in one axis
            if minXY[0] == maxXY[0]:
                minXY[0] -= 0.5
                maxXY[0] += 0.5
            if minXY[1] == maxXY[1]:
                maxXY[1] += 0.5

            return [minXY, maxXY]
    # ----


    # ----
    def getLegend(self):
        """ Get spectrum legend. """
        return (self.attributes['legend'], self.attributes['show'], self.attributes['colour'])
    # ----


    # ----
    def getPoint(self, xPos, userCoord=False):
        """ Get interpolated Y position from X (Interpolated simply from line). """

        # find Y coordinations
        pointsLen = len(self.points)
        if pointsLen != 0:

            # get relevant sub-part to speed-up the process
            startIndex = 0
            for i in range(0, pointsLen, 500):
                if self.points[i][0] > xPos:
                    break
                else:
                    startIndex = i

            # get index of nearest higher point
            index = 0
            for i in range(startIndex, pointsLen):
                if self.points[i][0] > xPos:
                    index = i
                    break

            # interpolate between two points
            x1 = self.points[index-1][0]
            y1 = self.points[index-1][1]
            x2 = self.points[index][0]
            y2 = self.points[index][1]
            yPos = y1 + ((xPos - x1) * (y2 - y1)/(x2 - x1))

            # get point values and coordinations
            point = []
            if userCoord:
                point.append(self.currentScale[0] * xPos + self.currentShift[0])
                point.append(self.currentScale[1] * yPos + self.currentShift[1])
            else:
                point.append(xPos)
                point.append(yPos)

        # no points in object
        else:
            if userCoord:
                point = [xPos, 0, 0, 0]
            else:
                point = [xPos, 0]

        return point
    # ----


    # ----
    def offsetPoints(self, xoffset, yoffset):
        """ Offset points. """

        # reverse offset
        offset = num.array([self.attributes['xoffset'], self.attributes['yoffset']])
        self.points = self.points - offset

        # set new offset
        offset = num.array([xoffset, yoffset])
        self.points = self.points + offset

        # store params
        self.attributes['xoffset'] = xoffset
        self.attributes['yoffset'] = yoffset

        # clear scaling params
        self.currentCrop = None
        self.currentScale = None
        self.currentShift = None
    # ----


    # ----
    def cropPoints(self, minX, maxX):
        """ Crop spectrum points to current view coordinations. """

        # get index of points in selection
        if len(self.points) != 0 and self.currentCrop != (minX, maxX):
            length = len( self.points)

            # interval halving
            int1 = [0,length]
            while int1[1] - int1[0] > 1:
                i1 = int( sum( int1) / 2)
                if self.points[i1][0] < minX:
                    int1[0] = i1
                else:
                    int1[1] = i1

            int2 = [0,length]
            while int2[1] - int2[0] > 1:
                i2 = int( sum( int2) / 2)
                if self.points[i2][0] > maxX:
                    int2[1] = i2
                else:
                    int2[0] = i2

            i1 = int1[0]
            i2 = int2[1]+1

            # set crop
            self.cropped = self.points[i1:i2]
            self.currentCrop = (minX, maxX)
    # ----


    # ----
    def scaleAndShift(self, scale=(1, 1), shift=(0, 0)):
        """ Scale and shift spectrum points. """

        if len(self.points) != 0 and (scale is not self.currentScale or shift is not self.currentShift):
            self.scaled = scale * self.cropped + shift
            self.currentScale = scale
            self.currentShift = shift
    # ----


    # ----
    def filterPoints(self, filterSize=1):
        """ Filter spectrum points for printing and exporting - delete all
        points which will not be visible in current resolution. """

        #check data
        if len(self.scaled) == 0:
            return

        # check fitersize
        if filterSize == None:
            filterSize = 1

        filteredPoints = []
        filteredPoints.append(self.scaled[0])
        lastX = filteredPoints[-1][0]
        lastY = filteredPoints[-1][1]
        minY = self.scaled[0][1]
        maxY = self.scaled[0][1]

        # filter points
        for point in self.scaled:

            if round(point[0] - lastX) >= filterSize:
                if lastY != maxY:
                    filteredPoints.append([lastX, maxY])
                if lastY != minY:
                    filteredPoints.append([lastX, minY])
                filteredPoints.append(point)
                lastX = point[0]
                lastY = point[1]
                maxY = point[1]
                minY = point[1]

            elif point[1] - lastY <= -1:
                minY = min(point[1], minY)
            elif point[1] - lastY >= 1:
                maxY = max(point[1], maxY)

        self.scaled = filteredPoints
    # ----


    # ----
    def draw(self, dc, printerScale=1):
        """ Define how to draw spectrum lines. """

        # escape if hidden object
        if not self.attributes['show'] or len(self.scaled) == 0:
            return

        # set pen params
        pen = wx.Pen(self.attributes['colour'], self.attributes['width']*printerScale, self.attributes['style'])
        dc.SetPen(pen)

        # draw lines
        dc.DrawLines(self.scaled)
    # ----


    # ----
    def drawGel(self, dc, gelCoords, gelHeight, printerScale):
        """ Define how to draw spectrum gel view. """

        # escape if hidden object or no data
        if not self.attributes['show'] or len(self.scaled) == 0:
            return False

        # get plot coordinates
        gelY1, plotX1, plotY1, plotX2, plotY2 = gelCoords
        yRange = plotY2 - plotY1

        # set color step
        step = yRange / 255
        if step == 0:
            return False

        # init pen
        pen = wx.Pen((255,255,255), printerScale, wx.SOLID)
        dc.SetPen(pen)

        # get first point and color
        lastX = round(self.scaled[0][0])
        lastY = 256
        maxY = 256

        # draw lines
        for point in self.scaled:

            # get point
            xPos = round(point[0])
            intens = round((point[1] - plotY1)/step)

            # check color range
            intens = min(intens, 255)
            intens = max(intens, 0)

            # filter points
            if xPos-lastX >= printerScale:

                # set color if different
                if lastY != maxY:
                    pen.SetColour(wx.Colour(maxY, maxY, maxY))
                    dc.SetPen(pen)

                # draw gel line
                dc.DrawLine(lastX, gelY1, lastX, gelY1 + gelHeight)

                # save last
                lastX = xPos
                lastY = maxY

                # set current color
                maxY = intens

            # get highest intensity
            maxY = min(intens, maxY)

        # draw legend rectangle
        wh = 5 * printerScale
        x = plotX2 - 12 * printerScale
        y = gelY1 + (gelHeight - wh)/2
        dc.SetPen(wx.TRANSPARENT_PEN)
        dc.SetBrush(wx.Brush(self.attributes['colour'], wx.SOLID))
        dc.DrawRectangle(x, y, wh, wh)

        return True
    # ----


class objectsContainer:
    """Container to hold plot objects and graph labels."""

    # ----
    def __init__(self, objects):
        self.objects = objects
    # ----


    # ----
    def getBoundingBox(self):
        """ Get container bounding box. """

        # init values if no data in objects
        rect = [num.array([0, 0]), num.array([1, 1])]

        # get bouding boxes from objects
        have = False
        for object in self.objects:
            oRect = object.getBoundingBox()
            if have and oRect:
                rect[0] = num.minimum(rect[0], oRect[0])
                rect[1] = num.maximum(rect[1], oRect[1])
            elif oRect:
                rect = oRect
                have = True

        # check scale
        if rect[0][0] == rect[1][0]:
            rect[0][0] -= 0.5
            rect[1][0] += 0.5
        if rect[0][1] == rect[1][1]:
            rect[1][1] += 0.5

        return rect
    # ----


    # ----
    def getLegend(self):
        """Get a list of legend names. """

        # get names
        names = []
        for object in self.objects:
            if isinstance(object, spectrumObject):
                names.append(object.getLegend())

        return names
    # ----


    # ----
    def getPoint(self, object, xPos, userCoord=False):
        """ Get point coordinations in selected object. """

        point = self.objects[object].getPoint(xPos, userCoord)
        return point
    # ----


    # ----
    def countVisibleObjects(self):
        """ Get number of visible objects. """

        count = 0
        for object in self.objects:
            if object.attributes['show']:
                count += 1

        return count
    # ----


    # ----
    def cropPoints(self, minX, maxX):
        """ Crop points in all objects in container. """

        for object in self.objects:
            object.cropPoints(minX, maxX)
    # ----


    # ----
    def scaleAndShift(self, scale, shift):
        """ Scale and shift points in all objects in container. """

        for object in self.objects:
            object.scaleAndShift(scale, shift)
    # ----


    # ----
    def filterPoints(self, filterSize=1):
        """ Filter points in all objects in container. """

        for object in self.objects:
            object.filterPoints(filterSize)
    # ----


    # ----
    def draw(self, dc, printerScale=1, reverse=False):
        """ Draw each object in container. """

        # draw in reverse order
        if reverse:
            self.objects.reverse()

        # draw objects
        for object in self.objects:
            object.draw(dc, printerScale)

        # reverse back order
        if reverse:
            self.objects.reverse()
    # ----


    # ----
    def drawGel(self, dc, gelCoords, gelHeight, printerScale):
        """ Draw gel view for objects in container. """

        # draw objects
        for object in self.objects[1:]:
            if object.drawGel(dc, gelCoords, gelHeight, printerScale):
                gelCoords[0] += gelHeight
    # ----


    # ----
    def __additem__(self, object):
        self.objects.append(object)
    # ----


    # ----
    def __delitem__(self, index):
        del self.objects[index]
    # ----


    # ----
    def __setitem__(self, index, object):
        self.objects[index] = object
    # ----


    # ----
    def __getitem__(self, index):
        return self.objects[index]
    # ----


    # ----
    def __len__(self):
        return len(self.objects)
    # ----
