﻿# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl


def createGraph(filename, *args, title=None,xdesc=None, width=0.5, space=0.1, groupSpace=5, picSize=(11,4), groupSpaceMult=True, barDescOri="horizontal"):
  """
    Create graph and save it in filename. Format of image is PNG.
    filename - path to new file with graph
    args - dictionary contain name, color (optionaly), values (list).
    title - title of graph
    xdesc - description of X axis
    width - width of one bar
    space - space between bars
    groupSpace - space between group of graphs
    picSize - size of graph in inches
    groupSpaceMult - if true then groupSpace = space * groupSpace
    barDescOri - orientation of bar description
  """
  
  #list of predefined colors
  colors=["#FF8800", "#1CED77", "#F00C0C", "#D1DB00", "#1A9E00", "#1CCAED"]
  groupCount=0
  
  if(groupSpaceMult):
    groupSpace=space*groupSpace
  groupWidth=len(args)*(width+space)-space
  
  #check parameters
  if(len(args) == 0):
    raise TypeError("Missing 1 required positional argument!")
  if(type(picSize) is not tuple):
    raise TypeError("Expecting tuple for picSize parameter.")
  if(len(picSize) != 2):
    raise ValueError("picSize must have 2 items")
  if(type(filename) is not str):
    raise TypeError("Expecting string for filename")
    
  for arg in args:
    if(type(arg) is not dict):
      raise TypeError("Expecting in params dictionary")
    if("values" not in arg):
      raise ValueError("Dictionary with data must contain \"values\" item.")
    if("name" not in arg):
      raise ValueError("Dictionary with data must contain \"name\" item.")
    if(type(arg["name"]) is not str):
      raise TypeError("Expecting string for \"name\" item")
    if(type(arg["values"]) is not list):
      raise TypeError("Expecting list for \"values\" item")
    
    #prepare color
    if("color" in arg):
      if(type(arg["color"]) is not str):
        raise TypeError("Expecting string for \"color\" item")
    elif (len(colors)>0):
      arg["color"]=colors.pop()
    else:
      arg["color"]="#0066ff"
    
    #check count of groups    
    if(groupCount == 0):
      groupCount=len(arg["values"])
    elif(groupCount != len(arg["values"])):
      raise ValueError("Expecting "+ str(groupCount) +" values, get "+ len(arg["values"]) +"!")
  
  
  #init max value
  maxVal = max(args[0]["values"])
  #create new figure
  fig=plt.figure(figsize=picSize, dpi=100)
  
  mpl.rcParams['font.family'] = "Arial"
  
  if(title):
    fig.suptitle(title)
  #Position for bars group on X axis  
  X=np.resize(np.arange(0,(groupWidth+groupSpace)*(groupCount+1)-groupSpace,groupWidth+groupSpace),groupCount)
  #cycle over all data
  barDesc=[]
  for i,data in zip(range(len(args)),args):
    if("stdDeviation" in data):
      stdDeviation=data["stdDeviation"];
    else:
      stdDeviation=None
    #create bar
    plt.bar(np.array(X+(width+space)*i), data["values"], width=width, color=data["color"], label=data["name"], yerr=stdDeviation, ecolor='black')
    #create description of bar
    if(stdDeviation is not None):
      for x,y,err in zip(X,data["values"],data["stdDeviation"]):
        barDesc.append(plt.text((x+width/2)+(i*(width+space)), y+err, '%.2f' % y, ha='center', va= 'bottom', size=9, rotation=barDescOri))
    else:
      for x,y in zip(X,data["values"]):
        barDesc.append(plt.text((x+width/2)+(i*(width+space)), y, '%.2f' % y, ha='center', va= 'bottom', size=9, rotation=barDescOri))
    if(stdDeviation is None and maxVal < max(data["values"])):
      maxVal=max(data["values"])
    elif(stdDeviation is not None and maxVal < max(np.array(data["values"])+np.array(data["stdDeviation"]))):
      maxVal=max(np.array(data["values"])+np.array(data["stdDeviation"]))
    stdDeviation=None
  
  for text in barDesc:
    pos=list(text.get_position())
    pos[1]=pos[1]+0.02*maxVal
    text.set_position(tuple(pos))
    
  #remove ticks from Y axis
  #plt.yticks([])
  
  #set black border for graph
  ax=plt.gca()
  ax.spines['right'].set_color('black')
  ax.spines['left'].set_color('black')
  ax.spines['top'].set_color('black')
  ax.spines['bottom'].set_color('black')
  #remove "spikes" on X axis
  ax.xaxis.set_ticks_position('none')
  plt.ylabel("Čas [s]")
  plt.gca().yaxis.grid(True, color="#AFAFAF")
  plt.gca().set_axisbelow(True) 
  plt.gca().xaxis.grid(False)
  
  #replace number with description (if set)
  if(xdesc is not None and len(xdesc) == groupCount):
    plt.xticks(X+groupWidth/2,xdesc, size=9)
  else:
    plt.xticks(X+groupWidth/2, size=9)
  
  lgd=plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
       ncol=2, mode="expand", borderaxespad=0., borderpad=1, fontsize=10)
  
  #set limit for X and Y axis
  plt.ylim(0,maxVal*1.2)
  plt.xlim(-groupSpace,(groupWidth+groupSpace)*groupCount)
  #save graph
  plt.savefig(filename, bbox_extra_artists=(lgd, ), bbox_inches='tight', format="svg")

#Example data  
#createGraph("a.png",{"name":"PHP", "values":[3.5,3.4,3.3,1.2,3.5]},{"name":"Python2", "values":[3.6,4.4,2.3,3.2,1.1]},{"name":"Pytho3", "values":[3.8,2.4,3.9,4.4,2.5]},{"name":"Lua", "values":[3.1,2.34,3.32,4.21,3.1]}, title="Graph", xdesc=["Sort", "Search", "Tree", "Other", "Crypto"], width=3, space=1)
