Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vytautas Jancauskas committed Jan 17, 2025
1 parent 589d178 commit a6eba26
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/methane_super_emitters/create_dataset_negative.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def process_tropomi_file(file_path, month_path, day_path, output_dir, input_file
original.mask.sum() < 0.2 * 32 * 32):
print(f"FOUND: {csv_line}")
emitter = True
if not emitter and np.random.random() < 0.001:
if not emitter and np.random.random() < 0.01 and original.mask.sum() < 0.2 * 32 * 32:
negative_path = os.path.join(output_dir, 'negative', f"{uuid.uuid4()}.npz")
np.savez(negative_path, methane=methane_window, lat=lat_window,
lon=lon_window, qa=qa_window, time=parsed_time,
Expand Down
4 changes: 3 additions & 1 deletion src/methane_super_emitters/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from torch.utils.data import DataLoader, Dataset
import os
import glob
import random
import numpy as np

class TROPOMISuperEmitterDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.positive = glob.glob(os.path.join(data_dir, 'positive', '*.npz'))
self.negative = glob.glob(os.path.join(data_dir, 'negative', '*.npz'))
negative_samples = glob.glob(os.path.join(data_dir, 'negative', '*.npz'))
self.negative = [random.choice(negative_samples) for _ in range(len(self.positive))]
self.all_samples = self.positive + self.negative

def __len__(self):
Expand Down
19 changes: 9 additions & 10 deletions src/methane_super_emitters/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,21 @@
class SuperEmitterDetector(L.LightningModule):
def __init__(self):
super().__init__()
self.cnv = nn.Conv2d(1, 128, 5, 4)
self.cnv = nn.Conv2d(1, 32, 5, 4)
self.rel = nn.ReLU()
self.bn = nn.BatchNorm2d(128)
self.bn = nn.BatchNorm2d(32)
self.mxpool = nn.MaxPool2d(4)
self.flat = nn.Flatten()
self.fc1 = nn.Linear(128, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 1)
self.fc1 = nn.Linear(32, 16)
self.fc2 = nn.Linear(16, 1)
self.sigmoid = nn.Sigmoid()
self.accuracy = torchmetrics.classification.BinaryAccuracy()

def forward(self, x):
out = self.bn(self.rel(self.cnv(x)))
out = self.flat(self.mxpool(out))
out = self.rel(self.fc1(out))
out = self.rel(self.fc2(out))
out = self.fc3(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out

Expand All @@ -37,6 +35,9 @@ def training_step(self, batch, batch_idx):
label = y.view(-1)
out = self(img)
loss = torch.nn.functional.binary_cross_entropy(out, y)
out = torch.where(out > 0.5, 1.0, 0.0)
accu = self.accuracy(out, y)
self.log('train_accu', accu)
return loss

def validation_step(self, batch, batch_idx):
Expand All @@ -45,10 +46,8 @@ def validation_step(self, batch, batch_idx):
label = y.view(-1)
out = self(img)
loss = torch.nn.functional.binary_cross_entropy(out, y)
#out = nn.Softmax(-1)(out)
#logits = torch.argmax(out, dim=1)
out = torch.where(out > 0.5, 1.0, 0.0)
accu = self.accuracy(out, y)
self.log('accuracy', accu)
self.log('val_accu', accu)
return loss, accu

34 changes: 34 additions & 0 deletions src/methane_super_emitters/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import click
import torch
import lightning as L
import numpy as np
import glob
import os
from methane_super_emitters.model import SuperEmitterDetector

def predict(model, npz_file):
data = np.load(npz_file)
m = data['methane']
m[data['mask']] = 0.0
m = np.array([[m]])
y_hat = model(torch.tensor(m, dtype=torch.float))
y_hat = y_hat.detach().cpu().numpy()
return round(y_hat[0][0])

@click.command()
@click.option('-c', '--checkpoint', help='Checkpoint file')
@click.option('-i', '--input-dir', help='Directory with files')
def main(checkpoint, input_dir):
model = SuperEmitterDetector.load_from_checkpoint(checkpoint)
model.eval()
sum_ = 0.0
counter = 0
for npz_file in glob.glob(os.path.join(input_dir, '*.npz')):
y_hat = predict(model, npz_file)
print(npz_file, predict(model, npz_file))
sum_ += y_hat
counter += 1
print(sum_ / counter)

if __name__ == '__main__':
main()

0 comments on commit a6eba26

Please sign in to comment.