-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
37 lines (31 loc) · 1.16 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from __future__ import annotations
import math
from torch import empty, Tensor
disc_radius = 1/math.sqrt(2 * math.pi)
disc_center = (.5, .5)
def compute_label(point: tuple[float, float]) -> int:
"""
:param point: tuple (x, y) representing a point in the Euclidean plane
:returns: 1 if point is inside the disc centered at (0.5, 0.5) of radius 1/sqrt(2*pi), 0 otherwise
"""
return int(math.dist(point, disc_center) <= disc_radius)
def generate_samples(n: int) -> tuple[Tensor, Tensor]:
"""
:param n: number of samples
:returns:
inputs - n points sampled uniformly from [0,1]^2
labels - n one-hot labels (see function compute_label)
"""
inputs = empty(n, 2).uniform_()
labels = empty(n).new_tensor(list(map(compute_label, inputs))).view(-1, 1)
return inputs, labels
def generate_data(n: int = 1000) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]:
"""
:param n: number of samples
:returns:
train_inputs - n training points
train_labels - n training labels
test_inputs - n test points
test_labels - n test labels
"""
return generate_samples(n), generate_samples(n)