Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 authored Oct 29, 2023
1 parent 40d607b commit b35db74
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 57 deletions.
44 changes: 22 additions & 22 deletions slurm/run_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,19 @@ def run_gamma(
index=False,
float_format="%.3f",
date_format="%Y-%m-%dT%H:%M:%S.%f",
columns=[
"time",
"magnitude",
"longitude",
"latitude",
# "depth(m)",
"depth_km",
"sigma_time",
"sigma_amp",
"cov_time_amp",
"event_index",
"gamma_score",
],
# columns=[
# "time",
# "magnitude",
# "longitude",
# "latitude",
# # "depth(m)",
# "depth_km",
# "sigma_time",
# "sigma_amp",
# "cov_time_amp",
# "event_index",
# "gamma_score",
# ],
)
# events = events[['time', 'magnitude', 'longitude', 'latitude', 'depth(m)', 'sigma_time', 'sigma_amp', 'gamma_score']]

Expand All @@ -173,15 +173,15 @@ def run_gamma(
fp,
index=False,
date_format="%Y-%m-%dT%H:%M:%S.%f",
columns=[
"station_id",
"phase_time",
"phase_type",
"phase_score",
"phase_amplitude",
"event_index",
"gamma_score",
],
# columns=[
# "station_id",
# "phase_time",
# "phase_type",
# "phase_score",
# "phase_amplitude",
# "event_index",
# "gamma_score",
# ],
)

if protocol != "file":
Expand Down
142 changes: 107 additions & 35 deletions slurm/run_phasenet_v2.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,114 @@
# %%
from pathlib import Path
import os
import torch
from typing import Dict, List

# %%
region = "BayArea"
root_path = Path(region)
data_path = root_path / "obspy"
result_path = root_path / "phasenet"
if not result_path.exists():
result_path.mkdir()
from kfp import dsl

# %%
waveform_path = data_path / "waveforms"
mseed_list = sorted(list(waveform_path.rglob("*.mseed")))
file_list = []
for f in mseed_list:
file_list.append(str(f))

file_list = sorted(list(set(file_list)))
with open(result_path / "data_list.txt", "w") as fp:
fp.write("\n".join(file_list))
@dsl.component()
def run_phasenet(
root_path: str,
region: str,
config: Dict,
rank: int = 0,
model_path: str = "../PhaseNet/",
data_type: str = "continuous",
mseed_list: List = None,
protocol: str = "file",
bucket: str = "",
token: Dict = None,
) -> str:
import os
from glob import glob

import fsspec
import torch

# %%
num_gpu = torch.cuda.device_count()
if num_gpu == 0:
cmd = f"python ../EQNet/predict.py --model phasenet --add_polarity --add_event --format mseed --data_list {result_path/'data_list.txt'} --batch_size 1 --result_path {result_path} --device=cpu"
elif num_gpu == 1:
cmd = f"python ../EQNet/predict.py --model phasenet --add_polarity --add_event --format mseed --data_list {result_path/'data_list.txt'} --batch_size 1 --result_path {result_path}"
else:
cmd = f"torchrun --standalone --nproc_per_node {num_gpu} ../EQNet/predict.py --model phasenet --add_polarity --add_event --format mseed --data_list {result_path/'data_list.txt'} --batch_size 1 --result_path {result_path}"
print(cmd)
os.system(cmd)

# # %%
# cmd = f"gsutil -m cp -r {result_path}/picks_phasenet {protocol}{bucket}/{folder}/phasenet/picks"
# print(cmd)
# os.system(cmd)
# %%
fs = fsspec.filesystem(protocol=protocol, token=token)

# %%
# %%
# result_path = f"{region}/phasenet/{rank:03d}"
result_path = f"{region}/phasenet"
if not os.path.exists(f"{root_path}/{result_path}"):
os.makedirs(f"{root_path}/{result_path}", exist_ok=True)

# %%
waveform_dir = f"{region}/waveforms"
if not os.path.exists(f"{root_path}/{waveform_dir}"):
if protocol != "file":
fs.get(f"{bucket}/{waveform_dir}/", f"{root_path}/{waveform_dir}/", recursive=True)

if mseed_list is None:
if data_type == "continuous":
mseed_list = sorted(glob(f"{root_path}/{waveform_dir}/????-???/??/*.mseed"))
elif data_type == "event":
mseed_list = sorted(glob(f"{root_path}/{waveform_dir}/*.mseed"))
else:
raise ValueError("data_type must be either continuous or event")

# %% group channels into stations
if data_type == "continuous":
mseed_list = list(set([x.split(".mseed")[0][:-1] + "*.mseed" for x in mseed_list]))
mseed_list = sorted(mseed_list)

# %%
if protocol != "file":
fs.get(f"{bucket}/{region}/obspy/inventory.xml", f"{root_path}/{region}/obspy/inventory.xml")

# %%
with open(f"{root_path}/{result_path}/mseed_list_{rank:03d}.csv", "w") as fp:
fp.write("\n".join(mseed_list))

# %%
if data_type == "continuous":
folder_depth = 3
elif data_type == "event":
folder_depth = 1
num_gpu = torch.cuda.device_count()
print(f"num_gpu = {num_gpu}")
base_cmd = f"../EQNet/predict.py --model phasenet --add_polarity --add_event --format mseed --data_list={root_path}/{result_path}/mseed_list.csv --response_xml={root_path}/{region}/obspy/inventory.xml --result_path={root_path}/{result_path} --batch_size 1 --workers 1 --folder_depth {folder_depth}"
if num_gpu == 0:
cmd = f"python {base_cmd} --device=cpu"
elif num_gpu == 1:
cmd = f"python {base_cmd}"
else:
cmd = f"torchrun --standalone --nproc_per_node {num_gpu} "
print(cmd)
os.system(cmd)

os.system(
f"cp {root_path}/{result_path}/picks_phasenet.csv {root_path}/{result_path}/phasenet_picks_{rank:03d}.csv"
)
os.system(
f"cp {root_path}/{result_path}/events_phasenet.csv {root_path}/{result_path}/phasenet_events_{rank:03d}.csv",
)

if protocol != "file":
fs.put(f"{root_path}/{result_path}/", f"{bucket}/{result_path}/", recursive=True)

return f"{result_path}/phasenet_picks_{rank:03d}.csv"

if __name__ == "__main__":
import json
import os
import sys

root_path = "local"
region = "demo"
if len(sys.argv) > 1:
root_path = sys.argv[1]
region = sys.argv[2]

with open(f"{root_path}/{region}/config.json", "r") as fp:
config = json.load(fp)

run_phasenet.python_func(root_path, region=region, config=config)# , data_type="event")

if config["num_nodes"] == 1:
os.system(f"mv {root_path}/{region}/phasenet/mseed_list_000.csv {root_path}/{region}/phasenet/mseed_list.csv")
os.system(
f"mv {root_path}/{region}/phasenet/phasenet_picks_000.csv {root_path}/{region}/phasenet/phasenet_picks.csv"
)
os.system(
f"mv {root_path}/{region}/phasenet/phasenet_events_000.csv {root_path}/{region}/phasenet/phasenet_events.csv"
)

0 comments on commit b35db74

Please sign in to comment.