Source code for packing_defect.core.classification

"""
classification.py

Implements the Strategy pattern for atom classification, with a default
that matches your original logic and an optional JSON-driven rule set.
"""

from abc import ABC, abstractmethod
from typing import Dict, Tuple
import json


[docs] class ClassificationStrategy(ABC): """ Abstract interface for atom classification strategies. """
[docs] @abstractmethod def classify(self, resname: str, atom_name: str) -> int: """ Given a residue name and atom name, return an integer code. """
[docs] class DefaultClassification(ClassificationStrategy): """ Default logic (mirrors your original `default_classify`): - For non-TRIO residues: tail atoms → 1, else → -1 - For TRIO residues: glycerol atoms → 2, else → 3 """ def __init__(self): self.tails = [f"C2{i}" for i in range(2, 23)] + \ [f"C3{i}" for i in range(2, 23)] + \ [f"H{i}{s}" for i in range(2, 23) for s in ['R', 'S', 'X', 'Y']] + \ ['H16Z', 'H18T', 'H91', 'H101', 'H18Z', 'H20T'] self.TGglyc = ['O11', 'O21', 'O31', 'O12', 'O22', 'O32', 'C1', 'C2', 'C3', 'C11', 'C21', 'C31', 'HA', 'HB', 'HS', 'HX', 'HY'] self.PL_resnames = ('POPC', 'DOPE', 'SAPI')
[docs] def classify(self, resname: str, atom_name: str) -> int: if resname in self.PL_resnames: return 1 if atom_name in self.tails else -1 if resname == 'TRIO': return 2 if atom_name in self.TGglyc else 3 return -1
[docs] class UserDictClassification(ClassificationStrategy): """ Load classification codes from a JSON file of the form:: { "RES1": {"ATOM1": "heads", "ATOM2": "tails", ...}, "RES2": { ... } } Labels ("heads", "tails") are automatically mapped to integer codes. """ def __init__(self, rules: Dict[Tuple[str, str], int], label_to_code: Dict[str, int]): self.rules = rules self.label_to_code = label_to_code
[docs] @classmethod def from_json(cls, json_file: str) -> 'UserDictClassification': with open(json_file, 'r') as f: data: Dict[str, Dict[str, str]] = json.load(f) label_to_code: Dict[str, int] = {} next_code = 1 rules: Dict[Tuple[str, str], int] = {} for resname, atom_map in data.items(): for atom_name, label in atom_map.items(): # assign a new integer code if label not seen if label not in label_to_code: label_to_code[label] = next_code next_code += 1 rules[(resname, atom_name)] = label_to_code[label] return cls(rules, label_to_code)
[docs] def classify(self, resname: str, atom_name: str) -> int: return self.rules.get((resname, atom_name), -1)