Skip to content

Commit

Permalink
revert back to old difference
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Jan 7, 2025
1 parent 6cb793c commit c563845
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 39 deletions.
118 changes: 94 additions & 24 deletions adloc/adloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,49 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

from .eikonal2d import _interp


class CalcTravelTime(Function):
@staticmethod
def forward(r, z, timetable, timetable_grad_r, timetable_grad_z, rgrid0, zgrid0, nr, nz, h):
tt = _interp(timetable, r.numpy(), z.numpy(), rgrid0, zgrid0, nr, nz, h)
tt = torch.from_numpy(tt)
return tt

@staticmethod
def setup_context(ctx, inputs, output):
r, z, timetable, timetable_grad_r, timetable_grad_z, rgrid0, zgrid0, nr, nz, h = inputs
ctx.save_for_backward(r, z)
ctx.timetable = timetable
ctx.timetable_grad_r = timetable_grad_r
ctx.timetable_grad_z = timetable_grad_z
ctx.rgrid0 = rgrid0
ctx.zgrid0 = zgrid0
ctx.nr = nr
ctx.nz = nz
ctx.h = h

@staticmethod
def backward(ctx, grad_output):
timetable_grad_r = ctx.timetable_grad_r
timetable_grad_z = ctx.timetable_grad_z
rgrid0 = ctx.rgrid0
zgrid0 = ctx.zgrid0
nr = ctx.nr
nz = ctx.nz
h = ctx.h
r, z = ctx.saved_tensors

grad_r = _interp(timetable_grad_r, r.numpy(), z.numpy(), rgrid0, zgrid0, nr, nz, h)
grad_z = _interp(timetable_grad_z, r.numpy(), z.numpy(), rgrid0, zgrid0, nr, nz, h)

grad_r = torch.from_numpy(grad_r) * grad_output
grad_z = torch.from_numpy(grad_z) * grad_output

return grad_r, grad_z, None, None, None, None, None, None, None, None


