/*
 * Copyright (c) 2020-2025 Eugene Larchenko <el6345@gmail.com>. All rights reserved.
 * See the attached LICENSE.txt file.
 */

#include "compress.h"

//#include <stdlib.h>
#include <exception>
#include <algorithm> // min,max
#include "types.h"
#include "longest_match.h"

using namespace std;

#define ARRAYLEN(a) (sizeof a / sizeof a[0])

// for sanity checks
#define throw_impossible() { throw std::exception(); }

byte* data; // input data
int size; // data size
byte packedData[MAX_INPUT_SIZE * 9/8 + 100];
int packedBits; // compressed size in bits

clock_t starttime, endtime;


// forward declarations
int FindOptimalSolution(const bool isReverseMode, const bool isFastMode, const pReportProgressFn reportPosition);
int EmitCompressed();


void Z_Function(const byte* zdata, int startpos, int len, int* matchLen)
{
	if (startpos >= len) {
		throw std::exception();
	}

	int* z = matchLen;

	const byte* s = zdata;
	const int sstart = startpos;
	const int n = (len - startpos) + len;

	int l = 0, r = 0;
	for (int i = 1; i < len; i++)
	{
		int zi;
		if (i > r)
			zi = 0;
		else
			zi = min(z[i - l], r + 1 - i);

		while (i + zi < n && s[sstart + i + zi] == s[sstart + zi]) {
			zi++;
		}

		if (i + zi - 1 > r)	{
			r = i + zi - 1;
			l = i;
		}

		z[i] = zi;
	}

	// z[0] is undefined
}

// computes floor(log2(x))
int log2i(int x)
{
	if (x < 1) {
		throw;
	}
	int r = 30;
	for(int b = 1<<r; (x&b)==0; b>>=1) {
		r--;
	}
	return r;
}

template<class T>
T* makearr(int cnt) {
	T* a = new T[cnt];
	if (!a) { // doublecheck
		throw std::bad_alloc();
	}
	memset(a, 0, cnt*sizeof(a[0]));
	return a;
}


#define LWM_FALSE 2
#define LWM_TRUE 1

#define Put0 1
#define Put1byte 2
#define Match1byte 3
#define MatchShort 4
#define ReuseOffset 5
#define MatchLong 6

#pragma pack(push,1)
struct Op
{
	ushort Len;
	ushort Ofs;
	byte Type;

	Op() {};

	Op(byte type, int len, int o)
	{
		if (len <= 0) throw_impossible();
		if (o < 0 || o == 0 && type != Put0 && type != Put1byte) throw_impossible();

		if (len < 0 || len > 0xFFFF) throw_impossible();
		if (o < 0 || o > 0xFFFF) throw_impossible();
		Len = (ushort)len;
		Ofs = (ushort)o;
		Type = type;
	}
};
#pragma pack(pop)


int* dp2[MAX_INPUT_SIZE + 1]; // result for LWM=2=false
int* dp1[MAX_INPUT_SIZE + 1]; // result for LWM=1=true
Op* dp_op[2][MAX_INPUT_SIZE + 1]; // solution for lwm=1 and lwm=2

