/*
								+---------------------------------+
								|                                 |
								| ***  Expression evaluation  *** |
								|                                 |
								|  Copyright   -tHE SWINe- 2011  |
								|                                 |
								|            ExpEval.h            |
								|                                 |
								+---------------------------------+
*/

#pragma once
#ifndef __EXPRESSION_EVALUATOR_INCLUDED
#define __EXPRESSION_EVALUATOR_INCLUDED

/**
 *	@file ExpEval.h
 *	@author -tHE SWINe-
 *	@brief simple expression evaluation template
 *	@date 2011
 *
 *	@date 2014-04-10
 *
 *	Added integer exception handling, improved documentation.
 */

#include <string>
#include <map>
#include <set>
#include <utility>
#include <math.h>
#include <stdio.h>
#include <string.h>
#include "Integer.h"

/**
 *	@def __EXPRESSION_EVALUATOR_COMPILE_TESTS
 *	@brief if defined, CExpEvalTester is compiled
 */
#define __EXPRESSION_EVALUATOR_COMPILE_TESTS

namespace exp_eval {

class CExpEvalLexer; // forward declaration

/**
 *	@brief static assertion helper class
 *	@tparam b_expression is value of the expression being asserted
 */
template <bool b_expression>
struct CStaticAssert {
	typedef void SUPPLIED_TYPE_IS_NOT_INTEGER_TYPE; /**< @brief static assertion tag; the _TyInt parameter of CExpression is not an integer */
};

/**
 *	@brief static assertion helper class (specialization for assertion failed)
 */
template <>
struct CStaticAssert<false> {};

} // ~exp_eval

/**
 *	@brief very simple single-type expression evaluation
 *
 *	@tparam _Ty is data type for expression values
 *	@tparam _TyInt is data type for bitwise expressions to be carried out in
 *	@tparam b_integer_exception_handling is explicit integer exception handling flag
 *		(if set, evaluating "1 / 0" throws std::runtime_error, otherwise it is left
 *		to the default exception handling, which should terminate the whole process)
 *
 *	This is implemented using perhaps the simplest recursive descent parser
 *	possible, the generated syntax tree is evaulated directly (each node have
 *	up to three subnodes and pointer to an evaluation function).
 *
 *	This can be used as follows:
 *
 *@code
 *	CFloatExpression eeval("2 * pi() * r");
 *	if(!eeval.b_Parsed())
 *		fail(); // there was an error in the expression
 *	eeval.SetVariable("r", 1);
 *	float f_circumference = eeval.t_Evaluate();
 *@endcode
 *
 *	Operator support (along with associativity):
 *
 *	LTR
 *		++	Suffix increment
 *		--	Suffix decrement
 *		[]	Array subscripting
 *
 *		(() function call is here)
 *
 *		.	Element selection by reference
 *		->	Element selection through pointer
 *		typeid()	Run-time type information (C++ only) (see typeid)
 *		const_cast	Type cast (C++ only) (see const cast)
 *		dynamic_cast	Type cast (C++ only) (see dynamic_cast)
 *		reinterpret_cast	Type cast (C++ only) (see reinterpret cast)
 *		static_cast	Type cast (C++ only) (see static cast)
 *
 *	RTL
 *		++	Prefix increment
 *		--	Prefix decrement
 *
 *		(unary plus / minus is here)
 *		(! logical not is here)
 *		(~ logical bitwise not is here)
 *
 *		(type)	Type cast
 *		*	Indirection (dereference)
 *		&	Address-of
 *		sizeof	Size-of
 *		new, new[]	Dynamic memory allocation (C++ only)
 *		delete, delete[]	Dynamic memory deallocation (C++ only)
 *
 *	LTR
 *		.*	Pointer to member (C++ only)
 *		->*	Pointer to member (C++ only)
 *
 *		(mul / div / mod is here)
 *		(add / sub is here)
 *		(<< / >> is here)
 *		(< / <= / > / >= is here)
 *		(== / != is here)
 *		(& is here)
 *		(^ is here)
 *		(| is here)
 *		(&&	is here)
 *		(^^	is here)
 *		(||	is here)
 *
 *	RTL
 *		(?: is here)
 *
 *		=	Direct assignment (provided by default for C++ classes)
 *		+=	Assignment by sum
 *		-=	Assignment by difference
 *		*=	Assignment by product
 *		/=	Assignment by quotient
 *		%=	Assignment by remainder
 *		<<=	Assignment by bitwise left shift
 *		>>=	Assignment by bitwise right shift
 *		&=	Assignment by bitwise AND
 *		^=	Assignment by bitwise XOR
 *		|=	Assignment by bitwise OR
 *
 *		throw	Throw operator (exceptions throwing, C++ only)
 */