# %%
Expand Down Expand Up @@ -64,6 +107,8 @@ def __init__(
station_dt=None,
event_loc=None,
event_time=None,
event_loc0=None,
event_time0=None,
velocity={"P": 6.0, "S": 6.0 / 1.73},
eikonal=None,
zlim=[0, 30],
Expand All @@ -72,11 +117,25 @@ def __init__(
):
super().__init__()
self.num_event = num_event
self.event_loc0 = nn.Embedding(num_event, 3)
self.event_time0 = nn.Embedding(num_event, 1)
self.event_loc = nn.Embedding(num_event, 3)
self.event_time = nn.Embedding(num_event, 1)
self.station_loc = nn.Embedding(num_station, 3)
self.station_dt = nn.Embedding(num_station, 1) # same statioin term for P and S

## initialize event_loc0 and event_time0
if event_loc0 is not None:
event_loc0 = torch.tensor(event_loc0, dtype=dtype).contiguous()
else:
event_loc0 = torch.zeros(num_event, 3, dtype=dtype).contiguous()
if event_time0 is not None:
event_time0 = torch.tensor(event_time0, dtype=dtype).contiguous()
else:
event_time0 = torch.zeros(num_event, 1, dtype=dtype).contiguous()
self.event_loc0.weight = torch.nn.Parameter(event_loc0, requires_grad=False)
self.event_time0.weight = torch.nn.Parameter(event_time0, requires_grad=False)

## check initialization
station_loc = torch.tensor(station_loc, dtype=dtype)
if station_dt is not None:
Expand All @@ -97,7 +156,8 @@ def __init__(
self.event_loc.weight = torch.nn.Parameter(event_loc, requires_grad=True)
self.event_time.weight = torch.nn.Parameter(event_time, requires_grad=True)

self.velocity = [velocity["P"], velocity["S"]]
# self.velocity = [velocity["P"], velocity["S"]]
self.velocity = velocity
self.eikonal = eikonal
self.zlim = zlim
self.grad_type = grad_type
Expand Down Expand Up @@ -125,25 +185,25 @@ def calc_time(self, event_loc, station_loc, phase_type):
z = event_loc[:, 2] - station_loc[:, 2]
r = torch.sqrt(x**2 + y**2)

# timetable = self.eikonal["up"] if phase_type == 0 else self.eikonal["us"]
# timetable_grad = self.eikonal["grad_up"] if phase_type == 0 else self.eikonal["grad_us"]
# timetable_grad_r = timetable_grad[0]
# timetable_grad_z = timetable_grad[1]
# rgrid0 = self.eikonal["rgrid"][0]
# zgrid0 = self.eikonal["zgrid"][0]
# nr = self.eikonal["nr"]
# nz = self.eikonal["nz"]
# h = self.eikonal["h"]
# tt = CalcTravelTime.apply(r, z, timetable, timetable_grad_r, timetable_grad_z, rgrid0, zgrid0, nr, nz, h)

if phase_type in [0, "P"]:
timetable = self.timetable_p
elif phase_type in [1, "S"]:
timetable = self.timetable_s
else:
raise ValueError("phase_type should be 0 or 1. for P and S, respectively.")

tt = interp2d(timetable, r, z, self.rgrid, self.zgrid, self.h)
timetable = self.eikonal["up"] if phase_type == 0 else self.eikonal["us"]
timetable_grad = self.eikonal["grad_up"] if phase_type == 0 else self.eikonal["grad_us"]
timetable_grad_r = timetable_grad[0]
timetable_grad_z = timetable_grad[1]
rgrid0 = self.eikonal["rgrid"][0]
zgrid0 = self.eikonal["zgrid"][0]
nr = self.eikonal["nr"]
nz = self.eikonal["nz"]
h = self.eikonal["h"]
tt = CalcTravelTime.apply(r, z, timetable, timetable_grad_r, timetable_grad_z, rgrid0, zgrid0, nr, nz, h)

# if phase_type in [0, "P"]:
# timetable = self.timetable_p
# elif phase_type in [1, "S"]:
# timetable = self.timetable_s
# else:
# raise ValueError("phase_type should be 0 or 1. for P and S, respectively.")

# tt = interp2d(timetable, r, z, self.rgrid, self.zgrid, self.h)

tt = tt.float().unsqueeze(-1)

Expand Down Expand Up @@ -205,6 +265,8 @@ def __init__(
station_dt=None,
event_loc=None,
event_time=None,
event_loc0=None,
event_time0=None,
velocity={"P": 6.0, "S": 6.0 / 1.73},
eikonal=None,
zlim=[0, 30],
Expand All @@ -218,6 +280,8 @@ def __init__(
station_dt=station_dt,
event_loc=event_loc,
event_time=event_time,
event_loc0=event_loc0,
event_time0=event_time0,
velocity=velocity,
eikonal=eikonal,
zlim=zlim,
Expand Down Expand Up @@ -297,12 +361,18 @@ def forward(
station_loc_ = self.station_loc(station_index_) # (nb, 3)
station_dt_ = self.station_dt(station_index_) # (nb, 1)

event_loc_ = self.event_loc(event_index_) # (nb, 2, 3)
event_time_ = self.event_time(event_index_) # (nb, 2, 1)
delta_event_loc_ = self.event_loc(event_index_) # (nb, 2, 3)
delta_event_time_ = self.event_time(event_index_) # (nb, 2, 1)

event_loc0_ = self.event_loc0(event_index_) # (nb, 2, 3)
event_time0_ = self.event_time0(event_index_) # (nb, 2, 1)

station_loc_ = station_loc_.unsqueeze(1) # (nb, 1, 3)
station_dt_ = station_dt_.unsqueeze(1) # (nb, 1, 1)

event_loc_ = event_loc0_ + delta_event_loc_
event_time_ = event_time0_ + delta_event_time_

tt_ = self.calc_time(event_loc_, station_loc_, type) # (nb, 2)

t_ = event_time_ + tt_ + station_dt_ # (nb, 2, 1)
Expand Down Expand Up @@ -529,8 +599,8 @@ def forward(
num_station,
stations[["x_km", "y_km", "z_km"]].values,
stations[["dt_s"]].values,
event_loc,
event_time,
event_loc0=event_loc,
event_time0=event_time,
eikonal=eikonal_config,
)

Expand Down
51 changes: 36 additions & 15 deletions docs/run_adloc_dd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tqdm.auto import tqdm

from adloc.adloc import TravelTimeDD
from adloc.data import PhaseDatasetDD
from adloc.data import PhaseDatasetDT, PhaseDatasetDTCC
from adloc.eikonal2d import init_eikonal2d
from adloc.inversion import optimize_dd
from utils import plotting_dd
Expand Down Expand Up @@ -244,7 +244,7 @@
)

# %%
phase_dataset = PhaseDatasetDD(pairs, picks, events, stations, rank=ddp_local_rank, world_size=ddp_world_size)
phase_dataset = PhaseDatasetDT(pairs, picks, events, stations, rank=ddp_local_rank, world_size=ddp_world_size)
data_loader = DataLoader(phase_dataset, batch_size=None, shuffle=False, num_workers=0, drop_last=False)

# %%
Expand All @@ -256,7 +256,7 @@
num_event,
num_station,
station_loc=station_loc,
event_loc=event_loc,
event_loc0=event_loc,
# event_time=event_time,
eikonal=config["eikonal"],
)
Expand All @@ -270,6 +270,7 @@
## invert loss
######################################################################################################
optimizer = optim.Adam(params=travel_time.parameters(), lr=0.1)
# optimizer = optim.AdamW(params=travel_time.parameters(), lr=0.1, weight_decay=1.0)
valid_index = np.ones(len(pairs), dtype=bool)
EPOCHS = 100
for epoch in range(EPOCHS):
Expand All @@ -296,11 +297,11 @@

# torch.nn.utils.clip_grad_norm_(travel_time.parameters(), 1.0)
optimizer.step()
with torch.no_grad():
raw_travel_time.event_loc.weight.data[:, 2].clamp_(
min=config["zlim_km"][0] + 0.1, max=config["zlim_km"][1] - 0.1
)
raw_travel_time.event_loc.weight.data[torch.isnan(raw_travel_time.event_loc.weight)] = 0.0
# with torch.no_grad():
# raw_travel_time.event_loc.weight.data[:, 2].clamp_(
# min=config["zlim_km"][0] + 0.1, max=config["zlim_km"][1] - 0.1
# )
# raw_travel_time.event_loc.weight.data[torch.isnan(raw_travel_time.event_loc.weight)] = 0.0
if ddp_local_rank == 0:
print(f"Epoch {epoch}: loss {loss:.6e} of {np.sum(valid_index)} picks, {loss / np.sum(valid_index):.6e}")

Expand All @@ -318,21 +319,37 @@
pred_time.append(meta["phase_time"].detach().numpy())

pred_time = np.concatenate(pred_time)
valid_index = np.abs(pred_time - pairs["phase_dtime"]) < np.std((pred_time - pairs["phase_dtime"])[valid_index]) * 3.0 #* (np.cos(epoch * np.pi / EPOCHS) + 2.0) # 3std -> 1std

pairs_df = pd.DataFrame({"event_index1": pairs["event_index1"], "event_index2": pairs["event_index2"], "station_index": pairs["station_index"]})
valid_index = (
np.abs(pred_time - pairs["phase_dtime"]) < np.std((pred_time - pairs["phase_dtime"])[valid_index]) * 3.0
) # * (np.cos(epoch * np.pi / EPOCHS) + 2.0) # 3std -> 1std

pairs_df = pd.DataFrame(
{
"event_index1": pairs["event_index1"],
"event_index2": pairs["event_index2"],
"station_index": pairs["station_index"],
}
)
pairs_df = pairs_df[valid_index]
config["MIN_OBS"] = 8
pairs_df = pairs_df.groupby(["event_index1", "event_index2"], as_index=False, group_keys=False).filter(lambda x: len(x) >= config["MIN_OBS"])
pairs_df = pairs_df.groupby(["event_index1", "event_index2"], as_index=False, group_keys=False).filter(
lambda x: len(x) >= config["MIN_OBS"]
)
valid_index = np.zeros(len(pairs), dtype=bool)
valid_index[pairs_df.index] = True

phase_dataset.valid_index = valid_index

invert_event_loc = raw_travel_time.event_loc.weight.clone().detach().numpy()
invert_event_time = raw_travel_time.event_time.weight.clone().detach().numpy()
valid_event_index = np.unique(pairs["event_index1"][valid_index])
valid_event_index = np.concatenate([np.unique(pairs["event_index1"][valid_index]), np.unique(pairs["event_index2"][valid_index])])
invert_event_loc0 = raw_travel_time.event_loc0.weight.clone().detach().numpy()
invert_event_time0 = raw_travel_time.event_time0.weight.clone().detach().numpy()
invert_event_loc = invert_event_loc0 + invert_event_loc
invert_event_time = invert_event_time0 + invert_event_time
valid_event_index = np.unique(pairs["event_index1"][valid_index])
valid_event_index = np.concatenate(
[np.unique(pairs["event_index1"][valid_index]), np.unique(pairs["event_index2"][valid_index])]
)
valid_event_index = np.sort(np.unique(valid_event_index))

if ddp_local_rank == 0 and (epoch % 10 == 0):
Expand Down Expand Up @@ -397,6 +414,10 @@

invert_event_loc = raw_travel_time.event_loc.weight.clone().detach().numpy()
invert_event_time = raw_travel_time.event_time.weight.clone().detach().numpy()
invert_event_loc0 = raw_travel_time.event_loc0.weight.clone().detach().numpy()
invert_event_time0 = raw_travel_time.event_time0.weight.clone().detach().numpy()
invert_event_loc = invert_event_loc0 + invert_event_loc
invert_event_time = invert_event_time0 + invert_event_time

events = events_init.copy()
events["time"] = events["time"] + pd.to_timedelta(np.squeeze(invert_event_time), unit="s")
Expand Down

0 comments on commit c563845

Please sign in to comment.