Skip to content

Proxy Independent Verification #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cfr/climate.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def compare(self, ref, timespan=None, stat='corr', interp_target='ref', interp=T
fd_rg.da.values,
coords={'time': ref_rg.da.time, 'lat': fd_rg.da.lat, 'lon': fd_rg.da.lon}
)

if stat == 'corr':
stat_da = xr.corr(fd_rg.da, ref_rg.da, dim='time')
stat_da = stat_da.expand_dims({'time': [1]})
Expand Down
198 changes: 161 additions & 37 deletions cfr/reconres.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,144 @@
import matplotlib.pyplot as plt
from matplotlib import gridspec
from .visual import CartopySettings
from .reconjob import ReconJob
import pandas as pd
from . import utils,visual


class ReconRes:
''' The class for reconstruction results '''
"""The class for reconstruction results"""

def __init__(self, dirpath, load_num=None, verbose=False):
''' Initialize a reconstruction result object.
"""Initialize a reconstruction result object.

Args:
dirpath (str): the directory path where the reconstruction results are stored.
load_num (int): the number of ensembles to load
verbose (bool, optional): print verbose information. Defaults to False.
'''
"""
try:
recon_paths = sorted(glob.glob(os.path.join(dirpath, 'job_r*_recon.nc')))
recon_paths = sorted(glob.glob(os.path.join(dirpath, "job_r*_recon.nc")))
if load_num is not None:
recon_paths = recon_paths[:load_num]
self.paths = recon_paths
except:
raise ValueError('No ""')

if verbose:
p_header(f'>>> res.paths:')
p_header(f">>> res.paths:")
print(self.paths)

self.recons = {}
self.da = {}

def load_proxylabels(self, verbose=False):
"""
Load proxy labels from the reconstruction results.
Proxy with "assim" means it is assimilated.
Proxy with "eval" means it is used for evaluation.
"""
proxy_labels = [] # list of proxy labels
for path in self.paths: # loop over all ensemble members
with xr.open_dataset(path) as ds_tmp:
proxy_labels.append(ds_tmp.attrs) # dict for proxy labels

self.proxy_labels = proxy_labels
if verbose:
p_success(f">>> ReconRes.proxy_labels created")

def independent_verify(self, job_path, verbose=False, calib_period=(1850, 2000),min_verify_len=10):
"""
Perform independent verification.
job_path (str): the path to the job.
verbose (bool, optional): print verbose information. Defaults to False.
"""
# load the reconstructions for the "prior"
job = ReconJob()
job.load(job_path)
independent_info_list = []
for path_index ,path in enumerate(self.paths):
proxy_labels = self.proxy_labels[path_index]
job.load_clim(
tag="prior",
path_dict={
"tas": path,
},
anom_period=(1951, 1980),
)
job.forward_psms(verbose=verbose)
if verbose:
p_success(f">>> Prior loaded from {path}")
# compare the pesudo-proxy records with the real records
calib_PDB = job.proxydb.filter(by="tag", keys=["calibrated"])
for i, (pname, proxy) in enumerate(calib_PDB.records.items()):
detail = proxy.psm.calib_details
attr_dict = {}
attr_dict['name'] = pname
attr_dict['seasonality'] = detail['seasonality']
if pname in proxy_labels['pids_assim']:
attr_dict['assim'] = True
elif pname in proxy_labels['pids_eval']:
attr_dict['assim'] = False
else:
raise ValueError(f"Proxy {pname} is not labeled as assim or eval. Please check the proxy labels.")
reconstructed = pd.DataFrame(
{
"time": proxy.pseudo.time,
"estimated": proxy.pseudo.value,
}
)
real = pd.DataFrame(
{
"time": proxy.time,
"observed": proxy.value,
}
)
Df = real.dropna().merge(reconstructed, on="time", how="inner")
Df.set_index("time", drop=True, inplace=True)
Df.sort_index(inplace=True)
Df.astype(float)
masks = {
"all": None,
"in": (Df.index >= calib_period[0]) & (Df.index <= calib_period[1]), # in the calibration period
"before": (Df.index < calib_period[0]), # before the calibration period
}
for mask_name, mask in masks.items():
if mask is not None:
Df_masked = Df[mask]
else:
Df_masked = Df
if len(Df_masked) < min_verify_len:
corr = np.nan
ce = np.nan
else:
corr = Df_masked.corr().iloc[0, 1]
ce = utils.coefficient_efficiency(
Df_masked.observed.values, Df_masked.estimated.values
)
attr_dict[mask_name + '_corr'] = corr
attr_dict[mask_name + '_ce'] = ce
independent_info_list.append(attr_dict)
independent_info_list = pd.DataFrame(independent_info_list)
self.independent_info_list = independent_info_list
if verbose:
p_success(f">>> Independent verification completed, results stored in ReconRes.independent_info_list")
p_success(f">>> Records Number: {len(independent_info_list)}")
return independent_info_list

def plot_independent_verify(self):
"""
Plot the independent verification results.
"""
fig, axs = visual.plot_independent_distribution(self.independent_info_list)
return fig, axs





def load(self, vn_list, verbose=False):
''' Load reconstruction results.
"""Load reconstruction results.