template <class _Ty, class _TyInt = _Ty, const bool b_integer_exception_handling = true>
class CExpression {
public:
	typedef _Ty TDataType; /**< @brief expression data type */
	typedef _TyInt TIntegerDataType; /**< @brief integer data type (for integer-only operations such as binary operations) */

	/**
	 *	@brief limits, stored as enum
	 */
	enum {
		max_FunctionParam_Num = 3 /**< @brief maximum number of custom function arguments */
	};

	/**
	 *	@brief configuration properties, stored as enum
	 */
	enum {
		type_IsIntegerDataType = _Ty(.5) == _Ty(0), /**< @brief integer main data type flag */
		type_IsSignedDataType = _Ty(-1) < _Ty(0), /**< @brief signed main data type flag */
		typeInt_IsIntegerDataType = _TyInt(.5) == _TyInt(0), /**< @brief integer aux data type flag */
		integer_ExceltionHandling = b_integer_exception_handling /**< @brief explicit integer exception handling flag (used for integer division by zero - only if _Ty is an integer type) */
	};

	typedef typename std::map<std::string, _Ty>::iterator TVariableIter; /**< @brief variable iterator type */
	typedef typename std::map<std::string, _Ty>::const_iterator TVariableConstIter; /**< @brief variable const iterator type */

	/**
	 *	@brief parse tree node
	 */
	struct TNode {
		TNode *p_node[max_FunctionParam_Num]; /**< @brief tree subnodes */
		_Ty (*p_eval)(const TNode &r_this); /**< @brief evaluation function */
		union {
			_Ty t_value; /**< @brief value (const expression) */
			_Ty *p_value; /**< @brief value (pointer to a variable) */
		};

		/**
		 *	@brief default constructor; assigns subnodes
		 *
		 *	@param[in] p_left is the first subnode (left operand in a binary
		 *		operation or operand of unary operation)
		 *	@param[in] p_right is the second subnode (right operand in a binary operation)
		 *	@param[in] p_third is third subnode (second right operand in ternary
		 *		operator or a function argument)
		 *
		 *	@note The subnodes will be deleted by this node destructor.
		 */
		TNode(TNode *p_left = 0, TNode *p_right = 0, TNode *p_third = 0);

		/**
		 *	@brief destructor; deletes this node and all subnodes
		 */
		~TNode();

		/**
		 *	@brief counts all the nodes in the tree
		 *	@return Returns number of the subnodes plus one for this node.
		 */
		size_t n_Node_Num() const;

		/**
		 *	@brief eliminates constant expressions
		 *
		 *	@param[in] p_eval_var_ptr is pointer for variable evaluation function
		 *		(required for const-ness checking)
		 *	@param[in] p_eval_const_ptr is pointer for constant evaluation function
		 *		(required for evaluation of the optimized nodes)
		 *
		 *	@note This function throws std::runtime_error if explicit integer exception
		 *		handling is enabled (see integer_ExceltionHandling).
		 */
		void Eliminate_ConstExpressions(_Ty (*p_eval_var_ptr)(const TNode &r_this),
			_Ty (*p_eval_const_ptr)(const TNode &r_this)); // throw(std::runtime_error)

		/**
		 *	@brief determines if this node represents a cons
		 *	@param[in] p_eval_var_ptr is pointer for variable evaluation function
		 *		(required for const-ness checking)tant expression
		 *	@return Returns true if this node evaluates to a constant value,
		 *		otherwise returns false.
		 */
		bool b_IsConstant(_Ty (*p_eval_var_ptr)(const TNode &r_this));

		/**
		 *	@brief evaluates child node
		 *	@param[in] n_child_node_index is zero-based index of the child node
		 *	@return Returns numerical value
		 *	@note This function throws std::runtime_error if explicit integer exception
		 *		handling is enabled (see integer_ExceltionHandling).
		 */
		inline _Ty t_EvalChild(size_t n_child_node_index) const // throw(std::runtime_error)
		{
			_ASSERTE(n_child_node_index < max_FunctionParam_Num && p_node[n_child_node_index]);
			return p_node[n_child_node_index]->p_eval(*p_node[n_child_node_index]);
		}

