Source code for atompack.ase_bridge

# Copyright 2026 Entalpic
"""ASE conversion helpers for atompack."""

from __future__ import annotations

import numpy as np

from ._atompack_rs import PyAtomDatabase as Database
from ._atompack_rs import PyMolecule as Molecule

_BUILTIN_FIELDS = {
    "energy",
    "forces",
    "charges",
    "velocities",
    "cell",
    "stress",
    "pbc",
}

_ASE_RESERVED_ARRAYS = {"numbers", "positions"}
_ASE_TYPES = None
_CALC_MODES = {"singlepoint", "nocopy", "none"}
_UNSUPPORTED_PROPERTY = object()
_SUPPORTED_ARRAY_DTYPES = {
    np.dtype(np.float32),
    np.dtype(np.float64),
    np.dtype(np.int32),
    np.dtype(np.int64),
}
_SUPPORTED_CUSTOM_PROPERTY_TEXT = (
    "supported values are None, scalar int/float/bool, str, or ndarray-like "
    "values with dtype float32, float64, int32, or int64"
)


def _voigt6_to_mat3x3(stress):
    stress = np.asarray(stress, dtype=np.float64)
    if stress.shape != (6,):
        raise ValueError("Voigt stress must have shape (6,)")
    xx, yy, zz, yz, xz, xy = stress
    return np.array([[xx, xy, xz], [xy, yy, yz], [xz, yz, zz]], dtype=np.float64)


def _get_stress(atoms):
    if not hasattr(atoms, "get_stress"):
        return None
    try:
        stress = atoms.get_stress(voigt=False)
    except TypeError:
        try:
            stress = atoms.get_stress()
        except Exception:
            return None
    except Exception:
        return None

    stress = np.asarray(stress)
    if stress.shape == (3, 3) and stress.dtype.kind == "f":
        return stress.astype(np.float64, copy=False)
    if stress.shape == (6,) and stress.dtype.kind == "f":
        return _voigt6_to_mat3x3(stress)
    return None


def _coerce_property(value):
    if value is None:
        return None
    if isinstance(value, str):
        return value
    if isinstance(value, (bool, int, float, np.integer, np.floating)):
        if isinstance(value, (bool, int, np.integer)):
            return int(value)
        return float(value)

    try:
        arr = np.asarray(value)
    except Exception:
        return _UNSUPPORTED_PROPERTY

    if arr.ndim == 0 and arr.dtype.kind in {"b", "i", "u", "f"}:
        return arr.item()
    if arr.dtype in _SUPPORTED_ARRAY_DTYPES:
        return arr.astype(arr.dtype, copy=False)
    return _UNSUPPORTED_PROPERTY


def _unsupported_property_reason(value):
    try:
        arr = np.asarray(value)
    except Exception as exc:
        return f"could not convert value to an ndarray: {exc}"
    return (
        f"got value of type {type(value).__name__} with ndarray dtype "
        f"{arr.dtype} and shape {arr.shape}; {_SUPPORTED_CUSTOM_PROPERTY_TEXT}"
    )


def _coerce_custom_property(key, value, source):
    coerced = _coerce_property(value)
    if coerced is _UNSUPPORTED_PROPERTY:
        raise TypeError(
            f"Unsupported ASE custom property {key!r} from {source}: "
            f"{_unsupported_property_reason(value)}"
        )
    return coerced


def _merge_properties(properties, builtins, values, source):
    for key, value in values.items():
        if key in _BUILTIN_FIELDS:
            # Builtin keys in atoms.info / info-override go to the builtins
            # dict (when shape/dtype matches), never into custom properties.
            # Without this guard, info["energy"] would land in both
            # builtins["energy"] (from get_potential_energy) and
            # properties["energy"], producing divergent state on round-trip.
            if key == "stress":
                arr = np.asarray(value)
                if arr.shape == (3, 3) and arr.dtype.kind == "f":
                    builtins["stress"] = arr.astype(np.float64, copy=False)
            continue
        properties[key] = _coerce_custom_property(key, value, source)


