/*
 *  Authors:
 *    Christian Schulte <schulte@gecode.org>
 *
 *  Copyright:
 *    Christian Schulte, 2008-2019
 *
 *  Permission is hereby granted, free of charge, to any person obtaining
 *  a copy of this software, to deal in the software without restriction,
 *  including without limitation the rights to use, copy, modify, merge,
 *  publish, distribute, sublicense, and/or sell copies of the software,
 *  and to permit persons to whom the software is furnished to do so, subject
 *  to the following conditions:
 *
 *  The above copyright notice and this permission notice shall be
 *  included in all copies or substantial portions of the software.
 *
 *  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 *  EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 *  MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 *  NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
 *  LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
 *  OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 *  WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 *
 */

#pragma push_macro("slots")
#undef slots
#include "Python.h"
#pragma pop_macro("slots")
#include <gecode/int.hh>
#include <gecode/brancher/data-collector.hh>
#include <gecode/brancher/ml-utils.hh>

#include <iostream>
#include <chrono>
#include <ctime>
#include <random>

using namespace Gecode;
using namespace std;

class EqNGL : public NGL {
protected:
	Int::IntView x; int n;
public:
	EqNGL(Space& home, Int::IntView x0, int n0)
	: NGL(home), x(x0), n(n0) {}
	EqNGL(Space& home, EqNGL& ngl)
	: NGL(home, ngl), n(ngl.n) {
		x.update(home, ngl.x);
	}
	// status
	virtual NGL::Status status(const Space& home) const {
		if (x.assigned())
			return (x.val() == n) ? NGL::SUBSUMED : NGL::FAILED;
		else 
			return x.in(n) ? NGL::NONE : NGL::FAILED;
	}
	// prune
	virtual ExecStatus prune(Space& home) {
		return me_failed(x.nq(home,n)) ? ES_FAILED : ES_OK;
	}
	// subscribe and cancel
	virtual void subscribe(Space& home, Propagator& p) {
		x.subscribe(home, p, Int::PC_INT_VAL);
	}
	virtual void cancel(Space& home, Propagator& p) {
		x.cancel(home, p, Int::PC_INT_VAL);
	}
	// re-scheduling
	virtual void reschedule(Space& home, Propagator& p) {
		x.reschedule(home, p, Int::PC_INT_VAL);
	}
	virtual NGL* copy(Space& home) {
		return new (home) EqNGL(home,*this);
	}
	virtual size_t dispose(Space& home) {
		(void) NGL::dispose(home);
		return sizeof(*this);
	}
};

