-
Notifications
You must be signed in to change notification settings - Fork 661
/
CSVDataset.java
85 lines (72 loc) · 2.59 KB
/
CSVDataset.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.util.Progress;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.io.IOException;
import java.io.Reader;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
public class CSVDataset extends RandomAccessDataset {
private final List<CSVRecord> csvRecords;
private CSVDataset(Builder builder) {
super(builder);
csvRecords = builder.csvRecords;
}
@Override
public Record get(NDManager manager, long index) {
CSVRecord record = csvRecords.get(Math.toIntExact(index));
NDArray datum = manager.create(encode(record.get("url")));
NDArray label = manager.create(Float.parseFloat(record.get("isMalicious")));
return new Record(new NDList(datum), new NDList(label));
}
@Override
public long availableSize() {
return csvRecords.size();
}
// we encode the url String based on the count of the character from a to z.
private int[] encode(String url) {
url = url.toLowerCase();
int[] encoding = new int[26];
for (char ch : url.toCharArray()) {
int index = ch - 'a';
if (index < 26 && index >= 0) {
encoding[ch - 'a']++;
}
}
return encoding;
}
@Override
public void prepare(Progress progress) {}
public static Builder builder() {
return new Builder();
}
public static final class Builder extends BaseBuilder<Builder> {
List<CSVRecord> csvRecords;
Builder(){}
@Override
protected Builder self() {
return this;
}
CSVDataset build() throws IOException {
String csvFilePath = "path/malicious_url_data.csv";
try (Reader reader = Files.newBufferedReader(Paths.get(csvFilePath));
CSVParser csvParser =
new CSVParser(
reader,
CSVFormat.DEFAULT
.withHeader("url", "isMalicious")
.withFirstRecordAsHeader()
.withIgnoreHeaderCase()
.withTrim())) {
csvRecords = csvParser.getRecords();
}
return new CSVDataset(this);
}
}
}