Skip to content

Commit

Permalink
Add additional generic types to DataFrame methods (#302)
Browse files Browse the repository at this point in the history
Adding generic types to a few more methods beyond what was added in #293
by @scarf005

Focusing mostly on adding identity types to methods which I believe
don’t change the original type of the dataframe. I added “identity” type
signatures to the following methods:
> extend, fillNull, filter, interpolate, limit, max, mean, median, min,
quantile, rechunk, shiftAndFill, shrinkToFit, slice, sort, std, sum,
tail, unique, var, vstack, where, upsample

These previously returned `DataFrame<any>`, even when called on a
well-typed DataFrame, but now return `DataFrame<T>` (the original type)

---

I also added better types for a few slightly more complex ones:
- map
- improved return type based on the function passed, but unimproved
parameter type
- nullCount
- toRecords
- toSeries
- for now, returning a broad union type, rather than identifying the
specific column by index
- withColumn

---

Along the way, I added minor fixes for the types of:
1. `pl.intRange`
[[1]](890bf21)
which had overloads in the wrong order leading to incorrect return
types, and
2. the `pl.Series(name, values, dtype)` constructor
[[2]](a2635bd),
whose strongly-typed overload was failing to apply in simple cases like
`pl.Series("index", [0, 1, 2, 3, 4], pl.Int64)` when the input array
used `number`s instead of `BigInt`s
  • Loading branch information
controversial authored Dec 17, 2024
1 parent 62a70dc commit 8816b46
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 65 deletions.
22 changes: 12 additions & 10 deletions __tests__/expr.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -525,20 +525,22 @@ describe("expr", () => {
a: [1, 2, 3, 3, 3],
b: ["a", "a", "b", "a", "a"],
});
let actual = df.select(pl.len());
let expected = pl.DataFrame({ len: [5] });
const actual = df.select(pl.len());
const expected = pl.DataFrame({ len: [5] });
expect(actual).toFrameEqual(expected);

actual = df.withColumn(pl.len());
expected = df.withColumn(pl.lit(5).alias("len"));
expect(actual).toFrameEqual(expected);
const actual2 = df.withColumn(pl.len());
const expected2 = df.withColumn(pl.lit(5).alias("len"));
expect(actual2).toFrameEqual(expected2);

actual = df.withColumn(pl.intRange(pl.len()).alias("index"));
expected = df.withColumn(pl.Series("index", [0, 1, 2, 3, 4], pl.Int64));
expect(actual).toFrameEqual(expected);
const actual3 = df.withColumn(pl.intRange(pl.len()).alias("index"));
const expected3 = df.withColumn(
pl.Series("index", [0, 1, 2, 3, 4], pl.Int64),
);
expect(actual3).toFrameEqual(expected3);

actual = df.groupBy("b").agg(pl.len());
expect(actual.shape).toEqual({ height: 2, width: 2 });
const actual4 = df.groupBy("b").agg(pl.len());
expect(actual4.shape).toEqual({ height: 2, width: 2 });
});
test("list", () => {
const df = pl.DataFrame({
Expand Down
92 changes: 52 additions & 40 deletions polars/dataframe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* @param other DataFrame to vertically add.
*/
extend(other: DataFrame): DataFrame;
extend(other: DataFrame<T>): DataFrame<T>;
/**
* Fill null/missing values by a filling strategy
*
Expand All @@ -480,7 +480,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* - "one"
* @returns DataFrame with None replaced with the filling strategy.
*/
fillNull(strategy: FillNullStrategy): DataFrame;
fillNull(strategy: FillNullStrategy): DataFrame<T>;
/**
* Filter the rows in the DataFrame based on a predicate expression.
* ___
Expand Down Expand Up @@ -519,7 +519,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* └─────┴─────┴─────┘
* ```
*/
filter(predicate: any): DataFrame;
filter(predicate: any): DataFrame<T>;
/**
* Find the index of a column by name.
* ___
Expand Down Expand Up @@ -764,7 +764,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
/**
* Interpolate intermediate values. The interpolation method is linear.
*/
interpolate(): DataFrame;
interpolate(): DataFrame<T>;
/**
* Get a mask of all duplicated rows in this DataFrame.
*/
Expand Down Expand Up @@ -937,8 +937,11 @@ export interface DataFrame<T extends Record<string, Series> = any>
* Get first N rows as DataFrame.
* @see {@link head}
*/
limit(length?: number): DataFrame;
map(func: (...args: any[]) => any): any[];
limit(length?: number): DataFrame<T>;
map<ReturnT>(
// TODO: strong types for the mapping function
func: (row: any[], i: number, arr: any[][]) => ReturnT,
): ReturnT[];

/**
* Aggregate the columns of this DataFrame to their maximum value.
Expand All @@ -962,8 +965,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
max(): DataFrame;
max(axis: 0): DataFrame;
max(): DataFrame<T>;
max(axis: 0): DataFrame<T>;
max(axis: 1): Series;
/**
* Aggregate the columns of this DataFrame to their mean value.
Expand All @@ -972,8 +975,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
* @param axis - either 0 or 1
* @param nullStrategy - this argument is only used if axis == 1
*/
mean(): DataFrame;
mean(axis: 0): DataFrame;
mean(): DataFrame<T>;
mean(axis: 0): DataFrame<T>;
mean(axis: 1): Series;
mean(axis: 1, nullStrategy?: "ignore" | "propagate"): Series;
/**
Expand All @@ -997,7 +1000,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
median(): DataFrame;
median(): DataFrame<T>;
/**
* Unpivot a DataFrame from wide to long format.
* @deprecated *since 0.13.0* use {@link unpivot}
Expand Down Expand Up @@ -1059,8 +1062,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
min(): DataFrame;
min(axis: 0): DataFrame;
min(): DataFrame<T>;
min(axis: 0): DataFrame<T>;
min(axis: 1): Series;
/**
* Get number of chunks used by the ChunkedArrays of this DataFrame.
Expand All @@ -1087,12 +1090,14 @@ export interface DataFrame<T extends Record<string, Series> = any>
* └─────┴─────┴─────┘
* ```
*/
nullCount(): DataFrame;
nullCount(): DataFrame<{
[K in keyof T]: Series<JsToDtype<number>, K & string>;
}>;
partitionBy(
cols: string | string[],
stable?: boolean,
includeKey?: boolean,
): DataFrame[];
): DataFrame<T>[];
partitionBy<T>(
cols: string | string[],
stable: boolean,
Expand Down Expand Up @@ -1210,13 +1215,13 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
quantile(quantile: number): DataFrame;
quantile(quantile: number): DataFrame<T>;
/**
* __Rechunk the data in this DataFrame to a contiguous allocation.__
*
* This will make sure all subsequent operations have optimal and predictable performance.
*/
rechunk(): DataFrame;
rechunk(): DataFrame<T>;
/**
* __Rename column names.__
* ___
Expand Down Expand Up @@ -1443,12 +1448,15 @@ export interface DataFrame<T extends Record<string, Series> = any>
* └─────┴─────┴─────┘
* ```
*/
shiftAndFill(n: number, fillValue: number): DataFrame;
shiftAndFill({ n, fillValue }: { n: number; fillValue: number }): DataFrame;
shiftAndFill(n: number, fillValue: number): DataFrame<T>;
shiftAndFill({
n,
fillValue,
}: { n: number; fillValue: number }): DataFrame<T>;
/**
* Shrink memory usage of this DataFrame to fit the exact capacity needed to hold the data.
*/
shrinkToFit(): DataFrame;
shrinkToFit(): DataFrame<T>;
shrinkToFit(inPlace: true): void;
shrinkToFit({ inPlace }: { inPlace: true }): void;
/**
Expand Down Expand Up @@ -1477,8 +1485,8 @@ export interface DataFrame<T extends Record<string, Series> = any>
* └─────┴─────┴─────┘
* ```
*/
slice({ offset, length }: { offset: number; length: number }): DataFrame;
slice(offset: number, length: number): DataFrame;
slice({ offset, length }: { offset: number; length: number }): DataFrame<T>;
slice(offset: number, length: number): DataFrame<T>;
/**
* Sort the DataFrame by column.
* ___
Expand All @@ -1493,7 +1501,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
descending?: boolean,
nullsLast?: boolean,
maintainOrder?: boolean,
): DataFrame;
): DataFrame<T>;
sort({
by,
reverse, // deprecated
Expand All @@ -1504,7 +1512,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
reverse?: boolean; // deprecated
nullsLast?: boolean;
maintainOrder?: boolean;
}): DataFrame;
}): DataFrame<T>;
sort({
by,
descending,
Expand All @@ -1514,7 +1522,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
descending?: boolean;
nullsLast?: boolean;
maintainOrder?: boolean;
}): DataFrame;
}): DataFrame<T>;
/**
* Aggregate the columns of this DataFrame to their standard deviation value.
* ___
Expand All @@ -1536,16 +1544,16 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
std(): DataFrame;
std(): DataFrame<T>;
/**
* Aggregate the columns of this DataFrame to their mean value.
* ___
*
* @param axis - either 0 or 1
* @param nullStrategy - this argument is only used if axis == 1
*/
sum(): DataFrame;
sum(axis: 0): DataFrame;
sum(): DataFrame<T>;
sum(axis: 0): DataFrame<T>;
sum(axis: 1): Series;
sum(axis: 1, nullStrategy?: "ignore" | "propagate"): Series;
/**
Expand Down Expand Up @@ -1595,7 +1603,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────────┴─────╯
* ```
*/
tail(length?: number): DataFrame;
tail(length?: number): DataFrame<T>;
/**
* @deprecated *since 0.4.0* use {@link writeCSV}
* @category Deprecated
Expand All @@ -1614,7 +1622,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ```
* @category IO
*/
toRecords(): Record<string, any>[];
toRecords(): { [K in keyof T]: DTypeToJs<T[K]["dtype"]> | null }[];

/**
* compat with `JSON.stringify`
Expand Down Expand Up @@ -1644,7 +1652,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ```
* @category IO
*/
toObject(): { [K in keyof T]: DTypeToJs<T[K]["dtype"]>[] };
toObject(): { [K in keyof T]: DTypeToJs<T[K]["dtype"] | null>[] };

/**
* @deprecated *since 0.4.0* use {@link writeIPC}
Expand All @@ -1656,7 +1664,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* @category IO Deprecated
*/
toParquet(destination?, options?);
toSeries(index?: number): Series;
toSeries(index?: number): T[keyof T];
toString(): string;
/**
* Convert a ``DataFrame`` to a ``Series`` of type ``Struct``
Expand Down Expand Up @@ -1768,12 +1776,12 @@ export interface DataFrame<T extends Record<string, Series> = any>
maintainOrder?: boolean,
subset?: ColumnSelection,
keep?: "first" | "last",
): DataFrame;
): DataFrame<T>;
unique(opts: {
maintainOrder?: boolean;
subset?: ColumnSelection;
keep?: "first" | "last";
}): DataFrame;
}): DataFrame<T>;
/**
Decompose a struct into its fields. The fields will be inserted in to the `DataFrame` on the
location of the `struct` type.
Expand Down Expand Up @@ -1833,7 +1841,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴──────╯
* ```
*/
var(): DataFrame;
var(): DataFrame<T>;
/**
* Grow this DataFrame vertically by stacking a DataFrame to it.
* @param df - DataFrame to stack.
Expand Down Expand Up @@ -1866,12 +1874,16 @@ export interface DataFrame<T extends Record<string, Series> = any>
* ╰─────┴─────┴─────╯
* ```
*/
vstack(df: DataFrame): DataFrame;
vstack(df: DataFrame<T>): DataFrame<T>;
/**
* Return a new DataFrame with the column added or replaced.
* @param column - Series, where the name of the Series refers to the column in the DataFrame.
*/
withColumn(column: Series | Expr): DataFrame;
withColumn<SeriesTypeT extends DataType, SeriesNameT extends string>(
column: Series<SeriesTypeT, SeriesNameT>,
): DataFrame<
Simplify<T & { [K in SeriesNameT]: Series<SeriesTypeT, SeriesNameT> }>
>;
withColumn(column: Series | Expr): DataFrame;
withColumns(...columns: (Expr | Series)[]): DataFrame;
/**
Expand All @@ -1896,7 +1908,7 @@ export interface DataFrame<T extends Record<string, Series> = any>
*/
withRowCount(name?: string): DataFrame;
/** @see {@link filter} */
where(predicate: any): DataFrame;
where(predicate: any): DataFrame<T>;
/**
Upsample a DataFrame at a regular frequency.
Expand Down Expand Up @@ -1972,13 +1984,13 @@ shape: (7, 3)
every: string,
by?: string | string[],
maintainOrder?: boolean,
): DataFrame;
): DataFrame<T>;
upsample(opts: {
timeColumn: string;
every: string;
by?: string | string[];
maintainOrder?: boolean;
}): DataFrame;
}): DataFrame<T>;
}

function prepareOtherArg(anyValue: any): Series {
Expand Down
14 changes: 14 additions & 0 deletions polars/datatypes/datatype.ts
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,20 @@ export type DTypeToJs<T> = T extends DataType.Decimal
: T extends DataType.Utf8
? string
: never;
// some objects can be constructed with a looser JS type than they’d return when converted back to JS
export type DTypeToJsLoose<T> = T extends DataType.Decimal
? number | bigint
: T extends DataType.Float64
? number | bigint
: T extends DataType.Int64
? number | bigint
: T extends DataType.Int32
? number | bigint
: T extends DataType.Bool
? boolean
: T extends DataType.Utf8
? string
: never;
export type DtypeToJsName<T> = T extends DataType.Decimal
? "Decimal"
: T extends DataType.Float64
Expand Down
Loading

0 comments on commit 8816b46

Please sign in to comment.