def _extract_ase_record(
    atoms,
    *,
    energy=None,
    forces=None,
    charges=None,
    velocities=None,
    cell=None,
    stress=None,
    copy_info=True,
    copy_arrays=True,
    info=None,
):
    positions = np.asarray(atoms.get_positions(), dtype=np.float32)
    atomic_numbers = np.asarray(atoms.get_atomic_numbers(), dtype=np.uint8)
    n_atoms = len(atomic_numbers)

    builtins = {
        "energy": None,
        "forces": None,
        "charges": None,
        "velocities": None,
        "cell": None,
        "stress": None,
        "pbc": None,
    }

    if energy is not None:
        builtins["energy"] = float(energy)
    else:
        try:
            builtins["energy"] = float(atoms.get_potential_energy())
        except Exception:
            pass

    if forces is not None:
        builtins["forces"] = np.asarray(forces, dtype=np.float32)
    else:
        try:
            builtins["forces"] = np.asarray(atoms.get_forces(), dtype=np.float32)
        except Exception:
            pass

    if charges is not None:
        builtins["charges"] = np.asarray(charges, dtype=np.float64)
    elif hasattr(atoms, "get_charges"):
        try:
            value = atoms.get_charges()
            if value is not None:
                builtins["charges"] = np.asarray(value, dtype=np.float64)
        except Exception:
            pass

    if velocities is not None:
        builtins["velocities"] = np.asarray(velocities, dtype=np.float32)
    else:
        try:
            value = atoms.get_velocities()
            if value is not None:
                builtins["velocities"] = np.asarray(value, dtype=np.float32)
        except Exception:
            pass

    try:
        pbc = np.asarray(getattr(atoms, "pbc", None), dtype=bool)
        if pbc.shape == (3,):
            builtins["pbc"] = tuple(bool(x) for x in pbc)
            if cell is None and pbc.any():
                builtins["cell"] = np.asarray(atoms.get_cell(), dtype=np.float64)
    except Exception:
        pass

    if cell is not None:
        builtins["cell"] = np.asarray(cell, dtype=np.float64)
    if stress is not None:
        builtins["stress"] = np.asarray(stress, dtype=np.float64)
    else:
        builtins["stress"] = _get_stress(atoms)

    properties = {}

    arrays = getattr(atoms, "arrays", None)
    if copy_arrays and isinstance(arrays, dict):
        for key, value in arrays.items():
            # Skip both ASE-reserved geometry keys ("positions", "numbers")
            # and atompack builtin field names. A user who stashes "forces"
            # in atoms.arrays must not have it duplicated into both
            # builtins["forces"] (from get_forces()) and properties["forces"].
            if key in _ASE_RESERVED_ARRAYS or key in _BUILTIN_FIELDS:
                continue
            properties[key] = _coerce_custom_property(key, value, "atoms.arrays")

    calc = getattr(atoms, "calc", None)
    results = getattr(calc, "results", None)
    if isinstance(results, dict):
        for key, value in results.items():
            if key not in _BUILTIN_FIELDS:
                properties[key] = _coerce_custom_property(key, value, "atoms.calc.results")

    if copy_info and getattr(atoms, "info", None):
        _merge_properties(properties, builtins, atoms.info, "atoms.info")
    if info is not None:
        _merge_properties(properties, builtins, info, "info override")

    return {
        "positions": positions,
        "atomic_numbers": atomic_numbers,
        "n_atoms": n_atoms,
        "builtins": builtins,
        "properties": properties,
    }


def _record_to_molecule(record):
    builtins = record["builtins"]
    mol = Molecule.from_arrays(
        record["positions"],
        record["atomic_numbers"],
        energy=builtins["energy"],
        forces=builtins["forces"],
        charges=builtins["charges"],
        velocities=builtins["velocities"],
        cell=builtins["cell"],
        stress=builtins["stress"],
        pbc=builtins["pbc"],
    )
    for key, value in record["properties"].items():
        mol.set_property(key, value)
    return mol


def _flush_fast_records(db, records):
    if not records:
        return
    builtins = records[0]["builtins"]
    positions = np.stack([record["positions"] for record in records], axis=0)
    atomic_numbers = np.stack([record["atomic_numbers"] for record in records], axis=0)
    kwargs = {}
    if builtins["energy"] is not None:
        kwargs["energy"] = np.array([r["builtins"]["energy"] for r in records], dtype=np.float64)
    if builtins["forces"] is not None:
        kwargs["forces"] = np.stack([r["builtins"]["forces"] for r in records], axis=0)
    if builtins["charges"] is not None:
        kwargs["charges"] = np.stack([r["builtins"]["charges"] for r in records], axis=0)
    if builtins["velocities"] is not None:
        kwargs["velocities"] = np.stack([r["builtins"]["velocities"] for r in records], axis=0)
    if builtins["cell"] is not None:
        kwargs["cell"] = np.stack([r["builtins"]["cell"] for r in records], axis=0)
    if builtins["stress"] is not None:
        kwargs["stress"] = np.stack([r["builtins"]["stress"] for r in records], axis=0)
    if builtins["pbc"] is not None:
        kwargs["pbc"] = np.array([r["builtins"]["pbc"] for r in records], dtype=bool)
    db.add_arrays_batch(positions, atomic_numbers, **kwargs)