	private:
		TNode(const TNode &r_other); // no-copy
		TNode &operator =(const TNode &r_other); // no-copy
	};

protected:
	typedef typename exp_eval::CStaticAssert<typeInt_IsIntegerDataType>::SUPPLIED_TYPE_IS_NOT_INTEGER_TYPE CAssert0; /**< @brief make sure that _TyInt is really an integer */
	typedef exp_eval::CExpEvalLexer CExpEvalLexer; /**< @brief unwrap the lexer type from its namespace to shorten the code below */

	/**
	 *	@brief user function entry
	 */
	struct TFuncEntry {
		size_t n_length; /**< @brief length of the name (precomputed to make parsing faster) */
		const char *p_s_name; /**< @brief null-terminated string, containing the function name */
		_Ty (*p_func)(const TNode&); /**< @brief pointer to the function implementation */
		size_t n_arg_num; /**< @brief number of function arguments (for semantic checking) */
	};

	TNode *m_p_tree; /**< @btief expression tree */
	std::map<std::string, _Ty> m_symbol_table; /**< @brief table of symbols */
	std::vector<TFuncEntry> m_function_table; /**< @brief table of functions */
	std::set<std::string> m_function_name_table; /**< @brief set of function names (for fast checking) */

public:
	/**
	 *	@brief default constructor; has no effect
	 */
	inline CExpression()
		:m_p_tree(0)
	{}

	/**
	 *	@brief constructor; parses a given expression
	 *	@param[in] p_s_expression is a null-terminated string with the expression to be evaluated
	 *	@note This doesn't reflect success of parsing, it is possible to call b_Parsed()
	 *		to see if it was successful.
	 */
	CExpression(const char *p_s_expression);

	/**
	 *	@brief destructor
	 */
	~CExpression();

	/**
	 *	@brief determines whether the expression was parsed successfully
	 *	@return Returns true if the expression was parsed successfully, otherwise returns false.
	 */
	inline bool b_Parsed() const
	{
		return m_p_tree != 0;
	}

	/**
	 *	@brief registers a new callable function
	 *
	 *	This allows users to add new functions for expression evaluation. These functions
	 *	must be deterministic, otherwise n_Optimize() will break the expression. The function
	 *	executes with a const node reference as its environment. Values of the function
	 *	arguments are found by calling TNode::t_EvalChild().
	 *
	 *	@param[in] p_s_name is function name (case sensitive)
	 *	@param[in] n_arg_num is number of arguments (must not exceed max_FunctionParam_Num)
	 *	@param[in] p_func is pointer to a function
	 *
	 *	@return Returns true on success, false on failure (function
	 *		already exists, too many arguments, or not enough memory).
	 *
	 *	@note This must be called before parsing (it is not possible to use parsing in constructor
	 *		in conjunction with user functions).
	 *	@note The functions can shadow the built-in functions, it is possible to redefine e.g. sin().
	 */
	bool RegisterFunction(const char *p_s_name, size_t n_arg_num, _Ty (*p_func)(const TNode&));

	/**
	 *	@brief parses a given expression
	 *	@param[in] p_s_expression is a null-terminated string with the expression to be evaluated
	 *	@return Returns true if the expression was parsed successfully, otherwise returns false.
	 */
	bool Parse(const char *p_s_expression);

	/**
	 *	@brief optimizes the expression tree by removing constant nodes
	 *	@return Returns number of nodes removed by the optimization, or -1 on failure.
	 *	@note This function throws std::runtime_error if explicit integer exception
	 *		handling is enabled (see integer_ExceltionHandling).
	 */
	size_t n_Optimize(); // throw(std::runtime_error)

	/**
	 *	@brief gets number of variables
	 *	@return Returns number of variables found in the parsed expression.
	 *	@note Values of all the variables are initially zero.
	 */
	inline size_t n_Variable_num() const
	{
		return m_symbol_table.size();
	}

	/**
	 *	@brief determines whether the expression has variables
	 *	@return Returns true if the expression has variables, otherwise false.
	 */
	inline bool b_HaveVariables() const
	{
		return !m_symbol_table.empty();
	}

