# -*- coding: utf-8 -*-
"""
@author: Timothy A Wiseman

The article incorporates many of these tests.  However, not all are 
discussed in the article.

Not using timeit here since we want to clear out the cache before 
every single run.  Given the need to clear out the cache each time
and the desire to not include the time to clear the cache in the measurements
timeit is less than optimal.  Therefore using a timing wrapper instead.

This script is meant to be run after the tables are set up.

There is more on clearing the cache for SQL Server at:
    http://www.mssqltips.com/sqlservertip/1360/clearing-cache-for-sql-server-performance-testing/


Note that this script can take a long time to run.
"""

#import useful modules
import pyodbc #to connect to SQL Server
import matplotlib.pyplot as plt #graph the results
import numpy as np
import time

##########################################################
#Establish paramaters that will be used throughout the running of the script
#These are, in effect, settings.  If this were meant as a complete program
#I would shunt these off into a configuration file or make them user 
#selectable.  As a custom, test script it makes sense to just include them here.

#Adjust the connection string as appropriate for your situation
sqlConnStr = ('DRIVER={SQL Server Native Client 11.0};Server=YOURSERVER;Database=Test;'+
            'Trusted_Connection=YES') 

#Leaving off the schema from the table names for ease of
#labeling graphs.

tableNames = ['IntTbl', 'BigIntTbl', 'DecimalTbl']

#########################################################
#Create the utility functions

def time_wrapper(func):
    def wrapper(*arg, **kw):
        start = time.time()
        result = func(*arg, **kw)
        end = time.time()
        return (end - start), result
    return wrapper
    
def clearCache(curs):
    curs.execute('checkpoint')
    curs.execute('dbcc dropcleanbuffers')
    
def getRowCount(curs, tableName):
    """Gets the rowcount of the table named by tableName."""
    sql = """select count(*) 
            from {}""".format(tableName)
    curs.execute(sql)
    return curs.fetchone()
    
def getAggregates(curs, tableName):
    """Get the averages of the columns."""
    #Leaving off column 2 and 3 to avoid arithmetic overflow erros
    sql = """select 
                avg(col1) as avg1,
                count(col1) as count1,
                max(col1) as max1,
                min(col1) as min1,
                sum(col1) as sum1,
                stdev(col1) as stdev1,
                var(col1) as var1
            from {}
            where col1 <= 65000""".format(tableName)
    curs.execute(sql)
    return curs.fetchone()
    
def getEvery100(curs, tableName):
    """Selects a sampling of the data that requires frequent comparisons.
    """
    sql = """select col1, col2, col3
            from {}
            where col1%100 = 0""".format(tableName)
    curs.execute(sql)
    return curs.fetchall()
    
def getEvery100FromNonPk(curs, tableName):
    """Selects a sampling of the data that requires frequent comparisons.
    """
    sql = """select col1, col2, col3
            from {}
            where col2%100 = 0""".format(tableName)
    curs.execute(sql)
    return curs.fetchall()
    
def get25Selected(curs, tableName):
    """Selects a sampling of the data that requires frequent comparisons.
    """
    sql = """select col1, col2, col3
            from {}
            where col1 in (1, 12, 123, 1234, 12345, 123456, 1234567,
                           2, 23, 234, 2345, 2456, 234567, 2345678,
                           3, 34, 345, 3456, 34567, 345678, 3456789,
                           4, 45, 456, 4567)""".format(tableName)
    curs.execute(sql)
    return curs.fetchall()
    
def get100Selected(curs, tableName):
    """Selects a sampling of the data that requires frequent comparisons.
    """
    sql = """select col1, col2, col3
            from {}
            where col1 in (1, 12, 123, 1234, 12345, 123456, 1234567,
                           2, 23, 234, 2345, 2456, 234567, 2345678,
                           3, 34, 345, 3456, 34567, 345678, 3456789,
                           4, 45, 456, 4567, 4678, 46789, 467890,
                           5, 56, 567, 5678, 56789, 567890, 5678901,
                           6, 67, 678, 6789, 67890, 678901, 6789012,
                           7, 78, 789, 7890, 78901, 789012, 7890123,
                           8, 89, 890, 8901, 89012, 890123, 8901234,
                           9, 90, 901, 9012, 90123, 901234, 9012345,
                           21, 212, 2123, 21234, 212345, 2123456, 21234567,
                           22, 223, 2234, 22345, 22456, 2234567, 22345678,
                           23, 234, 2345, 23456, 234567, 2345678, 23456789,
                           24, 245, 2456, 24567, 24678, 246789, 2467890,
                           25, 256, 2567, 25678, 256789, 2567890, 25678901,
                           26, 267)""".format(tableName)
    curs.execute(sql)
    return curs.fetchall()
    

