Skip to content
This repository was archived by the owner on Mar 10, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/scream-core/src/core/models/residue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl ResidueType {
"TRP" => Some(ResidueType::Tryptophan),
"TYR" => Some(ResidueType::Tyrosine),
"ASN" => Some(ResidueType::Asparagine),
"CYS" => Some(ResidueType::Cysteine),
"CYS" | "CYX" => Some(ResidueType::Cysteine),
"GLN" => Some(ResidueType::Glutamine),
"SER" => Some(ResidueType::Serine),
"THR" => Some(ResidueType::Threonine),
Expand Down
236 changes: 235 additions & 1 deletion crates/scream-core/src/core/models/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use super::ids::{AtomId, ChainId, ResidueId};
use super::residue::{Residue, ResidueType};
use super::topology::{Bond, BondOrder};
use slotmap::{SecondaryMap, SlotMap};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

const CYSTEINE_SULFUR_GAMMA_ATOM_NAME: &str = "SG";

/// Represents a complete molecular system with atoms, residues, chains, and bonds.
///
Expand Down Expand Up @@ -470,6 +472,68 @@ impl MolecularSystem {
pub fn background_atom_ids(&self) -> Vec<AtomId> {
self.background_atoms().map(|(id, _)| id).collect()
}

/// Detects and returns the residue IDs of all Cysteine residues involved in disulfide bonds.
///
/// A disulfide bond is identified by a covalent bond between the Sulfur-Gamma (SG)
/// atoms of two different Cysteine residues. This method correctly handles both
/// `CYS` and `CYX` residue names by relying on the `ResidueType`.
///
/// # Returns
///
/// A `HashSet<ResidueId>` containing the IDs of all residues participating in
/// any disulfide bond within the system.
pub fn find_disulfide_bonded_residues(&self) -> HashSet<ResidueId> {
let mut bonded_residue_ids = HashSet::new();

// 1. Collect all Cysteine residues and their SG atoms
let cysteine_sg_atoms: HashMap<ResidueId, AtomId> = self
.residues_iter()
.filter_map(|(res_id, residue)| {
if matches!(residue.residue_type, Some(ResidueType::Cysteine)) {
residue
.get_first_atom_id_by_name(CYSTEINE_SULFUR_GAMMA_ATOM_NAME)
.map(|sg_id| (res_id, sg_id))
} else {
None
}
})
.collect();

// If there are fewer than two Cysteines, no disulfide bonds are possible
if cysteine_sg_atoms.len() < 2 {
return bonded_residue_ids;
}

// Create a reverse map from SG AtomId to ResidueId for quick lookups
let sg_atom_to_residue: HashMap<AtomId, ResidueId> = cysteine_sg_atoms
.iter()
.map(|(&res_id, &atom_id)| (atom_id, res_id))
.collect();

// 2. Iterate through the collected SG atoms and check their bonds
for (res_id_a, sg_atom_id_a) in &cysteine_sg_atoms {
if bonded_residue_ids.contains(res_id_a) {
continue;
}

if let Some(neighbors) = self.get_bonded_neighbors(*sg_atom_id_a) {
for &neighbor_atom_id in neighbors {
// 3. Check if the neighbor is also an SG atom of another Cysteine
if let Some(res_id_b) = sg_atom_to_residue.get(&neighbor_atom_id) {
if res_id_a != res_id_b {
// Found a disulfide bond!
bonded_residue_ids.insert(*res_id_a);
bonded_residue_ids.insert(*res_id_b);
break;
}
}
}
}
}

bonded_residue_ids
}
}

#[cfg(test)]
Expand Down Expand Up @@ -897,4 +961,174 @@ mod tests {
assert!(background_ids.contains(id_map.get("UNKNOWN").unwrap()));
}
}

