Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 76 additions & 7 deletions rf_diffusion/conditions/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ def one_hot_buckets(a, low, high, n, eps=1e-6):
First category absorbs anything below low
Last category absorbs anything above high
'''
step = (high-low) / n
bins = torch.linspace(low+step, high-step, n-1)
cat = torch.bucketize(a, bins).long()
return F.one_hot(cat, num_classes=n)
a = a.float()
buckets = torch.linspace(low, high+eps, n+1)
bucket_idx = torch.searchsorted(buckets, a) - 1
bucket_idx = torch.clamp(bucket_idx, 0, n-1)
return F.one_hot(bucket_idx, n)

def init_radius_of_gyration(indep, feature_conf, feature_inference_conf, **kwargs):
"""
Expand Down Expand Up @@ -118,23 +119,91 @@ def get_radius_of_gyration_inference(indep, feature_conf, feature_inference_conf
ic(out[0:2, :], out[-3:-1, :])
return out

def parse_atomwise_rasa_config(rasa_config, indep, metadata):
"""
Parse the atomwise RASA configuration string and create a per-atom RASA tensor.

Args:
rasa_config (str or float): Either a single float for global RASA, or
a string like "0.0,O7:0.8,C8:1.0,C9:1.0"
indep (Indep): The indep object containing is_sm mask
metadata (dict): Metadata containing ligand_atom_names

Returns:
torch.Tensor: Per-atom RASA values for the entire indep
"""
rasa = torch.full((indep.length(),), 0.0)

# If it's just a number, apply globally to small molecules
if isinstance(rasa_config, (float, int)):
rasa[indep.is_sm] = float(rasa_config)
return rasa

# Parse the string format: "global_value,atom1:value1,atom2:value2,..."
config_str = str(rasa_config)
parts = [p.strip() for p in config_str.split(',')]
global_value = float(parts[0])
rasa[indep.is_sm] = global_value

if not metadata or 'ligand_atom_names' not in metadata:
print("[RASA WARNING] No metadata or ligand_atom_names found, using global RASA")
return rasa

# Build atom name to specific RASA mapping
atom_rasa_map = {}
for part in parts[1:]:
if ':' not in part:
continue
atom_name, value_str = part.split(':', 1)
atom_rasa_map[atom_name.strip()] = float(value_str.strip())
if not atom_rasa_map:
return rasa

# Apply atom-specific values
ligand_atom_names = metadata['ligand_atom_names']
sm_indices = torch.where(indep.is_sm)[0]
n_sm_atoms = len(sm_indices)

# Validate indices
if n_sm_atoms > len(ligand_atom_names):
print(f"[RASA ERROR] More SM atoms ({n_sm_atoms}) than ligand names ({len(ligand_atom_names)})")
return rasa

# Ligand atom names are stored at the end of the array
ligand_names_start = len(ligand_atom_names) - n_sm_atoms
matched_atoms = []
for i, sm_idx in enumerate(sm_indices):
ligand_name_idx = ligand_names_start + i
if ligand_name_idx < len(ligand_atom_names):
atom_name_in_metadata = ligand_atom_names[ligand_name_idx].strip()
if atom_name_in_metadata in atom_rasa_map:
rasa[sm_idx] = atom_rasa_map[atom_name_in_metadata]
matched_atoms.append(f"{atom_name_in_metadata}={atom_rasa_map[atom_name_in_metadata]}")

if matched_atoms:
print(f"[RASA] Set atom-specific RASA for {len(matched_atoms)} atoms: {', '.join(matched_atoms)}")

return rasa

def get_relative_sasa_inference(indep, feature_conf, feature_inference_conf, cache, **kwargs):
"""
Calculates the radius of gyration fature
Calculates the relative SASA feature with support for atom-wise specification

Args:
indep (Indep): The holy indep.
feature_conf (omegaconf): The feature config.
feature_inference_conf (omegaconf): The feature inference config.
cache (dict): data cache
**kwargs: Additional keyword arguments including metadata

Returns:
sasa feature
dict: Dictionary with 't1d' key containing the SASA feature tensor
"""
if not feature_inference_conf.active:
return {'t1d':torch.zeros((indep.length(), feature_conf.n_bins + 1))}
rasa = torch.full((indep.length(),), feature_inference_conf.rasa)

metadata = kwargs.get('metadata', {})
rasa = parse_atomwise_rasa_config(feature_inference_conf.rasa, indep, metadata)
one_hot = one_hot_buckets(rasa, feature_conf.low, feature_conf.high, feature_conf.n_bins)
is_feature_applicable = indep.is_sm
one_hot[~is_feature_applicable] = 0
Expand Down
8 changes: 6 additions & 2 deletions rf_diffusion/inference/model_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def sample_init(self, i_des=0):
"""
indep_uncond, self.indep_orig, self.indep_cond, metadata, self.is_diffused, self.atomizer, contig_map, t_step_input, self.conditions_dict = self.dataset[i_des % len(self.dataset)]
indep = self.indep_cond.clone()
self.metadata = metadata
return indep, contig_map, self.atomizer, t_step_input

def symmetrise_prev_pred(self, px0, seq_in, alpha):
Expand Down Expand Up @@ -236,7 +237,7 @@ def sample_step(self, t, indep, rfo, extra, features_cache):

extra_tXd_names = getattr(self._conf, 'extra_tXd', [])
t_cont = t/self._conf.diffuser.T
indep.extra_t1d, indep.extra_t2d = features.get_extra_tXd_inference(indep, extra_tXd_names, self._conf.extra_tXd_params, self._conf.inference.conditions, t_cont=t_cont, features_cache=features_cache, **self.conditions_dict)
indep.extra_t1d, indep.extra_t2d = features.get_extra_tXd_inference(indep, extra_tXd_names, self._conf.extra_tXd_params, self._conf.inference.conditions, t_cont=t_cont, features_cache=features_cache, metadata=getattr(self, 'metadata', {}), **self.conditions_dict)
rfi = self.model_adaptor.prepro(indep, t, self.is_diffused)

rf2aa.tensor_util.to_device(rfi, self.device)
Expand Down Expand Up @@ -323,7 +324,7 @@ class FlowMatching(Sampler):
def run_model(self, t, indep, rfo, is_diffused, features_cache):
extra_tXd_names = getattr(self._conf, 'extra_tXd', [])
t_cont = t/self._conf.diffuser.T
indep.extra_t1d, indep.extra_t2d = features.get_extra_tXd_inference(indep, extra_tXd_names, self._conf.extra_tXd_params, self._conf.inference.conditions, t_cont=t_cont, features_cache=features_cache, **self.conditions_dict)
indep.extra_t1d, indep.extra_t2d = features.get_extra_tXd_inference(indep, extra_tXd_names, self._conf.extra_tXd_params, self._conf.inference.conditions, t_cont=t_cont, features_cache=features_cache, metadata=getattr(self, 'metadata', {}), **self.conditions_dict)
rfi = self.model_adaptor.prepro(indep, t, is_diffused)
rf2aa.tensor_util.to_device(rfi, self.device)

Expand Down Expand Up @@ -524,12 +525,14 @@ class FlowMatching_make_conditional_diffuse_all(FlowMatching_make_conditional):

def sample_init(self, i_des=0):
indep_uncond, self.indep_orig, self.indep_cond, metadata, self.is_diffused, atomizer, contig_map, t_step_input, self.conditions_dict = self.dataset[i_des % len(self.dataset)]
self.metadata = metadata
return indep_uncond, contig_map, atomizer, t_step_input

class FlowMatching_make_conditional_diffuse_all_xt_unfrozen(FlowMatching):

def sample_init(self, i_des=0):
indep_uncond, self.indep_orig, self.indep_cond, metadata, self.is_diffused, atomizer, contig_map, t_step_input, self.conditions_dict = self.dataset[i_des % len(self.dataset)]
self.metadata = metadata
return indep_uncond, contig_map, atomizer, t_step_input

def sample_step(self, t, indep, rfo, extra, features_cache):
Expand All @@ -548,6 +551,7 @@ class ClassifierFreeGuidance(FlowMatching):
# WIP
def sample_init(self, i_des=0):
indep_uncond, self.indep_orig, self.indep_cond, metadata, self.is_diffused, atomizer, contig_map, t_step_input, self.conditions_dict = self.dataset[i_des % len(self.dataset)]
self.metadata = metadata
return indep_uncond, contig_map, atomizer, t_step_input

def get_grads(self, t, indep_in, indep_t, rfo, is_diffused, features_cache):
Expand Down