def _import_ase():
    global _ASE_TYPES
    if _ASE_TYPES is None:
        try:
            from ase import Atoms
            from ase.calculators.calculator import Calculator
            from ase.calculators.singlepoint import SinglePointCalculator
        except ImportError as exc:
            raise ImportError(
                "ASE is required for Molecule.to_ase(); install it with `uv add ase`."
            ) from exc

        class NoCopySinglePointCalculator(Calculator):
            implemented_properties = ["energy", "forces", "stress", "charges"]

            def __init__(self, atoms, **results):
                Calculator.__init__(self)
                self.results = {}
                for prop, value in results.items():
                    if value is None:
                        continue
                    if prop in {"energy", "magmom", "free_energy"}:
                        self.results[prop] = value
                    else:
                        self.results[prop] = np.asarray(value, dtype=float)
                self.atoms = atoms

        _ASE_TYPES = (Atoms, SinglePointCalculator, NoCopySinglePointCalculator)
    return _ASE_TYPES


def _normalize_calc_mode(attach_calc, calc_mode):
    if calc_mode is None:
        calc_mode = "singlepoint"
    if calc_mode not in _CALC_MODES:
        raise ValueError(f"Invalid calc_mode {calc_mode!r}; expected one of {sorted(_CALC_MODES)}")
    if not attach_calc:
        return "none"
    return calc_mode


def _build_ase_atoms(payload, atoms_cls, calc_factory):
    if isinstance(payload, tuple):
        return _build_ase_atoms_from_tuple(payload, atoms_cls, calc_factory)

    kwargs = {
        "numbers": payload["numbers"],
        "positions": payload["positions"],
    }
    cell = payload.get("cell")
    if cell is not None:
        kwargs["cell"] = cell
    pbc = payload.get("pbc")
    if pbc is not None:
        kwargs["pbc"] = pbc

    atoms = atoms_cls(**kwargs)

    velocities = payload.get("velocities")
    if velocities is not None:
        atoms.set_velocities(velocities)

    calc_results = payload.get("calc")
    if calc_factory is not None:
        if calc_results:
            atoms.calc = calc_factory(atoms, **calc_results)
    elif calc_results:
        energy = calc_results.get("energy")
        if energy is not None:
            atoms.info["energy"] = float(energy)

        stress = calc_results.get("stress")
        if stress is not None:
            atoms.info["stress"] = np.asarray(stress)

        forces = calc_results.get("forces")
        if forces is not None:
            atoms.set_array("forces", np.asarray(forces))

        charges = calc_results.get("charges")
        if charges is not None:
            atoms.set_array("charges", np.asarray(charges))

    for key, value in payload.get("arrays", {}).items():
        atoms.set_array(key, np.asarray(value))
    atoms.info.update(payload.get("info", {}))

    return atoms


def _build_ase_atoms_from_tuple(payload, atoms_cls, calc_factory):
    (
        numbers,
        positions,
        cell,
        pbc,
        velocities,
        energy,
        forces,
        stress,
        charges,
        arrays,
        info,
    ) = payload

    kwargs = {
        "numbers": numbers,
        "positions": positions,
    }
    if cell is not None:
        kwargs["cell"] = cell
    if pbc is not None:
        kwargs["pbc"] = pbc

    atoms = atoms_cls(**kwargs)

    if velocities is not None:
        atoms.set_velocities(velocities)

    if calc_factory is not None:
        calc_results = {}
        if energy is not None:
            calc_results["energy"] = energy
        if forces is not None:
            calc_results["forces"] = forces
        if stress is not None:
            calc_results["stress"] = stress
        if charges is not None:
            calc_results["charges"] = charges
        if calc_results:
            atoms.calc = calc_factory(atoms, **calc_results)
    else:
        if energy is not None:
            atoms.info["energy"] = float(energy)
        if stress is not None:
            atoms.info["stress"] = np.asarray(stress)
        if forces is not None:
            atoms.set_array("forces", np.asarray(forces))
        if charges is not None:
            atoms.set_array("charges", np.asarray(charges))

    if arrays is not None:
        for key, value in arrays.items():
            atoms.set_array(key, np.asarray(value))
    if info is not None:
        atoms.info.update(info)

    return atoms


