#include "stdafx.h"
#include "StackTrace.h"
#include "CppInfo.h"
#include "StackInfoSet.h"

#include <iomanip>

#ifdef POSIX
#include <unwind.h>
#include <ucontext.h>
#endif

// System specific headers.
#include "DbgHelper.h"

StackTrace::StackTrace() : frames(null), size(0), capacity(0) {}

StackTrace::StackTrace(nat n) : frames(new StackFrame[n]), size(n), capacity(n) {
	stackInfo().alloc(frames, size);
}

StackTrace::StackTrace(const StackTrace &o) : frames(null) {
	size = o.size;
	capacity = o.size;
	if (o.frames) {
		frames = new StackFrame[size];
		for (nat i = 0; i < size; i++)
			frames[i] = o.frames[i];
		stackInfo().alloc(frames, size);
	}
}

StackTrace &StackTrace::operator =(const StackTrace &o) {
	StackTrace copy(o);
	swap(frames, copy.frames);
	swap(size, copy.size);
	swap(capacity, copy.capacity);
	return *this;
}

StackTrace::~StackTrace() {
	if (frames)
		stackInfo().free(frames, size);
	delete []frames;
}

void StackTrace::push(const StackFrame &frame) {
	if (size >= capacity) {
		if (capacity == 0)
			capacity = 8;
		else
			capacity *= 2;

		StackFrame *n = new StackFrame[capacity];
		stackInfo().alloc(n, capacity);

		if (frames)
			for (nat i = 0; i < size; i++)
				n[i] = frames[i];

		swap(n, frames);
		if (n)
			stackInfo().free(n, size);
		delete []n;
	}

	frames[size++] = frame;
}

void StackTrace::output(wostream &to) const {
	for (nat i = 0; i < count(); i++) {
		to << endl << std::setw(3) << i;
		to << L": " << frames[i].fn();
		if (frames[i].returnLocation)
			to << L" (return @" << (void *)frames[i].returnLocation << L")";
	}
}

String format(const StackTrace &t) {
	std::wostringstream to;
	StdOutput sOut(to);
	StackInfoSet &l = stackInfo();

	for (nat i = 0; i < t.count(); i++) {
		const StackFrame &frame = t[i];

		sOut.nextFrame();
		l.format(sOut, frame.id, frame.fnBase, frame.offset);
	}
	to << endl;

	return to.str();
}

/**
 * System specific code for collecting the stack trace itself.
 */

// Windows stack traces using DbgHelp. Relies on debug information.
// Note: We could do this for X64 as well, but on X64 we don't have to rely on dbgHelp as we have other metadata!
#if !defined(STANDALONE_STACKWALK) && defined(X86)

#if defined(X86)
static const DWORD machineType = IMAGE_FILE_MACHINE_I386;

static void initFrame(CONTEXT &context, STACKFRAME64 &frame) {
	frame.AddrPC.Offset = context.Eip;
	frame.AddrPC.Mode = AddrModeFlat;
	frame.AddrFrame.Offset = context.Ebp;
	frame.AddrFrame.Mode = AddrModeFlat;
	frame.AddrStack.Offset = context.Esp;
	frame.AddrStack.Mode = AddrModeFlat;
}

#elif defined(X64)
static const DWORD machineType = IMAGE_FILE_MACHINE_AMD64;

static void initFrame(CONTEXT &context, STACKFRAME64 &frame) {
	frame.AddrPC.Offset = context.Rip;
	frame.AddrPC.Mode = AddrModeFlat;
	frame.AddrFrame.Offset = context.Rsp; // is this correct?
	frame.AddrFrame.Mode = AddrModeFlat;
	frame.AddrStack.Offset = context.Rsp;
	frame.AddrStack.Mode = AddrModeFlat;
}
#else
#error "Unknown windows platform!"
#endif

// Warning about not being able to protect from stack-overruns...
#pragma warning ( disable : 4748 )

