#ifndef __BITWISE_LOAD_STORE_MACROS_INCLUDED
#define __BITWISE_LOAD_STORE_MACROS_INCLUDED

/**
 *	@file gpgpu/kernel_utils/BitLoadStore.h
 *	@brief bitwise load / store / pack primitives for tiled computing
 *	@date 2016
 *	@author -tHE SWINe-
 */

/**
 *	@def GLOBAL_TO_LOCAL_DECODE_PACK_BITARRAY
 *
 *	@brief cooperatively read bit array (unpacked) from global memory (using coalesced accesses)
 *		and pack it to shared memory
 *
 *	@param[out] p_shared_packed_bits is storage for the packed bit array in natural order
 *	@param[in] n_tid is thread id
 *	@param[in] n_thread_num is number of threads
 *	@param[in] n_bit_per_thread_num is number of bits per thread
 *	@param[in] p_data is pointer to input values (any data type, nonzero decoded as a 1, zero as a 0 bit)
 *	@param[in] n_data_num is number of input values in the buffer (but only up to
 *		n_thread_num * n_value_per_thread_num are actually read)
 *	@param[in] n_padding_value is value of the input buffer assumed after the end of the array
 *		(still passes through the decoding - is not a bit value!)
 */
#define GLOBAL_TO_LOCAL_DECODE_PACK_BITARRAY(p_shared_packed_bits,n_tid,n_thread_num,n_bit_per_thread_num,p_data,n_data_num,n_padding_value) \
	do { \
		enum { \
			n_word_size = 8 * sizeof(*(p_shared_packed_bits)), /* not p_data, that can be int, uint8_t or e.g. double */ \
			n_log_word_size = n_Log2(n_word_size), \
			n_word_mask = n_word_size - 1 \
		}; \
		\
		STATIC_ASSERT(n_word_size == WARP_SIZE, DECODE_PACK_BITARRAY_WARP_SIZE_MISMATCH); /* ballot wouldn't do the job */ \
		\
		/*const unsigned int n_lane = (n_tid) & (WARP_SIZE - 1);*/ \
		\
		if((n_data_num) >= (n_bit_per_thread_num) * (n_thread_num)) { \
			_Pragma("unroll") \
			for(int i = 0; i < (n_bit_per_thread_num); ++ i) { \
				unsigned int n_src = (n_thread_num) * i + (n_tid), n_dest = n_src >> n_log_word_size; \
				uintwarp_t n_ballot = n_warp_ballot((p_data)[n_src] != 0); /* coalesced access */ \
				/*if(!n_lane) // avoid bank conflicts // having this here breaks the code on K40 */ \
					(p_shared_packed_bits)[n_dest] = n_ballot; \
			} \
		} else { \
			_Pragma("unroll") \
			for(int i = 0; i < (n_bit_per_thread_num); ++ i) { \
				unsigned int n_src = (n_thread_num) * i + (n_tid), n_dest = n_src >> n_log_word_size; \
				bool b_head_flag = n_padding_value != 0; /* shorter branch */ \
				if(n_src < (n_data_num)) \
					b_head_flag = (p_data)[n_src] != 0; \
				uintwarp_t n_ballot = n_warp_ballot(b_head_flag); \
				/*if(!n_lane) // avoid bank conflicts // having this here breaks the code on K40 */ \
					(p_shared_packed_bits)[n_dest] = n_ballot; \
			} \
		} \
	} while (0)

/**
 *	@def REGISTER_TO_LOCAL_DECODE_PACK_STRIDED_BITARRAY
 *
 *	@brief decode bit array stored (unpacked) in registers and write strided packed bit array to shared memory
 *
 *	@param[out] p_shared_packed_bits is storage for the packed bit array in strided order
 *		(all threads first bit, then all threads second bit and so on) in shared memory
 *	@param[in] n_warp is thread warp index (<tt>tid / WARP_SIZE</tt>)
 *	@param[in] n_lane is thread warp lane (<tt>tid % WARP_SIZE</tt>)
 *	@param[in] n_thread_num is number of threads
 *	@param[in] n_bit_per_thread_num is number of bits per thread
 *	@param[in] p_thread_data is pointer to input values in registers (any data type, nonzero decoded as a 1, zero as a 0 bit)
 */
