Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(D3 plugin): improve categories using #273

Merged
merged 3 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/plugins/d3/__stories__/bar-x/category.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ const Template: Story = () => {
visible: true,
data: [
{
category: 'A',
label: 10,
x: 'A',
y: 100,
},
{
category: 'B',
label: 12,
x: 'B',
y: 80,
},
],
Expand All @@ -39,8 +39,7 @@ const Template: Story = () => {
visible: true,
data: [
{
category: 'C',
x: 95.5,
x: 'C',
y: 120,
},
],
Expand Down
25 changes: 18 additions & 7 deletions src/plugins/d3/__stories__/scatter/LinearCategories.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,23 @@ const shapeScatterSeriesData = (args: {data: Record<string, any>[]; groupBy: str
acc[seriesName] = [];
}

acc[seriesName].push({
x: d[map.x],
y: d[map.y],
radius: random(3, 6),
...(map.category && {category: d[map.category]}),
});
const categoriesType = map.categoriesType as 'x' | 'y' | 'none' | undefined;
const isCategorical = categoriesType === 'x' || categoriesType === 'y';

if (isCategorical && map.category) {
acc[seriesName].push({
x: d[map.x],
y: d[map.y],
radius: random(3, 6),
[map.categoriesType]: d[map.category],
});
} else if (!isCategorical) {
acc[seriesName].push({
x: d[map.x],
y: d[map.y],
radius: random(3, 6),
});
}

return acc;
}, {});
Expand Down Expand Up @@ -133,7 +144,7 @@ const Template: Story = () => {
const shapedScatterSeriesData = shapeScatterSeriesData({
data: penguins,
groupBy,
map: {x, y, category},
map: {x, y, category, categoriesType},
});
const shapedScatterSeries = shapeScatterSeries(shapedScatterSeriesData);
const data = shapeScatterChartData(shapedScatterSeries, categoriesType, categories);
Expand Down
36 changes: 25 additions & 11 deletions src/plugins/d3/renderer/components/Tooltip/DefaultContent.tsx
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
import React from 'react';
import get from 'lodash/get';

import type {
ScatterSeriesData,
BarXSeriesData,
TooltipHoveredData,
} from '../../../../../types/widget-data';
import type {ChartKitWidgetSeriesData, TooltipHoveredData} from '../../../../../types/widget-data';

import type {PreparedAxis} from '../../hooks';
import {getDataCategoryValue} from '../../utils';

type Props = {
hovered: TooltipHoveredData;
xAxis: PreparedAxis;
yAxis: PreparedAxis;
};

const getXRowData = (xAxis: PreparedAxis, data: ChartKitWidgetSeriesData) => {
const categories = get(xAxis, 'categories', [] as string[]);

return xAxis.type === 'category'
? getDataCategoryValue({axisDirection: 'x', categories, data})
: (data as {x: number}).x;
};

const getYRowData = (yAxis: PreparedAxis, data: ChartKitWidgetSeriesData) => {
const categories = get(yAxis, 'categories', [] as string[]);

return yAxis.type === 'category'
? getDataCategoryValue({axisDirection: 'y', categories, data})
: (data as {y: number}).y;
};

export const DefaultContent = ({hovered, xAxis, yAxis}: Props) => {
const {data, series} = hovered;

switch (series.type) {
case 'scatter': {
const scatterData = data as ScatterSeriesData;
const xRow = xAxis.type === 'category' ? scatterData.category : scatterData.x;
const yRow = yAxis.type === 'category' ? scatterData.category : scatterData.y;
const xRow = getXRowData(xAxis, data);
const yRow = getYRowData(yAxis, data);

return (
<div>
<div>
Expand All @@ -36,9 +50,9 @@ export const DefaultContent = ({hovered, xAxis, yAxis}: Props) => {
);
}
case 'bar-x': {
const barXData = data as BarXSeriesData;
const xRow = xAxis.type === 'category' ? barXData.category : barXData.x;
const yRow = yAxis.type === 'category' ? barXData.category : barXData.y;
const xRow = getXRowData(xAxis, data);
const yRow = getYRowData(yAxis, data);

return (
<div>
<div>{xRow}</div>
Expand Down
38 changes: 27 additions & 11 deletions src/plugins/d3/renderer/hooks/useAxisScales/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import get from 'lodash/get';
import type {ChartOptions} from '../useChartOptions/types';
import {
getOnlyVisibleSeries,
getDataCategoryValue,
getDomainDataYBySeries,
isAxisRelatedSeries,
getDomainDataXBySeries,
isAxisRelatedSeries,
isSeriesWithCategoryValues,
} from '../../utils';
import type {AxisDirection} from '../../utils';
import {PreparedSeries} from '../useSeries/types';

export type ChartScale =
Expand All @@ -35,10 +37,22 @@ const isNumericalArrayData = (data: unknown[]): data is number[] => {
return data.every((d) => typeof d === 'number' || d === null);
};

const filterCategoriesByVisibleSeries = (categories: string[], series: PreparedSeries[]) => {
const filterCategoriesByVisibleSeries = (args: {
axisDirection: AxisDirection;
categories: string[];
series: PreparedSeries[];
}) => {
const {axisDirection, categories, series} = args;

return categories.filter((category) => {
return series.some((s) => {
return isSeriesWithCategoryValues(s) && s.data.some((d) => d.category === category);
return (
isSeriesWithCategoryValues(s) &&
s.data.some((d) => {
const dataCategory = getDataCategoryValue({axisDirection, categories, data: d});
return dataCategory === category;
})
);
});
});
};
Expand Down Expand Up @@ -75,10 +89,11 @@ const createScales = (args: Args) => {
}
case 'category': {
if (xCategories) {
const filteredCategories = filterCategoriesByVisibleSeries(
xCategories,
visibleSeries,
);
const filteredCategories = filterCategoriesByVisibleSeries({
axisDirection: 'x',
categories: xCategories,
series: visibleSeries,
});
xScale = scaleBand().domain(filteredCategories).range([0, boundsWidth]);
}

Expand Down Expand Up @@ -122,10 +137,11 @@ const createScales = (args: Args) => {
}
case 'category': {
if (yCategories) {
const filteredCategories = filterCategoriesByVisibleSeries(
yCategories,
visibleSeries,
);
const filteredCategories = filterCategoriesByVisibleSeries({
axisDirection: 'y',
categories: yCategories,
series: visibleSeries,
});
yScale = scaleBand().domain(filteredCategories).range([boundsHeight, 0]);
}

Expand Down
21 changes: 14 additions & 7 deletions src/plugins/d3/renderer/hooks/useShapes/bar-x.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import React from 'react';
import {ChartOptions} from '../useChartOptions/types';
import {ChartScale} from '../useAxisScales';
import {OnSeriesMouseLeave, OnSeriesMouseMove} from '../useTooltip/types';
import {BarXSeriesData} from '../../../../../types/widget-data';
import {group, pointer, select} from 'd3';
import type {ScaleBand, ScaleLinear, ScaleTime} from 'd3';
import get from 'lodash/get';

import type {BarXSeriesData} from '../../../../../types/widget-data';
import {block} from '../../../../../utils/cn';
import {group, pointer, ScaleBand, ScaleLinear, ScaleTime, select} from 'd3';
import {PreparedBarXSeries} from '../useSeries/types';

import {getDataCategoryValue} from '../../utils';
import type {ChartScale} from '../useAxisScales';
import type {ChartOptions} from '../useChartOptions/types';
import type {OnSeriesMouseLeave, OnSeriesMouseMove} from '../useTooltip/types';
import type {PreparedBarXSeries} from '../useSeries/types';

const DEFAULT_BAR_RECT_WIDTH = 50;
const DEFAULT_LINEAR_BAR_RECT_WIDTH = 20;
Expand Down Expand Up @@ -44,8 +49,10 @@ const getRectProperties = (args: {
if (xAxis.type === 'category') {
const xBandScale = xScale as ScaleBand<string>;
const maxWidth = xBandScale.bandwidth() - MIN_RECT_GAP;
const categories = get(xAxis, 'categories', [] as string[]);
const dataCategory = getDataCategoryValue({axisDirection: 'x', categories, data: point});
width = Math.min(maxWidth, DEFAULT_BAR_RECT_WIDTH);
cx = (xBandScale(point.category as string) || 0) + xBandScale.step() / 2 - width / 2;
cx = (xBandScale(dataCategory) || 0) + xBandScale.step() / 2 - width / 2;
} else {
const xLinearScale = xScale as ScaleLinear<number, number> | ScaleTime<number, number>;
const [min, max] = xLinearScale.domain();
Expand Down
48 changes: 22 additions & 26 deletions src/plugins/d3/renderer/hooks/useShapes/scatter.tsx
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import React from 'react';
import {pointer, select} from 'd3';
import type {ScaleBand, ScaleLinear, ScaleTime} from 'd3';
import React from 'react';
import {ChartOptions} from '../useChartOptions/types';
import {ChartScale} from '../useAxisScales';
import {OnSeriesMouseLeave, OnSeriesMouseMove} from '../useTooltip/types';
import {ScatterSeries, ScatterSeriesData} from '../../../../../types/widget-data';
import get from 'lodash/get';

import type {ScatterSeries, ScatterSeriesData} from '../../../../../types/widget-data';
import {block} from '../../../../../utils/cn';

import {getDataCategoryValue} from '../../utils';
import type {ChartScale} from '../useAxisScales';
import type {PreparedAxis} from '../useChartOptions/types';
import type {OnSeriesMouseLeave, OnSeriesMouseMove} from '../useTooltip/types';

type ScatterSeriesShapeProps = {
top: number;
left: number;
series: ScatterSeries;
xAxis: ChartOptions['xAxis'];
xAxis: PreparedAxis;
xScale: ChartScale;
yAxis: ChartOptions['yAxis'];
yAxis: PreparedAxis[];
yScale: ChartScale;
svgContainer: SVGSVGElement | null;
onSeriesMouseMove?: OnSeriesMouseMove;
Expand All @@ -23,26 +27,20 @@ type ScatterSeriesShapeProps = {
const b = block('d3-scatter');
const DEFAULT_SCATTER_POINT_RADIUS = 4;

const prepareCategoricalScatterData = (data: ScatterSeriesData[]) => {
return data.filter((d) => typeof d.category === 'string');
};

const prepareLinearScatterData = (data: ScatterSeriesData[]) => {
return data.filter((d) => typeof d.x === 'number' && typeof d.y === 'number');
};

const getCxAttr = (args: {
point: ScatterSeriesData;
xAxis: ChartOptions['xAxis'];
xScale: ChartScale;
}) => {
const getCxAttr = (args: {point: ScatterSeriesData; xAxis: PreparedAxis; xScale: ChartScale}) => {
const {point, xAxis, xScale} = args;

let cx: number;

if (xAxis.type === 'category') {
const xBandScale = xScale as ScaleBand<string>;
cx = (xBandScale(point.category as string) || 0) + xBandScale.step() / 2;
const categories = get(xAxis, 'categories', [] as string[]);
const dataCategory = getDataCategoryValue({axisDirection: 'x', categories, data: point});
cx = (xBandScale(dataCategory) || 0) + xBandScale.step() / 2;
} else {
const xLinearScale = xScale as ScaleLinear<number, number> | ScaleTime<number, number>;
cx = xLinearScale(point.x as number);
Expand All @@ -51,18 +49,16 @@ const getCxAttr = (args: {
return cx;
};

const getCyAttr = (args: {
point: ScatterSeriesData;
yAxis: ChartOptions['yAxis'];
yScale: ChartScale;
}) => {
const getCyAttr = (args: {point: ScatterSeriesData; yAxis: PreparedAxis; yScale: ChartScale}) => {
const {point, yAxis, yScale} = args;

let cy: number;

if (yAxis[0].type === 'category') {
if (yAxis.type === 'category') {
const yBandScale = yScale as ScaleBand<string>;
cy = (yBandScale(point.category as string) || 0) + yBandScale.step() / 2;
const categories = get(yAxis, 'categories', [] as string[]);
const dataCategory = getDataCategoryValue({axisDirection: 'y', categories, data: point});
cy = (yBandScale(dataCategory) || 0) + yBandScale.step() / 2;
} else {
const yLinearScale = yScale as ScaleLinear<number, number> | ScaleTime<number, number>;
cy = yLinearScale(point.y as number);
Expand Down Expand Up @@ -95,7 +91,7 @@ export function ScatterSeriesShape(props: ScatterSeriesShapeProps) {
svgElement.selectAll('*').remove();
const preparedData =
xAxis.type === 'category' || yAxis[0]?.type === 'category'
? prepareCategoricalScatterData(series.data)
? series.data
: prepareLinearScatterData(series.data);

svgElement
Expand All @@ -107,7 +103,7 @@ export function ScatterSeriesShape(props: ScatterSeriesShapeProps) {
.attr('fill', (d) => d.color || series.color || '')
.attr('r', (d) => d.radius || DEFAULT_SCATTER_POINT_RADIUS)
.attr('cx', (d) => getCxAttr({point: d, xAxis, xScale}))
.attr('cy', (d) => getCyAttr({point: d, yAxis, yScale}))
.attr('cy', (d) => getCyAttr({point: d, yAxis: yAxis[0], yScale}))
.on('mousemove', (e, d) => {
const [x, y] = pointer(e, svgContainer);
onSeriesMouseMove?.({
Expand Down
Loading