Skip to content

Commit

Permalink
cut based on file not station
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Dec 18, 2024
1 parent cf2d2ee commit a7c28df
Showing 1 changed file with 113 additions and 113 deletions.
226 changes: 113 additions & 113 deletions scripts/cut_templates_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import multiprocessing as mp
import os
import sys
from collections import defaultdict
from glob import glob

import fsspec
import matplotlib.pyplot as plt
import numpy as np
import obspy
Expand Down Expand Up @@ -141,10 +143,8 @@ def extract_template_numpy(
traveltime_fname,
traveltime_index_fname,
traveltime_mask_fname,
mseed_path,
picks_group,
events,
picks,
stations,
config,
lock,
):
Expand All @@ -158,80 +158,73 @@ def extract_template_numpy(
)
traveltime_mask = np.memmap(traveltime_mask_fname, dtype=bool, mode="r+", shape=tuple(config["traveltime_shape"]))

## Load waveforms
waveforms_dict = {}
for i, station in stations.iterrows():
station_id = station["station_id"]
# for c in station["component"]:
for c in ["E", "N", "Z", "1", "2", "3"]:
mseed_name = f"{mseed_path}/{station_id}{c}.mseed"
if os.path.exists(mseed_name):
try:
stream = obspy.read(mseed_name)
stream.merge(fill_value="latest")
if len(stream) > 1:
print(f"More than one trace: {stream}")
trace = stream[0]
if trace.stats.sampling_rate != config["sampling_rate"]:
if trace.stats.sampling_rate % config["sampling_rate"] == 0:
trace.decimate(int(trace.stats.sampling_rate / config["sampling_rate"]))
else:
trace.resample(config["sampling_rate"])
# trace.detrend("linear")
# trace.taper(max_percentage=0.05, type="cosine")
trace.filter("bandpass", freqmin=2.0, freqmax=12.0, corners=4, zerophase=True)
waveforms_dict[f"{station_id}{c}"] = trace
except Exception as e:
print(e)
continue

## Cut templates
for (idx_eve, idx_sta, phase_type), pick in picks.iterrows():

idx_pick = pick["idx_pick"]
phase_timestamp = pick["phase_timestamp"]

station = stations.loc[idx_sta]
station_id = station["station_id"]
event = events.loc[idx_eve]
for picks in picks_group:

# waveforms_dict = {}
picks = picks.set_index(["idx_eve", "idx_sta", "phase_type"])
picks_index = list(picks.index.unique())

## Cut templates
for (idx_eve, idx_sta, phase_type), pick in picks.iterrows():

idx_pick = pick["idx_pick"]
phase_timestamp = pick["phase_timestamp"]

event = events.loc[idx_eve]
ENZ = pick["ENZ"].split(",")

for c in ENZ:
if c not in waveforms_dict:
with fsspec.open(c, "rb", anon=True) as f:
stream = obspy.read(f)
stream.merge(fill_value="latest")
if len(stream) > 1:
print(f"More than one trace: {stream}")
trace = stream[0]
if trace.stats.sampling_rate != config["sampling_rate"]:
if trace.stats.sampling_rate % config["sampling_rate"] == 0:
trace.decimate(int(trace.stats.sampling_rate / config["sampling_rate"]))
else:
trace.resample(config["sampling_rate"])
# trace.detrend("linear")
# trace.taper(max_percentage=0.05, type="cosine")
trace.filter("bandpass", freqmin=2.0, freqmax=12.0, corners=4, zerophase=True)
waveforms_dict[c] = trace
else:
trace = waveforms_dict[c]

ic = config["component_mapping"][trace.stats.channel[-1]]

# for c in station["component"]:
for c in ["E", "N", "Z", "1", "2", "3"]:
ic = config["component_mapping"][c] # 012 for P, 345 for S

if f"{station_id}{c}" in waveforms_dict:
trace = waveforms_dict[f"{station_id}{c}"]
trace_starttime = (
pd.to_datetime(trace.stats.starttime.datetime, utc=True) - reference_t0
).total_seconds()
else:
continue