def _normalize_indices(db, indices):
    if indices is None:
        return list(range(len(db)))
    return list(indices)


def _copy_flat_properties(payload, flat, index, start, stop, copy_info, copy_arrays):
    if copy_arrays:
        arrays = {}
        for key, values in flat.get("atom_properties", {}).items():
            if key in _ASE_RESERVED_ARRAYS:
                continue
            arrays[key] = np.asarray(values[start:stop])
        if arrays:
            payload["arrays"] = arrays


def _molecule_to_ase_payload(molecule, *, copy_info, copy_arrays):
    return molecule._ase_builtin_tuple_fast(copy_info=copy_info, copy_arrays=copy_arrays)


def _db_to_ase_batch(
    db,
    indices,
    *,
    calc_mode,
    copy_info,
    copy_arrays,
    atoms_cls,
    calc_factory,
):
    if not indices:
        return []

    flat = db.get_molecules_flat(indices)
    if (copy_info or copy_arrays) and flat.get("properties"):
        return to_ase_batch(
            db.get_molecules(indices),
            calc_mode=calc_mode,
            copy_info=copy_info,
            copy_arrays=copy_arrays,
        )

    n_atoms = np.asarray(flat["n_atoms"], dtype=np.uint32)
    offsets = np.empty(len(n_atoms) + 1, dtype=np.int64)
    offsets[0] = 0
    np.cumsum(n_atoms, dtype=np.int64, out=offsets[1:])

    positions = flat["positions"]
    atomic_numbers = flat["atomic_numbers"]
    cells = flat.get("cell")
    pbc = flat.get("pbc")
    velocities = flat.get("velocities")
    energies = flat.get("energy")
    forces = flat.get("forces")
    stress = flat.get("stress")
    charges = flat.get("charges")

    atoms_list = []
    for i in range(len(indices)):
        start = int(offsets[i])
        stop = int(offsets[i + 1])
        payload = {
            "numbers": atomic_numbers[start:stop],
            "positions": positions[start:stop],
        }
        if cells is not None:
            payload["cell"] = cells[i]
        if pbc is not None:
            payload["pbc"] = pbc[i]
        if velocities is not None:
            payload["velocities"] = velocities[start:stop]

        calc = {}
        if energies is not None:
            calc["energy"] = float(energies[i])
        if forces is not None:
            calc["forces"] = forces[start:stop]
        if stress is not None:
            calc["stress"] = stress[i]
        if charges is not None:
            calc["charges"] = charges[start:stop]
        if calc:
            payload["calc"] = calc

        _copy_flat_properties(payload, flat, i, start, stop, copy_info, copy_arrays)
        atoms_list.append(_build_ase_atoms(payload, atoms_cls, calc_factory))

    return atoms_list


