#include "VideoComparer.h"


namespace vmatch {


VideoComparer::VideoComparer() {
}


SimilarityMatrix VideoComparer::compare(const KeyFrames & kfRef, const KeyFrames & kfQuery) {
	if(kfRef.size() == 0) {
		CV_Error(CV_StsBadArg, "Reference sequence contains no keyframes, no comparison is possible.");
	}

	if(kfQuery.size() == 0) {
		CV_Error(CV_StsBadArg, "Query sequence contains no keyframes, no comparison is possible.");
	}

	return compareFunc(kfRef, kfQuery);
}


cv::Mat VideoComparer::createSimilarityMatrixData(const KeyFrames & kfRef, const KeyFrames & kfQuery) const {
    int descriptorCountA = (int)kfRef.size();
    int descriptorCountB = (int)kfQuery.size();

	CV_Assert(descriptorCountA >= 0);
	CV_Assert(descriptorCountB >= 0);
	
	if(descriptorCountA <= 0 || descriptorCountB <= 0) {
		CV_Error(CV_StsBadArg, "On of the video sequences contains no frames.");
	}

	return SimilarityMatrix::createDataMatrix(descriptorCountA, descriptorCountB);
}


LSHVideoComparer::LSHVideoComparer(int countToFind) : VideoComparer() {
	setCountToFind(countToFind);
}


void LSHVideoComparer::setCountToFind(int countToFind) {
	if(countToFind <= 0) {
		CV_Error(CV_StsBadArg, "LSH count to find has to be > 0.");
	}

	this->countToFind = countToFind;
}


cv::Mat LSHVideoComparer::getDescriptorsMat(const KeyFrames & kf) const {
    int descriptorCount = (int)kf.size();
	int descriptorLength = (*kf.begin())->getFrameDescriptor()->getVector().cols;

	CV_Assert(descriptorCount > 0);
	CV_Assert(descriptorLength > 0);

	cv::Mat descriptorsRef = cv::Mat::zeros(descriptorCount, descriptorLength, CV_8U);

	int i = 0;
	for(KeyFrames::const_iterator it = kf.begin(); it != kf.end(); it++) {
		cv::Mat row = descriptorsRef.row(i++);
		(*it)->getFrameDescriptor()->getVector().row(0).copyTo(row);
	}

	return descriptorsRef;
}


SimilarityMatrix LSHVideoComparer::compareFunc(const KeyFrames & kfRef, const KeyFrames & kfQuery) {
	cv::Mat similarity = createSimilarityMatrixData(kfRef, kfQuery);
	
	if((*kfRef.begin())->getFrameDescriptor()->getType() != (*kfQuery.begin())->getFrameDescriptor()->getType()) {
		CV_Error(CV_StsBadArg, "Descriptor type differs for both video sequences.");
	}

	if(!(*kfRef.begin())->getFrameDescriptor()->isVector() || !(*kfQuery.begin())->getFrameDescriptor()->isVector()) {
		CV_Error(CV_StsBadArg, "LSHVideoComparer can handle vector descriptors only.");
	}

	cv::Mat descriptorsRef = getDescriptorsMat(kfRef);

	cv::flann::LshIndexParams params = cv::flann::LshIndexParams(100, 20, 2);
	cv::flann::Index lshHash = cv::flann::Index(descriptorsRef, params);

	int b = 0;
	for(KeyFrames::const_iterator it = kfQuery.begin(); it != kfQuery.end(); it++, b++) {
		cv::Mat indicies, dists;
		lshHash.knnSearch((*it)->getFrameDescriptor()->getVector(), indicies, dists, countToFind);

		int *indPtr = (int*)indicies.data;	
		CV_Assert(indPtr != NULL);

		for(int a = 0; a < indicies.cols; a++) {
			KeyFrames::const_iterator kfRefIt = kfRef.begin();
			advance(kfRefIt, indPtr[a]);
			
			double distance = (*kfRefIt)->getFrameDescriptor()->compare(*(*it)->getFrameDescriptor());
			CV_Assert(distance >= 0);

			similarity.at<float>(indPtr[a], b) = (float)distance;
		}
	}

	return SimilarityMatrix(similarity);
}


BruteForceVideoComparer::BruteForceVideoComparer() : VideoComparer() {
}


SimilarityMatrix BruteForceVideoComparer::compareFunc(const KeyFrames & kfRef, const KeyFrames & kfQuery) {
	cv::Mat similarity = createSimilarityMatrixData(kfRef, kfQuery);
	
	if((*kfRef.begin())->getFrameDescriptor()->getType() != (*kfQuery.begin())->getFrameDescriptor()->getType()) {
		CV_Error(CV_StsBadArg, "Descriptor type differs for both video sequences.");
	}

	int a = 0;
	for(KeyFrames::const_iterator itA = kfRef.begin(); itA != kfRef.end(); itA++, a++) {
		int b = 0;
		for(KeyFrames::const_iterator itB = kfQuery.begin(); itB != kfQuery.end(); itB++, b++) {
			similarity.at<float>(a, b) = (float)(*itA)->getFrameDescriptor()->compare(*(*itB)->getFrameDescriptor());
		}
	}

	return SimilarityMatrix(similarity);
}


}
