Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.12] [Security Solution] Fix not complete existing rule overwrite when importing rules (#176166) #177270

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ export interface UpdateRulesOptions {
rulesClient: RulesClient;
existingRule: RuleAlertType | null | undefined;
ruleUpdate: RuleUpdateProps;
allowMissingConnectorSecrets?: boolean;
}

export const updateRules = async ({
rulesClient,
existingRule,
ruleUpdate,
allowMissingConnectorSecrets,
}: UpdateRulesOptions): Promise<PartialRule<RuleParams> | null> => {
if (existingRule == null) {
return null;
Expand Down Expand Up @@ -81,6 +83,7 @@ export const updateRules = async ({
const update = await rulesClient.update({
id: existingRule.id,
data: newInternalRule,
allowMissingConnectorSecrets,
});

if (existingRule.enabled && enabled === false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ import {
getRuleMock,
getEmptyFindResult,
getFindResultWithSingleHit,
getFindResultWithMultiHits,
} from '../../../routes/__mocks__/request_responses';

import { createRules } from '../crud/create_rules';
import { patchRules } from '../crud/patch_rules';
import { updateRules } from '../crud/update_rules';
import { importRules } from './import_rules_utils';

jest.mock('../crud/create_rules');
jest.mock('../crud/patch_rules');
jest.mock('../crud/update_rules');

describe('importRules', () => {
const mlAuthz = {
Expand Down Expand Up @@ -84,7 +85,7 @@ describe('importRules', () => {

expect(result).toEqual([{ rule_id: 'rule-1', status_code: 200 }]);
expect(createRules).toHaveBeenCalled();
expect(patchRules).not.toHaveBeenCalled();
expect(updateRules).not.toHaveBeenCalled();
});

it('reports error if "overwriteRules" is "false" and matching rule found', async () => {
Expand All @@ -106,10 +107,10 @@ describe('importRules', () => {
},
]);
expect(createRules).not.toHaveBeenCalled();
expect(patchRules).not.toHaveBeenCalled();
expect(updateRules).not.toHaveBeenCalled();
});

it('patches rule if "overwriteRules" is "true" and matching rule found', async () => {
it('updates rule if "overwriteRules" is "true" and matching rule found', async () => {
clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit());

const result = await importRules({
Expand All @@ -129,7 +130,53 @@ describe('importRules', () => {

expect(result).toEqual([{ rule_id: 'rule-1', status_code: 200 }]);
expect(createRules).not.toHaveBeenCalled();
expect(patchRules).toHaveBeenCalled();
expect(updateRules).toHaveBeenCalled();
});

/**
* Existing rule may have nullable fields set to a value (e.g. `timestamp_override` is set to `some.value`) but
* a rule to import doesn't have these fields set (e.g. `timestamp_override` is NOT present at all in the ndjson file).
* We expect the updated rule won't have such fields preserved (e.g. `timestamp_override` will be removed).
*
* Unit test is only able to check `updateRules()` receives a proper update object.
*/
it('ensures overwritten rule DOES NOT preserve fields missed in the imported rule when "overwriteRules" is "true" and matching rule found', async () => {
const existingRule = getRuleMock(
getQueryRuleParams({
timestampOverride: 'some.value',
})
);

clients.rulesClient.find.mockResolvedValue(
getFindResultWithMultiHits({ data: [existingRule] })
);

const result = await importRules({
ruleChunks: [
[
{
...getImportRulesSchemaMock(),
rule_id: 'rule-1',
},
],
],
rulesResponseAcc: [],
mlAuthz,
overwriteRules: true,
rulesClient: context.alerting.getRulesClient(),
existingLists: {},
});

expect(result).toEqual([{ rule_id: 'rule-1', status_code: 200 }]);
expect(createRules).not.toHaveBeenCalled();
expect(updateRules).toHaveBeenCalledWith(
expect.objectContaining({
ruleUpdate: expect.not.objectContaining({
timestamp_override: expect.anything(),
timestampOverride: expect.anything(),
}),
})
);
});

it('reports error if rulesClient throws', async () => {
Expand All @@ -154,7 +201,7 @@ describe('importRules', () => {
},
]);
expect(createRules).not.toHaveBeenCalled();
expect(patchRules).not.toHaveBeenCalled();
expect(updateRules).not.toHaveBeenCalled();
});

it('reports error if "createRules" throws', async () => {
Expand All @@ -180,8 +227,8 @@ describe('importRules', () => {
]);
});

it('reports error if "patchRules" throws', async () => {
(patchRules as jest.Mock).mockRejectedValue(new Error('error patching rule'));
it('reports error if "updateRules" throws', async () => {
(updateRules as jest.Mock).mockRejectedValue(new Error('import rule error'));
clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit());

const result = await importRules({
Expand All @@ -196,7 +243,7 @@ describe('importRules', () => {
expect(result).toEqual([
{
error: {
message: 'error patching rule',
message: 'import rule error',
status_code: 400,
},
rule_id: 'rule-1',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import type { ImportRuleResponse } from '../../../routes/utils';
import { createBulkErrorObject } from '../../../routes/utils';
import { createRules } from '../crud/create_rules';
import { readRules } from '../crud/read_rules';
import { patchRules } from '../crud/patch_rules';
import { updateRules } from '../crud/update_rules';
import type { MlAuthz } from '../../../../machine_learning/authz';
import { throwAuthzError } from '../../../../machine_learning/validation';
import { checkRuleExceptionReferences } from './check_rule_exception_references';
Expand Down Expand Up @@ -68,96 +68,94 @@ export const importRules = async ({
// otherwise we would output we are success importing 0 rules.
if (ruleChunks.length === 0) {
return importRuleResponse;
} else {
while (ruleChunks.length) {
const batchParseObjects = ruleChunks.shift() ?? [];
const newImportRuleResponse = await Promise.all(
batchParseObjects.reduce<Array<Promise<ImportRuleResponse>>>((accum, parsedRule) => {
const importsWorkerPromise = new Promise<ImportRuleResponse>(async (resolve, reject) => {
}

while (ruleChunks.length) {
const batchParseObjects = ruleChunks.shift() ?? [];
const newImportRuleResponse = await Promise.all(
batchParseObjects.reduce<Array<Promise<ImportRuleResponse>>>((accum, parsedRule) => {
const importsWorkerPromise = new Promise<ImportRuleResponse>(async (resolve, reject) => {
try {
if (parsedRule instanceof Error) {
// If the JSON object had a validation or parse error then we return
// early with the error and an (unknown) for the ruleId
resolve(
createBulkErrorObject({
statusCode: 400,
message: parsedRule.message,
})
);
return null;
}

try {
if (parsedRule instanceof Error) {
// If the JSON object had a validation or parse error then we return
// early with the error and an (unknown) for the ruleId
resolve(
createBulkErrorObject({
statusCode: 400,
message: parsedRule.message,
})
);
return null;
}
const [exceptionErrors, exceptions] = checkRuleExceptionReferences({
rule: parsedRule,
existingLists,
});

try {
const [exceptionErrors, exceptions] = checkRuleExceptionReferences({
rule: parsedRule,
existingLists,
});
importRuleResponse = [...importRuleResponse, ...exceptionErrors];

importRuleResponse = [...importRuleResponse, ...exceptionErrors];
throwAuthzError(await mlAuthz.validateRuleType(parsedRule.type));
const rule = await readRules({
rulesClient,
ruleId: parsedRule.rule_id,
id: undefined,
});

throwAuthzError(await mlAuthz.validateRuleType(parsedRule.type));
const rule = await readRules({
if (rule == null) {
await createRules({
rulesClient,
ruleId: parsedRule.rule_id,
id: undefined,
params: {
...parsedRule,
exceptions_list: [...exceptions],
},
allowMissingConnectorSecrets,
});

if (rule == null) {
await createRules({
rulesClient,
params: {
...parsedRule,
exceptions_list: [...exceptions],
},
allowMissingConnectorSecrets,
});
resolve({
rule_id: parsedRule.rule_id,
status_code: 200,
});
} else if (rule != null && overwriteRules) {
await patchRules({
rulesClient,
existingRule: rule,
nextParams: {
...parsedRule,
exceptions_list: [...exceptions],
},
allowMissingConnectorSecrets,
shouldIncrementRevision: false,
});
resolve({
rule_id: parsedRule.rule_id,
status_code: 200,
});
} else if (rule != null) {
resolve(
createBulkErrorObject({
ruleId: parsedRule.rule_id,
statusCode: 409,
message: `rule_id: "${parsedRule.rule_id}" already exists`,
})
);
}
} catch (err) {
resolve({
rule_id: parsedRule.rule_id,
status_code: 200,
});
} else if (rule != null && overwriteRules) {
await updateRules({
rulesClient,
existingRule: rule,
ruleUpdate: {
...parsedRule,
exceptions_list: [...exceptions],
},
});
resolve({
rule_id: parsedRule.rule_id,
status_code: 200,
});
} else if (rule != null) {
resolve(
createBulkErrorObject({
ruleId: parsedRule.rule_id,
statusCode: err.statusCode ?? 400,
message: err.message,
statusCode: 409,
message: `rule_id: "${parsedRule.rule_id}" already exists`,
})
);
}
} catch (error) {
reject(error);
} catch (err) {
resolve(
createBulkErrorObject({
ruleId: parsedRule.rule_id,
statusCode: err.statusCode ?? 400,
message: err.message,
})
);
}
});
return [...accum, importsWorkerPromise];
}, [])
);
importRuleResponse = [...importRuleResponse, ...newImportRuleResponse];
}

return importRuleResponse;
} catch (error) {
reject(error);
}
});
return [...accum, importsWorkerPromise];
}, [])
);
importRuleResponse = [...importRuleResponse, ...newImportRuleResponse];
}

return importRuleResponse;
};
Loading