#define REGISTER_TO_LOCAL_DECODE_PACK_STRIDED_BITARRAY(p_shared_packed_bits,n_warp,n_lane,n_thread_num,n_bit_per_thread_num,p_thread_data) \
	do { \
		enum { \
			_n_bit_num = (n_bit_per_thread_num), /* make sure it is compile time const */ \
			n_word_size = 8 * sizeof(*(p_shared_packed_bits)), \
			n_log_warp_size = n_Log2(WARP_SIZE), \
			n_warp_num = (n_thread_num) / WARP_SIZE \
		}; \
		\
		/*const unsigned int n_warp = (n_tid) >> n_log_warp_size;*/ \
		\
		_Pragma("unroll") \
		for(int i = 0; i < (n_bit_per_thread_num); ++ i) { \
			uintwarp_t n_ballot = n_warp_ballot((p_thread_data)[i] != 0); /* decode the flag */ \
			/*if(!n_lane) // seen to cause trouble on K40 */ \
				(p_shared_packed_bits)[n_warp + n_warp_num * i] = n_ballot; \
		} \
	} while(0)

/**
 *	@def REGISTER_TO_LOCAL_DECODE_PACK_AND_READ_STRIDED_BITARRAY
 *
 *	@brief decode bit array stored (unpacked) in registers, write strided packed
 *		bit array to shared memory and write a thread portion of the bit array into a register
 *
 *	@param[out] p_shared_packed_bits is storage for the packed bit array in strided order
 *		(all threads first bit, then all threads second bit and so on) in shared memory
 *	@param[out] n_dest is a small packed bit array with up to 32-bits per thread
 *	@param[in] n_warp is thread warp index (<tt>tid / WARP_SIZE</tt>)
 *	@param[in] n_lane is thread warp lane (<tt>tid % WARP_SIZE</tt>)
 *	@param[in] n_thread_num is number of threads
 *	@param[in] n_bit_per_thread_num is number of bits per thread
 *	@param[in] p_thread_data is pointer to input values in registers (any data type, nonzero decoded as a 1, zero as a 0 bit)
 */
#define REGISTER_TO_LOCAL_DECODE_PACK_AND_READ_STRIDED_BITARRAY(p_shared_packed_bits,n_dest,n_warp,n_lane,n_thread_num,n_bit_per_thread_num,p_thread_data) \
	do { \
		enum { \
			_n_bit_num = (n_bit_per_thread_num), /* make sure it is compile time const */ \
			n_word_size = 8 * sizeof(*(p_shared_packed_bits)), \
			n_log_warp_size = n_Log2(WARP_SIZE), \
			n_warp_num = (n_thread_num) / WARP_SIZE \
		}; \
		\
		/*const unsigned int n_warp = (n_tid) >> n_log_warp_size;*/ \
		\
		(n_dest) = 0; \
		_Pragma("unroll") \
		for(int i = 0; i < (n_bit_per_thread_num); ++ i) { \
			bool b_head_flag = (p_thread_data)[i] != 0; /* decode the flag */ \
			(n_dest) |= (b_head_flag) << i; \
			/* collect flags */ \
			\
			uintwarp_t n_ballot = n_warp_ballot(b_head_flag); \
			/*if(!n_lane) // seen to cause trouble on K40 */ \
				(p_shared_packed_bits)[n_warp + n_warp_num * i] = n_ballot; \
			/* pack the head flags */ \
		} \
	} while(0)

/**
 *	@def REGISTER_DECODE_PACK_BITARRAY
 *
 *	@brief pack bit array stored (unpacked) in registers
 *
 *	@param[out] n_dest is a small packed bit array with up to 32-bits per thread
 *	@param[in] n_bit_per_thread_num is number of bits per thread
 *	@param[in] p_thread_data is pointer to input values in registers (any data type, nonzero decoded as a 1, zero as a 0 bit)
 */
#define REGISTER_DECODE_PACK_BITARRAY(n_dest,n_bit_per_thread_num,p_thread_data) \
	do { \
		(n_dest) = 0; \
		_Pragma("unroll") \
		for(int i = 0; i < (n_bit_per_thread_num); ++ i) { \
			bool b_head_flag = (p_thread_data)[i] != 0; /* decode the flag */ \
			(n_dest) |= (b_head_flag) << i; \
			/* collect flags */ \
		} \
	} while(0)

