diff --git a/library/agent/hooks/wrapExport.ts b/library/agent/hooks/wrapExport.ts index 4c9624975..fd9fd8e1f 100644 --- a/library/agent/hooks/wrapExport.ts +++ b/library/agent/hooks/wrapExport.ts @@ -154,7 +154,16 @@ function inspectArgs( module: pkgInfo.name, }); } + onInspectionInterceptorResult(context, agent, result, pkgInfo, start); +} +export function onInspectionInterceptorResult( + context: ReturnType, + agent: Agent, + result: InterceptorResult, + pkgInfo: WrapPackageInfo, + start: number +) { const end = performance.now(); agent.getInspectionStatistics().onInspectedCall({ sink: pkgInfo.name, diff --git a/library/agent/hooks/wrapNewInstance.ts b/library/agent/hooks/wrapNewInstance.ts index 3923de4c8..c916d406c 100644 --- a/library/agent/hooks/wrapNewInstance.ts +++ b/library/agent/hooks/wrapNewInstance.ts @@ -9,7 +9,7 @@ export function wrapNewInstance( subject: unknown, className: string | undefined, pkgInfo: WrapPackageInfo, - interceptor: (exports: any) => void + interceptor: (exports: any) => void | unknown ) { const agent = getInstance(); if (!agent) { @@ -28,7 +28,17 @@ export function wrapNewInstance( // @ts-expect-error It's a constructor const newInstance = new original(...args); - interceptor(newInstance); + try { + const returnVal = interceptor(newInstance); + if (returnVal) { + return returnVal; + } + } catch (error) { + agent.onFailedToWrapMethod( + pkgInfo.name, + className || "default export" + ); + } return newInstance; }; diff --git a/library/sinks/Prisma.ts b/library/sinks/Prisma.ts index 325d30949..b0c4a5409 100644 --- a/library/sinks/Prisma.ts +++ b/library/sinks/Prisma.ts @@ -1,6 +1,5 @@ import type { Hooks } from "../agent/hooks/Hooks"; import { Wrapper } from "../agent/Wrapper"; -import { wrapExport } from "../agent/hooks/wrapExport"; import { wrapNewInstance } from "../agent/hooks/wrapNewInstance"; import { SQLDialect } from "../vulnerabilities/sql-injection/dialects/SQLDialect"; import { SQLDialectMySQL } from "../vulnerabilities/sql-injection/dialects/SQLDialectMySQL"; @@ -10,47 +9,71 @@ import { SQLDialectSQLite } from "../vulnerabilities/sql-injection/dialects/SQLD import type { InterceptorResult } from "../agent/hooks/InterceptorResult"; import { checkContextForSqlInjection } from "../vulnerabilities/sql-injection/checkContextForSqlInjection"; import { getContext } from "../agent/Context"; +import type { PrismaPromise } from "@prisma/client"; +import { onInspectionInterceptorResult } from "../agent/hooks/wrapExport"; +import { getInstance } from "../agent/AgentSingleton"; +import type { Agent } from "../agent/Agent"; +import { WrapPackageInfo } from "../agent/hooks/WrapPackageInfo"; + +type AllOperationsQueryExtension = { + model?: string; + operation: string; + args: any; + query: (args: any) => PrismaPromise; +}; export class Prisma implements Wrapper { - private rawSQLMethodsToWrap = ["$queryRawUnsafe", "$executeRawUnsafe"]; + private rawSQLMethodsToProtect = ["$queryRawUnsafe", "$executeRawUnsafe"]; + + // Check if the prisma client is a NoSQL client + private isNoSQLClient(clientInstance: any): boolean { + if ( + !clientInstance || + typeof clientInstance !== "object" || + !("_engineConfig" in clientInstance) || + !clientInstance._engineConfig || + typeof clientInstance._engineConfig !== "object" || + !("activeProvider" in clientInstance._engineConfig) || + typeof clientInstance._engineConfig.activeProvider !== "string" + ) { + return false; + } - private dialect: SQLDialect = new SQLDialectGeneric(); + return clientInstance._engineConfig.activeProvider === "mongodb"; + } // Try to detect the SQL dialect used by the Prisma client, so we can use the correct SQL dialect for the SQL injection detection. - private detectSQLDialect(clientInstance: any) { + private getClientSQLDialect(clientInstance: any): SQLDialect { // https://github.com/prisma/prisma/blob/559988a47e50b4d4655dc45b11ceb9b5c73ef053/packages/generator-helper/src/types.ts#L75 if ( !clientInstance || typeof clientInstance !== "object" || - !("_accelerateEngineConfig" in clientInstance) || - !clientInstance._accelerateEngineConfig || - typeof clientInstance._accelerateEngineConfig !== "object" || - !("activeProvider" in clientInstance._accelerateEngineConfig) || - typeof clientInstance._accelerateEngineConfig.activeProvider !== "string" + !("_engineConfig" in clientInstance) || + !clientInstance._engineConfig || + typeof clientInstance._engineConfig !== "object" || + !("activeProvider" in clientInstance._engineConfig) || + typeof clientInstance._engineConfig.activeProvider !== "string" ) { - return; + return new SQLDialectGeneric(); } - switch (clientInstance._accelerateEngineConfig.activeProvider) { + switch (clientInstance._engineConfig.activeProvider) { case "mysql": - this.dialect = new SQLDialectMySQL(); - break; + return new SQLDialectMySQL(); case "postgresql": case "postgres": - this.dialect = new SQLDialectPostgres(); - break; + return new SQLDialectPostgres(); case "sqlite": - this.dialect = new SQLDialectSQLite(); - break; + return new SQLDialectSQLite(); default: - // Already set to generic - break; + return new SQLDialectGeneric(); } } private inspectSQLQuery( args: unknown[], - operation: string + operation: string, + dialect: SQLDialect ): InterceptorResult { const context = getContext(); @@ -65,32 +88,89 @@ export class Prisma implements Wrapper { sql: sql, context: context, operation: `prisma.${operation}`, - dialect: this.dialect, + dialect: dialect, }); } return undefined; } + private onClientOperation({ + model, + operation, + args, + query, + isNoSQLClient, + sqlDialect, + agent, + pkgInfo, + }: AllOperationsQueryExtension & { + isNoSQLClient: boolean; + sqlDialect?: SQLDialect; + agent: Agent; + pkgInfo: WrapPackageInfo; + }) { + let inspectionResult: InterceptorResult | undefined; + const start = performance.now(); + + if (!isNoSQLClient && this.rawSQLMethodsToProtect.includes(operation)) { + inspectionResult = this.inspectSQLQuery( + args, + operation, + sqlDialect || new SQLDialectGeneric() + ); + } + + if (inspectionResult) { + onInspectionInterceptorResult( + getContext(), + agent, + inspectionResult, + pkgInfo, + start + ); + } + + return query(args); + } + wrap(hooks: Hooks) { hooks .addPackage("@prisma/client") .withVersion("^5.0.0") .onRequire((exports, pkgInfo) => { wrapNewInstance(exports, "PrismaClient", pkgInfo, (instance) => { - this.detectSQLDialect(instance); - - for (const method of this.rawSQLMethodsToWrap) { - if (typeof instance[method] === "function") { - wrapExport(instance, method, pkgInfo, { - inspectArgs: (args) => { - return this.inspectSQLQuery(args, method); - }, - }); - } + const isNoSQLClient = this.isNoSQLClient(instance); + + const agent = getInstance(); + if (!agent) { + return; } - // Todo support mongodb methods + // https://www.prisma.io/docs/orm/prisma-client/client-extensions/query#modify-all-operations-in-all-models-of-your-schema + return instance.$extends({ + query: { + $allOperations: ({ + model, + operation, + args, + query, + }: AllOperationsQueryExtension) => { + return this.onClientOperation({ + model, + operation, + args, + query, + isNoSQLClient, + sqlDialect: !isNoSQLClient + ? this.getClientSQLDialect(instance) + : undefined, + agent, + pkgInfo, + }); + }, + }, + }); }); }); }