[docs] def to_ase( molecule, *, attach_calc=True, calc_mode="singlepoint", copy_info=True, copy_arrays=True, ): """Convert an atompack molecule to ``ase.Atoms``. The conversion reads directly from the molecule getters, so it works for both owned and view-backed molecules without going through ``molecule.atoms()``. That keeps the path compatible with lazy SOA-backed molecules, although ASE object creation still requires Python/NumPy allocations. Mapping rules: - ``positions`` and ``atomic_numbers`` always become the ASE geometry. - ``cell`` and ``pbc`` are copied when present. - ``velocities`` are attached with ``atoms.set_velocities(...)``. - ``energy``, ``forces``, ``stress``, and ``charges`` are attached through an ASE calculator when ``attach_calc=True``. ``calc_mode="singlepoint"`` preserves ASE's snapshot semantics, while ``calc_mode="nocopy"`` is faster but does not snapshot the atoms state. - Custom properties shaped like per-atom arrays are stored in ``atoms.arrays`` when ``copy_arrays=True``. - Remaining custom properties are stored in ``atoms.info`` when ``copy_info=True``. Parameters ---------- molecule : atompack.Molecule Molecule to convert. attach_calc : bool, default=True Attach supported builtin results through an ASE calculator. calc_mode : {"singlepoint", "nocopy", "none"}, default="singlepoint" Calculator attachment mode. ``"singlepoint"`` uses ASE's standard snapshotting calculator, ``"nocopy"`` skips the internal atoms copy for higher throughput, and ``"none"`` suppresses calculator attachment. copy_info : bool, default=True Copy non-array custom properties into ``atoms.info``. copy_arrays : bool, default=True Copy per-atom custom arrays into ``atoms.arrays``. Returns ------- ase.Atoms Converted ASE object. """ atoms_cls, single_point_calculator_cls, nocopy_single_point_calculator_cls = _import_ase() calc_mode = _normalize_calc_mode(attach_calc, calc_mode) calc_factory = { "singlepoint": single_point_calculator_cls, "nocopy": nocopy_single_point_calculator_cls, "none": None, }[calc_mode] payload = _molecule_to_ase_payload( molecule, copy_info=copy_info, copy_arrays=copy_arrays, ) return _build_ase_atoms(payload, atoms_cls, calc_factory)
[docs] def to_ase_batch( source, indices=None, *, attach_calc=True, calc_mode="singlepoint", copy_info=True, copy_arrays=True, ): """Convert many atompack molecules to ASE Atoms efficiently.""" atoms_cls, single_point_calculator_cls, nocopy_single_point_calculator_cls = _import_ase() calc_mode = _normalize_calc_mode(attach_calc, calc_mode) calc_factory = { "singlepoint": single_point_calculator_cls, "nocopy": nocopy_single_point_calculator_cls, "none": None, }[calc_mode] if hasattr(source, "get_molecules_flat"): return _db_to_ase_batch( source, _normalize_indices(source, indices), calc_mode=calc_mode, copy_info=copy_info, copy_arrays=copy_arrays, atoms_cls=atoms_cls, calc_factory=calc_factory, ) molecules = list(source if indices is None else (source[index] for index in indices)) return [ _build_ase_atoms( _molecule_to_ase_payload( molecule, copy_info=copy_info, copy_arrays=copy_arrays, ), atoms_cls, calc_factory, ) for molecule in molecules ]
def _database_to_ase_batch( self, indices=None, *, attach_calc=True, calc_mode="singlepoint", copy_info=True, copy_arrays=True, ): return to_ase_batch( self, indices=indices, attach_calc=attach_calc, calc_mode=calc_mode, copy_info=copy_info, copy_arrays=copy_arrays, ) def _normalize_info_overrides(info, count): if info is None: return [None] * count if isinstance(info, dict): return [info] * count overrides = list(info) if len(overrides) != count: raise ValueError( f"info override length ({len(overrides)}) doesn't match atoms count ({count})" ) return overrides def _fast_key(record): builtins = record["builtins"] return ( record["n_atoms"], builtins["energy"] is not None, builtins["forces"] is not None, builtins["charges"] is not None, builtins["velocities"] is not None, builtins["cell"] is not None, builtins["stress"] is not None, builtins["pbc"] is not None, )
[docs] def from_ase( atoms, energy=None, forces=None, charges=None, velocities=None, cell=None, stress=None, copy_info=True, copy_arrays=True, info=None, ): """Convert one ASE Atoms object to an atompack Molecule. Custom values from ``atoms.info``, ``atoms.arrays``, calculator results, and explicit ``info=`` overrides are stored as molecule-scope properties. Array shape is not used to infer atom-property scope during ingestion. """ return _record_to_molecule( _extract_ase_record( atoms, energy=energy, forces=forces, charges=charges, velocities=velocities, cell=cell, stress=stress, copy_info=copy_info, copy_arrays=copy_arrays, info=info, ) )
[docs] def add_ase_batch( db, atoms_list, *, copy_info=True, copy_arrays=True, info=None, batch_size=512, ): """Write many ASE Atoms objects efficiently, preserving supported metadata.""" atoms_list = list(atoms_list) if not atoms_list: return info_overrides = _normalize_info_overrides(info, len(atoms_list)) fast_key = None fast_records = [] slow_records = [] def flush_fast(): nonlocal fast_key if fast_records: _flush_fast_records(db, fast_records) fast_records.clear() fast_key = None def flush_slow(): if slow_records: db.add_molecules(slow_records) slow_records.clear() for atoms, info_override in zip(atoms_list, info_overrides): record = _extract_ase_record( atoms, copy_info=copy_info, copy_arrays=copy_arrays, info=info_override, ) if record["properties"]: flush_fast() slow_records.append(_record_to_molecule(record)) if len(slow_records) >= batch_size: flush_slow() continue flush_slow() key = _fast_key(record) if fast_key != key: flush_fast() fast_key = key fast_records.append(record) if len(fast_records) >= batch_size: flush_fast() flush_fast() flush_slow()
Molecule.to_ase = to_ase Database.to_ase_batch = _database_to_ase_batch