Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for 'array' datatype #224

Merged
merged 3 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions polars/dataframe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1292,13 +1292,7 @@ export interface DataFrame
* ```
*/
shiftAndFill(n: number, fillValue: number): DataFrame;
shiftAndFill({
n,
fillValue,
}: {
n: number;
fillValue: number;
}): DataFrame;
shiftAndFill({ n, fillValue }: { n: number; fillValue: number }): DataFrame;
/**
* Shrink memory usage of this DataFrame to fit the exact capacity needed to hold the data.
*/
Expand Down
213 changes: 126 additions & 87 deletions polars/datatypes/datatype.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Field } from "./field";

export abstract class DataType {
get variant() {
return this.constructor.name.slice(1);
return this.constructor.name;
}
protected identity = "DataType";
protected get inner(): null | any[] {
Expand All @@ -18,67 +18,67 @@ export abstract class DataType {

/** Null type */
public static get Null(): DataType {
return new _Null();
return new Null();
}
/** `true` and `false`. */
public static get Bool(): DataType {
return new _Bool();
return new Bool();
}
/** An `i8` */
public static get Int8(): DataType {
return new _Int8();
return new Int8();
}
/** An `i16` */
public static get Int16(): DataType {
return new _Int16();
return new Int16();
}
/** An `i32` */
public static get Int32(): DataType {
return new _Int32();
return new Int32();
}
/** An `i64` */
public static get Int64(): DataType {
return new _Int64();
return new Int64();
}
/** An `u8` */
public static get UInt8(): DataType {
return new _UInt8();
return new UInt8();
}
/** An `u16` */
public static get UInt16(): DataType {
return new _UInt16();
return new UInt16();
}
/** An `u32` */
public static get UInt32(): DataType {
return new _UInt32();
return new UInt32();
}
/** An `u64` */
public static get UInt64(): DataType {
return new _UInt64();
return new UInt64();
}

/** A `f32` */
public static get Float32(): DataType {
return new _Float32();
return new Float32();
}
/** A `f64` */
public static get Float64(): DataType {
return new _Float64();
return new Float64();
}
public static get Date(): DataType {
return new _Date();
return new Date();
}
/** Time of day type */
public static get Time(): DataType {
return new _Time();
return new Time();
}
/** Type for wrapping arbitrary JS objects */
public static get Object(): DataType {
return new _Object();
return new Object_();
}
/** A categorical encoding of a set of strings */
public static get Categorical(): DataType {
return new _Categorical();
return new Categorical();
}

/**
Expand All @@ -93,7 +93,7 @@ export abstract class DataType {
timeUnit,
timeZone: string | null | undefined = null,
): DataType {
return new _Datetime(timeUnit, timeZone as any);
return new Datetime(timeUnit, timeZone as any);
}
/**
* Nested list/array type
Expand All @@ -102,7 +102,15 @@ export abstract class DataType {
*
*/
public static List(inner: DataType): DataType {
return new _List(inner);
return new List(inner);
}
/**
* List of fixed length
* This is called `Array` in other polars implementations, but `Array` is widely used in JS, so we use `FixedSizeList` instead.
*
*/
public static FixedSizeList(inner: DataType, listSize: number): DataType {
return new FixedSizeList(inner, listSize);
}
/**
* Struct type
Expand All @@ -112,15 +120,15 @@ export abstract class DataType {
public static Struct(
fields: Field[] | { [key: string]: DataType },
): DataType {
return new _Struct(fields);
return new Struct(fields);
}
/** A variable-length UTF-8 encoded string whose offsets are represented as `i64`. */
public static get Utf8(): DataType {
return new _Utf8();
return new Utf8();
}

public static get String(): DataType {
return new _String();
return new String();
}

toString() {
Expand All @@ -131,7 +139,6 @@ export abstract class DataType {
}
toJSON() {
const inner = (this as any).inner;

if (inner) {
return {
[this.identity]: {
Expand All @@ -149,32 +156,40 @@ export abstract class DataType {
static from(obj): DataType {
return null as any;
}
asFixedSizeList() {
if (this instanceof FixedSizeList) {
return this;
}
return null;
}
}

class _Null extends DataType {}
class _Bool extends DataType {}
class _Int8 extends DataType {}
class _Int16 extends DataType {}
class _Int32 extends DataType {}
class _Int64 extends DataType {}
class _UInt8 extends DataType {}
class _UInt16 extends DataType {}
class _UInt32 extends DataType {}
class _UInt64 extends DataType {}
class _Float32 extends DataType {}
class _Float64 extends DataType {}
class _Date extends DataType {}
class _Time extends DataType {}
class _Object extends DataType {}
class _Utf8 extends DataType {}
class _String extends DataType {}
export class Null extends DataType {}
export class Bool extends DataType {}
export class Int8 extends DataType {}
export class Int16 extends DataType {}
export class Int32 extends DataType {}
export class Int64 extends DataType {}
export class UInt8 extends DataType {}
export class UInt16 extends DataType {}
export class UInt32 extends DataType {}
export class UInt64 extends DataType {}
export class Float32 extends DataType {}
export class Float64 extends DataType {}
// biome-ignore lint/suspicious/noShadowRestrictedNames: <explanation>
export class Date extends DataType {}
export class Time extends DataType {}
export class Object_ extends DataType {}
export class Utf8 extends DataType {}
// biome-ignore lint/suspicious/noShadowRestrictedNames: <explanation>
export class String extends DataType {}

class _Categorical extends DataType {}
export class Categorical extends DataType {}

/**
* Datetime type
*/
class _Datetime extends DataType {
export class Datetime extends DataType {
constructor(
private timeUnit: TimeUnit,
private timeZone?: string,
Expand All @@ -188,15 +203,15 @@ class _Datetime extends DataType {
override equals(other: DataType): boolean {
if (other.variant === this.variant) {
return (
this.timeUnit === (other as _Datetime).timeUnit &&
this.timeZone === (other as _Datetime).timeZone
this.timeUnit === (other as Datetime).timeUnit &&
this.timeZone === (other as Datetime).timeZone
);
}
return false;
}
}

class _List extends DataType {
export class List extends DataType {
constructor(protected __inner: DataType) {
super();
}
Expand All @@ -205,13 +220,50 @@ class _List extends DataType {
}
override equals(other: DataType): boolean {
if (other.variant === this.variant) {
return this.inner[0].equals((other as _List).inner[0]);
return this.inner[0].equals((other as List).inner[0]);
}
return false;
}
}

export class FixedSizeList extends DataType {
constructor(
protected __inner: DataType,
protected listSize: number,
) {
super();
}

override get variant() {
return "FixedSizeList";
}

override get inner(): [DataType, number] {
return [this.__inner, this.listSize];
}

override equals(other: DataType): boolean {
if (other.variant === this.variant) {
return (
this.inner[0].equals((other as FixedSizeList).inner[0]) &&
this.inner[1] === (other as FixedSizeList).inner[1]
);
}
return false;
}
override toJSON() {
return {
[this.identity]: {
[this.variant]: {
type: this.inner[0].toJSON(),
size: this.inner[1],
},
},
};
}
}

class _Struct extends DataType {
export class Struct extends DataType {
private fields: Field[];

constructor(
Expand All @@ -235,7 +287,7 @@ class _Struct extends DataType {
if (other.variant === this.variant) {
return this.inner
.map((fld, idx) => {
const otherfld = (other as _Struct).fields[idx];
const otherfld = (other as Struct).fields[idx];

return otherfld.name === fld.name && otherfld.dtype.equals(fld.dtype);
})
Expand Down Expand Up @@ -275,45 +327,28 @@ export namespace TimeUnit {
* Datatype namespace
*/
export namespace DataType {
/** Null */
export type Null = _Null;
/** Boolean */
export type Bool = _Bool;
/** Int8 */
export type Int8 = _Int8;
/** Int16 */
export type Int16 = _Int16;
/** Int32 */
export type Int32 = _Int32;
/** Int64 */
export type Int64 = _Int64;
/** UInt8 */
export type UInt8 = _UInt8;
/** UInt16 */
export type UInt16 = _UInt16;
/** UInt32 */
export type UInt32 = _UInt32;
/** UInt64 */
export type UInt64 = _UInt64;
/** Float32 */
export type Float32 = _Float32;
/** Float64 */
export type Float64 = _Float64;
/** Date dtype */
export type Date = _Date;
/** Datetime */
export type Datetime = _Datetime;
/** Utf8 */
export type Utf8 = _Utf8;
/** Utf8 */
export type String = _String;
/** Categorical */
export type Categorical = _Categorical;
/** List */
export type List = _List;
/** Struct */
export type Struct = _Struct;

export type Categorical = import(".").Categorical;
export type Int8 = import(".").Int8;
export type Int16 = import(".").Int16;
export type Int32 = import(".").Int32;
export type Int64 = import(".").Int64;
export type UInt8 = import(".").UInt8;
export type UInt16 = import(".").UInt16;
export type UInt32 = import(".").UInt32;
export type UInt64 = import(".").UInt64;
export type Float32 = import(".").Float32;
export type Float64 = import(".").Float64;
export type Bool = import(".").Bool;
export type Utf8 = import(".").Utf8;
export type String = import(".").String;
export type List = import(".").List;
export type FixedSizeList = import(".").FixedSizeList;
export type Date = import(".").Date;
export type Datetime = import(".").Datetime;
export type Time = import(".").Time;
export type Object = import(".").Object_;
export type Null = import(".").Null;
export type Struct = import(".").Struct;
/**
* deserializes a datatype from the serde output of rust polars `DataType`
* @param dtype dtype object
Expand All @@ -333,6 +368,10 @@ export namespace DataType {
inner = [deserialize(inner[0])];
}

if (variant === "FixedSizeList") {
inner = [deserialize(inner[0]), inner[1]];
}

return DataType[variant](...inner);
}
}
5 changes: 3 additions & 2 deletions polars/datatypes/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { DataType, TimeUnit } from "./datatype";
export { DataType, TimeUnit };
export * from "./datatype";
export { Field } from "./field";

import pli from "../internals/polars_internal";
// biome-ignore lint/style/useImportType: <explanation>
import { type DataType } from "./datatype";

/** @ignore */
export type TypedArray =
Expand Down
Loading