-
Notifications
You must be signed in to change notification settings - Fork 2
/
transformermodel.h
153 lines (114 loc) · 3.53 KB
/
transformermodel.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/*
Transformer Screening Project.
P. Karpov, 2020
*/
#ifndef TRANSFORMERMODEL_H
#define TRANSFORMERMODEL_H
#include <map>
#include <stdlib.h>
#include <iostream>
#include <string>
#include <sstream>
#include <vector>
#include <set>
#include <map>
#include <openbabel/mol.h>
#include <cblas.h>
//Used in error estimation for zero-like predictions.
#define TOL 1e-3
//Generate a batch of SMILES corresponding to the molecule.
//If the mol is out of vocabulary, or some problem occured
//the function returns false.
bool GetRandomSmiles(const std::string & smiles,
std::set<std::string> & mols,
int &max_n);
//< 30 points
float student(int freedom);
void calcMeanAndError(const std::vector<float> &data,
float * avg,
float * err);
class TransformerModel
{
public:
static const int MaxBatchSize = 10; //No more than 10 augmented SMILES per molecule.
static const int MaxSmilesSize = 128; //The length of a SMILE (characters, not atoms).
//These constants come from the Transformer-CNN article.
static const int EmbeddingSize = 64;
static const int HeadsCount = 10;
static const int ConvOffset = 20;
static const int HiddenSize = 512;
static const char * vocab;
static const int vocab_length = 66; //The length of the vocabulary string.
TransformerModel(const char * fname, const char * prop = NULL);
~TransformerModel();
struct ResultValue
{
bool valid;
float value[TransformerModel::MaxBatchSize];
int size;
ResultValue()
{
valid = false;
}
};
bool isGood() const;
ResultValue predict(std::set<std::string> & mols, int max_n,
float * embeddings = NULL);
float * getSmilesEmbeddings();
void setSmilesEmbeddings(float * s);
const std::string & getProp() const;
private:
std::string m_prop;
//Mapping from a symbol to wordId.
std::map<char, int> char_to_ix;
int left_mask_id [ MaxBatchSize ];
int x[MaxBatchSize * (MaxSmilesSize + ConvOffset)];
char * data; //model
bool classification;
float v_min;
float v_max;
//Variables for matrixes during the calculations.
float * smiles_embedding;
float * pos;
float *q, *k, *v;
float *a, *sa, *lc;
//Variables for the model loaded.
float *mdl;
//Attention layers.
float * K1[3], *Q1[3], *V1[3], *TD1[3], *B1[3];
float * gamma1[3], *beta1[3], *w1[3], *b1[3], *w2[3], *b2[3];
float * gamma2[3], *beta2[3];
//Convolutional filters.
float * Conv1, * Conv1_B;
float * Conv2, * Conv2_B;
float * Conv3, * Conv3_B;
float * Conv4, * Conv4_B;
float * Conv5, * Conv5_B;
float * Conv6, * Conv6_B;
float * Conv7, * Conv7_B;
float * Conv8, * Conv8_B;
float * Conv9, * Conv9_B;
float * Conv10, * Conv10_B;
float * Conv15, * Conv15_B;
float * Conv20, * Conv20_B;
//HighWay module.
float * CNN_W, * CNN_WB;
float * High1, * High1_B;
float * High2, * High2_B;
//The final output;
float * Out_W;
float * Out_B;
int NN;
int batch_size;
struct ConvInfo
{
int conv_number;
int n_filter;
float * conv; // W
float * bias; // B
int start; // Start position of this filter in the lc array.
};
void AttentionLayer(int layer);
};
#endif // TRANSFORMERMODEL_H