class HeuristicSearch : public Brancher {
protected:
	ViewArray<Int::IntView> x;
	mutable int start;
	// Threshold at where to quit ML and start normal heuristic search
	int th;
	class PosVal : public Choice {
	public:
		int pos; int val;
		PosVal(const HeuristicSearch& b, int p, int v)
			: Choice(b, 2), pos(p), val(v) {}
		virtual void archive(Archive& e) const {
			Choice::archive(e);
			e << pos << val;
		}
	};
	
public:
	HeuristicSearch(Home home, ViewArray<Int::IntView>& x0, int threshold)
		: Brancher(home), x(x0), th(threshold), start(0) {}
	static void post(Home home, ViewArray<Int::IntView>& x, int threshold) {
		(void) new (home) HeuristicSearch(home, x, threshold);
	}
	virtual size_t dispose(Space& home) {
		(void)Brancher::dispose(home);
		return sizeof(*this);
	}
	HeuristicSearch(Space& home, HeuristicSearch& b, int threshold)
		: Brancher(home, b), th(threshold), start(b.start) {
		x.update(home, b.x);
	}
	virtual Brancher* copy(Space& home) {
		return new (home) HeuristicSearch(home, *this, th);
	}
	// status
	virtual bool status(const Space& home) const {
		for (int i = 0; i < x.size(); i++)
			if (!x[i].assigned()) {
				start = i; 
				return true;
			}
		//std::cout << "Found a solution in asset " << a << "\n";
		return false;
	}
	// choice
	virtual Choice* choice(Space& home) {
		// DataCollector* dc = DataCollector::getInstance();
		// dc->incrementNodesSearched();
		// if(dc->getNodesSearched() % 50000 == 0) {
		// 	auto timenow = chrono::system_clock::to_time_t(chrono::system_clock::now()); 
		// 	cout << ctime(&timenow) << " nodes searched: " << dc->getNodesSearched() << endl;
		// 	//cout << "Current prediction time of domain: " << ml_prediction_duration / 1000000.0 << endl;
		// 	//cout << "#nodes * prediction time = " << (ml_prediction_duration / 1000000.0) * 50000.0 << "s" << endl;
		// }
		
		vector<int> cur_var_order;
		int var_pos = 1;
		// cout << "VarOrdering: ";
		for (int i = 0; i < x.size() - 1; i++) {
			if(x[i].assigned()){
				//cout << i << " ";
				cur_var_order.push_back(i);
				var_pos++;
			}
		}
		// Search just restarted --> reset cache
		//if(cur_var_order.size() == 0){
			// dc->resetVarOrderings();
			// cout << "First pos var/min_val: [ ";
			// for (int i = 0; i < x.size() - 1; i++) {
			// 	cout << x[i].min() << " ";
			// }
			// cout << "]" << endl;
		//}
		
		// cout << endl;
		if(cur_var_order.size() > th) {
			for (int i = 0; i < x.size() - 1; i++) {
				if(!x[i].assigned()){
					return new PosVal(*this, i, x[i].min()); 
				}
			}
		} else {
			//cout << "Searching through start pos: " << start << endl;
		}
			
		//vector<int> assigned;
		DataCollector* dc = DataCollector::getInstance();
		bool high_score_is_good = true;
		if(dc->getFunction() == "smallest"){
			high_score_is_good = false;
		}
		
		int int_val = 0;
		
		long degree = MLUtils::computeDomainSize(x);
		
		vector<long> degrees;
		vector<int> dom_sizes, var_orders, var_posses, values, val_posses, min_vals, max_vals, regret_min_vals, regret_max_vals; 
		
		//cout << endl;
		
		vector<vector<int>> orderings = dc->getVarOrderings();
		for(int i = 0; i < orderings.size(); i++) {
			if(orderings.at(i) == cur_var_order){
				var_pos = dc->varsSelected().at(i);
				//cout << "Found varorder in cache, var pos: " << var_pos << " with val " << x[var_pos].min() << endl;
				return new PosVal(*this, var_pos, x[var_pos].min());
			}
		}
		if(cur_var_order.size() == x.size() - 1){
			//cout << "Max size reached of " << x.size() - 1 << endl;
			// cout << "'Sol': [ ";
			// for (int i = 0; i < x.size() - 1; i++) {
			// 	cout << x[i].val() << " ";
			// }
			// cout << " ]" << endl;
			//cout << " ] obj: " << x[x.size() - 1].min() << endl;
			return new PosVal(*this, x.size() - 1, x[x.size() - 1].min());
		}
		
		for(int i = 0; i < x.size() - 1; i++) {
			if(!x[i].assigned()){
				int_val = x[i].min();
				dom_sizes.push_back(x[i].size()); 
				degrees.push_back(degree);
				var_orders.push_back(i); var_posses.push_back(var_pos); 
				values.push_back(int_val); val_posses.push_back(0);
				min_vals.push_back(x[i].min()); max_vals.push_back(x[i].max());
				regret_min_vals.push_back(x[i].regret_min()); regret_max_vals.push_back(x[i].regret_max());		
			}
		}
		vector<double> scores = dc->predictML(dom_sizes, degrees, var_orders, var_posses, values, val_posses, 
											  min_vals, max_vals, regret_min_vals, regret_max_vals);
		vector<int> tiebreaker_vals;
		bool break_tie = false;
		double compare_score;
		if(high_score_is_good) {
			compare_score = -DBL_MAX;
		} else {
			compare_score = DBL_MAX;
		}

		vector<int> todo = dc->getRootVars();
		if(todo.empty())
			dc->setRootVars(x.size() - 1);

		double score = 0;
		int var_ind = 0;
		//cout << "Preds: [ ";
		for(int i = 0; i < scores.size(); ++i){

			if(cur_var_order.empty() && !count(todo.begin(), todo.end(), i))
				continue;
			score = scores.at(i);
			//cout << score << " ";
			
			if( (high_score_is_good && (score > compare_score)) ||
				(!high_score_is_good && (score < compare_score))) {
				compare_score = score;
				tiebreaker_vals.clear();
				break_tie = false;
				var_ind = var_orders.at(i);
			} else if (score == compare_score) {
				// Tiebreaker
				if(!break_tie)
					tiebreaker_vals.push_back(var_ind);
				tiebreaker_vals.push_back(var_orders.at(i));
				break_tie = true;
			}

		}
		//cout << "]" << endl;

		if(break_tie) {
			random_device rd; 
			mt19937 eng(rd()); 
			uniform_int_distribution<> distr(0, tiebreaker_vals.size() - 1);
			var_ind = tiebreaker_vals.at(distr(eng));
			dc->incrementTiebreaks();
		}
		if(cur_var_order.size() == 0){
			dc->removeCheckedVars(var_ind);
			cout << "Chose this var pos: " << var_ind << " with value " << x[var_ind].min() << " and score " << compare_score << endl;
			// Remove old cached vars

			// 	cout << "Var Check list: [ ";
			// 	for(int i = 0; i < dc->getRootVars().size(); ++i){
			// 		cout << dc->getRootVars().at(i) << " ";
			// 	}
			// 	cout << "]" << endl;
		} else {
			//cout << "Predicted best score on: " << var_ind << " with value " << x[var_ind].min() << " and score " << compare_score << endl;
		}

		dc->addVarOrdering(cur_var_order);
		dc->addVar(var_ind);
		
		// cout << "Assigned: [ ";
		// for(int i = 0; i < cur_var_order.size(); ++i){
		// 	cout << cur_var_order.at(i) << " ";
		// }
		// cout << "]" << endl;
		//cout << "Prediction time: " << ml_prediction_duration / 1000000.0 << " with chosen value " << val << " and score " << max_score << endl;
		//cout << "Chosen var pos: " << var_ind << " with val " << x[var_ind].min() << endl;
		
		return new PosVal(*this, var_ind, x[var_ind].min());
	}
	virtual Choice* choice(const Space&, Archive& e) {
		int pos, val;
		e >> pos >> val;
		return new PosVal(*this, pos, val);
	}
	// commit
	virtual ExecStatus commit(Space& home,
		const Choice& c,
		unsigned int a) {

		const PosVal& pv = static_cast<const PosVal&>(c);
		int pos = pv.pos, val = pv.val;
		if (a == 0)
			return me_failed(x[pos].eq(home, val)) ? ES_FAILED : ES_OK;
		else
			return me_failed(x[pos].nq(home, val)) ? ES_FAILED : ES_OK;
	}
	// print
	virtual void print(const Space& home, const Choice& c,
		unsigned int a,
		ostream& o) const {
		const PosVal& pv = static_cast<const PosVal&>(c);
		int pos = pv.pos, val = pv.val;
		if (a == 0)
			o << "x[" << pos << "] = " << val;
		else
			o << "x[" << pos << "] != " << val;
	}
	// no-good literal creation
	virtual NGL* ngl(Space& home, const Choice& c,
		unsigned int a) const {
		const PosVal& pv = static_cast<const PosVal&>(c);
		int pos=pv.pos, val=pv.val;
		if (a == 0)
			return new (home) EqNGL(home, x[pos], val);
		else
			return NULL;
	}
};
void heuristic_search(Home home, const IntVarArgs& x, int threshold) {
	if (home.failed()) return;
	ViewArray<Int::IntView> y(home, x);
	cout << "Setting root domain for variable ordering." << endl;
	DataCollector::getInstance()->setRootVars(y.size() - 1);
	HeuristicSearch::post(home, y, threshold);
}