#This is repeated instead of adding a variable for the number of rows to pull
#in order to make it easier to use with makeGraphForTimes    

def makeGetTopXOrderDesc(numRows):
    """A function which generates other functions which will return
    the top numRows values from the table.  They are being generated this
    way instead of just making numRows a variable of the function itself to
    make it easier to work with the makeGraphForTimes which will
    take these functions as a paramter."""
    
    def getTopXOrderDesc(curs, tableName):
        sql = """select top {} *
                 from {}
                 order by col1 desc""".format(numRows, tableName)
        curs.execute(sql)
        return curs.fetchall()
    return getTopXOrderDesc
    
getTop1000OrderDesc = makeGetTopXOrderDesc(1000)
getTop10000OrderDesc = makeGetTopXOrderDesc(10000)
getTop100000OrderDesc = makeGetTopXOrderDesc(100000)

   
def makeGraphForTimes(funcToTime, curs, tableNames, graphFileName, graphTitle='', numReps = 6):
    """Times the repeated iteration of the function and makes a graph out of it.
    The function needs to take the paramaters (cursor, tableName) or this will generate an error."""
    timedFunc = time_wrapper(funcToTime)
    resultsDict = {}
    for tableName in tableNames: 
        resultsDict[tableName] = []
        for i in range(numReps):
            clearCache(curs)
            time.sleep(2) #give computer time for any background processes
            thisIterTime, result = timedFunc(curs, tableName)
            resultsDict[tableName].append(thisIterTime)
    #print resultsDict #used for testing
    avgs = {x: np.mean(resultsDict[x]) for x in resultsDict}
    plt.figure()
    width = .6
    plt.bar(np.arange(len(avgs)), avgs.values(), align = 'center')
    plt.xticks(np.arange(len(avgs) + width/2), avgs.keys(), rotation = 17, 
               size = 'small')
    plt.ylabel('Time in Seconds')
    plt.title(graphTitle)
    plt.savefig(graphFileName)
    plt.close('all')
    #Some print statements to help verify results make sense.
    print 'Results for {}'.format(graphFileName)
    for tableName in resultsDict:
        print 'For {} max {} min {} median {}'.format(tableName, 
                    max(resultsDict[tableName]), min(resultsDict[tableName]),
                    np.median(resultsDict[tableName]))
    print ''
    
            


########################################################
#Execute the main script

if __name__ == '__main__':
    #autocommit 
    mainStart = time.time()
    
    defaultReps = 50
    sqlConn = pyodbc.connect(sqlConnStr, autocommit = True)
    curs = sqlConn.cursor()
    makeGraphForTimes(get25Selected, curs, tableNames, '25Selected.jpg', 'Get 25 Specific Values', numReps = defaultReps)
    makeGraphForTimes(get100Selected, curs, tableNames, '100Selected.jpg', 'Get 100 Specific Values', numReps = defaultReps)
    makeGraphForTimes(getRowCount, curs, tableNames, 'RowCountTest.jpg', 'Row Count Execution Time', numReps = defaultReps)
    makeGraphForTimes(getAggregates, curs, tableNames, 'AggTest.jpg', 'Execution Time For Aggregates', numReps = defaultReps)
    makeGraphForTimes(getEvery100, curs, tableNames, 'DataFromIndex.jpg', 'Every 100th Value', numReps = defaultReps)
    makeGraphForTimes(getEvery100FromNonPk, curs, tableNames, 'DataWithoutIndex.jpg', 'Every 100th Value Without Index', numReps = defaultReps)
    makeGraphForTimes(getTop1000OrderDesc, curs, tableNames, 'Top1000OrderDesc.Jpg', 'Top 1000 Descending', numReps = defaultReps)
    makeGraphForTimes(getTop10000OrderDesc, curs, tableNames, 'Top10000OrderDesc.Jpg', 'Top 10000 Descending', numReps = defaultReps)
    makeGraphForTimes(getTop100000OrderDesc, curs, tableNames, 'Top100000OrderDesc.Jpg', 'Top 100000 Descending', numReps = defaultReps)    
    
    mainEnd = time.time()
    print 'Running time approx {} seconds'.format(mainEnd - mainStart)