int FindOptimalSolution(
	const bool isReverseMode, 
	const bool isFastMode, 
	const pReportProgressFn reportPosition)
{
	// First, find longest possible match. 
	// Helps reduce memory usage in most practical cases.
	int longestMatch = FindLongestMatch(data, size);
	longestMatch = max(1, longestMatch); // at least one position ahead is required for Put0/Put1Byte ops
	#ifdef _DEBUG
		printf("longest match is %d \n", longestMatch);
	#endif

	const int N = size;
	if (N < 1 || N > USHRT_MAX+1) {
		throw_impossible();
	}

	sbyte* gammalen = new sbyte[N];
	for (int i = 0; i < N; i++)
		gammalen[i] = sbyte(2 * log2i(i + 2));

	int* matchLen = new int[N];

	memset(dp1, 0, sizeof(dp1));
	memset(dp2, 0, sizeof(dp2));
	memset(dp_op, 0, sizeof(dp_op));

	// Preallocate memory, so we don't waste time if there is no enough

	int64 memReq = 0;
	for (int pos = N - 1; pos >= 1; pos--) {
		if (pos + longestMatch >= N) {
			memReq += (int64)2 * pos * sizeof(int);
		}
		memReq += (int64)2 * pos * sizeof(Op);
	}
	printf("Allocating %I64d MiB of memory...\n", memReq >> 20);

	for (int pos = N - 1; pos >= 1; pos--) {
		if (pos + longestMatch >= N) {
			dp2[pos] = makearr<int>(pos);
			dp1[pos] = makearr<int>(pos);
		}
		dp_op[0][pos] = makearr<Op>(pos);
		dp_op[1][pos] = makearr<Op>(pos);
	}
	dp2[N] = makearr<int>(N); // result for pos=N is 0
	dp1[N] = makearr<int>(N); // result for pos=N is 0

	printf("Memory allocated successfully\n");

	printf("Processing position:\n");
	starttime = endtime = clock();
	double lastreport = 0;

	for (int pos = N - 1; pos >= 1; pos--)
	{
		double now = clock() / (double)CLOCKS_PER_SEC;
		if (now < lastreport || now > lastreport + 0.03) { // limit to 30 reports/sec
			reportPosition(isReverseMode ? N-pos : pos);
			lastreport = now;
		}

		Z_Function(data, pos, size, matchLen);
		for(int i = 0; i < pos; i++) {
			matchLen[i] = min(matchLen[N-pos+i], N-pos); // move and fixup for convenience
		}
		matchLen[pos] = -1;

		// last_offset will be in [0..pos) range, so need arrays of this size
		if (!dp2[pos]) dp2[pos] = makearr<int>(pos);
		if (!dp1[pos]) dp1[pos] = makearr<int>(pos);
		//dp_op[0][pos] = makearr<Op>(pos); // this was preallocated
		//dp_op[1][pos] = makearr<Op>(pos); // this was preallocated

		// Try short match: len=2..3, o=1..127; LWM <- true
		int shortmatch_best = INT_MAX;
		Op shortmatch_bestOp = Op();
		for (int o = 1; o <= 127 && o <= pos; o++)
			for (int l = 2; l <= 3; l++)
			{
				if (pos + l <= N && matchLen[pos - o] >= l)
				{
					int t = 3 + 8 + dp1[pos + l][o]; // LWM_TRUE
					if (t < shortmatch_best)
					{
						shortmatch_best = t; shortmatch_bestOp = Op(MatchShort, l, o);
					}
				}
			}

		// Try long match for lwm=1 & lwm=2
		int longmatch2_best = INT_MAX;
		Op longmatch2_bestOp = Op();
		int longmatch1_best = INT_MAX;
		Op longmatch1_bestOp = Op();
		int lenLow1 = 0;
		{
			for (int o = 1; o <= pos; o++)
			{
				int f = o < 128 ? 2
					: o < 1280 ? 0
					: o < 32000 ? 1
					: 2;
				//if (matchLen[pos - o] > N - pos) throw;
				int minl = isFastMode 
					? max(2+f, lenLow1 -10)
					: 2+f;
				int maxl = matchLen[pos - o];
				for (int l = minl; l <= maxl; l++)
				{
					int t = 2 + gammalen[(o >> 8) + 1 + 1 - 2] + 8 + gammalen[l - f - 2];
					t += dp1[pos + l][o]; // LWM_TRUE
					if (t < longmatch1_best)
					{
						longmatch1_best = t; longmatch1_bestOp = Op(MatchLong, l, o);
					}

					t += gammalen[(o >> 8) + 2 + 1 - 2] - gammalen[(o >> 8) + 1 + 1 - 2];
					if (t < longmatch2_best)
					{
						longmatch2_best = t; longmatch2_bestOp = Op(MatchLong, l, o);
					}
				}
				lenLow1 = maxl;
			}
		}

		int lenLow2 = 0;
		for (int last_offset = 0; last_offset < pos; last_offset++)
		{
			// Try put 1 byte
			int best1 = 1 + 8 + dp2[pos + 1][last_offset]; // LWM_FALSE
			Op best1Op = Op(Put1byte, 1, 0);

			// Try short match
			if (shortmatch_best < best1)
			{
				best1 = shortmatch_best; best1Op = shortmatch_bestOp;
			}

			// Try put zero byte
			if (data[pos] == 0)
			{
				int t = 3 + 4 + dp2[pos + 1][last_offset]; // LWM_FALSE
				if (t < best1)
				{
					best1 = t; best1Op = Op(Put0, 1, 0);
				}
			}

			// Try 1-byte match
			for (int o = 1; o <= 15 && o <= pos; o++)
				if (matchLen[pos - o] > 0)
				{
					int t = 3 + 4 + dp2[pos + 1][last_offset]; // LWM_FALSE
					if (t < best1)
					{
						best1 = t; best1Op = Op(Match1byte, 1, o);
						break; // there will be no better solution
					}
				}

			// So far solutions for LWM=1 and LWM=2 are the same
			int best2 = best1;
			Op best2Op = best1Op;

			// Try last_offset assuming LWM=2=false
			if (last_offset != 0)
			{
				int o = last_offset;
				//if (matchLen[pos - o] > N - pos) throw;
				int minl = isFastMode
					? max(2, lenLow2 -10)
					: 2;
				int maxl = matchLen[pos - o];
				for (int l = minl; l <= maxl; l++)
				{
					int t = 2 + 2 + gammalen[l - 2];
					t += dp1[pos + l][o]; // LWM_TRUE
					if (t < best2)
					{
						best2 = t; best2Op = Op(ReuseOffset, l, o);
					}
				}
				lenLow2 = maxl;
			}

			// Try long match
			if (longmatch1_best < best1)
			{
				best1 = longmatch1_best; best1Op = longmatch1_bestOp;
			}
			if (longmatch2_best < best2)
			{
				best2 = longmatch2_best; best2Op = longmatch2_bestOp;
			}

			dp2[pos][last_offset] = best2;
			dp1[pos][last_offset] = best1;

			dp_op[2-1][pos][last_offset] = best2Op;
			dp_op[1-1][pos][last_offset] = best1Op;

		} // last_offset++

		if (pos + longestMatch <= N)
		{
			// we don't need results for these positions anymore, let's reuse memory
			delete[] dp1[pos + longestMatch]; dp1[pos + longestMatch] = NULL;
			delete[] dp2[pos + longestMatch]; dp2[pos + longestMatch] = NULL;
		}

	} // pos--

	endtime = clock();
	reportPosition(isReverseMode ? N : 0);

	// return compressed size in bits
	int reslen = dp2[1][0]; // LWM_FALSE
	reslen += 8; // first byte
	reslen += 3+8; // eos

	// We don't need dp1 and dp2 anymore
	for(int i=0; i<=N; i++)
	{
		if (dp1[i]) { delete[] dp1[i]; dp1[i] = NULL; }
		if (dp2[i]) { delete[] dp2[i]; dp2[i] = NULL; }
	}

	return reslen;
}

