Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,17 @@ def __init__(self, keys_en: dict, sim, redund_sim=None):
elif 'controls' in self.keys_en:
self.prior_info = extract.extract_initial_controls(self.keys_en)


# Ensemble size
self.ne = self.keys_en.get('ne', None)

# Calculate initial ensemble if IMPORTSTATICVAR has not been given in init. file.
# Prior info. on state variables must be given by PRIOR_<STATICVAR-name> keyword.
if 'importstaticvar' not in self.keys_en:
self.ne = int(self.keys_en['ne'])
if self.ne is None:
self.ne = 100
else:
self.ne = int(self.ne)

# Generate prior ensemble
self.enX, self.idX, self.cov_prior = entools.generate_prior_ensemble(
Expand All @@ -144,15 +151,18 @@ def __init__(self, keys_en: dict, sim, redund_sim=None):
# State variable imported as a Numpy save file
tmp_load = np.load(self.keys_en['importstaticvar'], allow_pickle=True)

if self.ne is None:
self.ne = tmp_load[key].shape[1]
else:
self.ne = int(self.ne)

# We assume that the user has saved the state dict. as **state (effectively saved all keys in state
# individually).
for key in self.keys_en['staticvar']:
if self.enX is None:
self.enX = tmp_load[key]
self.ne = self.enX.shape[1]
self.enX = tmp_load[key][:,:self.ne]
else:
assert self.ne == tmp_load[key].shape[1], 'Ensemble size of imported state variables do not match!'
self.enX = np.vstack((self.enX, tmp_load[key]))
self.enX = np.vstack((self.enX, tmp_load[key][:,:self.ne]))

# fill in indices
self.idX[key] = (self.enX.shape[0] - tmp_load[key].shape[0], self.enX.shape[0])
Expand Down Expand Up @@ -260,17 +270,11 @@ def calc_prediction(self, enX=None, save_prediction=None):
en_pred = []
pbar = tqdm(enumerate(enX), total=self.ne, **progbar_settings)
for member_index, state in pbar:
en_pred.append(deepcopy(self.sim.run_fwd_sim(state, member_index)))
en_pred.append(self.sim.run_fwd_sim(state, member_index))

# Parallelization on HPC using SLURM
elif self.sim.input_dict.get('hpc', False): # Run prediction in parallel on hpc
en_pred = self.run_on_HPC(enX, batch_size=nparallel)

# Parallellization internal to the simulator (e.g. batch processing on GPU )
elif self.sim.input_dict.get('parallel_internal', False):
# make a single matrix for each state
batch_enX = {key: np.array([d[key] for d in enX]) for key in enX[0].keys()} # key: (b, state)
en_pred = self.sim.run_fwd_sim(batch_enX, member_i=None)

# Parallelization on local machine using p_map
else:
Expand Down
7 changes: 4 additions & 3 deletions pipt/misc_tools/analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,11 +1528,12 @@ def truncSVD(matrix, r=None, energy=None, full_matrices=False):
# If not specified rank, energy must be given
if r is None:
if energy is not None:
# If energy is less than 100 we truncate the SVD matrices
# Energy is given as fraction
if energy < 1:
r = np.sum((np.cumsum(S) / sum(S)) <= energy)
r = np.searchsorted(np.cumsum(S)/np.sum(S), energy)
# Energy is given as a percentage
else:
r = np.sum((np.cumsum(S) / sum(S)) <= energy/100)
r = np.searchsorted(np.cumsum(S)/np.sum(S), energy/100)
else:
raise ValueError("Either rank 'r' or 'energy' must be specified for truncSVD.")

Expand Down
6 changes: 3 additions & 3 deletions pipt/update_schemes/enrml.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ def check_convergence(self):
'lambda': self.lam,
'lambda_stop': self.lam >= self.lam_max}

# Log step
self.log_update(success=success)

###############################################
##### update Lambda step-size values ##########
###############################################
Expand Down Expand Up @@ -288,9 +291,6 @@ def check_convergence(self):
self.logger(f'Data misfit increased! λ increased: {self.lam / self.gamma} ──> {self.lam}')
success = False

# Log update results
self.log_update(success=success)

if not success:
# Reset the objective function after report
self.data_misfit = self.prev_data_misfit
Expand Down