Skip to content

Commit

Permalink
update phasenet plus
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Nov 7, 2024
1 parent 099f1d8 commit b7d2a38
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 56 deletions.
3 changes: 3 additions & 0 deletions scripts/download_waveform_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,13 @@ def download_waveform(
client = obspy.clients.fdsn.Client(provider)

DELTATIME = "1H" # 1H or 1D
# DELTATIME = "1D"
if DELTATIME == "1H":
start = datetime.fromisoformat(config["starttime"]).strftime("%Y-%m-%dT%H")
elif DELTATIME == "1D":
start = datetime.fromisoformat(config["starttime"]).strftime("%Y-%m-%d")
else:
raise ValueError("Invalid interval")
starttimes = pd.date_range(start, config["endtime"], freq=DELTATIME, tz="UTC", inclusive="left").to_list()
starttimes = np.array_split(starttimes, num_nodes)[rank]
print(f"rank {rank}: {len(starttimes) = }, {starttimes[0]}, {starttimes[-1]}")
Expand Down
2 changes: 2 additions & 0 deletions scripts/download_waveform_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def download_waveform(
start = datetime.fromisoformat(config["starttime"]).strftime("%Y-%m-%dT%H")
elif DELTATIME == "1D":
start = datetime.fromisoformat(config["starttime"]).strftime("%Y-%m-%d")
else:
raise ValueError("Invalid interval")
starttimes = pd.date_range(start, config["endtime"], freq=DELTATIME, tz="UTC", inclusive="left").to_list()
starttimes = np.array_split(starttimes, num_nodes)[node_rank]
print(f"rank {node_rank}/{num_nodes}: {len(starttimes) = }, {starttimes[0]}, {starttimes[-1]}")
Expand Down
10 changes: 7 additions & 3 deletions scripts/merge_phasenet_picks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
import os
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
from glob import glob
from threading import Lock, Thread

import fsspec
import numpy as np
import pandas as pd
import pyproj
from args import parse_args
from obspy import read_inventory
from obspy.clients.fdsn import Client
from sklearn.cluster import DBSCAN
from tqdm import tqdm
from args import parse_args
from glob import glob


def scan_csv(year, root_path, region, model, fs=None, bucket=None, protocol="file"):
Expand All @@ -31,6 +31,7 @@ def scan_csv(year, root_path, region, model, fs=None, bucket=None, protocol="fil
csvs = fs.glob(f"{jday}/??/*.csv")
else:
csvs = glob(f"{root_path}/{region}/{model}/picks/{year}/{jday}/??/*.csv")
# csvs = glob(f"{root_path}/{region}/{model}/picks/{year}/{jday}/*.csv")

csv_list.extend([[year, jday, csv] for csv in csvs])

Expand Down Expand Up @@ -89,7 +90,7 @@ def read_csv(rows, region, model, year, jday, root_path, fs=None, bucket=None):

# %%
# years = os.listdir(f"{root_path}/{region}/{model}/picks_{model}")
years = glob(f"{root_path}/{region}/{model}/picks_{model}/????/")
years = glob(f"{root_path}/{region}/{model}/picks/????/")
years = [year.rstrip("/").split("/")[-1] for year in years]
print(f"Years: {years}")

Expand Down Expand Up @@ -132,6 +133,9 @@ def read_csv(rows, region, model, year, jday, root_path, fs=None, bucket=None):
for csv in tqdm(csvs, desc="Merge csv files"):
picks.append(pd.read_csv(csv, dtype=str))
picks = pd.concat(picks, ignore_index=True)
print(f"Number of picks: {len(picks):,}")
print(f"Number of P picks: {len(picks[picks['phase_type'] == 'P']):,}")
print(f"Number of S picks: {len(picks[picks['phase_type'] == 'S']):,}")
picks.to_csv(f"{root_path}/{region}/{model}/{model}_picks.csv", index=False)

# %%
162 changes: 128 additions & 34 deletions scripts/run_event_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,79 @@
from tqdm import tqdm


def plotting_debug(xt, hist, topk_index, topk_score, picks, events, stations, config):

# timestamp0 = config["timestamp0"]
# events_compare = pd.read_csv("local/Ridgecrest_debug5/adloc_gamma/ransac_events.csv")
# picks_compare = pd.read_csv("local/Ridgecrest_debug5/adloc_gamma/ransac_picks.csv")
# # events_compare = pd.read_csv("local/Ridgecrest_debug5/adloc_plus2/ransac_events_sst_0.csv")
# # picks_compare = pd.read_csv("local/Ridgecrest_debug5/adloc_plus2/ransac_picks_sst_0.csv")
# events_compare["time"] = pd.to_datetime(events_compare["time"])
# events_compare["timestamp"] = events_compare["time"].apply(lambda x: (x - timestamp0).total_seconds())
# picks_compare["phase_time"] = pd.to_datetime(picks_compare["phase_time"])
# picks_compare["timestamp"] = picks_compare["phase_time"].apply(lambda x: (x - timestamp0).total_seconds())

DT = config["DT"]
MIN_STATION = config["MIN_STATION"]

# map station_id to int
stations["xy"] = stations["longitude"] - stations["latitude"]
stations.sort_values(by="xy", inplace=True)
mapping_id = {v: i for i, v in enumerate(stations["station_id"])}
mapping_color = {v: f"C{i}" if v != -1 else "k" for i, v in enumerate(events["event_index"].unique())}

NX = 100
for i in tqdm(range(0, len(hist), NX)):
bins = np.arange(i, i + NX, DT)

fig, ax = plt.subplots(2, 1, figsize=(15, 10), sharex=True)

# plot hist
idx = (xt > i) & (xt < i + NX)
ax[0].bar(xt[idx], hist[idx], width=DT)

ylim = ax[0].get_ylim()
idx = (xt[topk_index] > i) & (xt[topk_index] < i + NX)
ax[0].vlines(xt[topk_index][idx], ylim[0], ylim[1], color="k", linewidth=1)

# idx = (events_compare["timestamp"] > i) & (events_compare["timestamp"] < i + NX)
# ax[0].vlines(events_compare["timestamp"][idx], ylim[0], ylim[1], color="r", linewidth=1, linestyle="--")

# plot picks-events match
idx = (events["timestamp"] > i) & (events["timestamp"] < i + NX)
ax[1].scatter(
events["timestamp"][idx],
events["station_id"][idx].map(mapping_id),
c=events["event_index"][idx].map(mapping_color),
marker=".",
s=30,
)

idx = (picks["timestamp"] > i) & (picks["timestamp"] < i + NX)
ax[1].scatter(
picks["timestamp"][idx],
picks["station_id"][idx].map(mapping_id),
c=picks["event_index"][idx].map(mapping_color),
marker="x",
linewidth=0.5,
s=10,
)

# idx = (picks_compare["timestamp"] > i) & (picks_compare["timestamp"] < i + NX)
# ax[1].scatter(
# picks_compare["timestamp"][idx],
# picks_compare["station_id"][idx].map(mapping_id),
# facecolors="none",
# edgecolors="r",
# linewidths=0.1,
# s=30,
# )

if not os.path.exists(f"figures"):
os.makedirs(f"figures")
plt.savefig(f"figures/debug_{i:04d}.png", dpi=300, bbox_inches="tight")


def associate(
picks: pd.DataFrame,
events: pd.DataFrame,
Expand All @@ -27,63 +100,68 @@ def associate(

VPVS_RATIO = config["VPVS_RATIO"]
VP = config["VP"]
DT = 1.0 # seconds
DT = 2.0 # seconds
MIN_STATION = 3

# %%
t0 = min(events["event_time"].min(), picks["phase_time"].min())
events["timestamp"] = events["event_time"].apply(lambda x: (x - t0).total_seconds())
events["timestamp_center"] = events["center_time"].apply(lambda x: (x - t0).total_seconds())
picks["timestamp"] = picks["phase_time"].apply(lambda x: (x - t0).total_seconds())
timestamp0 = min(events["event_time"].min(), picks["phase_time"].min())

# proj = Proj(proj="merc", datum="WGS84", units="km")
# stations[["x_km", "y_km"]] = stations.apply(lambda x: pd.Series(proj(x.longitude, x.latitude)), axis=1)
events["timestamp"] = events["event_time"].apply(lambda x: (x - timestamp0).total_seconds())
events["timestamp_center"] = events["center_time"].apply(lambda x: (x - timestamp0).total_seconds())
picks["timestamp"] = picks["phase_time"].apply(lambda x: (x - timestamp0).total_seconds())

# dist_matrix = squareform(pdist(stations[["x_km", "y_km"]].values))
# mst = minimum_spanning_tree(dist_matrix)
# dx = np.median(mst.data[mst.data > 0])
# print(f"dx: {dx:.3f}")
# eps_t = dx / VP * 2.0
# eps_t = 6.0
# eps_xy = eps_t * VP * 2 / (1.0 + VPVS_RATIO)
# print(f"eps_t: {eps_t:.3f}, eps_xy: {eps_xy:.3f}")
# eps_xy = 30.0
# print(f"eps_xy: {eps_xy:.3f}")
t0 = min(events["timestamp"].min(), picks["timestamp"].min())
t1 = max(events["timestamp"].max(), picks["timestamp"].max())

# %% Using DBSCAN to cluster events
# proj = Proj(proj="merc", datum="WGS84", units="km")
# stations[["x_km", "y_km"]] = stations.apply(lambda x: pd.Series(proj(x.longitude, x.latitude)), axis=1)
# events = events.merge(stations[["station_id", "x_km", "y_km"]], on="station_id", how="left")

# scaling = np.array([1.0, 1.0 / eps_xy, 1.0 / eps_xy])
# clustering = DBSCAN(eps=2.0, min_samples=4).fit(events[["timestamp", "x_km", "y_km"]] * scaling)
# # clustering = DBSCAN(eps=2.0, min_samples=4).fit(events[["timestamp"]])
# # clustering = DBSCAN(eps=3.0, min_samples=3).fit(events[["timestamp"]])
# # clustering = DBSCAN(eps=1.0, min_samples=3).fit(events[["timestamp"]])
# events["event_index"] = clustering.labels_

## Using histogram to cluster events
events["event_index"] = -1
t = np.arange(events["timestamp"].min(), events["timestamp"].max(), DT)
hist, _ = np.histogram(events["timestamp"], bins=t)
# retrieve picks using max_pool of kernel size 5 seconds
t = np.arange(t0, t1, DT)
hist, edge = np.histogram(events["timestamp"], bins=t, weights=events["event_score"])
xt = (edge[:-1] + edge[1:]) / 2 # center of the bin
# hist_numpy = hist.copy()

hist = torch.from_numpy(hist).float().unsqueeze(0).unsqueeze(0)
hist_pool = F.max_pool1d(hist, kernel_size=5, padding=2, stride=1)
# find the index of the maximum value in hist_pool
hist_pool = F.max_pool1d(hist, kernel_size=3, padding=1, stride=1)
mask = hist_pool == hist
hist = hist * mask
K = int((t[-1] - t[0]) / 10) # assume max 1 event per 10 seconds on average
hist = hist.squeeze(0).squeeze(0)
K = int((t[-1] - t[0]) / 5) # assume max 1 event per 10 seconds on average
topk_score, topk_index = torch.topk(hist, k=K)
topk_index = topk_index[topk_score > MIN_STATION] # min 3 stations
topk_index = topk_index.squeeze().numpy()
topk_index = topk_index[topk_score >= MIN_STATION] # min 3 stations
topk_score = topk_score[topk_score >= MIN_STATION]
topk_index = topk_index.numpy()
topk_score = topk_score.numpy()
num_events = len(topk_index)
# assign timestamp to events based on the topk_index within 2 DT
t0 = (topk_index - 2) * DT
t1 = (topk_index + 2) * DT
t00 = xt[topk_index - 1]
t11 = xt[topk_index + 1]
timestamp = events["timestamp"].values
for i in tqdm(range(num_events), desc="Assigning event index"):
mask = (timestamp >= t0[i]) & (timestamp <= t1[i])
mask = (timestamp >= t00[i]) & (timestamp <= t11[i])
events.loc[mask, "event_index"] = i

print(f"Number of associated events: {len(events['event_index'].unique())}")
events["num_picks"] = events.groupby("event_index").size()

# # refine event index using DBSCAN
# events["group_index"] = -1
# for group_id, event in tqdm(events.groupby("event_index"), desc="DBSCAN clustering"):
# if len(event) < MIN_STATION:
# events.loc[event.index, "event_index"] = -1
# clustering = DBSCAN(eps=20, min_samples=MIN_STATION).fit(event[["x_km", "y_km"]])
# events.loc[event.index, "group_index"] = clustering.labels_
# events["dummy_index"] = events["event_index"].astype(str) + "." + events["group_index"].astype(str)
# mapping = {v: i for i, v in enumerate(events["dummy_index"].unique())}
# events["dummy_index"] = events["dummy_index"].map(mapping)
# events.loc[(events["event_index"] == -1) | (events["group_index"] == -1), "dummy_index"] = -1
# events["event_index"] = events["dummy_index"]
# events.drop(columns=["dummy_index"], inplace=True)

# %% link picks to events
picks["event_index"] = -1
Expand All @@ -92,6 +170,8 @@ def associate(
for group_id, event in tqdm(events.groupby("station_id"), desc="Linking picks to events"):
# travel time tt = (tp + ts) / 2 = (1 + ps_ratio)/2 * tp => tp = tt * 2 / (1 + ps_ratio)
# (ts - tp) = (ps_ratio - 1) tp = tt * 2 * (ps_ratio - 1) / (ps_ratio + 1)

event = event.sort_values(by="num_picks", ascending=True)
ps_delta = event["travel_time_s"].values * 2 * (VPVS_RATIO - 1) / (VPVS_RATIO + 1)
t1 = event["timestamp_center"].values - ps_delta * 1.2
t2 = event["timestamp_center"].values + ps_delta * 1.2
Expand All @@ -107,6 +187,17 @@ def associate(

picks.reset_index(inplace=True)

# plotting_debug(
# xt,
# hist_numpy,
# topk_index,
# topk_score,
# picks,
# events,
# stations,
# {"DT": DT, "MIN_STATION": MIN_STATION, "timestamp0": timestamp0},
# )

picks.drop(columns=["timestamp"], inplace=True)
events.drop(columns=["timestamp", "timestamp_center"], inplace=True)

Expand All @@ -127,6 +218,9 @@ def associate(
# drop event index -1
events = events[events["event_index"] != -1]

print(f"Number of associated events: {len(events['event_index'].unique()):,}")
print(f"Number of associated picks: {len(picks[picks['event_index'] != -1]):,} / {len(picks):,}")

return events, picks


Expand Down
24 changes: 14 additions & 10 deletions scripts/run_phasenet_plus.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# %%
from typing import Dict, List
import json
import os
import sys
from args import parse_args
import os
from collections import defaultdict
from glob import glob
from typing import Dict, List

import fsspec
import torch
from collections import defaultdict
import numpy as np
import pandas as pd
import torch
from args import parse_args
from run_event_association import associate


Expand All @@ -31,6 +30,7 @@ def run_phasenet(
# %%
if data_type == "continuous":
subdir = 3
# subdir = 2
elif data_type == "event":
subdir = 1

Expand All @@ -50,6 +50,7 @@ def run_phasenet(

if data_type == "continuous":
mseed_list = sorted(glob(f"{root_path}/{waveform_dir}/????/???/??/*.mseed"))
# mseed_list = sorted(glob(f"{root_path}/{waveform_dir}/????/???/*.mseed"))
elif data_type == "event":
mseed_list = sorted(glob(f"{root_path}/{waveform_dir}/*.mseed"))
else:
Expand All @@ -59,16 +60,17 @@ def run_phasenet(
mseed_3c = defaultdict(list)
for mseed in mseed_list:
# key = mseed.replace(f"{root_path}/{waveform_dir}/", "").replace(".mseed", "")
key = "/".join(mseed.replace(".mseed", "").split("/")[-subdir:])
key = "/".join(mseed.replace(".mseed", "").split("/")[-subdir - 1 :])
if data_type == "continuous":
key = key[:-1]
mseed_3c[key].append(mseed)
print(f"Number of mseed files: {len(mseed_3c)}")

# %% skip processed files
if not overwrite:
processed = sorted(glob(f"{root_path}/{result_path}/picks_phasenet_plus/????/???/??/*.csv"))
processed = ["/".join(f.replace(".csv", "").split("/")[-subdir:]) for f in processed]
processed = sorted(glob(f"{root_path}/{result_path}/picks_phasenet_plus/????/???/*.csv"))
# processed = sorted(glob(f"{root_path}/{result_path}/picks_phasenet_plus/????/???/*.csv"))
processed = ["/".join(f.replace(".csv", "").split("/")[-subdir - 1 :]) for f in processed]
processed = [p[:-1] for p in processed] ## remove the channel suffix
print(f"Number of processed files: {len(processed)}")

Expand All @@ -93,6 +95,7 @@ def run_phasenet(
num_gpu = torch.cuda.device_count()
print(f"num_gpu = {num_gpu}")
base_cmd = f"../EQNet/predict.py --model phasenet_plus --add_polarity --add_event --format mseed --data_list={root_path}/{result_path}/mseed_list_{node_rank:03d}_{num_nodes:03d}.csv --response_path={root_path}/{response_path} --result_path={root_path}/{result_path} --batch_size 1 --workers 1 --subdir_level {subdir}"
# base_cmd += " --resume ../../QuakeFlow/EQNet/model_phasenet_plus_0630/model_99.pth"
if num_gpu == 0:
cmd = f"python {base_cmd} --device=cpu"
elif num_gpu == 1:
Expand All @@ -116,6 +119,9 @@ def run_phasenet(

run_phasenet(root_path=root_path, region=region, config=config)

if num_nodes == 1:
os.system(f"python merge_phasenet_plus_picks.py --region {region}")

if num_nodes == 1:
config.update({"VPVS_RATIO": 1.73, "VP": 6.0})
stations = pd.read_json(f"{root_path}/{region}/obspy/stations.json", orient="index")
Expand All @@ -125,8 +131,6 @@ def run_phasenet(
)
picks = pd.read_csv(f"{root_path}/{region}/phasenet_plus/picks_phasenet_plus.csv", parse_dates=["phase_time"])
events, picks = associate(picks, events, stations, config)
print(f"Number of picks: {len(picks):,}")
print(f"Number of associated events: {len(events['event_index'].unique()):,}")
events.to_csv(f"{root_path}/{region}/phasenet_plus/phasenet_plus_events_associated.csv", index=False)
picks.to_csv(f"{root_path}/{region}/phasenet_plus/phasenet_plus_picks_associated.csv", index=False)

Expand Down
Loading

0 comments on commit b7d2a38

Please sign in to comment.