Both answers from @RbMm and @AhmedAEK seemed to work well. Thanks to everyone for their suggestions. I believe I'll use Method 2, as it seems cleaner. Here are code examples of each:
This method uses DLL loading callbacks from ntdll.dll to call methods in the new libraries. It seems like there may be some vulnerabilities here in case you're not careful, and will likely require more thorough investigation.
#include <iostream>
#include <Windows.h>
#include "Allocator.hpp"
typedef struct _UNICODE_STR
{
USHORT Length;
USHORT MaximumLength;
PWSTR pBuffer;
} UNICODE_STR, * PUNICODE_STR;
// Sources:
// https://shorsec.io/blog/dll-notification-injection/
// https://modexp.wordpress.com/2020/08/06/windows-data-structures-and-callbacks-part-1/
typedef struct _LDR_DLL_LOADED_NOTIFICATION_DATA {
ULONG Flags; // Reserved.
PUNICODE_STR FullDllName; // The full path name of the DLL module.
PUNICODE_STR BaseDllName; // The base file name of the DLL module.
PVOID DllBase; // A pointer to the base address for the DLL in memory.
ULONG SizeOfImage; // The size of the DLL image, in bytes.
} LDR_DLL_LOADED_NOTIFICATION_DATA, * PLDR_DLL_LOADED_NOTIFICATION_DATA;
typedef struct _LDR_DLL_UNLOADED_NOTIFICATION_DATA {
ULONG Flags; // Reserved.
PUNICODE_STR FullDllName; // The full path name of the DLL module.
PUNICODE_STR BaseDllName; // The base file name of the DLL module.
PVOID DllBase; // A pointer to the base address for the DLL in memory.
ULONG SizeOfImage; // The size of the DLL image, in bytes.
} LDR_DLL_UNLOADED_NOTIFICATION_DATA, * PLDR_DLL_UNLOADED_NOTIFICATION_DATA;
typedef union _LDR_DLL_NOTIFICATION_DATA {
LDR_DLL_LOADED_NOTIFICATION_DATA Loaded;
LDR_DLL_UNLOADED_NOTIFICATION_DATA Unloaded;
} LDR_DLL_NOTIFICATION_DATA, * PLDR_DLL_NOTIFICATION_DATA;
typedef VOID(CALLBACK* PLDR_DLL_NOTIFICATION_FUNCTION)(
ULONG NotificationReason,
PLDR_DLL_NOTIFICATION_DATA NotificationData,
PVOID Context);
typedef NTSTATUS(NTAPI* _LdrRegisterDllNotification) (
ULONG Flags,
PLDR_DLL_NOTIFICATION_FUNCTION NotificationFunction,
PVOID Context,
PVOID* Cookie);
VOID MyCallback(ULONG NotificationReason, const PLDR_DLL_NOTIFICATION_DATA NotificationData, PVOID Context) {
if (lstrcmpiW(NotificationData->Loaded.BaseDllName->pBuffer, L"Library.dll") != 0) {
return;
}
HINSTANCE dllHandle = reinterpret_cast<HINSTANCE>(NotificationData->Loaded.DllBase);
auto fn = reinterpret_cast<void(*)(Allocator*)>(GetProcAddress(dllHandle, "RegisterAllocator"));
if (!fn) {
printf("Could not locate the function.\n");
return;
}
fn(Allocator::GetAllocator());
}
int main() {
Allocator allocator;
Allocator::SetAllocator(&allocator);
HMODULE hNtdll = GetModuleHandleA("NTDLL.dll");
if (hNtdll != NULL) {
_LdrRegisterDllNotification pLdrRegisterDllNotification = (_LdrRegisterDllNotification)GetProcAddress(hNtdll, "LdrRegisterDllNotification");
PVOID cookie;
NTSTATUS status = pLdrRegisterDllNotification(0, (PLDR_DLL_NOTIFICATION_FUNCTION)MyCallback, NULL, &cookie);
if (status != 0) {
printf("Failed to load DLL Callback! Exiting\n");
return EXIT_FAILURE;
}
}
else {
printf("Failed to load NTDLL.dll! Exiting\n");
return EXIT_FAILURE;
}
printf("Allocated Size Before: %zu.\n", allocator.GetUsedSize());
HINSTANCE dllHandle = LoadLibrary(L"Library.dll");
if (!dllHandle) {
printf("Could not load the dynamic library.\n");
return EXIT_FAILURE;
}
printf("Allocated Size After: %zu.\n", allocator.GetUsedSize());
return 0;
}
This alteration to Allocator.cpp uses the Getter of the Singleton to find its value in the main executable if it is not yet set. I use GetModuleHandle(NULL) to get the main module, and find the ProcAddress of a non-member function that will return the singleton's value.
// ...
#if MAIN_EXECUTABLE
extern "C" {
__declspec(dllexport) Allocator* GetAllocator() {
return Allocator::GetAllocator();
}
}
Allocator* Allocator::GetAllocator() {
return allocatorState;
}
#else
#include <Windows.h>
Allocator* Allocator::GetAllocator() {
if (allocatorState == nullptr) {
HMODULE module = GetModuleHandle(NULL);
auto GetAllocatorFn = (Allocator*(*)(void))GetProcAddress(module, "GetAllocator");
if (GetAllocatorFn != NULL) {
return GetAllocatorFn();
}
return nullptr;
}
return allocatorState;
}
#endif