#include "KeyFramesExtractor.h"


namespace vmatch {


LinearKeyFramesExtractor::LinearKeyFramesExtractor(int step) {
	setStep(step);
}


KeyFrames LinearKeyFramesExtractor::extract(VideoSequencePtr sequence) {
	KeyFrames keyframes(sequence->getDescriptorExtractor()->getType());

	int currentSegmentId = 0;
	while(sequence->next()) {
		if(sequence->getFrameNumber() % step == 0) {
			keyframes.push_back(new KeyFrame(sequence->getFrameDescriptor(), sequence->getFrameNumber(), currentSegmentId++));
		}
	}

    keyframes.push_back(new KeyFrame(sequence->getFrameDescriptor(), sequence->getFrameNumber(), currentSegmentId++));

	return keyframes;
}


void LinearKeyFramesExtractor::setStep(int step) {
	if(step <= 0) {
        CV_Error(CV_StsBadArg, "Invalid frame step.");
	}
	this->step = step;
}


ThresholdKeyFramesExtractor::ThresholdKeyFramesExtractor(FrameDescriptorExtractorPtr extractor, double threshold) {
	setExtractor(extractor);
	setThreshold(threshold);
}


KeyFrames ThresholdKeyFramesExtractor::extract(VideoSequencePtr sequence) {
	KeyFrames keyframes(sequence->getDescriptorExtractor()->getType());

	cv::Mat frame;
	FrameDescriptorPtr descriptor;
	FrameDescriptorPtr prevKFdescriptor;
	
	int frameNumber = 0;
	int prevKFframeNumber = 0;
	int segmentId = 0;

	while(sequence->next()) {
		frame = sequence->getFrame();
		descriptor = extractor->extract(frame);
		frameNumber = sequence->getFrameNumber();

		if(frameNumber == 0 || descriptor->compare(*prevKFdescriptor) > threshold) {
			// push the last frame of previous segment
			if(frameNumber-1 > prevKFframeNumber) {
				keyframes.push_back(new KeyFrame(prevKFdescriptor, frameNumber-1, segmentId));
			}
			
			// push the first frame of new segment
			segmentId++;
			keyframes.push_back(new KeyFrame(descriptor, frameNumber, segmentId));
			prevKFframeNumber = frameNumber;
			prevKFdescriptor = descriptor;
		}
	}

	// push the last frame of the whole sequence
	if(frameNumber-1 > prevKFframeNumber) {
		keyframes.push_back(new KeyFrame(prevKFdescriptor, frameNumber-1, segmentId));
	}

	return keyframes;
}


void ThresholdKeyFramesExtractor::setExtractor(FrameDescriptorExtractorPtr extractor) {
	if(extractor == NULL) {
		CV_Error(CV_StsBadArg, "Null pointer to FrameDescriptorExtractor given.");
	}
	this->extractor = extractor;
}


void ThresholdKeyFramesExtractor::setThreshold(double threshold) {
	if(threshold < 0) {
		CV_Error(CV_StsBadArg, "Invalid threshold value.");
	}
	this->threshold = threshold;
}


CascadeKeyFramesExtractor::CascadeKeyFramesExtractor(FrameDescriptorExtractorPtr firstExtractor, double thrSimilar, double thrDifferent, double thrAccept, int maxStep) : KeyFramesExtractor() {
	setFirstExtractor(firstExtractor);
	setThresholdSimilar(thrSimilar);
	setThresholdDifferenr(thrDifferent);
	setThresholdAccept(thrAccept);
	setMaxStep(maxStep);
}


KeyFrames CascadeKeyFramesExtractor::extract(VideoSequencePtr sequence) {
	KeyFrames keyframes(sequence->getDescriptorExtractor()->getType());

	cv::Mat currentFrame;
	cv::Mat previousFrame;
	FrameDescriptorPtr currentDescriptor1;
	FrameDescriptorPtr previousDescriptor1;
	FrameDescriptorPtr currentDescriptor2;
	FrameDescriptorPtr previousDescriptor2;
	int currentSegmentId = 0;
	int lastKfPosition = 0;

	int currentStep = 0;
	while(sequence->next()) {
		currentFrame = sequence->getFrame();
		currentDescriptor1 = firstExtractor->extract(currentFrame);

		// decide about segment boundary
		if(sequence->getFrameNumber() != 0) {
			double dist1 = currentDescriptor1->compare(*previousDescriptor1);

			if(dist1 > thrSimilar) {
				currentDescriptor2 = sequence->getFrameDescriptor();
				
				if(dist1 > thrDifferent) { // totally different frames => new segment
					// last keyframe of the previous segment, if not already present
					if(sequence->getFrameNumber()-1 > lastKfPosition) {
						if(previousDescriptor2 == NULL) {
							previousDescriptor2 = sequence->getDescriptorExtractor()->extract(previousFrame);
						}
						keyframes.push_back(new KeyFrame(previousDescriptor2, sequence->getFrameNumber()-1, currentSegmentId));
					}
					
					currentSegmentId++;
					keyframes.push_back(new KeyFrame(currentDescriptor2, sequence->getFrameNumber(), currentSegmentId));
					lastKfPosition = sequence->getFrameNumber();
					currentStep = 0;
				}
				else { // maybe different => need another comparison
					if(previousDescriptor2 == NULL) {
						previousDescriptor2 = sequence->getDescriptorExtractor()->extract(previousFrame);
					}

					double dist2 = currentDescriptor2->compare(*previousDescriptor2);

					if(dist2 > thrAccept) {
						keyframes.push_back(new KeyFrame(previousDescriptor2, sequence->getFrameNumber()-1, currentSegmentId));
						currentSegmentId++;
						keyframes.push_back(new KeyFrame(currentDescriptor2, sequence->getFrameNumber(), currentSegmentId));
						lastKfPosition = sequence->getFrameNumber();
						currentStep = 0;
					}
					else {
						currentStep++;
					}
				}
			}
			// forced keyframe in the current segment
			else if(currentStep >= maxStep-1) {
				if(currentDescriptor2 == NULL) {
					currentDescriptor2 = sequence->getFrameDescriptor();
				}
				
				keyframes.push_back(new KeyFrame(currentDescriptor2, sequence->getFrameNumber(), currentSegmentId));
				lastKfPosition = sequence->getFrameNumber();
				currentStep = 0;
			}
			// no keyframe
			else {
				currentStep++;
			}
		}
		else {
			// first frame of the sequence is always a keframe
			keyframes.push_back(new KeyFrame(sequence->getFrameDescriptor(), sequence->getFrameNumber(), currentSegmentId));
		}

		previousFrame = currentFrame.clone();
		previousDescriptor1 = currentDescriptor1;
		previousDescriptor2 = currentDescriptor2;
		currentDescriptor1 = NULL;
		currentDescriptor2 = NULL;
	}

	return keyframes;
}


void CascadeKeyFramesExtractor::setFirstExtractor(FrameDescriptorExtractorPtr firstExtractor) {
	if(firstExtractor == NULL) {
		CV_Error(CV_StsBadArg, "Null pointer to cascade first-level FrameDescriptorExtractor given.");
	}
	this->firstExtractor = firstExtractor;
}


void CascadeKeyFramesExtractor::setThresholdSimilar(double thrSimilar) {
	if(thrSimilar < 0) {
		CV_Error(CV_StsBadArg, "Invalid similar-threshold value.");
	}
	this->thrSimilar = thrSimilar;
}


void CascadeKeyFramesExtractor::setThresholdDifferenr(double thrDifferent) {
	if(thrDifferent <= thrSimilar) {
		CV_Error(CV_StsBadArg, "nvalid different-threshold value.");
	}
	this->thrDifferent = thrDifferent;
}


void CascadeKeyFramesExtractor::setThresholdAccept(double thrAccept) {
	if(thrAccept < 0) {
		CV_Error(CV_StsBadArg, "nvalid accept-threshold value.");
	}
	this->thrAccept = thrAccept;
}


void CascadeKeyFramesExtractor::setMaxStep(int maxStep) {
	if(maxStep <= 0) {
		CV_Error(CV_StsBadArg, "Invalid frame step.");
	}
	this->maxStep = maxStep;
}


TrackingBasedKeyFramesExtractor::TrackingBasedKeyFramesExtractor(double lostRatioThreshold, int maxStep) : KeyFramesExtractor() {
	setLostRatioThreshold(lostRatioThreshold);
	setMaxStep(maxStep);

	detector = new GoodFeaturesToTrackDetector(100, 0.04, 1);
	extractor = new Motion::NoneDescriptorExtractor;
	matcher = new Motion::OpticalFlowMatcher(40, 0.6f, -1);
	repairer = new Motion::NoneTrackRepairer;

	gray = Mat();
	prevGray = Mat();
}


KeyFrames TrackingBasedKeyFramesExtractor::extract(VideoSequencePtr sequence) {
	KeyFrames keyframes(sequence->getDescriptorExtractor()->getType());
	
	cv::Mat prevFrame;
	int currentStep = 0;
	int currentSegmentId = 0;
	int lastKfPosition = 0;

	while(sequence->next()) {
		Mat frame = sequence->getFrame();
		double lostRatio = 0;
		update(frame, lostRatio);

		if(sequence->getFrameNumber() == 0) {
			keyframes.push_back(new KeyFrame(sequence->getFrameDescriptor(), sequence->getFrameNumber(), currentSegmentId));
		}
		else {
			// new segment
			if(lostRatio > lostRatioThreshold) {
				if(sequence->getFrameNumber()-1 > lastKfPosition) {
					keyframes.push_back(new KeyFrame(sequence->getDescriptorExtractor()->extract(prevFrame), sequence->getFrameNumber()-1, currentSegmentId));
				}

				currentSegmentId++;
				keyframes.push_back(new KeyFrame(sequence->getFrameDescriptor(), sequence->getFrameNumber(), currentSegmentId));
				lastKfPosition = sequence->getFrameNumber();
				currentStep = 0;
			}
			// same segment
			else if(currentStep >= maxStep) {
				keyframes.push_back(new KeyFrame(sequence->getFrameDescriptor(), sequence->getFrameNumber(), currentSegmentId));
				lastKfPosition = sequence->getFrameNumber();
				currentStep = 0;
			}
		}

		currentStep++;
		prevFrame = frame;
	}

	return keyframes;
}


void TrackingBasedKeyFramesExtractor::update(const Mat & frame, double & lostRatio) {
	int minDistance = 3; // TODO dat jinam
	
	Motion::KeyPoints keypoints;
	Mat descriptors;

	prevGray = gray;
	gray = lazyConvertToGray(frame);	

	// find features
	detector->detect(gray, keypoints);
	extractor->compute(frame, keypoints, descriptors);

	if(prevGray.empty()) { // (processing first frame, no matching can be done)
		Motion::TrackContainer initialTracks(keypoints, descriptors, 0);
		currentTracks.addContent(initialTracks, minDistance);
		return;
	}

	// convert found features to tracks
	Motion::TrackContainer newTracks(keypoints, descriptors, 0);

	// track features => extend their trajectories
	matcher->match(currentTracks, newTracks, gray, prevGray);
	Motion::TrackContainer currentLost = currentTracks.removeContent(Motion::TrackContainer::NotMatched());
	lostRatio = ((double)currentLost.size())/(currentLost.size() + currentTracks.size());

	// short-range lost track repair
	lostTracks.addContent(currentLost);	
	repairer->repair(currentTracks, lostTracks, newTracks, gray, prevGray);
	currentTracks.addContent(newTracks.removeContent(Motion::TrackContainer::NotMatched(), false), minDistance);

	allTracks.clear();
	allTracks.addContent(currentTracks);
}


void TrackingBasedKeyFramesExtractor::render(Mat & frame) {
	for(unsigned int j = 0; j < allTracks.size(); j++) {
		Motion::TrackPtr track = allTracks[j];		

			for(unsigned int k = 0; k < track->getLength()-1; k++) {
				line(frame, track->getKeypoint(k).pt, track->getKeypoint(k+1).pt, Scalar(0, 255, 0), 1, CV_AA);
			}

			circle(frame, track->getKeypoint().pt, 2, Scalar(255, 0, 0), 2, CV_AA);
	}
}

Mat TrackingBasedKeyFramesExtractor::lazyConvertToGray(const Mat & frame) const {
	// convert only if already not gray
	if(frame.channels() == 1) {
		return frame.clone();
	} else {
		Mat bw;
		cvtColor(frame, bw, CV_BGR2GRAY);
		return bw;
	}
}


void TrackingBasedKeyFramesExtractor::setLostRatioThreshold(double lostRatioThreshold) {
	if(lostRatioThreshold < 0 || lostRatioThreshold > 1) {
		CV_Error(CV_StsBadArg, "Invalid lost tracks ratio threshold.");
	}
	this->lostRatioThreshold = lostRatioThreshold;
}


void TrackingBasedKeyFramesExtractor::setMaxStep(int maxStep) {
	if(maxStep <= 0) {
		CV_Error(CV_StsBadArg, "Invalid frame step.");
	}
	this->maxStep = maxStep;
}


PresetKeyFramesExtractor::PresetKeyFramesExtractor(PresetKeyFramesExtractor::FrameNumbers frameNumbers) {
	setFrameNumbers(frameNumbers);
}


KeyFrames PresetKeyFramesExtractor::extract(VideoSequencePtr sequence) {
	KeyFrames keyframes;

	if(frameNumbers.empty()) {
		return keyframes;
	}
	
	/*FrameNumbers::iterator keyframePosition = frameNumbers.begin();
	sequence->rewind();

	do {		
		std::cout << sequence->getFrameNumber() << " -> " << *keyframePosition << "\n";
		if(sequence->getFrameNumber() == *keyframePosition) {
			keyframePosition++;
			KeyFramePtr keyframe = new KeyFrame(sequence->getFrameDescriptor(), sequence->getFrameNumber(), KeyFrame::UNKNOWN);
			keyframes.push_back(keyframe);
		}
	} while(keyframePosition != frameNumbers.end() && sequence->next());*/

	size_t kfIdx = 0;

	while(kfIdx < frameNumbers.size() && sequence->next()) {
		while(sequence->getFrameNumber() < frameNumbers[kfIdx] && sequence->next()) {
		}
		KeyFramePtr keyframe = new KeyFrame(sequence->getFrameDescriptor(), sequence->getFrameNumber(), KeyFrame::UNKNOWN);
		keyframes.push_back(keyframe);
		kfIdx++;
	}

	return keyframes;
}


void PresetKeyFramesExtractor::setFrameNumbers(PresetKeyFramesExtractor::FrameNumbers frameNumbers) {
	std::sort(frameNumbers.begin(), frameNumbers.end());
	this->frameNumbers = frameNumbers;
}


}
