Skip to content

Commit

Permalink
Fix faulty memory access in Util's User custom actions
Browse files Browse the repository at this point in the history
Generally, clean up the handling of getting the domain from a server name by
centralizing and simplifying it behind an improved GetDomainFromServerName()
based on the buggy GetServerName().

Fixes 8576
  • Loading branch information
robmen committed Jul 15, 2024
1 parent 733886e commit 9280ac1
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 120 deletions.
3 changes: 2 additions & 1 deletion src/ext/Util/ca/precomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#include <msxml2.h>
#include <Iads.h>
#include <activeds.h>
#include <activeds.h>
#include <lm.h> // NetApi32.lib
#include <Ntsecapi.h>
#include <Dsgetdc.h>
Expand Down Expand Up @@ -50,5 +50,6 @@
#include "scauser.h"
#include "scasmb.h"
#include "scasmbexec.h"
#include "utilca.h"

#include "..\..\caDecor.h"
100 changes: 21 additions & 79 deletions src/ext/Util/ca/scaexec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,7 @@ static HRESULT RemoveUserInternal(
LPWSTR pwz = NULL;
LPWSTR pwzGroup = NULL;
LPWSTR pwzGroupDomain = NULL;
LPCWSTR wz = NULL;
PDOMAIN_CONTROLLER_INFOW pDomainControllerInfo = NULL;
LPWSTR pwzDomainName = NULL;

//
// Remove the logon as service privilege.
Expand Down Expand Up @@ -644,30 +643,10 @@ static HRESULT RemoveUserInternal(
//
if (!(SCAU_DONT_CREATE_USER & iAttributes))
{
if (wzDomain && *wzDomain)
{
er = ::DsGetDcNameW(NULL, (LPCWSTR)wzDomain, NULL, NULL, NULL, &pDomainControllerInfo);
if (RPC_S_SERVER_UNAVAILABLE == er)
{
// MSDN says, if we get the above error code, try again with the "DS_FORCE_REDISCOVERY" flag
er = ::DsGetDcNameW(NULL, (LPCWSTR)wzDomain, NULL, NULL, DS_FORCE_REDISCOVERY, &pDomainControllerInfo);
}
if (ERROR_SUCCESS == er)
{
if (2 <= wcslen(pDomainControllerInfo->DomainControllerName))
{
wz = pDomainControllerInfo->DomainControllerName + 2; // Add 2 so that we don't get the \\ prefix.
// Pass the entire string if it is too short
// to have a \\ prefix.
}
}
else
{
wz = wzDomain;
}
}
hr = GetDomainFromServerName(&pwzDomainName, wzDomain, 0);
ExitOnFailure(hr, "Failed to get domain to remove user from server name: %ls", wzDomain);

er = ::NetUserDel(wz, wzName);
er = ::NetUserDel(pwzDomainName, wzName);
if (NERR_UserNotFound == er)
{
er = NERR_Success;
Expand Down Expand Up @@ -707,52 +686,13 @@ static HRESULT RemoveUserInternal(
}

LExit:
if (pDomainControllerInfo)
{
::NetApiBufferFree(static_cast<LPVOID>(pDomainControllerInfo));
}
ReleaseStr(pwzDomainName);
ReleaseStr(pwzGroupDomain);
ReleaseStr(pwzGroup);

return hr;
}

static void GetServerName(LPWSTR pwzDomain, LPWSTR* ppwzServerName)
{
DWORD er = ERROR_SUCCESS;
PDOMAIN_CONTROLLER_INFOW pDomainControllerInfo = NULL;

if (pwzDomain && *pwzDomain)
{
er = ::DsGetDcNameW(NULL, (LPCWSTR)pwzDomain, NULL, NULL, NULL, &pDomainControllerInfo);
if (RPC_S_SERVER_UNAVAILABLE == er)
{
// MSDN says, if we get the above error code, try again with the "DS_FORCE_REDISCOVERY" flag
er = ::DsGetDcNameW(NULL, (LPCWSTR)pwzDomain, NULL, NULL, DS_FORCE_REDISCOVERY, &pDomainControllerInfo);
}

if (ERROR_SUCCESS == er && pDomainControllerInfo->DomainControllerName)
{
// Skip the \\ prefix if present.
if ('\\' == *pDomainControllerInfo->DomainControllerName && '\\' == *pDomainControllerInfo->DomainControllerName + 1)
{
*ppwzServerName = pDomainControllerInfo->DomainControllerName + 2;
}
else
{
*ppwzServerName = pDomainControllerInfo->DomainControllerName;
}
}
else
{
*ppwzServerName = pwzDomain;
}
}

if (pDomainControllerInfo)
{
::NetApiBufferFree((LPVOID)pDomainControllerInfo);
}
}

/********************************************************************
CreateUser - CUSTOM ACTION ENTRY POINT for creating users
Expand All @@ -776,6 +716,7 @@ extern "C" UINT __stdcall CreateUser(
LPWSTR pwzPassword = NULL;
LPWSTR pwzGroup = NULL;
LPWSTR pwzGroupDomain = NULL;
LPWSTR pwzDomainName = NULL;
int iAttributes = 0;
BOOL fInitializedCom = FALSE;

Expand All @@ -786,7 +727,6 @@ extern "C" UINT __stdcall CreateUser(
USER_INFO_1 userInfo1;
USER_INFO_1* pUserInfo1 = NULL;
DWORD dw;
LPWSTR pwzServerName = NULL;

hr = WcaInitialize(hInstall, "CreateUser");
ExitOnFailure(hr, "failed to initialize");
Expand Down Expand Up @@ -845,9 +785,10 @@ extern "C" UINT __stdcall CreateUser(
//
// Create the User
//
GetServerName(pwzDomain, &pwzServerName);
hr = GetDomainFromServerName(&pwzDomainName, pwzDomain, 0);
ExitOnFailure(hr, "Failed to get domain from server name: %ls", pwzDomain);

er = ::NetUserAdd(pwzServerName, 1, reinterpret_cast<LPBYTE>(pUserInfo1), &dw);
er = ::NetUserAdd(pwzDomainName, 1, reinterpret_cast<LPBYTE>(pUserInfo1), &dw);
if (NERR_UserExists == er)
{
if (SCAU_FAIL_IF_EXISTS & iAttributes)
Expand All @@ -862,7 +803,7 @@ extern "C" UINT __stdcall CreateUser(
if (SCAU_UPDATE_IF_EXISTS & iAttributes)
{
pUserInfo1 = NULL;
er = ::NetUserGetInfo(pwzServerName, pwzName, 1, reinterpret_cast<LPBYTE*>(&pUserInfo1));
er = ::NetUserGetInfo(pwzDomainName, pwzName, 1, reinterpret_cast<LPBYTE*>(&pUserInfo1));
if (ERROR_SUCCESS == er)
{
// There is no rollback scheduled if the key is empty.
Expand Down Expand Up @@ -922,28 +863,28 @@ extern "C" UINT __stdcall CreateUser(

if (ERROR_SUCCESS == er)
{
hr = SetUserPassword(pwzServerName, pwzName, pwzPassword);
hr = SetUserPassword(pwzDomainName, pwzName, pwzPassword);
if (FAILED(hr))
{
WcaLogError(hr, "failed to set user password for user %ls\\%ls, continuing anyway.", pwzServerName, pwzName);
WcaLogError(hr, "failed to set user password for user %ls\\%ls, continuing anyway.", pwzDomainName, pwzName);
hr = S_OK;
}

if (SCAU_REMOVE_COMMENT & iAttributes)
{
hr = SetUserComment(pwzServerName, pwzName, L"");
hr = SetUserComment(pwzDomainName, pwzName, L"");
if (FAILED(hr))
{
WcaLogError(hr, "failed to clear user comment for user %ls\\%ls, continuing anyway.", pwzServerName, pwzName);
WcaLogError(hr, "failed to clear user comment for user %ls\\%ls, continuing anyway.", pwzDomainName, pwzName);
hr = S_OK;
}
}
else if (pwzComment && *pwzComment)
{
hr = SetUserComment(pwzServerName, pwzName, pwzComment);
hr = SetUserComment(pwzDomainName, pwzName, pwzComment);
if (FAILED(hr))
{
WcaLogError(hr, "failed to set user comment to %ls for user %ls\\%ls, continuing anyway.", pwzComment, pwzServerName, pwzName);
WcaLogError(hr, "failed to set user comment to %ls for user %ls\\%ls, continuing anyway.", pwzComment, pwzDomainName, pwzName);
hr = S_OK;
}
}
Expand All @@ -952,10 +893,10 @@ extern "C" UINT __stdcall CreateUser(

ApplyAttributes(iAttributes, &flags);

hr = SetUserFlags(pwzServerName, pwzName, flags);
hr = SetUserFlags(pwzDomainName, pwzName, flags);
if (FAILED(hr))
{
WcaLogError(hr, "failed to set user flags for user %ls\\%ls, continuing anyway.", pwzServerName, pwzName);
WcaLogError(hr, "failed to set user flags for user %ls\\%ls, continuing anyway.", pwzDomainName, pwzName);
hr = S_OK;
}
}
Expand Down Expand Up @@ -1018,6 +959,7 @@ extern "C" UINT __stdcall CreateUser(
ReleaseStr(pwzPassword);
ReleaseStr(pwzGroup);
ReleaseStr(pwzGroupDomain);
ReleaseStr(pwzDomainName)

if (fInitializedCom)
{
Expand Down
48 changes: 8 additions & 40 deletions src/ext/Util/ca/scauser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ HRESULT ScaUserExecute(
{
HRESULT hr = S_OK;
DWORD er = 0;
PDOMAIN_CONTROLLER_INFOW pDomainControllerInfo = NULL;
LPWSTR pwzDomainName = NULL;

LPWSTR pwzBaseScriptKey = NULL;
DWORD cScriptKey = 0;
Expand Down Expand Up @@ -518,36 +518,11 @@ HRESULT ScaUserExecute(
ExitOnFailure(hr, "Failed to add user comment to custom action data: %ls", psu->wzComment);

// Check to see if the user already exists since we have to be very careful when adding
// and removing users. Note: MSDN says that it is safe to call these APIs from any
// user, so we should be safe calling it during immediate mode.
er = ::NetApiBufferAllocate(sizeof(USER_INFO_0), reinterpret_cast<LPVOID*>(&pUserInfo));
hr = HRESULT_FROM_WIN32(er);
ExitOnFailure(hr, "Failed to allocate memory to check existence of user: %ls", psu->wzName);

LPCWSTR wzDomain = psu->wzDomain;
if (wzDomain && *wzDomain)
{
er = ::DsGetDcNameW(NULL, wzDomain, NULL, NULL, NULL, &pDomainControllerInfo);
if (RPC_S_SERVER_UNAVAILABLE == er)
{
// MSDN says, if we get the above error code, try again with the "DS_FORCE_REDISCOVERY" flag
er = ::DsGetDcNameW(NULL, wzDomain, NULL, NULL, DS_FORCE_REDISCOVERY, &pDomainControllerInfo);
}
if (ERROR_SUCCESS == er && pDomainControllerInfo->DomainControllerName)
{
// If the \\ prefix on the queried domain was present, skip it.
if ('\\' == *pDomainControllerInfo->DomainControllerName && '\\' == *pDomainControllerInfo->DomainControllerName + 1)
{
wzDomain = pDomainControllerInfo->DomainControllerName + 2;
}
else
{
wzDomain = pDomainControllerInfo->DomainControllerName;
}
}
}
// and removing users.
hr = GetDomainFromServerName(&pwzDomainName, psu->wzDomain, 0);
ExitOnFailure(hr, "Failed to get domain from server name: %ls", psu->wzDomain);

er = ::NetUserGetInfo(wzDomain, psu->wzName, 0, reinterpret_cast<LPBYTE*>(pUserInfo));
er = ::NetUserGetInfo(pwzDomainName, psu->wzName, 0, reinterpret_cast<LPBYTE*>(&pUserInfo));
if (NERR_Success == er)
{
ueUserExists = USER_EXISTS_YES;
Expand All @@ -560,7 +535,7 @@ HRESULT ScaUserExecute(
{
ueUserExists = USER_EXISTS_INDETERMINATE;
hr = HRESULT_FROM_WIN32(er);
WcaLog(LOGMSG_VERBOSE, "Failed to check existence of domain: %ls, user: %ls (error code 0x%x) - continuing", wzDomain, psu->wzName, hr);
WcaLog(LOGMSG_VERBOSE, "Failed to check existence of domain: %ls, user: %ls (error code 0x%x) - continuing", pwzDomainName, psu->wzName, hr);
hr = S_OK;
er = ERROR_SUCCESS;
}
Expand Down Expand Up @@ -685,26 +660,19 @@ HRESULT ScaUserExecute(
::NetApiBufferFree(static_cast<LPVOID>(pUserInfo));
pUserInfo = NULL;
}
if (pDomainControllerInfo)
{
::NetApiBufferFree(static_cast<LPVOID>(pDomainControllerInfo));
pDomainControllerInfo = NULL;
}
}

LExit:
ReleaseStr(pwzBaseScriptKey);
ReleaseStr(pwzScriptKey);
ReleaseStr(pwzActionData);
ReleaseStr(pwzRollbackData);
ReleaseStr(pwzDomainName);

if (pUserInfo)
{
::NetApiBufferFree(static_cast<LPVOID>(pUserInfo));
}
if (pDomainControllerInfo)
{
::NetApiBufferFree(static_cast<LPVOID>(pDomainControllerInfo));
}

return hr;
}
Expand Down
56 changes: 56 additions & 0 deletions src/ext/Util/ca/utilca.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,59 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information.

#include "precomp.h"

HRESULT GetDomainFromServerName(
__inout_z LPWSTR* psczDomainName,
__in_z LPCWSTR wzServerName,
__in DWORD dwFlags
)
{
HRESULT hr = S_OK;
DWORD er = ERROR_SUCCESS;
PDOMAIN_CONTROLLER_INFOW pDomainControllerInfo = NULL;
LPCWSTR wz = wzServerName ? wzServerName : L""; // initialize the domain to the provided server name (or empty string).

// If the server name was not empty, try to get the domain name out of it.
if (*wz)
{
er = ::DsGetDcNameW(NULL, wz, NULL, NULL, dwFlags, &pDomainControllerInfo);
if (RPC_S_SERVER_UNAVAILABLE == er)
{
// MSDN says, if we get the above error code, try again with the "DS_FORCE_REDISCOVERY" flag.
er = ::DsGetDcNameW(NULL, wz, NULL, NULL, dwFlags | DS_FORCE_REDISCOVERY, &pDomainControllerInfo);
}
ExitOnWin32Error(er, hr, "Could not get domain name from server name: %ls", wz);

if (pDomainControllerInfo->DomainControllerName)
{
// Skip the \\ prefix if present.
if ('\\' == *pDomainControllerInfo->DomainControllerName && '\\' == *pDomainControllerInfo->DomainControllerName + 1)
{
wz = pDomainControllerInfo->DomainControllerName + 2;
}
else
{
wz = pDomainControllerInfo->DomainControllerName;
}
}
}

LExit:
// Note: we overwrite the error code here as failure to contact domain controller above is not a fatal error.
if (wz && *wz)
{
hr = StrAllocString(psczDomainName, wz, 0);
}
else // return NULL the server name ended up empty.
{
ReleaseNullStr(psczDomainName);
hr = S_OK;
}

if (pDomainControllerInfo)
{
::NetApiBufferFree((LPVOID)pDomainControllerInfo);
}

return hr;
}
8 changes: 8 additions & 0 deletions src/ext/Util/ca/utilca.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma once
// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information.

HRESULT GetDomainFromServerName(
__inout_z LPWSTR* psczDomain,
__in_z LPCWSTR wzServerName,
__in DWORD dwFlags
);

0 comments on commit 9280ac1

Please sign in to comment.