Skip to content

Commit

Permalink
Merge pull request #2908 from drizzle-team/count-generator
Browse files Browse the repository at this point in the history
Additional fixes for `$count()` generator
  • Loading branch information
AndriiSherman authored Sep 5, 2024
2 parents 2bc0d1e + 8cf7a61 commit efd821d
Show file tree
Hide file tree
Showing 19 changed files with 130 additions and 67 deletions.
15 changes: 13 additions & 2 deletions drizzle-orm/src/monodriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,14 @@ export async function drizzle<
& InitializerParams[TClient]
& (TClient extends 'mysql2' ? MySql2DrizzleConfig<TSchema>
: TClient extends 'aws-data-api-pg' ? DrizzleAwsDataApiPgConfig<TSchema>
: TClient extends 'neon-serverless' ? DrizzleConfig<TSchema> & {
ws?: any;
}
: DrizzleConfig<TSchema>),
): Promise<DetermineClient<TClient, TSchema>> {
const { connection, ...drizzleConfig } = params;
const { connection, ws, ...drizzleConfig } = params as typeof params & {
ws?: any;
};

switch (client) {
case 'node-postgres': {
Expand Down Expand Up @@ -320,10 +325,16 @@ export async function drizzle<
return db;
}
case 'neon-serverless': {
const { Pool } = await import('@neondatabase/serverless').catch(() => importError('@neondatabase/serverless'));
const { Pool, neonConfig } = await import('@neondatabase/serverless').catch(() =>
importError('@neondatabase/serverless')
);
const { drizzle } = await import('./neon-serverless');
const instance = new Pool(connection as NeonServerlessConfig);

if (ws) {
neonConfig.webSocketConstructor = ws;
}

const db = drizzle(instance, drizzleConfig) as any;
db.$client = instance;

Expand Down
2 changes: 1 addition & 1 deletion drizzle-orm/src/mysql-core/db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ export class MySqlDatabase<
source: MySqlTable | MySqlViewBase | SQL | SQLWrapper,
filters?: SQL<unknown>,
) {
return new MySqlCountBuilder({ source, filters, dialect: this.dialect, session: this.session });
return new MySqlCountBuilder({ source, filters, session: this.session });
}

/**
Expand Down
10 changes: 4 additions & 6 deletions drizzle-orm/src/mysql-core/query-builders/count.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { entityKind, sql } from '~/index.ts';
import type { SQLWrapper } from '~/sql/sql.ts';
import { SQL } from '~/sql/sql.ts';
import type { MySqlDialect } from '../dialect.ts';
import type { MySqlSession } from '../session.ts';
import type { MySqlTable } from '../table.ts';
import type { MySqlViewBase } from '../view-base.ts';
Expand All @@ -27,19 +26,20 @@ export class MySqlCountBuilder<
source: MySqlTable | MySqlViewBase | SQL | SQLWrapper,
filters?: SQL<unknown>,
): SQL<number> {
return sql<number>`select count(*) from ${source}${sql.raw(' where ').if(filters)}${filters}`;
return sql<number>`select count(*) as count from ${source}${sql.raw(' where ').if(filters)}${filters}`;
}

constructor(
readonly params: {
source: MySqlTable | MySqlViewBase | SQL | SQLWrapper;
filters?: SQL<unknown>;
dialect: MySqlDialect;
session: TSession;
},
) {
super(MySqlCountBuilder.buildEmbeddedCount(params.source, params.filters).queryChunks);

this.mapWith(Number);

this.session = params.session;

this.sql = MySqlCountBuilder.buildCount(
Expand All @@ -52,9 +52,7 @@ export class MySqlCountBuilder<
onfulfilled?: ((value: number) => TResult1 | PromiseLike<TResult1>) | null | undefined,
onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined,
): Promise<TResult1 | TResult2> {
return Promise.resolve(this.session.all(this.sql)).then<number>((it) => {
return (<[{ 'count(*)': number }]> it)[0]['count(*)'];
})
return Promise.resolve(this.session.count(this.sql))
.then(
onfulfilled,
onrejected,
Expand Down
8 changes: 8 additions & 0 deletions drizzle-orm/src/mysql-core/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ export abstract class MySqlSession<

abstract all<T = unknown>(query: SQL): Promise<T[]>;

async count(sql: SQL): Promise<number> {
const res = await this.execute<[[{ count: string }]]>(sql);

return Number(
res[0][0]['count'],
);
}

abstract transaction<T>(
transaction: (tx: MySqlTransaction<TQueryResult, TPreparedQueryHKT, TFullSchema, TSchema>) => Promise<T>,
config?: MySqlTransactionConfig,
Expand Down
10 changes: 9 additions & 1 deletion drizzle-orm/src/neon-http/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import type { PgQueryResultHKT, PgTransactionConfig, PreparedQueryConfig } from
import { PgPreparedQuery as PgPreparedQuery, PgSession } from '~/pg-core/session.ts';
import type { RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts';
import type { PreparedQuery } from '~/session.ts';
import { fillPlaceholders, type Query } from '~/sql/sql.ts';
import { fillPlaceholders, type Query, type SQL } from '~/sql/sql.ts';
import { mapResultRow } from '~/utils.ts';

export type NeonHttpClient = NeonQueryFunction<any, any>;
Expand Down Expand Up @@ -161,6 +161,14 @@ export class NeonHttpSession<
return this.client(query, params, { arrayMode: false, fullResults: true });
}

override async count(sql: SQL): Promise<number> {
const res = await this.execute<{ rows: [{ count: string }] }>(sql);

return Number(
res['rows'][0]['count'],
);
}

override async transaction<T>(
_transaction: (tx: NeonTransaction<TFullSchema, TSchema>) => Promise<T>,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
Expand Down
9 changes: 8 additions & 1 deletion drizzle-orm/src/node-postgres/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import type { SelectedFieldsOrdered } from '~/pg-core/query-builders/select.type
import type { PgQueryResultHKT, PgTransactionConfig, PreparedQueryConfig } from '~/pg-core/session.ts';
import { PgPreparedQuery, PgSession } from '~/pg-core/session.ts';
import type { RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts';
import { fillPlaceholders, type Query, sql } from '~/sql/sql.ts';
import { fillPlaceholders, type Query, type SQL, sql } from '~/sql/sql.ts';
import { tracer } from '~/tracing.ts';
import { type Assume, mapResultRow } from '~/utils.ts';

Expand Down Expand Up @@ -164,6 +164,13 @@ export class NodePgSession<
}
}
}

override async count(sql: SQL): Promise<number> {
const res = await this.execute<{ rows: [{ count: string }] }>(sql);
return Number(
res['rows'][0]['count'],
);
}
}

export class NodePgTransaction<
Expand Down
2 changes: 1 addition & 1 deletion drizzle-orm/src/pg-core/db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ export class PgDatabase<
source: PgTable | PgViewBase | SQL | SQLWrapper,
filters?: SQL<unknown>,
) {
return new PgCountBuilder({ source, filters, dialect: this.dialect, session: this.session });
return new PgCountBuilder({ source, filters, session: this.session });
}

/**
Expand Down
12 changes: 5 additions & 7 deletions drizzle-orm/src/pg-core/query-builders/count.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { entityKind, sql } from '~/index.ts';
import type { SQLWrapper } from '~/sql/sql.ts';
import { SQL } from '~/sql/sql.ts';
import type { PgDialect } from '../dialect.ts';
import type { PgSession } from '../session.ts';
import type { PgTable } from '../table.ts';

Expand All @@ -19,26 +18,27 @@ export class PgCountBuilder<
source: PgTable | SQL | SQLWrapper,
filters?: SQL<unknown>,
): SQL<number> {
return sql<number>`(select count(*)::int from ${source}${sql.raw(' where ').if(filters)}${filters})`;
return sql<number>`(select count(*) from ${source}${sql.raw(' where ').if(filters)}${filters})`;
}

private static buildCount(
source: PgTable | SQL | SQLWrapper,
filters?: SQL<unknown>,
): SQL<number> {
return sql<number>`select count(*)::int from ${source}${sql.raw(' where ').if(filters)}${filters};`;
return sql<number>`select count(*) as count from ${source}${sql.raw(' where ').if(filters)}${filters};`;
}

constructor(
readonly params: {
source: PgTable | SQL | SQLWrapper;
filters?: SQL<unknown>;
dialect: PgDialect;
session: TSession;
},
) {
super(PgCountBuilder.buildEmbeddedCount(params.source, params.filters).queryChunks);

this.mapWith(Number);

this.session = params.session;

this.sql = PgCountBuilder.buildCount(
Expand All @@ -51,9 +51,7 @@ export class PgCountBuilder<
onfulfilled?: ((value: number) => TResult1 | PromiseLike<TResult1>) | null | undefined,
onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined,
): Promise<TResult1 | TResult2> {
return Promise.resolve(this.session.all(this.sql)).then<number>((it) => {
return (<[{ count: number }]> it)[0]['count'] as number;
})
return Promise.resolve(this.session.count(this.sql))
.then(
onfulfilled,
onrejected,
Expand Down
8 changes: 8 additions & 0 deletions drizzle-orm/src/pg-core/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ export abstract class PgSession<
).all();
}

async count(sql: SQL): Promise<number> {
const res = await this.execute<[{ count: string }]>(sql);

return Number(
res[0]['count'],
);
}

abstract transaction<T>(
transaction: (tx: PgTransaction<TQueryResult, TFullSchema, TSchema>) => Promise<T>,
config?: PgTransactionConfig,
Expand Down
3 changes: 2 additions & 1 deletion drizzle-orm/src/pg-proxy/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ export class PreparedQuery<T extends PreparedQueryConfig> extends PreparedQueryB
});
}

async all() {}
async all() {
}

/** @internal */
isResponseInArrayMode(): boolean {
Expand Down
9 changes: 8 additions & 1 deletion drizzle-orm/src/pglite/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import type { SelectedFieldsOrdered } from '~/pg-core/query-builders/select.type
import type { PgQueryResultHKT, PgTransactionConfig, PreparedQueryConfig } from '~/pg-core/session.ts';
import { PgPreparedQuery, PgSession } from '~/pg-core/session.ts';
import type { RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts';
import { fillPlaceholders, type Query, sql } from '~/sql/sql.ts';
import { fillPlaceholders, type Query, type SQL, sql } from '~/sql/sql.ts';
import { type Assume, mapResultRow } from '~/utils.ts';

import { types } from '@electric-sql/pglite';
Expand Down Expand Up @@ -140,6 +140,13 @@ export class PgliteSession<
return transaction(tx);
}) as Promise<T>;
}

override async count(sql: SQL): Promise<number> {
const res = await this.execute<{ rows: [{ count: string }] }>(sql);
return Number(
res['rows'][0]['count'],
);
}
}

