Skip to content

Commit

Permalink
feat: add create modal
Browse files Browse the repository at this point in the history
  • Loading branch information
chaxus committed Nov 5, 2023
1 parent 865f80f commit dc9837a
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
"classname",
"componentization",
"dataavailable",
"denormalise",
"Deoptimization",
"DOMAPI",
"esbuild",
"IIFE",
"keyof",
"Metiral",
"normalise",
"nums",
"Ovride",
"picocolors",
Expand Down
96 changes: 71 additions & 25 deletions packages/ml/client/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,94 @@ import { BrowserRouter } from 'react-router-dom';
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis'
import type { Point2D } from '@tensorflow/tfjs-vis';
import type { TensorContainerObject } from '@tensorflow/tfjs';
import type { Rank, Tensor, TensorContainerObject } from '@tensorflow/tfjs';

const csv = '../assets/dataset/kc_house_data.csv'

interface HouseSaleDataSet extends TensorContainerObject {
price: number,
sqft_living: number
}

const tfTensor = async () => {
const plot = (points: Point2D[], name: string) => {
tfvis.render.scatterplot(
{ name: `${name} vs House Price` },
{ values: [points], series: ["original"] },
{
xLabel: name,
yLabel: "Price"
}
)
const normalise = (tensor: Tensor) => {
const min = tensor.min()
const max = tensor.max()
return {
tensor: tensor.sub(min).div(max.sub(min)),
max,
min
}
}

const denormalise = (tensor: Tensor, max: Tensor<Rank>, min: Tensor<Rank>) => {
return tensor.mul(max.sub(min)).add(min)
}
/**
* @description: 绘制图形
* @param {Point2D} points
* @param {string} name
* @return {*}
*/
const plot = (points: Point2D[], name: string) => {
tfvis.render.scatterplot(
{ name: `${name} vs House Price` },
{ values: [points], series: ["original"] },
{
xLabel: name,
yLabel: "Price"
}
)
}

const createModal = () => {
const model = tf.sequential()

model.add(tf.layers.dense({
units: 1,
useBias: true,
activation: 'linear',
inputDim: 1
}))
return model
}
const tfTensor = async () => {
// 导入数据
const houseSaleDateSet = tf.data.csv(csv)
houseSaleDateSet.take(10).toArray().then(res => {
console.log('a', res);
})
const points = houseSaleDateSet.map((record: HouseSaleDataSet) => {
// 从数据中提取x,y值并绘制图形
const pointsDataSet = houseSaleDateSet.map((record: HouseSaleDataSet) => {
return {
x: record.sqft_living,
y: record.price
}
})
points.toArray().then((res: Point2D[]) => {
plot(res, 'Square feet')
console.log('points', res);
})
// Feature (inputs)
const featureValue = await points.map(p => p.x).toArray()
const points: Point2D[] = await pointsDataSet.toArray()

if (points.length % 2 !== 0) {
// 如果张量是奇数,会导致无法平均分割,需要变成偶数
points.pop()
}
tf.util.shuffle(points)
plot(points, 'Square feet')
// Feature (inputs) 提取特征并存在张量中
const featureValue = points.map(p => p.x)
const featureTensor = tf.tensor2d(featureValue, [featureValue.length, 1])
// Labels (outputs)
const labelValue = await points.map(p => p.y).toArray()
// Labels (outputs) 对标签做同样的操作
const labelValue = points.map(p => p.y)
const labelTensor = tf.tensor2d(labelValue, [labelValue.length, 1])

featureTensor.print()
labelTensor.print()
// 标准化标签和特征
const normaliseFeatureTensor = normalise(featureTensor)
const normaliseLabelTensor = normalise(labelTensor)

normaliseFeatureTensor.tensor.print()
normaliseLabelTensor.tensor.print()
// 分割测试集和训练集
const [trainingFeatureTensor, testingFeatureTensor] = tf.split(normaliseFeatureTensor.tensor, 2)
const [trainingLabelTensor, testingLabelTensor] = tf.split(normaliseLabelTensor.tensor, 2)
// 创建模型
const modal = createModal()
modal.summary()
tfvis.show.modelSummary({ name: "Modal summary" }, modal)
}

const App = () => {
Expand Down
20 changes: 19 additions & 1 deletion packages/ml/client/lib.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import * as tf from '@tensorflow/tfjs';
import type { MemoryInfo } from '@tensorflow/tfjs'

import * as tfvis from '@tensorflow/tfjs-vis'
import type { Point2D } from '@tensorflow/tfjs-vis';
/**
* @description: 测试tf的内存管理
* @return {*}
Expand Down Expand Up @@ -77,4 +78,21 @@ export const tfInfo = (): TfInfo => {

export const csv2DataSet = (path: string): tf.data.CSVDataset => {
return tf.data.csv(path)
}

/**
* @description: 绘制图形
* @param {Point2D} points
* @param {string} name
* @return {*}
*/
export const plot = (points: Point2D[], name: string): void => {
tfvis.render.scatterplot(
{ name: `${name} vs House Price` },
{ values: [points], series: ["original"] },
{
xLabel: name,
yLabel: "Price"
}
)
}

0 comments on commit dc9837a

Please sign in to comment.