begin_time = phase_timestamp - trace_starttime - config[f"time_before_{phase_type.lower()}"]
end_time = phase_timestamp - trace_starttime + config[f"time_after_{phase_type.lower()}"]
begin_time = phase_timestamp - trace_starttime - config[f"time_before_{phase_type.lower()}"]
end_time = phase_timestamp - trace_starttime + config[f"time_after_{phase_type.lower()}"]

if phase_type == "P" and ((idx_eve, idx_sta, "S") in picks.index):
s_begin_time = (
picks.loc[idx_eve, idx_sta, "S"]["phase_timestamp"] - trace_starttime - config[f"time_before_s"]
)
if config["no_overlapping"]:
end_time = min(end_time, s_begin_time)
if phase_type == "P" and ((idx_eve, idx_sta, "S") in picks.index):
s_begin_time = (
picks.loc[idx_eve, idx_sta, "S"]["phase_timestamp"] - trace_starttime - config[f"time_before_s"]
)
if config["no_overlapping"]:
end_time = min(end_time, s_begin_time)

begin_time_index = max(0, int(round(begin_time * config["sampling_rate"])))
end_time_index = max(0, int(round(end_time * config["sampling_rate"])))
begin_time_index = max(0, int(round(begin_time * config["sampling_rate"])))
end_time_index = max(0, int(round(end_time * config["sampling_rate"])))

## define traveltime at the exact data point of event origin time
traveltime_array[idx_pick, ic, 0] = begin_time_index / config["sampling_rate"] - (
event["event_timestamp"] - trace_starttime - config[f"time_before_{phase_type.lower()}"]
)
traveltime_index_array[idx_pick, ic, 0] = begin_time_index - int(
(event["event_timestamp"] - trace_starttime - config[f"time_before_{phase_type.lower()}"])
* config["sampling_rate"]
)
traveltime_mask[idx_pick, ic, 0] = True
## define traveltime at the exact data point of event origin time
traveltime_array[idx_pick, ic, 0] = begin_time_index / config["sampling_rate"] - (
event["event_timestamp"] - trace_starttime - config[f"time_before_{phase_type.lower()}"]
)
traveltime_index_array[idx_pick, ic, 0] = begin_time_index - int(
(event["event_timestamp"] - trace_starttime - config[f"time_before_{phase_type.lower()}"])
* config["sampling_rate"]
)
traveltime_mask[idx_pick, ic, 0] = True

trace_data = trace.data[begin_time_index:end_time_index].astype(np.float32)
template_array[idx_pick, ic, 0, : len(trace_data)] = trace_data
trace_data = trace.data[begin_time_index:end_time_index].astype(np.float32)
template_array[idx_pick, ic, 0, : len(trace_data)] = trace_data

