diff --git a/ensemble/ensemble.py b/ensemble/ensemble.py index c084d43..d0d06af 100644 --- a/ensemble/ensemble.py +++ b/ensemble/ensemble.py @@ -128,10 +128,17 @@ def __init__(self, keys_en: dict, sim, redund_sim=None): elif 'controls' in self.keys_en: self.prior_info = extract.extract_initial_controls(self.keys_en) + + # Ensemble size + self.ne = self.keys_en.get('ne', None) + # Calculate initial ensemble if IMPORTSTATICVAR has not been given in init. file. # Prior info. on state variables must be given by PRIOR_ keyword. if 'importstaticvar' not in self.keys_en: - self.ne = int(self.keys_en['ne']) + if self.ne is None: + self.ne = 100 + else: + self.ne = int(self.ne) # Generate prior ensemble self.enX, self.idX, self.cov_prior = entools.generate_prior_ensemble( @@ -144,15 +151,18 @@ def __init__(self, keys_en: dict, sim, redund_sim=None): # State variable imported as a Numpy save file tmp_load = np.load(self.keys_en['importstaticvar'], allow_pickle=True) + if self.ne is None: + self.ne = tmp_load[key].shape[1] + else: + self.ne = int(self.ne) + # We assume that the user has saved the state dict. as **state (effectively saved all keys in state # individually). for key in self.keys_en['staticvar']: if self.enX is None: - self.enX = tmp_load[key] - self.ne = self.enX.shape[1] + self.enX = tmp_load[key][:,:self.ne] else: - assert self.ne == tmp_load[key].shape[1], 'Ensemble size of imported state variables do not match!' - self.enX = np.vstack((self.enX, tmp_load[key])) + self.enX = np.vstack((self.enX, tmp_load[key][:,:self.ne])) # fill in indices self.idX[key] = (self.enX.shape[0] - tmp_load[key].shape[0], self.enX.shape[0]) @@ -260,17 +270,11 @@ def calc_prediction(self, enX=None, save_prediction=None): en_pred = [] pbar = tqdm(enumerate(enX), total=self.ne, **progbar_settings) for member_index, state in pbar: - en_pred.append(deepcopy(self.sim.run_fwd_sim(state, member_index))) + en_pred.append(self.sim.run_fwd_sim(state, member_index)) # Parallelization on HPC using SLURM elif self.sim.input_dict.get('hpc', False): # Run prediction in parallel on hpc en_pred = self.run_on_HPC(enX, batch_size=nparallel) - - # Parallellization internal to the simulator (e.g. batch processing on GPU ) - elif self.sim.input_dict.get('parallel_internal', False): - # make a single matrix for each state - batch_enX = {key: np.array([d[key] for d in enX]) for key in enX[0].keys()} # key: (b, state) - en_pred = self.sim.run_fwd_sim(batch_enX, member_i=None) # Parallelization on local machine using p_map else: diff --git a/pipt/misc_tools/analysis_tools.py b/pipt/misc_tools/analysis_tools.py index b30fd5d..a0e4f5d 100644 --- a/pipt/misc_tools/analysis_tools.py +++ b/pipt/misc_tools/analysis_tools.py @@ -1528,11 +1528,12 @@ def truncSVD(matrix, r=None, energy=None, full_matrices=False): # If not specified rank, energy must be given if r is None: if energy is not None: - # If energy is less than 100 we truncate the SVD matrices + # Energy is given as fraction if energy < 1: - r = np.sum((np.cumsum(S) / sum(S)) <= energy) + r = np.searchsorted(np.cumsum(S)/np.sum(S), energy) + # Energy is given as a percentage else: - r = np.sum((np.cumsum(S) / sum(S)) <= energy/100) + r = np.searchsorted(np.cumsum(S)/np.sum(S), energy/100) else: raise ValueError("Either rank 'r' or 'energy' must be specified for truncSVD.") diff --git a/pipt/update_schemes/enrml.py b/pipt/update_schemes/enrml.py index ec6274e..7f81951 100644 --- a/pipt/update_schemes/enrml.py +++ b/pipt/update_schemes/enrml.py @@ -250,6 +250,9 @@ def check_convergence(self): 'lambda': self.lam, 'lambda_stop': self.lam >= self.lam_max} + # Log step + self.log_update(success=success) + ############################################### ##### update Lambda step-size values ########## ############################################### @@ -288,9 +291,6 @@ def check_convergence(self): self.logger(f'Data misfit increased! λ increased: {self.lam / self.gamma} ──> {self.lam}') success = False - # Log update results - self.log_update(success=success) - if not success: # Reset the objective function after report self.data_misfit = self.prev_data_misfit