/**
 *	@def LOCAL_READ_STRIDED_BITARRAY
 *
 *	@brief reads strided packed bit array stored in shared memory
 *
 *	@param[out] n_dest is a small packed bit array with up to 32-bits per thread
 *	@param[in] n_warp is thread warp index (<tt>tid / WARP_SIZE</tt>)
 *	@param[in] n_lane is thread warp lane (<tt>tid % WARP_SIZE</tt>)
 *	@param[in] n_thread_num is number of threads
 *	@param[in] n_bit_per_thread_num is number of bits per thread
 *	@param[in] p_shared_packed_bits is packed bit array in strided order (all
 *		threads first bit, then all threads second bit and so on) in shared memory
 */
#define LOCAL_READ_STRIDED_BITARRAY(n_dest,n_warp,n_lane,n_thread_num,n_bit_per_thread_num,p_shared_packed_bits) \
	do { \
		enum { \
		/*	_n_bit_num = (n_bit_per_thread_num),*/ /* make sure it is compile time const */ \
		/*	n_word_size = 8 * sizeof(*(p_shared_packed_bits)),*/ \
		/*	n_log_warp_size = n_Log2(WARP_SIZE),*/ \
		/*	n_warp_mask = WARP_SIZE - 1,*/ \
			n_warp_num = (n_thread_num) / WARP_SIZE \
		}; \
		\
		/*const unsigned int n_warp = (n_tid) >> n_log_warp_size;*/ \
		/*const unsigned int n_lane = (n_tid) & n_warp_mask;*/ \
		\
		(n_dest) = 0; \
		_Pragma("unroll") \
		for(int i = 0; i < (n_bit_per_thread_num); ++ i) \
			(n_dest) |= (((p_shared_packed_bits)[(n_warp) + n_warp_num * i] >> (n_lane)) & 1) << i; \
	} while(0)

/**
 *	@def LOCAL_READ_STRIDED_BITARRAY_OVERLAP_1
 *
 *	@brief reads strided packed bit array stored in shared memory, with 1 bit overlap between adjacent threads
 *
 *	@param[out] n_dest is a small packed bit array with up to 32-bits per thread
 *	@param[in] n_warp is thread warp index (<tt>tid / WARP_SIZE</tt>)
 *	@param[in] n_lane is thread warp lane (<tt>tid % WARP_SIZE</tt>)
 *	@param[in] n_thread_num is number of threads
 *	@param[in] n_bit_per_thread_num is number of bits per thread (not counting the 1 bit overlap)
 *	@param[in] p_shared_packed_bits is packed bit array in strided order (all
 *		threads first bit, then all threads second bit and so on) in shared memory
 *
 *	@note This is much like calling:
 *		@code
 *		LOCAL_READ_BITARRAY_OVERLAP(n_dest, n_tid, n_warp, n_lane, n_bit_per_thread_num, 1, p_shared_packed_bits);
 *		@endcode
 *		with the exception that this is intended for the strided arrays. Extension to multiple overlapping
 *		bits would require extra instructions and was not implemented as it was not yet needed. The trick is to
 *		correctly modulo the warp / lane shift (after skipping all the warps, skip by the whole block rather than
 *		by another warp).
 */