// Builds final compressed block using precalculations from dp_op array
int EmitCompressed()
{
	int rpos = 0;

	auto emitByte = [&rpos](byte b) {
		packedData[rpos++] = b;
	};

	int apos = -1;
	int acnt = 8;
	auto emitBit = [&apos, &acnt, &rpos](int bit) {
		if (acnt == 8) {
			apos = rpos;
			packedData[rpos++] = 0;
			acnt = 0;
		}
		packedData[apos] = packedData[apos] * 2 + (bit & 1);
		acnt++;
	};

	auto emitGamma = [emitBit](int x) {
		if (x < 2) {
			throw_impossible();
		}
		int b = 1 << 30;
		while ((x & b) == 0) {
			b >>= 1;
		}
		while ((b >>= 1) != 0) {
			emitBit((x & b) == 0 ? 0 : 1);
			emitBit(b == 1 ? 0 : 1);
		}
	};

	int pos = 0;
	int lwm = LWM_FALSE;
	int last_offset = 0;

	if (size < 1) {
		throw_impossible();
	}
	emitByte(data[pos++]); // first byte is simply copied

	while (pos != size)
	{
		if (pos >= size)
			throw_impossible(); // something is wrong
	
		Op op = dp_op[lwm-1][pos][last_offset];
		switch (op.Type)
		{
			case Put1byte: {
				if (op.Len != 1) throw_impossible();
				emitBit(0);
				emitByte(data[pos]);
				lwm = LWM_FALSE;
				break;
			}
			case Put0: {
				if (op.Len != 1) throw_impossible();
				if (data[pos] != 0) throw_impossible();
				emitBit(1);
				emitBit(1);
				emitBit(1);
				emitBit(0); emitBit(0); emitBit(0); emitBit(0);
				lwm = LWM_FALSE;
				break;
			}
			case Match1byte: {
				if (op.Len != 1) throw_impossible();
				if (op.Ofs < 1 || op.Ofs > 15) throw_impossible();
				emitBit(1);
				emitBit(1);
				emitBit(1);
				for (int i = 3; i >= 0; i--) emitBit(op.Ofs >> i & 1);
				lwm = LWM_FALSE;
				break;
			}
			case MatchShort: {
				if (op.Len < 2 || op.Len > 3) throw_impossible();
				if (op.Ofs < 1 || op.Ofs > 127) throw_impossible();
				emitBit(1);
				emitBit(1);
				emitBit(0);
				emitByte((op.Ofs * 2 + (op.Len - 2)));
				lwm = LWM_TRUE;
				last_offset = op.Ofs;
				break;
			}
			case ReuseOffset: {
				if (lwm != LWM_FALSE || last_offset <= 0) throw_impossible();
				if (op.Len < 2) throw_impossible();
				if (op.Ofs != last_offset) throw_impossible();
				emitBit(1);
				emitBit(0);
				emitBit(0); emitBit(0);
				emitGamma(op.Len);
				lwm = LWM_TRUE;
				break;
			}
			case MatchLong: {
				int f = op.Ofs < 128 ? 2
					: op.Ofs < 1280 ? 0
					: op.Ofs < 32000 ? 1
					: 2;
				if (op.Len - f < 2) throw_impossible();
				if (op.Ofs < 1) throw_impossible();
				emitBit(1);
				emitBit(0);
				emitGamma((op.Ofs >> 8) + lwm + 1);
				emitByte((byte)op.Ofs);
				emitGamma(op.Len - f);
				lwm = LWM_TRUE;
				last_offset = op.Ofs;
				break;
			}
			default: {
				throw_impossible();
			}
		}
		pos += op.Len;
	}

	if (pos != size) {
		throw_impossible();
	}

	// end of stream marker
	emitBit(1);
	emitBit(1);
	emitBit(0);
	emitByte(0);

	// final check
	if (packedBits != rpos * 8 - (8 - acnt)) {
		throw_impossible();
	}

	// finalize bit stream
	while (acnt != 8) {
		emitBit(0);
	}

	int packedSize = rpos;
	if (packedSize > ARRAYLEN(packedData)) {
		throw_impossible(); // buffer overflow?
	}

	return packedSize;
}

int Compress(bool isReverseMode, bool isFastMode, pReportProgressFn reportProgress) {
	// repeat data; required for z_function
	memcpy(data + size, data, size);

	packedBits = FindOptimalSolution(isReverseMode, isFastMode, reportProgress);
	int packedSize = EmitCompressed();
	return packedSize;
}