export class PgliteTransaction<
Expand Down
8 changes: 8 additions & 0 deletions drizzle-orm/src/planetscale-serverless/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ export class PlanetscaleSession<
) => eQuery.rows as T[]);
}

override async count(sql: SQL): Promise<number> {
const res = await this.execute<{ rows: [{ count: string }] }>(sql);

return Number(
res['rows'][0]['count'],
);
}

override transaction<T>(
transaction: (tx: PlanetScaleTransaction<TFullSchema, TSchema>) => Promise<T>,
): Promise<T> {
Expand Down
2 changes: 1 addition & 1 deletion drizzle-orm/src/sqlite-core/db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ export class BaseSQLiteDatabase<
source: SQLiteTable | SQLiteViewBase | SQL | SQLWrapper,
filters?: SQL<unknown>,
) {
return new SQLiteCountBuilder({ source, filters, dialect: this.dialect, session: this.session });
return new SQLiteCountBuilder({ source, filters, session: this.session });
}

/**
Expand Down
4 changes: 1 addition & 3 deletions drizzle-orm/src/sqlite-core/query-builders/count.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { entityKind, sql } from '~/index.ts';
import type { SQLWrapper } from '~/sql/sql.ts';
import { SQL } from '~/sql/sql.ts';
import type { SQLiteDialect } from '../dialect.ts';
import type { SQLiteSession } from '../session.ts';
import type { SQLiteTable } from '../table.ts';
import type { SQLiteView } from '../view.ts';
Expand Down Expand Up @@ -34,7 +33,6 @@ export class SQLiteCountBuilder<
readonly params: {
source: SQLiteTable | SQLiteView | SQL | SQLWrapper;
filters?: SQL<unknown>;
dialect: SQLiteDialect;
session: TSession;
},
) {
Expand All @@ -52,7 +50,7 @@ export class SQLiteCountBuilder<
onfulfilled?: ((value: number) => TResult1 | PromiseLike<TResult1>) | null | undefined,
onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined,
): Promise<TResult1 | TResult2> {
return Promise.resolve(this.session.values(this.sql)).then<number>((it) => it[0]![0] as number).then(
return Promise.resolve(this.session.count(this.sql)).then(
onfulfilled,
onrejected,
);
Expand Down
6 changes: 6 additions & 0 deletions drizzle-orm/src/sqlite-core/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ export abstract class SQLiteSession<
>;
}

async count(sql: SQL) {
const result = await this.values(sql) as [[number]];

return result[0][0];
}

/** @internal */
extractRawValuesValueFromBatchResult(_result: unknown): unknown {
throw new Error('Not implemented');
Expand Down
8 changes: 8 additions & 0 deletions drizzle-orm/src/tidb-serverless/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ export class TiDBServerlessSession<
return this.client.execute(querySql.sql, querySql.params) as Promise<T[]>;
}

override async count(sql: SQL): Promise<number> {
const res = await this.execute<{ rows: [{ count: string }] }>(sql);

return Number(
res['rows'][0]['count'],
);
}

override async transaction<T>(
transaction: (tx: TiDBServerlessTransaction<TFullSchema, TSchema>) => Promise<T>,
): Promise<T> {
Expand Down
Loading

0 comments on commit efd821d

Please sign in to comment.