mod disulfide_bond_detection {
use super::*;
use crate::core::models::topology::BondOrder;
use nalgebra::Point3;

fn create_disulfide_test_system()
-> (MolecularSystem, ResidueId, ResidueId, ResidueId, ResidueId) {
let mut system = MolecularSystem::new();
let chain_a_id = system.add_chain('A', ChainType::Protein);

let cys1_id = system
.add_residue(chain_a_id, 1, "CYS", Some(ResidueType::Cysteine))
.unwrap();
let cys1_sg_id = system
.add_atom_to_residue(
cys1_id,
Atom::new(
CYSTEINE_SULFUR_GAMMA_ATOM_NAME,
cys1_id,
Point3::new(0.0, 0.0, 0.0),
),
)
.unwrap();

let cys2_id = system
.add_residue(chain_a_id, 2, "CYX", Some(ResidueType::Cysteine))
.unwrap();
let cys2_sg_id = system
.add_atom_to_residue(
cys2_id,
Atom::new(
CYSTEINE_SULFUR_GAMMA_ATOM_NAME,
cys2_id,
Point3::new(2.0, 0.0, 0.0),
),
)
.unwrap();

let cys3_id = system
.add_residue(chain_a_id, 3, "CYS", Some(ResidueType::Cysteine))
.unwrap();
system
.add_atom_to_residue(
cys3_id,
Atom::new(
CYSTEINE_SULFUR_GAMMA_ATOM_NAME,
cys3_id,
Point3::new(10.0, 0.0, 0.0),
),
)
.unwrap();

let ala4_id = system
.add_residue(chain_a_id, 4, "ALA", Some(ResidueType::Alanine))
.unwrap();
system
.add_atom_to_residue(
ala4_id,
Atom::new("CB", ala4_id, Point3::new(12.0, 0.0, 0.0)),
)
.unwrap();

system
.add_bond(cys1_sg_id, cys2_sg_id, BondOrder::Single)
.unwrap();

(system, cys1_id, cys2_id, cys3_id, ala4_id)
}

#[test]
fn find_disulfide_bonded_residues_identifies_correct_pair() {
let (system, cys1_id, cys2_id, cys3_id, ala4_id) = create_disulfide_test_system();

let bonded_residues = system.find_disulfide_bonded_residues();

assert_eq!(bonded_residues.len(), 2);
assert!(bonded_residues.contains(&cys1_id));
assert!(bonded_residues.contains(&cys2_id));
assert!(!bonded_residues.contains(&cys3_id));
assert!(!bonded_residues.contains(&ala4_id));
}

#[test]
fn find_disulfide_bonded_residues_returns_empty_for_no_bonds() {
let (mut system, cys1_id, cys2_id, _, _) = create_disulfide_test_system();
let cys1_sg_id = system
.residue(cys1_id)
.unwrap()
.get_first_atom_id_by_name(CYSTEINE_SULFUR_GAMMA_ATOM_NAME)
.unwrap();
let cys2_sg_id = system
.residue(cys2_id)
.unwrap()
.get_first_atom_id_by_name(CYSTEINE_SULFUR_GAMMA_ATOM_NAME)
.unwrap();

system
.bonds
.retain(|bond| !(bond.contains(cys1_sg_id) && bond.contains(cys2_sg_id)));
system.bond_adjacency.get_mut(cys1_sg_id).unwrap().clear();
system.bond_adjacency.get_mut(cys2_sg_id).unwrap().clear();

let bonded_residues = system.find_disulfide_bonded_residues();
assert!(
bonded_residues.is_empty(),
"Expected no bonded residues after removing the bond"
);
}

#[test]
fn find_disulfide_bonded_residues_returns_empty_for_no_cysteines() {
let mut system = MolecularSystem::new();
let chain_a_id = system.add_chain('A', ChainType::Protein);
system
.add_residue(chain_a_id, 4, "ALA", Some(ResidueType::Alanine))
.unwrap();

let bonded_residues = system.find_disulfide_bonded_residues();
assert!(bonded_residues.is_empty());
}

#[test]
fn find_disulfide_bonded_residues_handles_multiple_bonds() {
let (mut system, cys1_id, cys2_id, cys3_id, ala4_id) = create_disulfide_test_system();

let chain_b_id = system.add_chain('B', ChainType::Protein);
let cys10_id = system
.add_residue(chain_b_id, 10, "CYS", Some(ResidueType::Cysteine))
.unwrap();
let cys10_sg_id = system
.add_atom_to_residue(
cys10_id,
Atom::new(
CYSTEINE_SULFUR_GAMMA_ATOM_NAME,
cys10_id,
Point3::new(50.0, 0.0, 0.0),
),
)
.unwrap();

let cys20_id = system
.add_residue(chain_b_id, 20, "CYS", Some(ResidueType::Cysteine))
.unwrap();
let cys20_sg_id = system
.add_atom_to_residue(
cys20_id,
Atom::new(
CYSTEINE_SULFUR_GAMMA_ATOM_NAME,
cys20_id,
Point3::new(52.0, 0.0, 0.0),
),
)
.unwrap();

system
.add_bond(cys10_sg_id, cys20_sg_id, BondOrder::Single)
.unwrap();

let bonded_residues = system.find_disulfide_bonded_residues();

assert_eq!(bonded_residues.len(), 4);
assert!(bonded_residues.contains(&cys1_id));
assert!(bonded_residues.contains(&cys2_id));
assert!(!bonded_residues.contains(&cys3_id));
assert!(!bonded_residues.contains(&ala4_id));
assert!(bonded_residues.contains(&cys10_id));
assert!(bonded_residues.contains(&cys20_id));
}
}
}
56 changes: 28 additions & 28 deletions crates/scream-core/src/engine/energy_grid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::core::forcefield::term::EnergyTerm;
use crate::core::models::atom::AtomRole;
use crate::core::models::ids::ResidueId;
use crate::core::models::system::MolecularSystem;
use itertools::Itertools;
use std::collections::{HashMap, HashSet};
use tracing::{info, trace};

