-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
36 lines (31 loc) · 1.09 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#############################################################################
# YOUR GENERATIVE MODEL
# ---------------------
# Should be implemented in the 'generative_model' function
# !! *DO NOT MODIFY THE NAME OF THE FUNCTION* !!
#
# You can store your parameters in any format you want (npy, h5, json, yaml, ...)
# <!> *SAVE YOUR PARAMETERS IN THE parameters/ DICRECTORY* <!>
#
# See below an example of a generative model
# G_\theta(Z) = np.max(0, \theta.Z)
############################################################################
import sys
sys.path.append('/parameters/20221203')
import mlflow
import torch
# <!> DO NOT ADD ANY OTHER ARGUMENTS <!>
def generative_model(noise):
"""
Generative model
Parameters
----------
noise : ndarray with shape (n_samples, n_dim)
input noise of the generative model
"""
# See below an example
# ---------------------
latent_variable = torch.Tensor(noise[:, :7])
model = mlflow.pytorch.load_model('parameters/20221203/best_ad_mean')
samples = model.sample(latent_variable).detach().numpy()
return samples