Skip to content

Commit

Permalink
add filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Sep 3, 2024
1 parent 8a47824 commit df1aa4f
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions docs/run_adloc_dd.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@
# result_path = f"results/{region}"
# figure_path = f"figures/{region}"

# picks_file = os.path.join(data_path, "gamma_picks.csv")
# events_file = os.path.join(data_path, "gamma_events.csv")
# stations_file = os.path.join(data_path, "stations.csv")
# # picks_file = os.path.join(data_path, "gamma_picks.csv")
# # events_file = os.path.join(data_path, "gamma_events.csv")
# # stations_file = os.path.join(data_path, "stations.csv")

# # picks_file = os.path.join(result_path, "ransac_picks_sst.csv")
# # events_file = os.path.join(result_path, "ransac_events_sst.csv")
# # stations_file = os.path.join(result_path, "ransac_stations_sst.csv")
# picks_file = os.path.join(result_path, "ransac_picks_sst.csv")
# events_file = os.path.join(result_path, "ransac_events_sst.csv")
# stations_file = os.path.join(result_path, "ransac_stations_sst.csv")

# # %% generate the double-difference pair file
# if ddp_local_rank == 0:
Expand All @@ -79,6 +79,8 @@
# events = pd.read_csv(os.path.join(result_path, "pair_events.csv"), parse_dates=["time"])
# stations = pd.read_csv(os.path.join(result_path, "pair_stations.csv"))
# picks = pd.read_csv(os.path.join(result_path, "pair_picks.csv"), parse_dates=["phase_time"])
# if "adloc_mask" in picks.columns:
# picks = picks[picks["adloc_mask"] == 1]
# dtypes = pickle.load(open(os.path.join(result_path, "pair_dtypes.pkl"), "rb"))
# pairs = np.memmap(os.path.join(result_path, "pair_dt.dat"), mode="r", dtype=dtypes)

Expand Down Expand Up @@ -269,7 +271,7 @@
######################################################################################################
optimizer = optim.Adam(params=travel_time.parameters(), lr=0.1)
valid_index = np.ones(len(pairs), dtype=bool)
EPOCHS = 30
EPOCHS = 100
for epoch in range(EPOCHS):
loss = 0
optimizer.zero_grad()
Expand Down Expand Up @@ -317,6 +319,14 @@

pred_time = np.concatenate(pred_time)
valid_index = np.abs(pred_time - pairs["phase_dtime"]) < np.std((pred_time - pairs["phase_dtime"])[valid_index]) * 3.0 #* (np.cos(epoch * np.pi / EPOCHS) + 2.0) # 3std -> 1std

pairs_df = pd.DataFrame({"event_index1": pairs["event_index1"], "event_index2": pairs["event_index2"], "station_index": pairs["station_index"]})
pairs_df = pairs_df[valid_index]
config["MIN_OBS"] = 8
pairs_df = pairs_df.groupby(["event_index1", "event_index2"], as_index=False, group_keys=False).filter(lambda x: len(x) >= config["MIN_OBS"])
valid_index = np.zeros(len(pairs), dtype=bool)
valid_index[pairs_df.index] = True

phase_dataset.valid_index = valid_index

invert_event_loc = raw_travel_time.event_loc.weight.clone().detach().numpy()
Expand Down

0 comments on commit df1aa4f

Please sign in to comment.