I have developed a working solution I wanted to run by your briliant minds for comments, feedback or any other considerations i might have missed.
I used Archer recommendation to get the ReturnAddress and work my way from there by creating a code cave.
Code:
typedef BOOLEAN(NTAPI *PDLL_INIT_ROUTINE)(IN PVOID DllHandle, IN ULONG Reason, IN PCONTEXT Context OPTIONAL);
BOOLEAN WINAPI LdrpCallInitRoutine(IN PVOID BaseAddress, IN ULONG Reason, IN PVOID Context, IN PDLL_INIT_ROUTINE EntryPoint)
{
#ifdef _DEBUG
TCHAR szName[MAX_PATH] = { 0 };
GetModuleFileName((HMODULE)BaseAddress, szName, _countof(szName));
#endif
return EntryPoint(BaseAddress, Reason, (PCONTEXT)Context);
}
PBYTE NtDllFindDetourAddress(const PBYTE pAddress, SIZE_T dwSize)
{
MEMORY_BASIC_INFORMATION meminfo = { 0 };
if (VirtualQuery(pAddress, &meminfo, sizeof(meminfo))) {
// Find spare bytes at the end of the memory region that are unused
// so we can jump to this address and set up the detour.
PBYTE end = (PBYTE)meminfo.BaseAddress + meminfo.RegionSize;
PBYTE begin = end;
while (((SIZE_T)(end - begin) < dwSize) && (begin != pAddress)) {
if (*(--begin) != 0x00)
end = begin;
}
if (begin != pAddress)
return begin;
}
return NULL;
}
PBYTE NtDllFindParamAddress(const PBYTE pAddress)
{
PBYTE ptr = pAddress;
// Test previous 32 bytes to find the begining address we need to patch
// for 32bit find => push [ebp][14h] => parameters are pushed to stack
// for 64bit find => mov r8,... => parameters are moved to registers r8, rdx, rcx
while (pAddress - --ptr < 0x20) {
#ifdef _WIN64
if (((ptr[0] & 0x4D) == ptr[0]) && (ptr[1] == 0x8B) && ((ptr[2] & 0xC7) == ptr[2])) {
#else
if ((ptr[0] == 0xFF) && (ptr[1] == 0x75) && (ptr[2] == 0x14)) {
#endif
return ptr;
}
}
return NULL;
}
PBYTE NtDllFindCallAddress(const PBYTE pAddress)
{
PBYTE ptr = pAddress;
// Test previous 32 bytes to find the begining address we need to patch
// for 32bit find => call [ebp][08h]
// for 64bit find => call <register>
while (pAddress - --ptr < 0x20) {
#ifdef _WIN64
if ((ptr[0] == 0xFF) && ((ptr[1] & 0xD7) == ptr[1])) {
if ((*(ptr - 1) & 0x41) == *(ptr - 1)) {
--ptr;
}
#else
if ((ptr[0] == 0xFF) && (ptr[1] == 0x55) && (ptr[2] == 0x08)) {
#endif
return ptr;
}
}
return NULL;
}
typedef struct _NTDLL_LDR_PATCH {
PBYTE pPatchAddress;
SIZE_T nPatchSize;
BYTE pBackup[0x20];
PBYTE pDetourAddress;
SIZE_T nDetourSize;
BOOL bState;
} NTDLL_LDR_PATCH, *PNTDLL_LDR_PATCH;
NTDLL_LDR_PATCH patch;
BOOL NtDllPatch(const PBYTE pReturnAddress, NTDLL_LDR_PATCH &NtDllPatch)
{
if (NtDllPatch.bState == FALSE) {
#ifdef _WIN64
BYTE ptr[] = { '?', 0x87, '?' }; // xchg r.., r9
BYTE mov[] = { 0x48, 0xB8, '?', '?', '?', '?', '?', '?', '?', '?' }; // mov rax, 0x0000000000000000
BYTE call[] = { 0xFF, 0xD0, '?', 0x87, '?' }; // call rax // xchg r.., r9
#else
BYTE ptr[] = { 0xFF, 0x75, 0x08 }; // push [ebp][08h]
BYTE mov[] = { 0x90, 0xB8, '?', '?', '?', '?' }; // mov eax, 0x00000000
BYTE call[] = { 0xFF, 0xD0 }; // call eax
#endif
BYTE jmp[] = { 0xE9, '?', '?', '?', '?' }; // jmp 0x00000000
NtDllPatch.pPatchAddress = NtDllFindParamAddress(pReturnAddress);
PBYTE pCallAddress = NtDllFindCallAddress(pReturnAddress);
NtDllPatch.nPatchSize = pReturnAddress - NtDllPatch.pPatchAddress;
SIZE_T nParamSize = pCallAddress - NtDllPatch.pPatchAddress;
NtDllPatch.nDetourSize = _countof(ptr) + nParamSize + _countof(mov) + _countof(jmp);
NtDllPatch.pDetourAddress = NtDllFindDetourAddress(pReturnAddress, NtDllPatch.nDetourSize);
if (NtDllPatch.pPatchAddress && NtDllPatch.pDetourAddress && ((_countof(jmp) + _countof(call)) <= NtDllPatch.nPatchSize)) {
memcpy(NtDllPatch.pBackup, NtDllPatch.pPatchAddress, NtDllPatch.nPatchSize);
DWORD dwProtect = 0;
if (VirtualProtect(NtDllPatch.pDetourAddress, NtDllPatch.nDetourSize, PAGE_EXECUTE_READWRITE, &dwProtect)) {
memset(NtDllPatch.pDetourAddress, 0x90, NtDllPatch.nDetourSize);
#ifdef _WIN64
// Copy original param instructions
memcpy(&NtDllPatch.pDetourAddress[0], NtDllPatch.pPatchAddress, nParamSize);
// Exchange the register that holds the EntryPoint with r9
BYTE reg = ((pCallAddress[0] & 0x41) == 0x41 ? 0x08 : 0x00) + (pCallAddress[pReturnAddress - pCallAddress - 1] & 0x07);
ptr[0] = 0x4C + ((reg & 0x08) ? 0x01 : 0x00); //ptr[0] = 0x49 + ((reg & 0x08) ? 0x04 : 0x00);
ptr[2] = 0xC8 + (reg & 0x07); //ptr[2] = 0xC1 + (((reg & 0x07) / 2) * 0x10) + ((reg & 0x07) % 2 ? 0x08 : 0x00);
memcpy(&NtDllPatch.pDetourAddress[nParamSize], &ptr, _countof(ptr));
#else
// Push EntryPoint as last parameter
memcpy(&NtDllPatch.pDetourAddress[0], &ptr, _countof(ptr));
// Copy original param instructions
memcpy(&NtDllPatch.pDetourAddress[_countof(ptr)], NtDllPatch.pPatchAddress, nParamSize);
#endif
// Move LdrpCallInitRoutine to eax/rax
*(PSIZE_T)(&mov[2]) = (SIZE_T)LdrpCallInitRoutine;
memcpy(&NtDllPatch.pDetourAddress[_countof(ptr) + nParamSize], &mov, _countof(mov));
// Jump to original function
*(DWORD*)(&jmp[1]) = (DWORD)(pReturnAddress - _countof(call) - (NtDllPatch.pDetourAddress + NtDllPatch.nDetourSize));
memcpy(&NtDllPatch.pDetourAddress[_countof(ptr) + nParamSize + _countof(mov)], &jmp, _countof(jmp));
VirtualProtect(NtDllPatch.pDetourAddress, NtDllPatch.nDetourSize, dwProtect, &dwProtect);
if (VirtualProtect(NtDllPatch.pPatchAddress, NtDllPatch.nPatchSize, PAGE_EXECUTE_READWRITE, &dwProtect)) {
memset(NtDllPatch.pPatchAddress, 0x90, NtDllPatch.nPatchSize);
// Jump to detour address
*(DWORD*)(&jmp[1]) = (DWORD)(NtDllPatch.pDetourAddress - (pReturnAddress - _countof(call)));
memcpy(pReturnAddress - _countof(call) - _countof(jmp), &jmp, _countof(jmp));
#ifdef _WIN64
// Exchange r9 with the register that originally held the EntryPoint
memcpy(&call[2], &ptr, _countof(ptr));
#endif
// Call LdrpCallInitRoutine from eax/rax
memcpy(pReturnAddress - _countof(call), &call, _countof(call));
VirtualProtect(NtDllPatch.pPatchAddress, NtDllPatch.nPatchSize, dwProtect, &dwProtect);
NtDllPatch.bState = TRUE;
}
}
}
}
return NtDllPatch.bState;
}
BOOL NtDllRestore(NTDLL_LDR_PATCH &NtDllPatch)
{
// Restore patched bytes
BOOL bResult = FALSE;
if (NtDllPatch.bState && NtDllPatch.nPatchSize && &NtDllPatch.pBackup[0]) {
DWORD dwProtect = 0;
if (VirtualProtect(NtDllPatch.pPatchAddress, NtDllPatch.nPatchSize, PAGE_EXECUTE_READWRITE, &dwProtect)) {
memcpy(NtDllPatch.pPatchAddress, NtDllPatch.pBackup, NtDllPatch.nPatchSize);
VirtualProtect(NtDllPatch.pPatchAddress, NtDllPatch.nPatchSize, dwProtect, &dwProtect);
if (VirtualProtect(NtDllPatch.pDetourAddress, NtDllPatch.nDetourSize, PAGE_EXECUTE_READWRITE, &dwProtect)) {
memset(NtDllPatch.pDetourAddress, 0x00, NtDllPatch.nDetourSize);
VirtualProtect(NtDllPatch.pDetourAddress, NtDllPatch.nDetourSize, dwProtect, &dwProtect);
bResult = TRUE;
}
}
}
return bResult;
}
#define _DECL_DLLMAIN // for _CRT_INIT
#include <process.h> // for _CRT_INIT
#pragma comment(linker, "/entry:DllEntryPoint")
__declspec(noinline)
BOOL WINAPI DllEntryPoint(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
{
// Patch/Restore ntdll address that calls the dll entry point
if (fdwReason == DLL_PROCESS_ATTACH) {
NtDllPatch((PBYTE)_ReturnAddress(), patch);
}
if (fdwReason == DLL_PROCESS_ATTACH || fdwReason == DLL_THREAD_ATTACH)
if (!_CRT_INIT(hinstDLL, fdwReason, lpReserved))
return(FALSE);
if (fdwReason == DLL_PROCESS_DETACH || fdwReason == DLL_THREAD_DETACH)
if (!_CRT_INIT(hinstDLL, fdwReason, lpReserved))
return(FALSE);
if (fdwReason == DLL_PROCESS_DETACH) {
NtDllRestore(patch);
}
return(TRUE);
}