Skip to content

Commit

Permalink
Rewrite Prisma sink using client extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
timokoessler committed Nov 25, 2024
1 parent 83f6b03 commit 31059d9
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 34 deletions.
9 changes: 9 additions & 0 deletions library/agent/hooks/wrapExport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,16 @@ function inspectArgs(
module: pkgInfo.name,
});
}
onInspectionInterceptorResult(context, agent, result, pkgInfo, start);
}

export function onInspectionInterceptorResult(
context: ReturnType<typeof getContext>,
agent: Agent,
result: InterceptorResult,
pkgInfo: WrapPackageInfo,
start: number
) {
const end = performance.now();
agent.getInspectionStatistics().onInspectedCall({
sink: pkgInfo.name,
Expand Down
14 changes: 12 additions & 2 deletions library/agent/hooks/wrapNewInstance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export function wrapNewInstance(
subject: unknown,
className: string | undefined,
pkgInfo: WrapPackageInfo,
interceptor: (exports: any) => void
interceptor: (exports: any) => void | unknown
) {
const agent = getInstance();
if (!agent) {
Expand All @@ -28,7 +28,17 @@ export function wrapNewInstance(
// @ts-expect-error It's a constructor
const newInstance = new original(...args);

interceptor(newInstance);
try {
const returnVal = interceptor(newInstance);
if (returnVal) {
return returnVal;
}
} catch (error) {
agent.onFailedToWrapMethod(
pkgInfo.name,
className || "default export"
);
}

return newInstance;
};
Expand Down
144 changes: 112 additions & 32 deletions library/sinks/Prisma.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import type { Hooks } from "../agent/hooks/Hooks";
import { Wrapper } from "../agent/Wrapper";
import { wrapExport } from "../agent/hooks/wrapExport";
import { wrapNewInstance } from "../agent/hooks/wrapNewInstance";
import { SQLDialect } from "../vulnerabilities/sql-injection/dialects/SQLDialect";
import { SQLDialectMySQL } from "../vulnerabilities/sql-injection/dialects/SQLDialectMySQL";
Expand All @@ -10,47 +9,71 @@ import { SQLDialectSQLite } from "../vulnerabilities/sql-injection/dialects/SQLD
import type { InterceptorResult } from "../agent/hooks/InterceptorResult";
import { checkContextForSqlInjection } from "../vulnerabilities/sql-injection/checkContextForSqlInjection";
import { getContext } from "../agent/Context";
import type { PrismaPromise } from "@prisma/client";
import { onInspectionInterceptorResult } from "../agent/hooks/wrapExport";
import { getInstance } from "../agent/AgentSingleton";
import type { Agent } from "../agent/Agent";
import { WrapPackageInfo } from "../agent/hooks/WrapPackageInfo";

type AllOperationsQueryExtension = {
model?: string;
operation: string;
args: any;
query: (args: any) => PrismaPromise<any>;
};

export class Prisma implements Wrapper {
private rawSQLMethodsToWrap = ["$queryRawUnsafe", "$executeRawUnsafe"];
private rawSQLMethodsToProtect = ["$queryRawUnsafe", "$executeRawUnsafe"];

// Check if the prisma client is a NoSQL client
private isNoSQLClient(clientInstance: any): boolean {
if (
!clientInstance ||
typeof clientInstance !== "object" ||
!("_engineConfig" in clientInstance) ||
!clientInstance._engineConfig ||
typeof clientInstance._engineConfig !== "object" ||
!("activeProvider" in clientInstance._engineConfig) ||
typeof clientInstance._engineConfig.activeProvider !== "string"
) {
return false;
}

private dialect: SQLDialect = new SQLDialectGeneric();
return clientInstance._engineConfig.activeProvider === "mongodb";
}

// Try to detect the SQL dialect used by the Prisma client, so we can use the correct SQL dialect for the SQL injection detection.
private detectSQLDialect(clientInstance: any) {
private getClientSQLDialect(clientInstance: any): SQLDialect {
// https://github.com/prisma/prisma/blob/559988a47e50b4d4655dc45b11ceb9b5c73ef053/packages/generator-helper/src/types.ts#L75
if (
!clientInstance ||
typeof clientInstance !== "object" ||
!("_accelerateEngineConfig" in clientInstance) ||
!clientInstance._accelerateEngineConfig ||
typeof clientInstance._accelerateEngineConfig !== "object" ||
!("activeProvider" in clientInstance._accelerateEngineConfig) ||
typeof clientInstance._accelerateEngineConfig.activeProvider !== "string"
!("_engineConfig" in clientInstance) ||
!clientInstance._engineConfig ||
typeof clientInstance._engineConfig !== "object" ||
!("activeProvider" in clientInstance._engineConfig) ||
typeof clientInstance._engineConfig.activeProvider !== "string"
) {
return;
return new SQLDialectGeneric();
}

switch (clientInstance._accelerateEngineConfig.activeProvider) {
switch (clientInstance._engineConfig.activeProvider) {
case "mysql":
this.dialect = new SQLDialectMySQL();
break;
return new SQLDialectMySQL();
case "postgresql":
case "postgres":
this.dialect = new SQLDialectPostgres();
break;
return new SQLDialectPostgres();
case "sqlite":
this.dialect = new SQLDialectSQLite();
break;
return new SQLDialectSQLite();
default:
// Already set to generic
break;
return new SQLDialectGeneric();
}
}

private inspectSQLQuery(
args: unknown[],
operation: string
operation: string,
dialect: SQLDialect
): InterceptorResult {
const context = getContext();

Expand All @@ -65,32 +88,89 @@ export class Prisma implements Wrapper {
sql: sql,
context: context,
operation: `prisma.${operation}`,
dialect: this.dialect,
dialect: dialect,
});
}

return undefined;
}

private onClientOperation({
model,
operation,
args,
query,
isNoSQLClient,
sqlDialect,
agent,
pkgInfo,
}: AllOperationsQueryExtension & {
isNoSQLClient: boolean;
sqlDialect?: SQLDialect;
agent: Agent;
pkgInfo: WrapPackageInfo;
}) {
let inspectionResult: InterceptorResult | undefined;
const start = performance.now();

if (!isNoSQLClient && this.rawSQLMethodsToProtect.includes(operation)) {
inspectionResult = this.inspectSQLQuery(
args,
operation,
sqlDialect || new SQLDialectGeneric()
);
}

if (inspectionResult) {
onInspectionInterceptorResult(
getContext(),
agent,
inspectionResult,
pkgInfo,
start
);
}

return query(args);
}

wrap(hooks: Hooks) {
hooks
.addPackage("@prisma/client")
.withVersion("^5.0.0")
.onRequire((exports, pkgInfo) => {
wrapNewInstance(exports, "PrismaClient", pkgInfo, (instance) => {
this.detectSQLDialect(instance);

for (const method of this.rawSQLMethodsToWrap) {
if (typeof instance[method] === "function") {
wrapExport(instance, method, pkgInfo, {
inspectArgs: (args) => {
return this.inspectSQLQuery(args, method);
},
});
}
const isNoSQLClient = this.isNoSQLClient(instance);

const agent = getInstance();
if (!agent) {
return;
}

// Todo support mongodb methods
// https://www.prisma.io/docs/orm/prisma-client/client-extensions/query#modify-all-operations-in-all-models-of-your-schema
return instance.$extends({
query: {
$allOperations: ({
model,
operation,
args,
query,
}: AllOperationsQueryExtension) => {
return this.onClientOperation({
model,
operation,
args,
query,
isNoSQLClient,
sqlDialect: !isNoSQLClient
? this.getClientSQLDialect(instance)
: undefined,
agent,
pkgInfo,
});
},
},
});
});
});
}
Expand Down

0 comments on commit 31059d9

Please sign in to comment.