diff --git a/pytential/qbx/__init__.py b/pytential/qbx/__init__.py index e392844d2..9f287972e 100644 --- a/pytential/qbx/__init__.py +++ b/pytential/qbx/__init__.py @@ -46,6 +46,7 @@ ) from sumpy.expansion.local import LocalExpansionBase +from pytential.qbx.refinement import QBXRefinementMode, QBXRefinementNeededError from pytential.qbx.target_assoc import QBXTargetAssociationFailedError from pytential.source import LayerPotentialSourceBase @@ -86,6 +87,10 @@ .. autoclass:: NonFFTExpansionFactory .. autodata:: FMMBackend + +.. autoclass:: QBXRefinementMode + +.. autoclass:: QBXRefinementNeededError """ @@ -147,7 +152,7 @@ class QBXLayerPotentialSource(LayerPotentialSourceBase): target_association_tolerance: float debug: bool - _disable_refinement: bool + refinement_mode: QBXRefinementMode _expansions_in_tree_have_extent: bool _expansion_stick_out_factor: float _well_sep_is_n_away: int @@ -175,7 +180,8 @@ def __init__( # begin experimental arguments # FIXME default debug=False once everything has matured debug: bool = True, - _disable_refinement: bool = False, + refinement_mode: QBXRefinementMode | None = None, + _disable_refinement: bool | None = None, _expansions_in_tree_have_extent: bool = True, _expansion_stick_out_factor: float = 0.5, _max_leaf_refine_weight: int | None = None, @@ -208,6 +214,8 @@ def __init__( the FMM evaluations. :arg target_association_tolerance: passed on to :func:`pytential.qbx.target_assoc.associate_targets_to_qbx_centers`. + :arg refinement_mode: A :class:`~pytential.qbx.refinement.QBXRefinementMode` + controlling whether and how refinement is performed. Experimental arguments without a promise of forward compatibility: @@ -234,6 +242,7 @@ def __init__( :arg cost_model: Either *None* or an object implementing the :class:`~pytential.qbx.cost.AbstractQBXCostModel` interface, used for gathering modeled costs if provided (experimental). + :arg _disable_refinement: Deprecated. Use *refinement_mode* instead. """ # {{{ argument processing @@ -317,6 +326,21 @@ def fmm_lto(kernel, kernel_args, tree, level): from pytential.qbx.cost import QBXCostModel cost_model = QBXCostModel() + if _disable_refinement is not None: + from warnings import warn + warn( + "'_disable_refinement' is deprecated. " + "Use 'refinement_mode' instead.", + DeprecationWarning, stacklevel=2) + if refinement_mode is None: + refinement_mode = ( + QBXRefinementMode.NO_REFINEMENT + if _disable_refinement + else QBXRefinementMode.REFINE) + + if refinement_mode is None: + refinement_mode = QBXRefinementMode.REFINE + # }}} if density_discr.dim != density_discr.ambient_dim - 1: @@ -335,7 +359,7 @@ def fmm_lto(kernel, kernel_args, tree, level): self.target_association_tolerance = target_association_tolerance self.debug = debug - self._disable_refinement = _disable_refinement + self.refinement_mode = refinement_mode self._expansions_in_tree_have_extent = _expansions_in_tree_have_extent self._expansion_stick_out_factor = _expansion_stick_out_factor self._well_sep_is_n_away = _well_sep_is_n_away @@ -376,6 +400,7 @@ def copy( fmm_backend=None, debug=_not_provided, + refinement_mode=_not_provided, _disable_refinement=_not_provided, ): if target_association_tolerance is _not_provided: @@ -394,6 +419,18 @@ def copy( else: kwargs["fmm_level_to_order"] = self.fmm_level_to_order + if _disable_refinement is not _not_provided: + from warnings import warn + warn( + "'_disable_refinement' is deprecated. " + "Use 'refinement_mode' instead.", + DeprecationWarning, stacklevel=2) + if refinement_mode is _not_provided: + refinement_mode = ( + QBXRefinementMode.NO_REFINEMENT + if _disable_refinement + else QBXRefinementMode.REFINE) + # FIXME Could/should share wrangler and geometry kernels # if no relevant changes have been made. return type(self)( @@ -409,11 +446,10 @@ def copy( debug=( # False is a valid value here debug if debug is not _not_provided else self.debug), - _disable_refinement=( - # False is a valid value here - _disable_refinement - if _disable_refinement is not _not_provided - else self._disable_refinement), + refinement_mode=( + refinement_mode + if refinement_mode is not _not_provided + else self.refinement_mode), _expansions_in_tree_have_extent=( # False is a valid value here _expansions_in_tree_have_extent @@ -569,7 +605,7 @@ def drive_cost_model( def _dispatch_compute_potential_insn(self, actx, insn, bound_expr, evaluate, func, extra_args=None): - if self._disable_refinement: + if self.refinement_mode == QBXRefinementMode.NO_REFINEMENT: from warnings import warn warn( "Executing global QBX without refinement. " @@ -1072,6 +1108,8 @@ def get_flat_strengths_from_densities( "LocalExpansionBase", "QBXDefaultExpansionFactory", "QBXLayerPotentialSource", + "QBXRefinementMode", + "QBXRefinementNeededError", "QBXTargetAssociationFailedError", ) diff --git a/pytential/qbx/refinement.py b/pytential/qbx/refinement.py index 429ac6b4b..a690ff062 100644 --- a/pytential/qbx/refinement.py +++ b/pytential/qbx/refinement.py @@ -27,6 +27,7 @@ """ import logging +from enum import Enum, auto from typing import TYPE_CHECKING, Any, cast import numpy as np @@ -77,8 +78,15 @@ The element size is bounded by a kernel length scale. This applies only to Helmholtz kernels. -Warnings emitted by refinement -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Refinement mode +^^^^^^^^^^^^^^^ + +.. autoclass:: QBXRefinementMode + +Errors and warnings emitted by refinement +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: QBXRefinementNeededError .. autoclass:: RefinerNotConvergedWarning @@ -97,6 +105,48 @@ .. autofunction:: refine_geometry_collection """ + +# {{{ QBXRefinementMode + +class QBXRefinementMode(Enum): + """Controls the refinement behavior of a + :class:`~pytential.qbx.QBXLayerPotentialSource`. + + .. attribute:: REFINE + + Perform refinement as needed. This is the default behavior. + + .. attribute:: NO_REFINEMENT + + Skip refinement entirely. An + :class:`~meshmode.discretization.connection.IdentityDiscretizationConnection` + is returned instead of performing any mesh refinement. + + .. warning:: + + Executing global QBX without refinement is unlikely to give + accurate results. + + .. attribute:: COMPLAIN + + Do not perform any refinement, but raise a + :class:`QBXRefinementNeededError` if stage-1 or stage-2 refinement + would be required to satisfy the QBX refinement criteria. + """ + + REFINE = auto() + NO_REFINEMENT = auto() + COMPLAIN = auto() + + +class QBXRefinementNeededError(RuntimeError): + """Raised when :attr:`QBXRefinementMode.COMPLAIN` is in effect and + refinement would be needed to satisfy the QBX refinement criteria. + """ + +# }}} + + # {{{ kernels # Refinement checker for Condition 1. @@ -616,7 +666,7 @@ def _refine_qbx_stage1(lpot_source, density_discr, expansion_disturbance_tolerance=None, maxiter=None, debug=None, visualize=False): from pytential import bind, sym - if lpot_source._disable_refinement: + if lpot_source.refinement_mode == QBXRefinementMode.NO_REFINEMENT: from meshmode.discretization.connection import IdentityDiscretizationConnection return density_discr, IdentityDiscretizationConnection(density_discr) @@ -717,6 +767,13 @@ def _refine_qbx_stage1(lpot_source, density_discr, if iter_violated_criteria: violated_criteria.append(" and ".join(iter_violated_criteria)) + if lpot_source.refinement_mode == QBXRefinementMode.COMPLAIN: + raise QBXRefinementNeededError( + "Stage-1 QBX refinement is needed but refinement mode is " + f"'{QBXRefinementMode.COMPLAIN.name}'. " + "Criteria requiring refinement: " + + ", ".join(iter_violated_criteria)) + conn = wrangler.refine( stage1_density_discr, refiner, refine_flags, group_factory, debug) @@ -737,7 +794,7 @@ def _refine_qbx_stage2(lpot_source, stage1_density_discr, expansion_disturbance_tolerance=None, force_stage2_uniform_refinement_rounds=None, maxiter=None, debug=None, visualize=False): - if lpot_source._disable_refinement: + if lpot_source.refinement_mode == QBXRefinementMode.NO_REFINEMENT: from meshmode.discretization.connection import IdentityDiscretizationConnection return (stage1_density_discr, IdentityDiscretizationConnection(stage1_density_discr)) @@ -789,6 +846,13 @@ def _refine_qbx_stage2(lpot_source, stage1_density_discr, if iter_violated_criteria: violated_criteria.append(" and ".join(iter_violated_criteria)) + if lpot_source.refinement_mode == QBXRefinementMode.COMPLAIN: + raise QBXRefinementNeededError( + "Stage-2 QBX refinement is needed but refinement mode is " + f"'{QBXRefinementMode.COMPLAIN.name}'. " + "Criteria requiring refinement: " + + ", ".join(iter_violated_criteria)) + conn = wrangler.refine( stage2_density_discr, refiner, refine_flags, group_factory, debug) diff --git a/test/extra_int_eq_data.py b/test/extra_int_eq_data.py index 35675ff01..944ce5a70 100644 --- a/test/extra_int_eq_data.py +++ b/test/extra_int_eq_data.py @@ -39,6 +39,7 @@ from pytential import sym from pytential.qbx import FMMBackend, QBXLayerPotentialSource +from pytential.qbx.refinement import QBXRefinementMode from pytential.source import PointPotentialSource from pytential.target import PointsTarget @@ -261,7 +262,10 @@ def get_layer_potential(self, qbx_order=self.qbx_order, fmm_backend=fmm_backend, **fmm_kwargs, - _disable_refinement=not self.use_refinement, + refinement_mode=( + QBXRefinementMode.REFINE + if self.use_refinement + else QBXRefinementMode.NO_REFINEMENT), _box_extent_norm=self.box_extent_norm, _from_sep_smaller_crit=self.from_sep_smaller_crit, _from_sep_smaller_min_nsources_cumul=30,