	/**
	 *	@brief gets iterator to the first variable
	 *	@return Returns iterator to the first variable.
	 */
	inline TVariableIter p_First_Variable_it() // @todo - doc these undoc'd
	{
		return m_symbol_table.begin();
	}

	/**
	 *	@brief gets iterator to the last variable
	 *	@return Returns iterator pointing one past the last variable.
	 */
	inline TVariableIter p_Last_Variable_it()
	{
		return m_symbol_table.end();
	}

	/**
	 *	@brief gets iterator to the first variable
	 *	@return Returns const iterator to the first variable.
	 */
	inline TVariableConstIter p_First_Variable_it() const
	{
		return m_symbol_table.begin();
	}

	/**
	 *	@brief gets iterator to the last variable
	 *	@return Returns const iterator pointing one past the last variable.
	 */
	inline TVariableConstIter p_Last_Variable_it() const
	{
		return m_symbol_table.end();
	}

	/**
	 *	@brief sets value of a variable
	 *
	 *	@param[in] p_s_symbol_name is name of the variable (case sensitive)
	 *	@param[in] t_value is value to set the variable to
	 *
	 *	@return Returns true if the value was set, false if it was not found.
	 */
	bool SetVariable(const char *p_s_symbol_name, _Ty t_value);

	/**
	 *	@brief evaluates the expression
	 *	@return Returns the value of the expression.
	 *	@note This function throws std::runtime_error if explicit integer exception
	 *		handling is enabled (see integer_ExceltionHandling).
	 */
	_Ty t_Evaluate() const; // throw(std::runtime_error)

protected:
	TNode *p_Parse_La(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Lb(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Lc(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Ld(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Le(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Lf(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Lg(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Lh(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Li(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Lj(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Lk(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Ll(CExpEvalLexer &r_lexer);
	TNode *p_Parse_Lm(CExpEvalLexer &r_lexer);
	static _Ty EvalAdd(const TNode &r_node)				{	return r_node.t_EvalChild(0) + r_node.t_EvalChild(1);	}
	static _Ty EvalSub(const TNode &r_node)				{	return r_node.t_EvalChild(0) - r_node.t_EvalChild(1);	}
	static _Ty EvalNeg(const TNode &r_node)				{	return -r_node.t_EvalChild(0);	}
	static _Ty EvalBitNeg(const TNode &r_node)			{	return _Ty(~_TyInt(r_node.t_EvalChild(0)));	}
	static _Ty EvalBoolNeg(const TNode &r_node)			{	return !r_node.t_EvalChild(0);	}
	static _Ty EvalMul(const TNode &r_node)				{	return r_node.t_EvalChild(0) * r_node.t_EvalChild(1);	}
	static _Ty EvalDiv(const TNode &r_node);  // throw(std::runtime_error)
	static _Ty EvalMod(const TNode &r_node); // throw(std::runtime_error)
	static _Ty EvalNull(const TNode &UNUSED(r_node))	{	return _Ty(0);	}
	static _Ty EvalConst(const TNode &r_node)			{	return r_node.t_value;	}
	static _Ty EvalTernary(const TNode &r_node)			{	return (r_node.t_EvalChild(0))? r_node.t_EvalChild(1) : r_node.t_EvalChild(2);	}
	static _Ty EvalBoolOr(const TNode &r_node)			{	return r_node.t_EvalChild(0) || r_node.t_EvalChild(1);	}
	static _Ty EvalBoolAnd(const TNode &r_node)			{	return r_node.t_EvalChild(0) && r_node.t_EvalChild(1);	}
	static _Ty EvalBoolXor(const TNode &r_node)			{	return !r_node.t_EvalChild(0) != !r_node.t_EvalChild(1);	}
	static _Ty EvalBitOr(const TNode &r_node)			{	return _Ty(_TyInt(r_node.t_EvalChild(0)) | _TyInt(r_node.t_EvalChild(1)));	}
	static _Ty EvalBitAnd(const TNode &r_node)			{	return _Ty(_TyInt(r_node.t_EvalChild(0)) & _TyInt(r_node.t_EvalChild(1)));	}
	static _Ty EvalBitXor(const TNode &r_node)			{	return _Ty(_TyInt(r_node.t_EvalChild(0)) ^ _TyInt(r_node.t_EvalChild(1)));	}
	static _Ty EvalIsEqual(const TNode &r_node)			{	return r_node.t_EvalChild(0) == r_node.t_EvalChild(1);	}
	static _Ty EvalIsNotEqual(const TNode &r_node)		{	return r_node.t_EvalChild(0) != r_node.t_EvalChild(1);	}
	static _Ty EvalIsLess(const TNode &r_node)			{	return r_node.t_EvalChild(0) < r_node.t_EvalChild(1);	}
	static _Ty EvalIsGreater(const TNode &r_node)		{	return r_node.t_EvalChild(0) > r_node.t_EvalChild(1);	}
	static _Ty EvalIsLessEqual(const TNode &r_node)		{	return r_node.t_EvalChild(0) <= r_node.t_EvalChild(1);	}
	static _Ty EvalIsGreaterEqual(const TNode &r_node)	{	return r_node.t_EvalChild(0) >= r_node.t_EvalChild(1);	}
	static _Ty EvalShL(const TNode &r_node)				{	return _Ty(_TyInt(r_node.t_EvalChild(0)) << _TyInt(r_node.t_EvalChild(1)));	}
	static _Ty EvalShR(const TNode &r_node)				{	return _Ty(_TyInt(r_node.t_EvalChild(0)) >> _TyInt(r_node.t_EvalChild(1)));	}
	static _Ty EvalVar(const TNode &r_node);
	static _Ty EvalPi(const TNode &UNUSED(r_node))		{	return _Ty(3.1415926535897932384626433832795028841971697510);	}
	static _Ty EvalLn(const TNode &r_node)				{	return _Ty(log(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalLog(const TNode &r_node)				{	return _Ty(log10(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalSqr(const TNode &r_node);
	static _Ty EvalAbs(const TNode &r_node);
	static _Ty EvalExp(const TNode &r_node)				{	return _Ty(exp(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalFloor(const TNode &r_node)			{	return _Ty(floor(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalCeil(const TNode &r_node)			{	return _Ty(ceil(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalRound(const TNode &r_node)			{	return _Ty(floor(r_node.t_EvalChild(0) + .5));	}
	static _Ty EvalFract(const TNode &r_node);
	static _Ty EvalMin(const TNode &r_node);
	static _Ty EvalMax(const TNode &r_node);
	static _Ty EvalMix(const TNode &r_node);
	static _Ty EvalStep(const TNode &r_node)			{	return (r_node.t_EvalChild(1) >= r_node.t_EvalChild(0))? _Ty(1) : _Ty(0);	}
	static _Ty EvalSmoothStep(const TNode &r_node);
	static _Ty EvalClamp(const TNode &r_node);
	static _Ty EvalSin(const TNode &r_node)				{	return _Ty(sin(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalCos(const TNode &r_node)				{	return _Ty(cos(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalTan(const TNode &r_node)				{	return _Ty(tan(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalASin(const TNode &r_node)			{	return _Ty(asin(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalACos(const TNode &r_node)			{	return _Ty(acos(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalATan(const TNode &r_node)			{	return _Ty(atan(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalATan2(const TNode &r_node)			{	return _Ty(atan2(double(r_node.t_EvalChild(0)), double(r_node.t_EvalChild(1))));	}
	static _Ty EvalPow(const TNode &r_node)				{	return _Ty(pow(double(r_node.t_EvalChild(0)), double(r_node.t_EvalChild(1))));	}
	static _Ty EvalSqrt(const TNode &r_node)			{	return _Ty(sqrt(double(r_node.t_EvalChild(0))));	}
	static _Ty EvalSign(const TNode &r_node)			{	return (r_node.t_EvalChild(0) < 0)? _Ty(-1) : _Ty(1);	}

private:
	CExpression(const CExpression &r_other); // no-copy, use pointers if needed
	CExpression &operator =(const CExpression &r_other); // no-copy, use pointers if needed
};

#include "ExpEval.inl"
// function definitions go first, before CIntExpression, CFloatExpression
// and CDoubleExpression are declared

/**
 *	@brief double-precission floating-point expression evaluation
 */
typedef CExpression<double, int> CDoubleExpression;

/**
 *	@brief (single-precission) floating-point expression evaluation
 */
typedef CExpression<float, int> CFloatExpression;

/**
 *	@brief integer expression evaluation
 */
typedef CExpression<int> CIntExpression;

#endif // !__EXPRESSION_EVALUATOR_INCLUDED
