Skip to content

Commit

Permalink
Improve proxy visualization funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
fzhu2e committed Jan 16, 2024
1 parent ab50004 commit 4a56dce
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
13 changes: 10 additions & 3 deletions cfr/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def plotly(self, **kwargs):
return fig

def plot(self, figsize=[12, 4], legend=False, ms=200, stock_img=True, edge_clr='w',
wspace=0.1, hspace=0.1, plot_map=True, **kwargs):
wspace=0.1, hspace=0.1, plot_map=True, p=visual.STYLE, **kwargs):
''' Visualize the ProxyRecord
Args:
Expand All @@ -635,7 +635,10 @@ def plot(self, figsize=[12, 4], legend=False, ms=200, stock_img=True, edge_clr='
plot_map (bool): if True, plot the record on a map. Defaults to True.
'''
if 'color' not in kwargs and 'c' not in kwargs:
kwargs['color'] = visual.STYLE.colors_dict[self.ptype]
if self.ptype in p.colors_dict:
kwargs['color'] = p.colors_dict[self.ptype]
else:
kwargs['color'] = 'tab:blue'

fig = plt.figure(figsize=figsize)

Expand Down Expand Up @@ -671,9 +674,13 @@ def plot(self, figsize=[12, 4], legend=False, ms=200, stock_img=True, edge_clr='
if stock_img:
ax['map'].stock_img()

if self.ptype in p.markers_dict:
marker = p.markers_dict[self.ptype]
else:
marker = 'o'
transform=ccrs.PlateCarree()
ax['map'].scatter(
self.lon, self.lat, marker=visual.STYLE.markers_dict[self.ptype],
self.lon, self.lat, marker=marker,
s=ms, c=kwargs['color'], edgecolor=edge_clr, transform=transform,
)

Expand Down
4 changes: 4 additions & 0 deletions cfr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ def colored_noise_2regimes(alpha1, alpha2, f_break, t, f0=None, m=None, seed=Non

return y

def arr_str2np(arr):
arr = np.array([float(s) for s in arr[1:-1].split(',')])
return arr

def is_numeric(obj):
attrs = ['__add__', '__sub__', '__mul__', '__truediv__', '__pow__']
return all(hasattr(obj, attr) for attr in attrs)
Expand Down
11 changes: 9 additions & 2 deletions cfr/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,10 +555,17 @@ def plot_proxies(df, year=np.arange(2001), lon_col='lon', lat_col='lat', type_co
type_names.append(f'{ptype} (n={max_count[-1]})')
lons = list(df[selector][lon_col])
lats = list(df[selector][lat_col])
if ptype in markers_dict:
marker = markers_dict[ptype]
color = colors_dict[ptype]
else:
marker = 'o'
color = 'tab:blue'

s_plots.append(
ax['map'].scatter(
lons, lats, marker=markers_dict[ptype],
c=colors_dict[ptype], edgecolor='k', s=markersize, transform=ccrs.PlateCarree()
lons, lats, marker=marker, c=color,
edgecolor='k', s=markersize, transform=ccrs.PlateCarree()
)
)

Expand Down

0 comments on commit 4a56dce

Please sign in to comment.