void createStackTrace(TraceGen &gen, nat skip, void *state) {
	// Initialize the library if it is not already done.
	dbgHelp();

	CONTEXT context;
	if (state) {
		context = *(CONTEXT *)state;
	} else {
#ifdef X64
		RtlCaptureContext(&context);
#else
		// Sometimes RtlCaptureContext crashes for X86, so we do it with inline-assembly instead!
		__asm {
		label:
			mov [context.Ebp], ebp;
			mov [context.Esp], esp;
			mov eax, [label];
			mov [context.Eip], eax;
		}
#endif
	}

	HANDLE process = GetCurrentProcess();
	HANDLE thread = GetCurrentThread();
	STACKFRAME64 frame;
	zeroMem(frame);
	initFrame(context, frame);

	bool first = true;

	gen.init(0);
	StackInfoSet &s = stackInfo();
	while (StackWalk64(machineType, process, thread, &frame, &context, NULL, NULL, NULL, NULL)) {
		if (skip > 0) {
			skip--;
			first = false;
			continue;
		}

		StackFrame f;
		f.id = s.translate((void *)frame.AddrPC.Offset, f.fnBase, f.offset);

		if (!first) {
			// Since this is the SP at the point of the call, right above it will be the return
			// address to this frame! This does not work for the first frame, which is why we don't
			// generate anything for the first frame.
			f.returnLocation = (void *)(frame.AddrStack.Offset - sizeof(void *));
		}

		gen.put(f);
		first = false;
	}
}

#endif

// The stand-alone stack walk for X86 windows. Fails when frame pointer is omitted.
#if defined(STANDALONE_STACKWALK) && defined(X86) && defined(WINDOWS)

static bool onStack(void *stackMin, void *stackMax, void *ptr) {
	return ptr >= stackMin
		&& ptr <= stackMax;
}

static void *prevFrame(void *bp) {
	void **v = (void **)bp;
	return v[0];
}

static void *prevIp(void *bp) {
	void **v = (void **)bp;
	return v[1];
}

static void *prevParam(void *bp, nat id) {
	void **v = (void **)bp;
	return v[2 + id];
}

static NT_TIB *getTIB() {
	NT_TIB *tib;
	__asm {
		// read 'self' from 'fs:0x18'
		mov eax, fs:0x18;
		mov tib, eax;
	}
	assert(tib == tib->Self);
	return tib;
}

void createStackTrace(TraceGen &gen, nat skip, void *state) {
	NT_TIB *tib = getTIB();
	void *stackMax = tib->StackBase;
	void *stackMin = tib->StackLimit;

	void *base = null;
	if (state) {
		base = ((CONTEXT *)state)->Ebp;
	} else {
		__asm {
			mov base, ebp;
		}
	}

	// Count frames.
	nat frames = 0;
	for (void *now = base; onStack(stackMin, stackMax, prevFrame(now)); now = prevFrame(now))
		frames++;

	if (frames > skip)
		frames -= skip;
	else
		skip = 0;

	// Collect the trace itself.
	gen.init(frames);
	StackInfoSet &s = stackInfo();
	void *now = base;
	for (nat i = 0; i < skip; i++)
		now = prevFrame(now);

	for (nat i = 0; i < frames; i++) {
		StackFrame f;
		f.id = s.translate(prevIp(now), f.fnBase, f.offset);
		TODO(L"Populate SP and BP!");
		now = prevFrame(now);

		gen.put(f);
	}
}

#endif


#if defined(WINDOWS) && defined(X64)

namespace internal {

	// Copied from winternl.h to make it possible to reserve enough stack space for the history table.
	// Note: We have some leeway in the history size just in case.
#define UNWIND_HISTORY_SIZE (12 + 4)

	struct UNWIND_HISTORY_ENTRY {
		ULONG64 base;
		RUNTIME_FUNCTION *entry;
	};

	struct UNWIND_HISTORY_TABLE {
		ULONG count;
		BYTE localHint;
		BYTE globalHint;
		BYTE search;
		BYTE once;
		ULONG64 lowAddr;
		ULONG64 highAddr;
		UNWIND_HISTORY_ENTRY entries[UNWIND_HISTORY_SIZE];
	};

	struct KNONVOLATILE_CONTEXT_POINTERS;

	extern "C"
	NTSYSAPI PRUNTIME_FUNCTION RtlLookupFunctionEntry(DWORD64 pc, PDWORD64 base, UNWIND_HISTORY_TABLE *history);

	extern "C"
	NTSYSAPI void *RtlVirtualUnwind(DWORD handlerType, DWORD64 base, DWORD64 pc,
									PRUNTIME_FUNCTION entry, PCONTEXT context,
									PVOID *handlerData, PDWORD64 establisher,
									KNONVOLATILE_CONTEXT_POINTERS *nv);

	// Not in the Win32 headers, but described here:
	// https://learn.microsoft.com/en-us/cpp/build/exception-handling-x64?view=msvc-170
	struct UnwindInfo {
		byte version : 3;
		byte flags : 5;
		byte prologSize;
		byte unwindCount;
		byte frameRegister : 4;
		byte frameOffset : 4;
	};