#define LOCAL_READ_STRIDED_BITARRAY_OVERLAP_1(n_dest,n_warp,n_lane,n_thread_num,n_bit_per_thread_num,p_shared_packed_bits) \
	do { \
		enum { \
		/*	_n_bit_num = (n_bit_per_thread_num),*/ /* make sure it is compile time const */ \
		/*	n_word_size = 8 * sizeof(*(p_shared_packed_bits)),*/ \
		/*	n_log_warp_size = n_Log2(WARP_SIZE),*/ \
		/*	n_warp_mask = WARP_SIZE - 1,*/ \
			n_warp_num = (n_thread_num) / WARP_SIZE, \
			n_last_warp = n_warp_num * ((n_bit_per_thread_num) - 1) \
		}; \
		\
		/*const unsigned int n_warp = (n_tid) >> n_log_warp_size;*/ \
		/*const unsigned int n_lane = (n_tid) & n_warp_mask;*/ \
		\
		(n_dest) = 0; \
		_Pragma("unroll") \
		for(int i = 0; i < (n_bit_per_thread_num); ++ i) \
			(n_dest) |= (((p_shared_packed_bits)[(n_warp) + n_warp_num * i] >> (n_lane)) & 1) << i; \
		unsigned int n_tid = ((n_warp) << LOG_WARP_SIZE) | (n_lane); \
		if((n_bit_per_thread_num) == 1 || n_tid < (n_thread_num) - 1) { \
			unsigned int n_next_warp = (n_warp) + (((n_lane) + 1) >> LOG_WARP_SIZE) + n_warp_num * 0; /* the *first* bit of the *next* thread */ \
			unsigned int n_next_lane = ((n_lane) + 1) & (WARP_SIZE - 1); \
			(n_dest) |= (((p_shared_packed_bits)[n_next_warp] >> n_next_lane) & 1) << (n_bit_per_thread_num); \
		} else { \
			/*unsigned int n_tid = ((n_warp) << LOG_WARP_SIZE) + (n_lane);*/ \
			/*if(n_tid == (n_thread_num) - 1)*/ \
				(n_dest) |= (((p_shared_packed_bits)[/*(n_warp) +*/ n_warp_num * (n_bit_per_thread_num)] /*>> (n_lane)*/) & 1) << (n_bit_per_thread_num); \
			/* shitcock. */ \
		} \
	} while(0)

/**
 *	@def LOCAL_READ_BITARRAY
 *
 *	@brief reads packed bit array in natural order stored in shared memory
 *
 *	@param[out] n_dest is a small packed bit array with up to 32-bits per thread
 *	@param[in] n_tid is thread id (only used if n_bit_per_thread_num > 1)
 *	@param[in] n_warp is thread warp index (<tt>tid / WARP_SIZE</tt>; only used if n_bit_per_thread_num <= 1)
 *	@param[in] n_lane is thread warp lane (<tt>tid % WARP_SIZE</tt>; only used if n_bit_per_thread_num <= 1)
 *	@param[in] n_bit_per_thread_num is number of bits per thread
 *	@param[in] p_shared_packed_bits is packed bit array in natural order (needs one extra padding element)
 */
#define LOCAL_READ_BITARRAY(n_dest,n_tid,n_warp,n_lane,n_bit_per_thread_num,p_shared_packed_bits) \
	do { \
		enum { \
			_n_bit_num = (n_bit_per_thread_num), /* make sure it is compile time const */ \
			n_word_size = 8 * sizeof(*(p_shared_packed_hf)), \
			n_log_word_size = n_Log2(n_word_size), \
			n_word_mask = n_word_size - 1 \
		}; \
		\
		if((n_bit_per_thread_num) > 1) { \
			STATIC_ASSERT(_n_bit_num <= 32, LOCAL_READ_BITARRAY_CAN_ONLY_UNPACK_32_BITS_PER_THREAD); /* would have to write a code using 3 words or would have to use 64-bit ints */ \
			\
			unsigned int n_index = (n_tid) * (n_bit_per_thread_num); \
			unsigned int n_first_word = (n_index) >> n_log_word_size; \
			unsigned int n_first_shift = (n_index) & n_word_mask; \
			unsigned int n_first_bits = min((unsigned int)(n_bit_per_thread_num), n_word_size - n_first_shift); \
			unsigned int n_second_word = n_first_word + 1; /* (n_index + n_first_bits) >> n_log_word_size; // the 1 padding word makes this simpler */ \
			unsigned int n_second_shift = ((n_index) + n_first_bits) & n_word_mask; \
			/*unsigned int n_second_bits = (n_bit_per_thread_num) - n_first_bits; */ \
			\
			(n_dest) = (((p_shared_packed_hf[n_first_word] >> n_first_shift) & n_Mask_32(n_first_bits)) | \
						((p_shared_packed_hf[n_second_word] >> n_second_shift) << n_first_bits)) & /* note that n_second_shift - n_first_bits could be negative, have to do both shifts */ \
						  n_Mask_32(n_bit_per_thread_num); /* using const mask if n_bit_per_thread_num is const */ \
		} else \
			(n_dest) = (p_shared_packed_hf[n_warp] >> n_lane) & n_Mask_32(n_bit_per_thread_num); \
	} while(0)