Args:
vn_list (list): list of variable names; supported names, taking 'tas' as an example:
Expand All @@ -51,7 +161,7 @@ def load(self, vn_list, verbose=False):
* ensemble timeseries: 'tas_gm', 'tas_nhm', 'tas_shm'

verbose (bool, optional): print verbose information. Defaults to False.
'''
"""
if type(vn_list) is str:
vn_list = [vn_list]

Expand All @@ -61,24 +171,23 @@ def load(self, vn_list, verbose=False):
with xr.open_dataset(path) as ds_tmp:
da_list.append(ds_tmp[vn])

da = xr.concat(da_list, dim='ens')
if 'ens' not in da.coords:
da.coords['ens'] = np.arange(len(self.paths))
da = da.transpose('time', 'ens', ...)
da = xr.concat(da_list, dim="ens")
if "ens" not in da.coords:
da.coords["ens"] = np.arange(len(self.paths))
da = da.transpose("time", "ens", ...)

self.da[vn] = da
if 'lat' not in da.coords and 'lon' not in da.coords:
if "lat" not in da.coords and "lon" not in da.coords:
self.recons[vn] = EnsTS(time=da.time, value=da.values, value_name=vn)
else:
self.recons[vn] = ClimateField(da.mean(dim='ens'))
self.recons[vn] = ClimateField(da.mean(dim="ens"))

if verbose:
p_success(f'>>> ReconRes.recons["{vn}"] created')
p_success(f'>>> ReconRes.da["{vn}"] created')


def valid(self, target_dict, stat=['corr'], timespan=None, verbose=False):
''' Validate against a target dictionary
def valid(self, target_dict, stat=["corr"], timespan=None, verbose=False):
"""Validate against a target dictionary

Args:
target_dict (dict): a dictionary of multiple variables for validation.
Expand All @@ -90,61 +199,76 @@ def valid(self, target_dict, stat=['corr'], timespan=None, verbose=False):

timespan (list or tuple): the timespan over which to perform the validation.
verbose (bool, optional): print verbose information. Defaults to False.
'''
if type(stat) is not list: stat = [stat]
"""
if type(stat) is not list:
stat = [stat]
vn_list = target_dict.keys()
self.load(vn_list, verbose=verbose)
valid_fd, valid_ts = {}, {}
for vn in vn_list:
p_header(f'>>> Validating variable: {vn} ...')
p_header(f">>> Validating variable: {vn} ...")
if isinstance(self.recons[vn], ClimateField):
for st in stat:
valid_fd[f'{vn}_{st}'] = self.recons[vn].compare(target_dict[vn], stat=st, timespan=timespan)
valid_fd[f'{vn}_{st}'].plot_kwargs.update({'cbar_orientation': 'horizontal', 'cbar_pad': 0.1})
if verbose: p_success(f'>>> ReconRes.valid_fd[{vn}_{st}] created')
valid_fd[f"{vn}_{st}"] = self.recons[vn].compare(
target_dict[vn], stat=st, timespan=timespan
)
valid_fd[f"{vn}_{st}"].plot_kwargs.update(
{"cbar_orientation": "horizontal", "cbar_pad": 0.1}
)
if verbose:
p_success(f">>> ReconRes.valid_fd[{vn}_{st}] created")
elif isinstance(self.recons[vn], EnsTS):
valid_ts[vn] = self.recons[vn].compare(target_dict[vn], timespan=timespan)
if verbose: p_success(f'>>> ReconRes.valid_ts[{vn}] created')
valid_ts[vn] = self.recons[vn].compare(
target_dict[vn], timespan=timespan
)
if verbose:
p_success(f">>> ReconRes.valid_ts[{vn}] created")

self.valid_fd = valid_fd
self.valid_ts = valid_ts