	static DWORD64 readRegister(byte regNumber, CONTEXT &from) {
		switch (regNumber) {
		case 0: return from.Rax;
		case 1: return from.Rcx;
		case 2: return from.Rdx;
		case 3: return from.Rbx;
		case 4: return from.Rsp;
		case 5: return from.Rbp;
		case 6: return from.Rsi;
		case 7: return from.Rdi;
		case 8: return from.R8;
		case 9: return from.R9;
		case 10: return from.R10;
		case 11: return from.R11;
		case 12: return from.R12;
		case 13: return from.R13;
		case 14: return from.R14;
		case 15: return from.R15;
		}
		return 0;
	}
}

void createStackTrace(TraceGen &gen, nat skip, void *state) {
	using namespace internal;

	StackInfoSet &s = stackInfo();
	gen.init(0);

	CONTEXT context;
	if (state) {
		context = *(CONTEXT *)state;
	} else {
		RtlCaptureContext(&context);
	}

	// Note: It is fine if we don't support leaf functions. We know that this is not a leaf function
	// (we call other functions), so we will have metadata, and all preceeding functions need
	// metadata.
	DWORD64 base = null;
	UNWIND_HISTORY_TABLE history;
	zeroMem(history);
	RUNTIME_FUNCTION *entry;
	bool first = true;
	while (entry = RtlLookupFunctionEntry(context.Rip, &base, &history)) {
		// Add 'frame' if we need to.
		if (skip > 0) {
			skip--;
		} else {
			StackFrame frame;
			frame.id = s.translate((void *)context.Rip, frame.fnBase, frame.offset);

			// We can find the return location by looking just above RSP. This does not work for the
			// first stack frame, though.
			if (!first)
				frame.returnLocation = (void *)(context.Rsp - sizeof(void *));

			gen.put(frame);
		}

		// Step to the next frame:
		UnwindInfo *uwInfo = (UnwindInfo *)(base + entry->UnwindData);
		DWORD type = uwInfo->flags;

		// If bit 3 is set, then this is a nested information, and we might want to "fold" the call
		// into the previous entry?

		DWORD64 oldIp = context.Rip;

		void *handler = 0;
		DWORD64 establisher = 0;
		RtlVirtualUnwind(type, base, oldIp, entry, &context, &handler, &establisher, NULL);
		// Note: There is not really a good way to see if RtlVirtualUnwind succeeded or not.

		// End of the stack trace?
		if (context.Rip == 0)
			break;

		first = false;
	}
}

#endif

#ifdef POSIX
#include <execinfo.h>
// #include <sys/auxv.h>

#if defined(ARM64)

#if defined(ARM_USE_PAC)

// static int supportsPAC = -1;

// static bool hasPAC() {
// 	if (supportsPAC < 0) {
// 		// Check if the current CPU supports PAC. Note 'getauxval' seems to be fairly cheap, so it
// 		// does not hurt too much to call it more than once, but reading a global is likely much
// 		// cheaper.
// 		unsigned long aux = getauxval(AT_HWCAP);
// 		supportsPAC = (aux & HWCAP_PACA) ? 1 : 0;
// 	}
// 	return supportsPAC > 0;
// }

// Strip the PAC from the instruction pointer.
static void *stripPAC(void *from) {
	void *result = from;
	// Note: We *could* use 'xpaci' here, but that would require targetting armv8.3 or above.
	// That also causes GCC to use other new instructions (e.g. retaa) that makes the program
	// crash on lower armv8 CPU:s. So instead we abuse the fact that 'xpaclri' is a NOOP on lower
	// ARM chips, so that we can always run it!
	__asm__(
		"mov x30, %0\n"
		"xpaclri\n"
		"mov %0, x30\n"
		: "+r" (result) : : "r30"
		);

	return result;
}

