Skip to content

Commit

Permalink
calibrate_xaj script for cross val
Browse files Browse the repository at this point in the history
  • Loading branch information
OuyangWenyu committed Mar 26, 2024
1 parent 262ec8e commit c8fc038
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 126 deletions.
1 change: 0 additions & 1 deletion env-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ dependencies:
- sphinx
- black
- flake8
- pytest
# pip
- pip
- pip:
Expand Down
101 changes: 96 additions & 5 deletions hydromodel/datasets/data_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2022-10-25 21:16:22
LastEditTime: 2024-03-25 19:54:15
LastEditTime: 2024-03-26 19:18:29
LastEditors: Wenyu Ouyang
Description: preprocess data for models in hydro-model-xaj
FilePath: \hydro-model-xaj\hydromodel\datasets\data_preprocess.py
Expand All @@ -10,15 +10,15 @@

import os
import re
from hydrodataset import Camels
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from collections import OrderedDict
import xarray as xr

from hydroutils import hydro_time, hydro_file
from hydrodata.utils.utils import streamflow_unit_conv

from hydromodel import CACHE_DIR
from hydromodel import CACHE_DIR, SETTING
from hydromodel.datasets import *


Expand Down Expand Up @@ -293,7 +293,8 @@ def split_train_test(ts_data, train_period, test_period):
Returns
-------
None
tuple of xr.Dataset
A tuple of xr.Dataset for training and testing data
"""
# Convert date strings to pandas datetime objects
train_start, train_end = pd.to_datetime(train_period[0]), pd.to_datetime(
Expand Down Expand Up @@ -387,3 +388,93 @@ def cross_valid_data(ts_data, period, warmup, cv_fold, freq="1D"):
train_test_data.append((train_data, test_data))

return train_test_data


def get_ts_from_diffsource(data_type, data_dir, periods, basin_ids):
"""Get time series data from different sources and unify the format and unit of streamflow.
Parameters
----------
data_type
The type of the data source, 'camels' or 'owndata'
data_dir
The directory of the data source
periods
The periods of the time series data
basin_ids
The ids of the basins
Returns
-------
xr.Dataset
The time series data
Raises
------
NotImplementedError
The data type is not 'camels' or 'owndata'
"""
prcp_name = remove_unit_from_name(PRCP_NAME)
pet_name = remove_unit_from_name(PET_NAME)
flow_name = remove_unit_from_name(FLOW_NAME)
area_name = remove_unit_from_name(AREA_NAME)
if data_type == "camels":
camels_data_dir = os.path.join(
SETTING["local_data_path"]["datasets-origin"], "camels", data_dir
)
camels = Camels(camels_data_dir)
ts_data = camels.read_ts_xrdataset(
basin_ids, periods, ["prcp", "PET", "streamflow"]
)
basin_area = camels.read_area(basin_ids)
# trans unit to mm/day
qobs_ = ts_data[["streamflow"]]
target_unit = ts_data["prcp"].attrs.get("units", "unknown")
r_mmd = streamflow_unit_conv(qobs_, basin_area, target_unit=target_unit)
ts_data[flow_name] = r_mmd["streamflow"]
ts_data[flow_name].attrs["units"] = target_unit
ts_data = ts_data.rename({"PET": pet_name})
# ts_data = ts_data.drop_vars('streamflow')
elif data_type == "owndata":
ts_data = xr.open_dataset(
os.path.join(os.path.dirname(data_dir), "timeseries.nc")
)
attr_data = xr.open_dataset(
os.path.join(os.path.dirname(data_dir), "attributes.nc")
)
basin_area = attr_data[area_name].values
target_unit = ts_data[prcp_name].attrs.get("units", "unknown")
qobs_ = ts_data[[flow_name]]
r_mmd = streamflow_unit_conv(qobs_, basin_area, target_unit=target_unit)
ts_data[flow_name] = r_mmd[flow_name]
ts_data[flow_name].attrs["units"] = target_unit
else:
raise NotImplementedError(
"You should set the data type as 'camels' or 'owndata'"
)

return ts_data


def get_pe_q_from_ts(ts_xr_dataset):
"""Transform the time series data to the format that can be used in the calibration process
Parameters
----------
ts_xr_dataset : xr.Dataset
The time series data
Returns
-------
tuple[np.ndarray, np.ndarray]
The tuple contains the precipitation and evaporation data and the observed streamflow data
"""
prcp_name = remove_unit_from_name(PRCP_NAME)
pet_name = remove_unit_from_name(PET_NAME)
flow_name = remove_unit_from_name(FLOW_NAME)
p_and_e = (
ts_xr_dataset[[prcp_name, pet_name]].to_array().to_numpy().transpose(2, 1, 0)
)
qobs = np.expand_dims(ts_xr_dataset[flow_name].to_numpy().transpose(1, 0), axis=2)

return p_and_e, qobs
5 changes: 1 addition & 4 deletions hydromodel/trainers/train_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2022-10-25 21:16:22
LastEditTime: 2024-03-22 20:07:08
LastEditTime: 2024-03-26 18:20:57
LastEditors: Wenyu Ouyang
Description: Plots for calibration and testing results
FilePath: \hydro-model-xaj\hydromodel\trainers\train_utils.py
Expand Down Expand Up @@ -64,7 +64,6 @@ def plot_train_iteration(likelihood, save_fig):


def show_calibrate_result(
spot_setup,
sceua_calibrated_file,
warmup_length,
save_dir,
Expand All @@ -79,8 +78,6 @@ def show_calibrate_result(
Parameters
----------
spot_setup
Spotpy's setup class instance
sceua_calibrated_file
the result file saved after optimizing
basin_id
Expand Down
105 changes: 37 additions & 68 deletions scripts/calibrate_xaj.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Author: Wenyu Ouyang
Date: 2022-11-19 17:27:05
LastEditTime: 2024-03-26 16:54:05
LastEditTime: 2024-03-26 18:55:25
LastEditors: Wenyu Ouyang
Description: the script to calibrate a model for CAMELS basin
FilePath: \hydro-model-xaj\scripts\calibrate_xaj.py
Expand All @@ -14,18 +14,18 @@
import sys
import os
from pathlib import Path
import xarray as xr
import yaml

from hydrodataset import Camels
from hydrodata.utils.utils import streamflow_unit_conv


repo_path = os.path.dirname(Path(os.path.abspath(__file__)).parent)
sys.path.append(repo_path)
from hydromodel import SETTING
from hydromodel.datasets import *
from hydromodel.datasets.data_preprocess import cross_valid_data, split_train_test
from hydromodel.datasets.data_preprocess import (
cross_valid_data,
split_train_test,
get_ts_from_diffsource,
get_pe_q_from_ts,
)
from hydromodel.trainers.calibrate_sceua import calibrate_by_sceua


Expand All @@ -42,22 +42,7 @@ def calibrate(args):
model_info = args.model
algo_info = args.algorithm
loss = args.loss
if data_type == "camels":
camels_data_dir = os.path.join(
SETTING["local_data_path"]["datasets-origin"], "camels", data_dir
)
camels = Camels(camels_data_dir)
ts_data = camels.read_ts_xrdataset(
basin_ids, periods, ["prcp", "PET", "streamflow"]
)
elif data_type == "owndata":
ts_data = xr.open_dataset(
os.path.join(os.path.dirname(data_dir), "timeseries.nc")
)
else:
raise NotImplementedError(
"You should set the data type as 'camels' or 'owndata'"
)
ts_data = get_ts_from_diffsource(data_type, data_dir, periods, basin_ids)

where_save = Path(os.path.join(repo_path, "result", exp))
if os.path.exists(where_save) is False:
Expand All @@ -68,55 +53,39 @@ def calibrate(args):
periods = np.sort(
[train_period[0], train_period[1], test_period[0], test_period[1]]
)
if cv_fold > 1:
train_and_test_data = cross_valid_data(ts_data, periods, warmup, cv_fold)
else:
# when using train_test_split, the warmup period is not used
# so you should include the warmup period in the train and test period
train_and_test_data = split_train_test(ts_data, train_period, test_period)
else:
# cross validation
train_and_test_data = cross_valid_data(ts_data, periods, warmup, cv_fold)

print("Start to calibrate the model")

if data_type == "camels":
basin_area = camels.read_area(basin_ids)
p_and_e = (
train_and_test_data[0][["prcp", "PET"]]
.to_array()
.to_numpy()
.transpose(2, 1, 0)
)
# trans unit to mm/day
qobs_ = train_and_test_data[0][["streamflow"]]
r_mmd = streamflow_unit_conv(qobs_, basin_area, target_unit="mm/d")
qobs = np.expand_dims(r_mmd["streamflow"].to_numpy().transpose(1, 0), axis=2)
elif data_type == "owndata":
attr_data = xr.open_dataset(
os.path.join(os.path.dirname(data_dir), "attributes.nc")
)
basin_area = attr_data["area"].values
p_and_e = (
train_and_test_data[0][[PRCP_NAME, PET_NAME]]
.to_array()
.to_numpy()
.transpose(2, 1, 0)
)
qobs = np.expand_dims(
train_and_test_data[0][[FLOW_NAME]].to_array().to_numpy().transpose(1, 0),
axis=2,
if cv_fold <= 1:
p_and_e, qobs = get_pe_q_from_ts(train_and_test_data[0])
calibrate_by_sceua(
basin_ids,
p_and_e,
qobs,
os.path.join(where_save, "sceua_xaj"),
warmup,
model=model_info,
algorithm=algo_info,
loss=loss,
)
else:
raise NotImplementedError(
"You should set the data type as 'camels' or 'owndata'"
)
calibrate_by_sceua(
basin_ids,
p_and_e,
qobs,
os.path.join(where_save, "sceua_xaj"),
warmup,
model=model_info,
algorithm=algo_info,
loss=loss,
)
for i in range(cv_fold):
train_data, _ = train_and_test_data[i]
p_and_e_cv, qobs_cv = get_pe_q_from_ts(train_data)
calibrate_by_sceua(
basin_ids,
p_and_e_cv,
qobs_cv,
os.path.join(where_save, f"sceua_xaj_cv{i+1}"),
warmup,
model=model_info,
algorithm=algo_info,
loss=loss,
)
# Convert the arguments to a dictionary
args_dict = vars(args)
# Save the arguments to a YAML file
Expand Down Expand Up @@ -155,7 +124,7 @@ def calibrate(args):
"--cv_fold",
dest="cv_fold",
help="the number of cross-validation fold",
default=1,
default=2,
type=int,
)
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions scripts/calibrate_xaj_for_multicases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pathlib import Path

sys.path.append(os.path.dirname(Path(os.path.abspath(__file__)).parent))
from scripts.evaluate_xaj import calibrate
from scripts.evaluate_xaj import evaluate

matplotlib.use("Agg")
exp = "exp61561"
Expand Down Expand Up @@ -49,4 +49,4 @@ def __init__(self, exp, book, source, rep, ngs):
rep = reps[i]
ng = ngs[j]
xaj_calibrate = XAJCalibrateMultiCases(exp, book, source, rep, ng)
calibrate(xaj_calibrate)
evaluate(xaj_calibrate)
Loading

0 comments on commit c8fc038

Please sign in to comment.