-
Notifications
You must be signed in to change notification settings - Fork 0
/
median.c
299 lines (253 loc) · 7.58 KB
/
median.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
#include <postgres.h>
#include <fmgr.h>
#include <utils/timestamp.h>
#include <catalog/pg_type.h>
#include "datum_comparator.h"
#ifdef PG_MODULE_MAGIC
PG_MODULE_MAGIC;
#endif
#define WINDOW_LEN 1000
/* aggregate median:
* median(state, anyelement) returns the median element of the given list of unsorted row values
*
* Usage:
* SELECT median(col_name) FROM table
*
*
* Description:
* Median returns the median value of the specified list of row values. The supported data types
* are : smallint (int2), integer (int4), bigint (int8), real(float4), double precision (float8),
* varchar (only on odd number of rows), and timestamp with timezone.
*
* The macro WINDOW_LEN defines the maximum number of rows on which median calculation is performed.
* If the number of input rows is greater than WINDOW_LEN, median calculation is done only on the
* last WINDOW_LEN row values.
*/
typedef struct state {
int length;
Datum values[WINDOW_LEN];
bool maxed;
} State;
PG_FUNCTION_INFO_V1( median_transfn);
PG_FUNCTION_INFO_V1( median_finalfn);
/* Returns the respective compare function for the given data type. */
static comparison_fn_t get_compare_function(Oid oid) {
comparison_fn_t cmp;
switch (oid) {
case INT2OID:
cmp = cmp_dimension_id_int16;
break;
case INT4OID:
cmp = cmp_dimension_id_int32;
break;
case INT8OID:
cmp = cmp_dimension_id_int64;
break;
case FLOAT4OID:
cmp = cmp_dimension_id_float4;
break;
case FLOAT8OID:
cmp = cmp_dimension_id_float8;
break;
case TIMESTAMPTZOID:
cmp = cmp_dimension_id_timestamptz;
break;
case TEXTOID:
cmp = cmp_dimension_id_varchar;
break;
default:
elog(ERROR, "unsupported data type for comparison");
}
return cmp;
}
/* Calculates and returns the mean of two datum values. */
static Datum get_mean(Oid oid, Datum val1, Datum val2) {
Datum res;
switch (oid) {
case INT2OID:
res = Int16GetDatum(
DatumGetInt16(val1)
+ (DatumGetInt16(val2) - DatumGetInt16(val1)) / 2);
break;
case INT4OID:
res = Int32GetDatum(
DatumGetInt32(val1)
+ (DatumGetInt32(val2) - DatumGetInt32(val1)) / 2);
break;
case INT8OID:
res = Int64GetDatum(
DatumGetInt64(val1)
+ (DatumGetInt64(val2) - DatumGetInt64(val1)) / 2);
break;
case FLOAT4OID:
res = Float4GetDatum(
DatumGetFloat4(val1)
+ (DatumGetFloat4(val2) - DatumGetFloat4(val1)) / 2.0);
break;
case FLOAT8OID:
res = Float8GetDatum(
DatumGetFloat8(val1)
+ (DatumGetFloat8(val2) - DatumGetFloat8(val1)) / 2.0);
break;
case TIMESTAMPTZOID:
res =
TimestampTzGetDatum(
DatumGetTimestampTz(val1)
+ (DatumGetTimestampTz(val2)
- DatumGetTimestampTz(val1)) / 2.0);
break;
case TEXTOID:
/*Not possible to evaluate mean of two text values*/
res = (Datum) 0;
break;
default:
elog(ERROR, "unknown input datatype");
res = (Datum) 0;
}
return res;
}
static void swap(Datum *arr, int i, int j) {
Datum temp;
temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
/*
* Performs quick sort partitioning. Treats rightmost element as the pivot and returns
* its index in a sorted array.
*/
static int partition(Datum *arr, int l, int r, comparison_fn_t cmp)
{
int i;
i = l;
for (int j = l; j <= r - 1; j++) {
if (cmp(&arr[j],&arr[r]) <= 0) {
swap(arr, i, j);
i++;
}
}
swap(arr, i, r);
return i;
}
/**
* Implements 'quickselect' algorithm.
* Returns the kth smallest element in the given array with average case O(n) time complexity.
* In comparison, quicksort's average case complexity is O(nlogn).
*/
static Datum get_k_smallest(Datum *arr, int l, int r, int k, comparison_fn_t cmp)
{
int index;
if (k > 0 && k <= r - l + 1) {
index = partition(arr, l, r, cmp);
// If kth position, return element
if (index - l == k - 1){
return arr[index];
}
// If position is more, look in the left partition
if (index - l > k - 1){
return get_k_smallest(arr, l, index - 1, k, cmp);
}
// Else look in the right partition
return get_k_smallest(arr, index + 1, r, k - index + l - 1, cmp);
}
//k should not be less than zero or greater the array size
return (Datum)0;
}
/*
* Median state transfer function.
*
* This function is called for every value in the set that we are calculating
* the median for. On first call, the aggregate state, if any, needs to be
* initialized.
*/
Datum median_transfn( PG_FUNCTION_ARGS) {
MemoryContext agg_context;
bytea *pg_state = (PG_ARGISNULL(0) ? NULL : PG_GETARG_BYTEA_P(0));
Oid element_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
Datum element;
State *state;
if (!AggCheckCallContext(fcinfo, &agg_context))
elog(ERROR, "median_transfn called in non-aggregate context");
if (!OidIsValid(element_type))
elog(ERROR, "could not determine data type of input");
if (PG_ARGISNULL(1))
PG_RETURN_BYTEA_P(pg_state); //ignore null values, and return current state without change
else
element = PG_GETARG_DATUM(1);
if (!pg_state) {
/* Allocate memory for the state and add the first node*/
int size = sizeof(State);
pg_state = (bytea*) palloc(VARHDRSZ + size);
SET_VARSIZE(pg_state, size);
state = (State*) VARDATA(pg_state);
state->length = 1;
state->values[0] = element;
state->maxed = false;
} else {
state = (State*) VARDATA(pg_state);
/* Add the new datum node to the list*/
state->values[state->length] = element;
state->length += 1;
if(state->length == WINDOW_LEN){
state->maxed = true;
/*roll over, start over-writing oldest elements in the window array*/
state->length = 0;
}
}
PG_RETURN_BYTEA_P(pg_state);
}
/*
* Median final function.
*
* This function is called after all values in the median set has been
* processed by the state transfer function. It should perform any necessary
* post processing and clean up any temporary state.
*/
Datum median_finalfn( PG_FUNCTION_ARGS) {
MemoryContext agg_context;
bytea *pg_state = (PG_ARGISNULL(0) ? NULL : PG_GETARG_BYTEA_P(0));
Oid element_type = get_fn_expr_argtype(fcinfo->flinfo, 1);
Datum ret;
State *state;
comparison_fn_t cmp;
int mid;
if (pg_state == NULL)
PG_RETURN_NULL();
if (!AggCheckCallContext(fcinfo, &agg_context))
elog(ERROR, "median_finalfn called in non-aggregate context");
if (!OidIsValid(element_type))
elog(ERROR, "could not determine data type of input");
if (element_type != INT2OID && element_type != INT4OID
&& element_type != INT8OID && element_type != FLOAT4OID
&& element_type != FLOAT8OID && element_type != TIMESTAMPTZOID
&& element_type != TEXTOID) {
elog(ERROR, "input data type not supported for median determination");
PG_RETURN_NULL();
}
state = (State*) VARDATA(pg_state);
/* if there were more rows than WINDOW_LEN, then perform median
* on all WINDOW_LEN elements*/
if(state->maxed)
state->length = WINDOW_LEN;
/*return null if no elements in the list*/
if (state->length == 0)
PG_RETURN_NULL();
if (element_type == TEXTOID && state->length % 2 == 0) {
elog(ERROR, "median for even number of text inputs not supported");
PG_RETURN_NULL();
}
/* find compare function for the input data type */
cmp = get_compare_function(element_type);
mid = state->length / 2;
/*get the kth smallest element without actually sorting the array*/
ret = get_k_smallest(state->values, 0, state->length -1, mid+1, cmp);
if(state->length % 2 != 0){
//odd entries,return the median
} else{
Datum median2;
//get (mid-1)th smallest element
median2 = get_k_smallest(state->values, 0, state->length -1, mid, cmp);
ret = get_mean(element_type, ret, median2);
}
PG_RETURN_DATUM(ret);
}