def plot_valid(self, recon_name_dict=None, target_name_dict=None,
valid_ts_kws=None, valid_fd_kws=None):
''' Plot the validation result
def plot_valid(
self,
recon_name_dict=None,
target_name_dict=None,
valid_ts_kws=None,
valid_fd_kws=None,
):
"""Plot the validation result

Args:
recon_name_dict (dict): the dictionary for variable names in the reconstruction. For example, {'tas': 'LMR/tas', 'nino3.4': 'NINO3.4 [K]'}.
target_name_dict (dict): the dictionary for variable names in the validation target. For example, {'tas': '20CRv3', 'nino3.4': 'BC09'}.
valid_ts_kws (dict): the dictionary of keyword arguments for validating the timeseries.
valid_fd_kws (dict): the dictionary of keyword arguments for validating the field.
'''
"""
# print(valid_fd_kws)
valid_fd_kws = {} if valid_fd_kws is None else valid_fd_kws
valid_ts_kws = {} if valid_ts_kws is None else valid_ts_kws
target_name_dict = {} if target_name_dict is None else target_name_dict
recon_name_dict = {} if recon_name_dict is None else recon_name_dict

if 'latlon_range' in valid_fd_kws:
lat_min, lat_max, lon_min, lon_max = valid_fd_kws['latlon_range']
if "latlon_range" in valid_fd_kws:
lat_min, lat_max, lon_min, lon_max = valid_fd_kws["latlon_range"]
else:
lat_min, lat_max, lon_min, lon_max = -90, 90, 0, 360

fig, ax = {}, {}
for k, v in self.valid_fd.items():
vn, st = k.split('_')
if vn not in target_name_dict: target_name_dict[vn] = 'obs'
vn, st = k.split("_")
if vn not in target_name_dict:
target_name_dict[vn] = "obs"
fig[k], ax[k] = v.plot(
title=f'{st}({recon_name_dict[vn]}, {target_name_dict[vn]}), mean={v.geo_mean(lat_min=lat_min, lat_max=lat_max, lon_min=lon_min, lon_max=lon_max).value[0,0]:.2f}',
**valid_fd_kws)
title=f"{st}({recon_name_dict[vn]}, {target_name_dict[vn]}), mean={v.geo_mean(lat_min=lat_min, lat_max=lat_max, lon_min=lon_min, lon_max=lon_max).value[0,0]:.2f}",
**valid_fd_kws,
)

for k, v in self.valid_ts.items():
v.ref_name = target_name_dict[k]
if v.value.shape[-1] > 1:
fig[k], ax[k] = v.plot_qs(**valid_ts_kws)
else:
fig[k], ax[k] = v.plot(label='recon', **valid_ts_kws)
fig[k], ax[k] = v.plot(label="recon", **valid_ts_kws)
ax[k].set_ylabel(recon_name_dict[k])

return fig, ax
return fig, ax
28 changes: 28 additions & 0 deletions cfr/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,34 @@ def plot_field_map(field_var, lat, lon, levels=50, add_cyclic_point=True,

return fig, ax

def plot_independent_distribution(independent_info_list: pd.DataFrame,calib_period = [1880, 2000]):
fig = plt.figure(figsize=[20, 10])
gs = gridspec.GridSpec(2, 2)
gs.update(wspace=0.2, hspace=0.2)
bins = np.linspace(-1, 1, 41)
axs = {}
fs = 20
for ind_y, metric in enumerate(['corr','ce']):
i = 0
for label in [True, False]:
axs[str(label) + metric] = fig.add_subplot(gs[ind_y, i])
ax = axs[str(label) + metric]
table_use = independent_info_list[independent_info_list['assim'] == label]
ax.hist(
table_use[f'in_{metric}'], bins=bins, alpha=0.5, label=f'{calib_period[0]} to {calib_period[1]}', color='blue', density=True
)
ax.hist(table_use[f'before_{metric}'], bins=bins, alpha=0.5,
label=f'before {calib_period[0]}', density=True, color='red')
ax.legend(loc='upper left', fontsize=fs)
ax.axvline(x=0, color='black', linestyle='--')
title = ['Assimilated Proxies', 'Non-assimilated Proxies'][i]
ax.set_title(title, fontsize=fs) if ind_y == 0 else None
ax.set_xlabel('Correlation', fontsize=fs) if metric == 'corr' else ax.set_xlabel('Correlation Efficiency', fontsize=fs)
ax.set_ylabel('Density', fontsize=fs)
i += 1
return fig, axs


def plot_proxies(df, year=np.arange(2001), lon_col='lon', lat_col='lat', type_col='type', time_col='time',
title=None, title_weight='normal', markers_dict=None, colors_dict=None,
plot_timespan=None, plot_xticks=[850, 1000, 1200, 1400, 1600, 1800, 2000],
Expand Down
Loading
Loading