if lock is not None:
with lock:
Expand All @@ -240,7 +233,7 @@ def extract_template_numpy(
traveltime_index_array.flush()
traveltime_mask.flush()

return mseed_path
return


# %%
Expand Down Expand Up @@ -508,65 +501,74 @@ def cut_templates(root_path, region, config):
config["reference_t0"] = reference_t0
events = events[["idx_eve", "x_km", "y_km", "z_km", "event_index", "event_time", "event_timestamp"]]
stations = stations[["idx_sta", "x_km", "y_km", "z_km", "station_id", "component", "network", "station"]]
picks = picks[["idx_eve", "idx_sta", "phase_type", "phase_score", "phase_time", "phase_timestamp", "phase_source"]]
picks = picks[
[
"idx_eve",
"idx_sta",
"phase_type",
"phase_score",
"phase_time",
"phase_timestamp",
"phase_source",
"station_id",
]
]
events.set_index("idx_eve", inplace=True)
stations.set_index("idx_sta", inplace=True)
picks.sort_values(by=["idx_eve", "idx_sta", "phase_type"], inplace=True)
picks["idx_pick"] = np.arange(len(picks))

picks.to_csv(f"{root_path}/{result_path}/cctorch_picks.csv", index=False)

## By hour
# dirs = sorted(glob(f"{root_path}/{region}/waveforms/????/???/??"), reverse=True)
## By day
dirs = sorted(glob(f"{root_path}/{region}/waveforms/????/???"), reverse=True)
## Find mseed files
mseed_list = sorted(glob(f"{root_path}/{region}/waveforms/????/???/*.mseed"))
subdir = 2

mseed_3c = defaultdict(list)
for mseed in mseed_list:
key = "/".join(mseed.replace(".mseed", "").split("/")[-subdir - 1 :])
key = key[:-1] ## remove the channel suffix
mseed_3c[key].append(mseed)
print(f"Number of mseed files: {len(mseed_3c)}")

def parse_key(key):
year, jday, name = key.split("/")
network, station, location, instrument = name.split(".")
return [year, jday, network, station, location, instrument]

mseeds = [parse_key(k) + [",".join(sorted(mseed_3c[k]))] for k in mseed_3c]
mseeds = pd.DataFrame(mseeds, columns=["year", "jday", "network", "station", "location", "instrument", "ENZ"])

## Match picks with mseed files
picks["network"] = picks["station_id"].apply(lambda x: x.split(".")[0])
picks["station"] = picks["station_id"].apply(lambda x: x.split(".")[1])
picks["location"] = picks["station_id"].apply(lambda x: x.split(".")[2])
picks["instrument"] = picks["station_id"].apply(lambda x: x.split(".")[3])
picks["year"] = picks["phase_time"].dt.strftime("%Y")
picks["jday"] = picks["phase_time"].dt.strftime("%j")
picks = picks.merge(mseeds, on=["network", "station", "location", "instrument", "year", "jday"])
picks.drop(columns=["station_id", "network", "location", "instrument", "year", "jday"], inplace=True)

picks_group = picks.copy()
picks_group = picks_group.groupby("ENZ")

ncpu = min(16, mp.cpu_count())
nsplit = min(ncpu * 2, len(picks_group))
print(f"Using {ncpu} cores")

pbar = tqdm(total=len(dirs), desc="Cutting templates")

def pbar_update(x):
"""
x: the return value of extract_template_numpy
"""
pbar.update()
pbar.set_description(f"Cutting templates: {'/'.join(x.split('/')[-3:])}")
pbar = tqdm(total=nsplit, desc="Cutting templates")

ctx = mp.get_context("spawn")
picks_group = picks.copy()
## By hour
# picks_group["year_jday_hour"] = picks_group["phase_time"].dt.strftime("%Y-%jT%H")
# picks_group = picks_group.groupby("year_jday_hour")
## By day
picks_group["year_jday"] = picks_group["phase_time"].dt.strftime("%Y-%j")
picks_group = picks_group.groupby("year_jday")

with ctx.Manager() as manager:
lock = manager.Lock()
with ctx.Pool(ncpu) as pool:
jobs = []
for d in dirs:

tmp = d.split("/")
## By hour
# year, jday, hour = tmp[-3:]
## By day
year, jday = tmp[-2:]

## By hour
# if f"{year}-{jday}T{hour}" not in picks_group.groups:
## By day
if f"{year}-{jday}" not in picks_group.groups:
pbar_update(d)
continue

## By hour
# picks_ = picks_group.get_group(f"{year}-{jday}T{hour}")
## By day
picks_ = picks_group.get_group(f"{year}-{jday}")
events_ = events.loc[picks_["idx_eve"].unique()]
picks_ = picks_.set_index(["idx_eve", "idx_sta", "phase_type"])
group_chunk = np.array_split(list(picks_group.groups.keys()), nsplit)
picks_group_chunk = [[picks_group.get_group(g) for g in group] for group in group_chunk]

for picks_group in picks_group_chunk:

job = pool.apply_async(
extract_template_numpy,
Expand All @@ -575,14 +577,12 @@ def pbar_update(x):
traveltime_fname,
traveltime_index_fname,
traveltime_mask_fname,
d,
events_,
picks_,
stations,
picks_group,
events,
config,
lock,
),
callback=pbar_update,
callback=lambda x: pbar.update(),
)
jobs.append(job)
pool.close()
Expand Down

0 comments on commit a7c28df

Please sign in to comment.