#define UPDATE_RETURN_LOC_DEFINED
void StackFrame::updateReturnLocation(const void *target) const {
	// On Aarch64, stack frames look as follows:
	// --- end of stack frame (high address) ---
	// (data)
	// saved x30 (return address)
	// saved x29 (frame ptr for prev. function)
	// --- start of stack frame (low address), sp points here ---

	// Note that x30 is authenticated with sp as it is after the stack
	// frame is popped, i.e. with the saved x29 (according to the
	// ABI). In particular, we can rely on this for Storm, since we
	// will only attempt to replace return pointers of functions that
	// were called by Storm-generated code. As such, as long as Storm
	// ensure that x29 is sp at the start of the function, everything
	// is good.

	// For additional robustness, we could find the stack frame size
	// from the function metadata for this stack frame and use
	// that. It is, however, not currently easily accessible for us
	// unless we store it in the StackFrame struct.

	const void **returnLoc = (const void **)this->returnLocation;
	// Variable to sign from ASM.
	const void *sign = target;
	// Find the old SP. It is always next to the stack pointer.
	const void *oldSP = returnLoc[-1];

	__asm__(
		"mov x17, %[ptr]\n" // address
		"mov x16, %[sp]\n" // sign with (sp)
		"pacia1716\n"   // perform authentication
		"mov %[ptr], x17\n" // write back address
		: [ptr]"+r" (sign)
		: [sp]"r" (oldSP)
		: "r16", "r17"
		);

	*returnLoc = sign;
}

#else

static void *stripPAC(void *from) {
	return from;
}

#endif

// Note: Calling "backtrace()" on Arm64 sometimes crashes internally if there are functions without
// unwind info on the stack. However, the calling convention requires storing stack- and base
// pointers on the stack. That makes it very easy to traverse the stack anyway!
void createStackTrace(TraceGen &gen, nat skip, void *state) {
	void *framePointer = null;
	if (state) {
		// Note: State is an ucontext_t on Linux.
		ucontext_t *context = (ucontext_t *)state;
		framePointer = (void *)(context->uc_mcontext.regs[29]);
	} else {
		__asm__ volatile ("mov %0, x29\n"
						: "=r" (framePointer)
						: : );
	}

	nat depth = 0;
	for (void *fp = framePointer; fp; fp = *(void **)fp)
		depth++;

	if (depth >= skip)
		depth -= skip;
	else
		skip = 0;

	gen.init(depth);

	StackInfoSet &s = stackInfo();
	depth = 0;
	while (framePointer) {
		void **returnLoc = &((void **)framePointer)[1];
		void *returnIp = stripPAC(*returnLoc);
		framePointer = ((void **)framePointer)[0];

		if (depth >= skip) {
			StackFrame frame;
			frame.id = s.translate(returnIp, frame.fnBase, frame.offset);
			frame.returnLocation = returnLoc;
			gen.put(frame);
		}
		depth++;
	}
}

#else

struct UnwindState {
	TraceGen &gen;
	StackInfoSet &s;
	nat skip;
};

static _Unwind_Reason_Code unwindCallback(struct _Unwind_Context *context, void *data) {
	UnwindState *state = (UnwindState *)data;

	if (state->skip > 0) {
		state->skip--;
		return _URC_NO_REASON;
	}

	// Note: Interestingly, it seems like IP is "out of sync" one step with the rest of the members
	// here (CFA in particular). That actually makes it more convenient for us, though.
	void *pc = (void *)_Unwind_GetIP(context);

	StackFrame frame;
	frame.id = state->s.translate(pc, frame.fnBase, frame.offset);
	// Note: The CFA is defined as the location of SP when the "CALL" instruction begins executing,
	// so before the return address has been pushed. Note that this appears to be the frame below
	// whatever IP refers to.
	void **cfa = (void **)_Unwind_GetCFA(context);
	frame.returnLocation = &cfa[-1];
	// Note: Interestingly enough, getting the stack pointer causes an infinite loop at times, so we
	// don't do that.

	state->gen.put(frame);

	return _URC_NO_REASON;
}

void createStackTrace(TraceGen &gen, nat skip, void *state) {
	// Note: State is an ucontext_t on Linux.
	(void) state; // TODO: Extract state from here if it exists.

	// We don't know the depth yet...
	gen.init(0);

	UnwindState uState = {
		gen,
		stackInfo(),
		skip,
	};

	// Start unwinding!
	_Unwind_Reason_Code result = _Unwind_Backtrace(&unwindCallback, &uState);
	if (result != _URC_END_OF_STACK) {
		WARNING(L"Possibly incomplete stack trace. Received " << result << L" from _Unwind_Backtrace.");
	}
}

#endif

#endif

#ifndef UPDATE_RETURN_LOC_DEFINED

void StackFrame::updateReturnLocation(const void *target) const {
	*(const void **)this->returnLocation = target;
}

#endif


class STGen : public TraceGen {
public:
	StackTrace trace;

	void init(size_t count) {}

	void put(const StackFrame &frame) {
		trace.push(frame);
	}
};

StackTrace stackTrace(nat skip) {
	STGen gen;
	createStackTrace(gen, skip);
	return gen.trace;
}
