diff --git a/packages/ml/client/index.tsx b/packages/ml/client/index.tsx index 7ef2674d8..fe7269a88 100644 --- a/packages/ml/client/index.tsx +++ b/packages/ml/client/index.tsx @@ -47,11 +47,17 @@ const createModal = () => { const model = tf.sequential() model.add(tf.layers.dense({ - units: 1, + units: 4, useBias: true, activation: 'linear', inputDim: 1 })) + + const optimizer = tf.train.sgd(0.1) + model.compile({ + loss: "meanSquaredError", + optimizer + }) return model } const tfTensor = async () => { @@ -88,10 +94,13 @@ const tfTensor = async () => { // 分割测试集和训练集 const [trainingFeatureTensor, testingFeatureTensor] = tf.split(normaliseFeatureTensor.tensor, 2) const [trainingLabelTensor, testingLabelTensor] = tf.split(normaliseLabelTensor.tensor, 2) - // 创建模型 + // 创建模型 const modal = createModal() modal.summary() tfvis.show.modelSummary({ name: "Modal summary" }, modal) + // 了解 layer + const layer = modal.getLayer(undefined, 0) + tfvis.show.layer({ name: "Layer 1" }, layer) } const App = () => {