/**
 *	@def LOCAL_READ_BITARRAY_OVERLAP
 *
 *	@brief reads packed bit array in natural order stored in shared memory, with overlap
 *
 *	@param[out] n_dest is a small packed bit array with up to 32-bits per thread
 *	@param[in] n_tid is thread id (only used if n_bit_per_thread_num > 1)
 *	@param[in] n_warp is thread warp index (<tt>tid / WARP_SIZE</tt>; only used if n_bit_per_thread_num <= 1)
 *	@param[in] n_lane is thread warp lane (<tt>tid % WARP_SIZE</tt>; only used if n_bit_per_thread_num <= 1)
 *	@param[in] n_bit_per_thread_num is number of bits per thread
 *	@param[in] n_overlap_bit_num is number of extra bits per thread which overlap the low bits of the next thread
 *	@param[in] p_shared_packed_bits is packed bit array in natural order (needs one extra padding element)
 */
#define LOCAL_READ_BITARRAY_OVERLAP(n_dest,n_tid,n_warp,n_lane,n_bit_per_thread_num,n_overlap_bit_num,p_shared_packed_bits) \
	do { \
		enum { \
			_n_bit_num = (n_bit_per_thread_num) + (n_overlap_bit_num), /* make sure it is compile time const */ \
			n_word_size = 8 * sizeof(*(p_shared_packed_hf)), \
			n_log_word_size = n_Log2(n_word_size), \
			n_word_mask = n_word_size - 1 \
		}; \
		\
		if((n_bit_per_thread_num) + (n_overlap_bit_num) > 1) { \
			STATIC_ASSERT(_n_bit_num <= 32, LOCAL_READ_BITARRAY_CAN_ONLY_UNPACK_32_BITS_PER_THREAD); /* would have to write a code using 3 words or would have to use 64-bit ints */ \
			\
			unsigned int n_index = (n_tid) * (n_bit_per_thread_num); \
			unsigned int n_first_word = (n_index) >> n_log_word_size; \
			unsigned int n_first_shift = (n_index) & n_word_mask; \
			unsigned int n_first_bits = min((unsigned int)((n_bit_per_thread_num) + (n_overlap_bit_num)), n_word_size - n_first_shift); \
			unsigned int n_second_word = n_first_word + 1; /* (n_index + n_first_bits) >> n_log_word_size; // the 1 padding word makes this simpler */ \
			unsigned int n_second_shift = ((n_index) + n_first_bits) & n_word_mask; \
			/*unsigned int n_second_bits = (n_bit_per_thread_num) + (n_overlap_bit_num) - n_first_bits; */ \
			\
			(n_dest) = (((p_shared_packed_hf[n_first_word] >> n_first_shift) & n_Mask_32(n_first_bits)) | \
						((p_shared_packed_hf[n_second_word] >> n_second_shift) << n_first_bits)) & /* note that n_second_shift - n_first_bits could be negative, have to do both shifts */ \
						  n_Mask_32((n_bit_per_thread_num) + (n_overlap_bit_num)); /* using const mask if n_bit_per_thread_num is const */ \
		} else \
			(n_dest) = (p_shared_packed_hf[n_warp] >> n_lane) & n_Mask_32((n_bit_per_thread_num) + (n_overlap_bit_num)); \
	} while(0)

/*
p_shared_packed_bits,       n_tid,              n_thread_num,n_bit_per_thread_num,p_data,n_data_num,n_padding_value // GLOBAL_TO_LOCAL_DECODE_PACK_BITARRAY
p_shared_packed_bits,n_dest,      n_warp,n_lane,n_thread_num,n_bit_per_thread_num,p_thread_data // REGISTER_TO_LOCAL_DECODE_PACK_STRIDED_BITARRAY
                     n_dest,                                 n_bit_per_thread_num,p_thread_data // REGISTER_DECODE_PACK_BITARRAY
                     n_dest,      n_warp,n_lane,n_thread_num,n_bit_per_thread_num,p_shared_packed_bits // LOCAL_READ_STRIDED_BITARRAY
                     n_dest,n_tid,n_warp,n_lane,             n_bit_per_thread_num,p_shared_packed_bits // LOCAL_READ_BITARRAY
*/

#endif // !__BITWISE_LOAD_STORE_MACROS_INCLUDED