Expand Down Expand Up @@ -105,42 +104,43 @@ impl EnergyGrid {

let scorer = Scorer::new(system, forcefield);

for pair in active_residues.iter().combinations(2) {
let res_a_id = *pair[0];
let res_b_id = *pair[1];
let active_residue_vec: Vec<_> = active_residues.iter().cloned().collect();
for i in 0..active_residue_vec.len() {
for j in (i + 1)..active_residue_vec.len() {
let res_a_id = active_residue_vec[i];
let res_b_id = active_residue_vec[j];

let atoms_a = collect_active_sidechain_atoms(system, &HashSet::from([res_a_id]));
let atoms_b = collect_active_sidechain_atoms(system, &HashSet::from([res_b_id]));
let atoms_a = collect_active_sidechain_atoms(system, &HashSet::from([res_a_id]));
let atoms_b = collect_active_sidechain_atoms(system, &HashSet::from([res_b_id]));

let atoms_a_slice = atoms_a
.get(&res_a_id)
.map_or([].as_slice(), |v| v.as_slice());
let atoms_b_slice = atoms_b
.get(&res_b_id)
.map_or([].as_slice(), |v| v.as_slice());
let atoms_a_slice = atoms_a
.get(&res_a_id)
.map_or([].as_slice(), |v| v.as_slice());
let atoms_b_slice = atoms_b
.get(&res_b_id)
.map_or([].as_slice(), |v| v.as_slice());

if atoms_a_slice.is_empty() || atoms_b_slice.is_empty() {
continue;
}
if atoms_a_slice.is_empty() || atoms_b_slice.is_empty() {
continue;
}

let interaction = scorer.score_interaction(atoms_a_slice, atoms_b_slice)?;
let interaction = scorer.score_interaction(atoms_a_slice, atoms_b_slice)?;

let key = if res_a_id < res_b_id {
(res_a_id, res_b_id)
} else {
(res_b_id, res_a_id)
};
pair_interactions.insert(key, interaction);
let key = (res_a_id, res_b_id);
pair_interactions.insert(key, interaction);

*total_residue_interactions.get_mut(&res_a_id).unwrap() += interaction;
*total_residue_interactions.get_mut(&res_b_id).unwrap() += interaction;
}
}

*total_residue_interactions.get_mut(&res_a_id).unwrap() += interaction;
*total_residue_interactions.get_mut(&res_b_id).unwrap() += interaction;
for term in total_residue_interactions.values_mut() {
*term = *term * 0.5;
}

// The total interaction energy is double-counted in the sum above, so divide by 2
let total_interaction_energy = total_residue_interactions
let total_interaction_energy = pair_interactions
.values()
.fold(EnergyTerm::default(), |acc, term| acc + *term)
* 0.5;
.fold(EnergyTerm::default(), |acc, term| acc + *term);

let mut current_el_energies = HashMap::with_capacity(active_residues.len());
let mut total_el_energy = EnergyTerm::default();
Expand Down
Loading