diff --git a/drizzle-orm/src/monodriver.ts b/drizzle-orm/src/monodriver.ts index 47e1881e1..0706f271a 100644 --- a/drizzle-orm/src/monodriver.ts +++ b/drizzle-orm/src/monodriver.ts @@ -209,9 +209,14 @@ export async function drizzle< & InitializerParams[TClient] & (TClient extends 'mysql2' ? MySql2DrizzleConfig : TClient extends 'aws-data-api-pg' ? DrizzleAwsDataApiPgConfig + : TClient extends 'neon-serverless' ? DrizzleConfig & { + ws?: any; + } : DrizzleConfig), ): Promise> { - const { connection, ...drizzleConfig } = params; + const { connection, ws, ...drizzleConfig } = params as typeof params & { + ws?: any; + }; switch (client) { case 'node-postgres': { @@ -320,10 +325,16 @@ export async function drizzle< return db; } case 'neon-serverless': { - const { Pool } = await import('@neondatabase/serverless').catch(() => importError('@neondatabase/serverless')); + const { Pool, neonConfig } = await import('@neondatabase/serverless').catch(() => + importError('@neondatabase/serverless') + ); const { drizzle } = await import('./neon-serverless'); const instance = new Pool(connection as NeonServerlessConfig); + if (ws) { + neonConfig.webSocketConstructor = ws; + } + const db = drizzle(instance, drizzleConfig) as any; db.$client = instance; diff --git a/drizzle-orm/src/mysql-core/db.ts b/drizzle-orm/src/mysql-core/db.ts index 419359022..8934c0edf 100644 --- a/drizzle-orm/src/mysql-core/db.ts +++ b/drizzle-orm/src/mysql-core/db.ts @@ -140,7 +140,7 @@ export class MySqlDatabase< source: MySqlTable | MySqlViewBase | SQL | SQLWrapper, filters?: SQL, ) { - return new MySqlCountBuilder({ source, filters, dialect: this.dialect, session: this.session }); + return new MySqlCountBuilder({ source, filters, session: this.session }); } /** diff --git a/drizzle-orm/src/mysql-core/query-builders/count.ts b/drizzle-orm/src/mysql-core/query-builders/count.ts index 751ba61c7..645bb4753 100644 --- a/drizzle-orm/src/mysql-core/query-builders/count.ts +++ b/drizzle-orm/src/mysql-core/query-builders/count.ts @@ -1,7 +1,6 @@ import { entityKind, sql } from '~/index.ts'; import type { SQLWrapper } from '~/sql/sql.ts'; import { SQL } from '~/sql/sql.ts'; -import type { MySqlDialect } from '../dialect.ts'; import type { MySqlSession } from '../session.ts'; import type { MySqlTable } from '../table.ts'; import type { MySqlViewBase } from '../view-base.ts'; @@ -27,19 +26,20 @@ export class MySqlCountBuilder< source: MySqlTable | MySqlViewBase | SQL | SQLWrapper, filters?: SQL, ): SQL { - return sql`select count(*) from ${source}${sql.raw(' where ').if(filters)}${filters}`; + return sql`select count(*) as count from ${source}${sql.raw(' where ').if(filters)}${filters}`; } constructor( readonly params: { source: MySqlTable | MySqlViewBase | SQL | SQLWrapper; filters?: SQL; - dialect: MySqlDialect; session: TSession; }, ) { super(MySqlCountBuilder.buildEmbeddedCount(params.source, params.filters).queryChunks); + this.mapWith(Number); + this.session = params.session; this.sql = MySqlCountBuilder.buildCount( @@ -52,9 +52,7 @@ export class MySqlCountBuilder< onfulfilled?: ((value: number) => TResult1 | PromiseLike) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike) | null | undefined, ): Promise { - return Promise.resolve(this.session.all(this.sql)).then((it) => { - return (<[{ 'count(*)': number }]> it)[0]['count(*)']; - }) + return Promise.resolve(this.session.count(this.sql)) .then( onfulfilled, onrejected, diff --git a/drizzle-orm/src/mysql-core/session.ts b/drizzle-orm/src/mysql-core/session.ts index 6b6269639..021d4276d 100644 --- a/drizzle-orm/src/mysql-core/session.ts +++ b/drizzle-orm/src/mysql-core/session.ts @@ -86,6 +86,14 @@ export abstract class MySqlSession< abstract all(query: SQL): Promise; + async count(sql: SQL): Promise { + const res = await this.execute<[[{ count: string }]]>(sql); + + return Number( + res[0][0]['count'], + ); + } + abstract transaction( transaction: (tx: MySqlTransaction) => Promise, config?: MySqlTransactionConfig, diff --git a/drizzle-orm/src/neon-http/session.ts b/drizzle-orm/src/neon-http/session.ts index 6d7685116..4dd768d3e 100644 --- a/drizzle-orm/src/neon-http/session.ts +++ b/drizzle-orm/src/neon-http/session.ts @@ -10,7 +10,7 @@ import type { PgQueryResultHKT, PgTransactionConfig, PreparedQueryConfig } from import { PgPreparedQuery as PgPreparedQuery, PgSession } from '~/pg-core/session.ts'; import type { RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts'; import type { PreparedQuery } from '~/session.ts'; -import { fillPlaceholders, type Query } from '~/sql/sql.ts'; +import { fillPlaceholders, type Query, type SQL } from '~/sql/sql.ts'; import { mapResultRow } from '~/utils.ts'; export type NeonHttpClient = NeonQueryFunction; @@ -161,6 +161,14 @@ export class NeonHttpSession< return this.client(query, params, { arrayMode: false, fullResults: true }); } + override async count(sql: SQL): Promise { + const res = await this.execute<{ rows: [{ count: string }] }>(sql); + + return Number( + res['rows'][0]['count'], + ); + } + override async transaction( _transaction: (tx: NeonTransaction) => Promise, // eslint-disable-next-line @typescript-eslint/no-unused-vars diff --git a/drizzle-orm/src/node-postgres/session.ts b/drizzle-orm/src/node-postgres/session.ts index 91a21312a..ef6779354 100644 --- a/drizzle-orm/src/node-postgres/session.ts +++ b/drizzle-orm/src/node-postgres/session.ts @@ -8,7 +8,7 @@ import type { SelectedFieldsOrdered } from '~/pg-core/query-builders/select.type import type { PgQueryResultHKT, PgTransactionConfig, PreparedQueryConfig } from '~/pg-core/session.ts'; import { PgPreparedQuery, PgSession } from '~/pg-core/session.ts'; import type { RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts'; -import { fillPlaceholders, type Query, sql } from '~/sql/sql.ts'; +import { fillPlaceholders, type Query, type SQL, sql } from '~/sql/sql.ts'; import { tracer } from '~/tracing.ts'; import { type Assume, mapResultRow } from '~/utils.ts'; @@ -164,6 +164,13 @@ export class NodePgSession< } } } + + override async count(sql: SQL): Promise { + const res = await this.execute<{ rows: [{ count: string }] }>(sql); + return Number( + res['rows'][0]['count'], + ); + } } export class NodePgTransaction< diff --git a/drizzle-orm/src/pg-core/db.ts b/drizzle-orm/src/pg-core/db.ts index 85dc797c9..62b64fb8f 100644 --- a/drizzle-orm/src/pg-core/db.ts +++ b/drizzle-orm/src/pg-core/db.ts @@ -141,7 +141,7 @@ export class PgDatabase< source: PgTable | PgViewBase | SQL | SQLWrapper, filters?: SQL, ) { - return new PgCountBuilder({ source, filters, dialect: this.dialect, session: this.session }); + return new PgCountBuilder({ source, filters, session: this.session }); } /** diff --git a/drizzle-orm/src/pg-core/query-builders/count.ts b/drizzle-orm/src/pg-core/query-builders/count.ts index 7ccd722a0..c93cbb18d 100644 --- a/drizzle-orm/src/pg-core/query-builders/count.ts +++ b/drizzle-orm/src/pg-core/query-builders/count.ts @@ -1,7 +1,6 @@ import { entityKind, sql } from '~/index.ts'; import type { SQLWrapper } from '~/sql/sql.ts'; import { SQL } from '~/sql/sql.ts'; -import type { PgDialect } from '../dialect.ts'; import type { PgSession } from '../session.ts'; import type { PgTable } from '../table.ts'; @@ -19,26 +18,27 @@ export class PgCountBuilder< source: PgTable | SQL | SQLWrapper, filters?: SQL, ): SQL { - return sql`(select count(*)::int from ${source}${sql.raw(' where ').if(filters)}${filters})`; + return sql`(select count(*) from ${source}${sql.raw(' where ').if(filters)}${filters})`; } private static buildCount( source: PgTable | SQL | SQLWrapper, filters?: SQL, ): SQL { - return sql`select count(*)::int from ${source}${sql.raw(' where ').if(filters)}${filters};`; + return sql`select count(*) as count from ${source}${sql.raw(' where ').if(filters)}${filters};`; } constructor( readonly params: { source: PgTable | SQL | SQLWrapper; filters?: SQL; - dialect: PgDialect; session: TSession; }, ) { super(PgCountBuilder.buildEmbeddedCount(params.source, params.filters).queryChunks); + this.mapWith(Number); + this.session = params.session; this.sql = PgCountBuilder.buildCount( @@ -51,9 +51,7 @@ export class PgCountBuilder< onfulfilled?: ((value: number) => TResult1 | PromiseLike) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike) | null | undefined, ): Promise { - return Promise.resolve(this.session.all(this.sql)).then((it) => { - return (<[{ count: number }]> it)[0]['count'] as number; - }) + return Promise.resolve(this.session.count(this.sql)) .then( onfulfilled, onrejected, diff --git a/drizzle-orm/src/pg-core/session.ts b/drizzle-orm/src/pg-core/session.ts index 434ebc086..ea820f2d8 100644 --- a/drizzle-orm/src/pg-core/session.ts +++ b/drizzle-orm/src/pg-core/session.ts @@ -86,6 +86,14 @@ export abstract class PgSession< ).all(); } + async count(sql: SQL): Promise { + const res = await this.execute<[{ count: string }]>(sql); + + return Number( + res[0]['count'], + ); + } + abstract transaction( transaction: (tx: PgTransaction) => Promise, config?: PgTransactionConfig, diff --git a/drizzle-orm/src/pg-proxy/session.ts b/drizzle-orm/src/pg-proxy/session.ts index eb6a1b1a3..1a30c0a3c 100644 --- a/drizzle-orm/src/pg-proxy/session.ts +++ b/drizzle-orm/src/pg-proxy/session.ts @@ -130,7 +130,8 @@ export class PreparedQuery extends PreparedQueryB }); } - async all() {} + async all() { + } /** @internal */ isResponseInArrayMode(): boolean { diff --git a/drizzle-orm/src/pglite/session.ts b/drizzle-orm/src/pglite/session.ts index c7a1dbb5d..ebf7701a6 100644 --- a/drizzle-orm/src/pglite/session.ts +++ b/drizzle-orm/src/pglite/session.ts @@ -7,7 +7,7 @@ import type { SelectedFieldsOrdered } from '~/pg-core/query-builders/select.type import type { PgQueryResultHKT, PgTransactionConfig, PreparedQueryConfig } from '~/pg-core/session.ts'; import { PgPreparedQuery, PgSession } from '~/pg-core/session.ts'; import type { RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts'; -import { fillPlaceholders, type Query, sql } from '~/sql/sql.ts'; +import { fillPlaceholders, type Query, type SQL, sql } from '~/sql/sql.ts'; import { type Assume, mapResultRow } from '~/utils.ts'; import { types } from '@electric-sql/pglite'; @@ -140,6 +140,13 @@ export class PgliteSession< return transaction(tx); }) as Promise; } + + override async count(sql: SQL): Promise { + const res = await this.execute<{ rows: [{ count: string }] }>(sql); + return Number( + res['rows'][0]['count'], + ); + } } export class PgliteTransaction< diff --git a/drizzle-orm/src/planetscale-serverless/session.ts b/drizzle-orm/src/planetscale-serverless/session.ts index f2275b7f2..987529d7c 100644 --- a/drizzle-orm/src/planetscale-serverless/session.ts +++ b/drizzle-orm/src/planetscale-serverless/session.ts @@ -164,6 +164,14 @@ export class PlanetscaleSession< ) => eQuery.rows as T[]); } + override async count(sql: SQL): Promise { + const res = await this.execute<{ rows: [{ count: string }] }>(sql); + + return Number( + res['rows'][0]['count'], + ); + } + override transaction( transaction: (tx: PlanetScaleTransaction) => Promise, ): Promise { diff --git a/drizzle-orm/src/sqlite-core/db.ts b/drizzle-orm/src/sqlite-core/db.ts index 75b088f6d..7ae2736e0 100644 --- a/drizzle-orm/src/sqlite-core/db.ts +++ b/drizzle-orm/src/sqlite-core/db.ts @@ -140,7 +140,7 @@ export class BaseSQLiteDatabase< source: SQLiteTable | SQLiteViewBase | SQL | SQLWrapper, filters?: SQL, ) { - return new SQLiteCountBuilder({ source, filters, dialect: this.dialect, session: this.session }); + return new SQLiteCountBuilder({ source, filters, session: this.session }); } /** diff --git a/drizzle-orm/src/sqlite-core/query-builders/count.ts b/drizzle-orm/src/sqlite-core/query-builders/count.ts index ed6cd9a1d..1b19eed07 100644 --- a/drizzle-orm/src/sqlite-core/query-builders/count.ts +++ b/drizzle-orm/src/sqlite-core/query-builders/count.ts @@ -1,7 +1,6 @@ import { entityKind, sql } from '~/index.ts'; import type { SQLWrapper } from '~/sql/sql.ts'; import { SQL } from '~/sql/sql.ts'; -import type { SQLiteDialect } from '../dialect.ts'; import type { SQLiteSession } from '../session.ts'; import type { SQLiteTable } from '../table.ts'; import type { SQLiteView } from '../view.ts'; @@ -34,7 +33,6 @@ export class SQLiteCountBuilder< readonly params: { source: SQLiteTable | SQLiteView | SQL | SQLWrapper; filters?: SQL; - dialect: SQLiteDialect; session: TSession; }, ) { @@ -52,7 +50,7 @@ export class SQLiteCountBuilder< onfulfilled?: ((value: number) => TResult1 | PromiseLike) | null | undefined, onrejected?: ((reason: any) => TResult2 | PromiseLike) | null | undefined, ): Promise { - return Promise.resolve(this.session.values(this.sql)).then((it) => it[0]![0] as number).then( + return Promise.resolve(this.session.count(this.sql)).then( onfulfilled, onrejected, ); diff --git a/drizzle-orm/src/sqlite-core/session.ts b/drizzle-orm/src/sqlite-core/session.ts index 4ac987b4a..d291b6fdf 100644 --- a/drizzle-orm/src/sqlite-core/session.ts +++ b/drizzle-orm/src/sqlite-core/session.ts @@ -187,6 +187,12 @@ export abstract class SQLiteSession< >; } + async count(sql: SQL) { + const result = await this.values(sql) as [[number]]; + + return result[0][0]; + } + /** @internal */ extractRawValuesValueFromBatchResult(_result: unknown): unknown { throw new Error('Not implemented'); diff --git a/drizzle-orm/src/tidb-serverless/session.ts b/drizzle-orm/src/tidb-serverless/session.ts index 64a8d61d7..b01b9f948 100644 --- a/drizzle-orm/src/tidb-serverless/session.ts +++ b/drizzle-orm/src/tidb-serverless/session.ts @@ -139,6 +139,14 @@ export class TiDBServerlessSession< return this.client.execute(querySql.sql, querySql.params) as Promise; } + override async count(sql: SQL): Promise { + const res = await this.execute<{ rows: [{ count: string }] }>(sql); + + return Number( + res['rows'][0]['count'], + ); + } + override async transaction( transaction: (tx: TiDBServerlessTransaction) => Promise, ): Promise { diff --git a/integration-tests/tests/mysql/mysql-common.ts b/integration-tests/tests/mysql/mysql-common.ts index 05c69cada..8a2fb768b 100644 --- a/integration-tests/tests/mysql/mysql-common.ts +++ b/integration-tests/tests/mysql/mysql-common.ts @@ -3581,7 +3581,7 @@ export function tests(driver?: string) { test('$count separate', async (ctx) => { const { db } = ctx.mysql; - const countTestTable = mysqlTable('users_distinct', { + const countTestTable = mysqlTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -3606,7 +3606,7 @@ export function tests(driver?: string) { test('$count embedded', async (ctx) => { const { db } = ctx.mysql; - const countTestTable = mysqlTable('users_distinct', { + const countTestTable = mysqlTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -3638,7 +3638,7 @@ export function tests(driver?: string) { test('$count separate reuse', async (ctx) => { const { db } = ctx.mysql; - const countTestTable = mysqlTable('users_distinct', { + const countTestTable = mysqlTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -3675,7 +3675,7 @@ export function tests(driver?: string) { test('$count embedded reuse', async (ctx) => { const { db } = ctx.mysql; - const countTestTable = mysqlTable('users_distinct', { + const countTestTable = mysqlTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -3706,8 +3706,6 @@ export function tests(driver?: string) { await db.execute(sql`drop table ${countTestTable}`); - await db.execute(sql`drop table ${countTestTable}`); - expect(count1).toStrictEqual([ { count: 4 }, { count: 4 }, @@ -3734,7 +3732,7 @@ export function tests(driver?: string) { test('$count separate with filters', async (ctx) => { const { db } = ctx.mysql; - const countTestTable = mysqlTable('users_distinct', { + const countTestTable = mysqlTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -3759,7 +3757,7 @@ export function tests(driver?: string) { test('$count embedded with filters', async (ctx) => { const { db } = ctx.mysql; - const countTestTable = mysqlTable('users_distinct', { + const countTestTable = mysqlTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -3784,31 +3782,32 @@ export function tests(driver?: string) { { count: 3 }, { count: 3 }, { count: 3 }, + { count: 3 }, ]); }); - }); - test('limit 0', async (ctx) => { - const { db } = ctx.mysql; + test('limit 0', async (ctx) => { + const { db } = ctx.mysql; - await db.insert(usersTable).values({ name: 'John' }); - const users = await db - .select() - .from(usersTable) - .limit(0); + await db.insert(usersTable).values({ name: 'John' }); + const users = await db + .select() + .from(usersTable) + .limit(0); - expect(users).toEqual([]); - }); + expect(users).toEqual([]); + }); - test('limit -1', async (ctx) => { - const { db } = ctx.mysql; + test('limit -1', async (ctx) => { + const { db } = ctx.mysql; - await db.insert(usersTable).values({ name: 'John' }); - const users = await db - .select() - .from(usersTable) - .limit(-1); + await db.insert(usersTable).values({ name: 'John' }); + const users = await db + .select() + .from(usersTable) + .limit(-1); - expect(users.length).toBeGreaterThan(0); + expect(users.length).toBeGreaterThan(0); + }); }); } diff --git a/integration-tests/tests/pg/pg-common.ts b/integration-tests/tests/pg/pg-common.ts index 3f3dd75cc..b4afdcf64 100644 --- a/integration-tests/tests/pg/pg-common.ts +++ b/integration-tests/tests/pg/pg-common.ts @@ -4664,7 +4664,7 @@ export function tests() { test('$count separate', async (ctx) => { const { db } = ctx.pg; - const countTestTable = pgTable('users_distinct', { + const countTestTable = pgTable('count_test', { id: integer('id').notNull(), name: text('name').notNull(), }); @@ -4689,7 +4689,7 @@ export function tests() { test('$count embedded', async (ctx) => { const { db } = ctx.pg; - const countTestTable = pgTable('users_distinct', { + const countTestTable = pgTable('count_test', { id: integer('id').notNull(), name: text('name').notNull(), }); @@ -4721,7 +4721,7 @@ export function tests() { test('$count separate reuse', async (ctx) => { const { db } = ctx.pg; - const countTestTable = pgTable('users_distinct', { + const countTestTable = pgTable('count_test', { id: integer('id').notNull(), name: text('name').notNull(), }); @@ -4758,7 +4758,7 @@ export function tests() { test('$count embedded reuse', async (ctx) => { const { db } = ctx.pg; - const countTestTable = pgTable('users_distinct', { + const countTestTable = pgTable('count_test', { id: integer('id').notNull(), name: text('name').notNull(), }); @@ -4789,8 +4789,6 @@ export function tests() { await db.execute(sql`drop table ${countTestTable}`); - await db.execute(sql`drop table ${countTestTable}`); - expect(count1).toStrictEqual([ { count: 4 }, { count: 4 }, @@ -4817,7 +4815,7 @@ export function tests() { test('$count separate with filters', async (ctx) => { const { db } = ctx.pg; - const countTestTable = pgTable('users_distinct', { + const countTestTable = pgTable('count_test', { id: integer('id').notNull(), name: text('name').notNull(), }); @@ -4842,7 +4840,7 @@ export function tests() { test('$count embedded with filters', async (ctx) => { const { db } = ctx.pg; - const countTestTable = pgTable('users_distinct', { + const countTestTable = pgTable('count_test', { id: integer('id').notNull(), name: text('name').notNull(), }); @@ -4867,6 +4865,7 @@ export function tests() { { count: 3 }, { count: 3 }, { count: 3 }, + { count: 3 }, ]); }); }); diff --git a/integration-tests/tests/sqlite/sqlite-common.ts b/integration-tests/tests/sqlite/sqlite-common.ts index ed13f5b7b..e8ddb86e6 100644 --- a/integration-tests/tests/sqlite/sqlite-common.ts +++ b/integration-tests/tests/sqlite/sqlite-common.ts @@ -2683,7 +2683,7 @@ export function tests() { test('$count separate', async (ctx) => { const { db } = ctx.sqlite; - const countTestTable = sqliteTable('users_distinct', { + const countTestTable = sqliteTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -2708,7 +2708,7 @@ export function tests() { test('$count embedded', async (ctx) => { const { db } = ctx.sqlite; - const countTestTable = sqliteTable('users_distinct', { + const countTestTable = sqliteTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -2740,7 +2740,7 @@ export function tests() { test('$count separate reuse', async (ctx) => { const { db } = ctx.sqlite; - const countTestTable = sqliteTable('users_distinct', { + const countTestTable = sqliteTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -2777,7 +2777,7 @@ export function tests() { test('$count embedded reuse', async (ctx) => { const { db } = ctx.sqlite; - const countTestTable = sqliteTable('users_distinct', { + const countTestTable = sqliteTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -2808,8 +2808,6 @@ export function tests() { await db.run(sql`drop table ${countTestTable}`); - await db.run(sql`drop table ${countTestTable}`); - expect(count1).toStrictEqual([ { count: 4 }, { count: 4 }, @@ -2836,7 +2834,7 @@ export function tests() { test('$count separate with filters', async (ctx) => { const { db } = ctx.sqlite; - const countTestTable = sqliteTable('users_distinct', { + const countTestTable = sqliteTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -2861,7 +2859,7 @@ export function tests() { test('$count embedded with filters', async (ctx) => { const { db } = ctx.sqlite; - const countTestTable = sqliteTable('users_distinct', { + const countTestTable = sqliteTable('count_test', { id: int('id').notNull(), name: text('name').notNull(), }); @@ -2886,6 +2884,7 @@ export function tests() { { count: 3 }, { count: 3 }, { count: 3 }, + { count: 3 }, ]); }); });