Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 43 additions & 82 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ def _global_point_pick_info(

def __call__(
self, ary: ArrayOrContainerT, *,
_force_use_loopy: bool = False,
_force_no_merged_batches: bool = False,
) -> ArrayOrContainerT:
"""
Expand All @@ -577,8 +576,8 @@ def __call__(
coefficient data on :attr:`from_discr`.

"""
# _force_use_loopy, _force_no_merged_batches:
# private arguments only used to ensure test coverage of all code paths.
# _force_no_merged_batches:
# private argument only used to ensure test coverage of all code paths.

# {{{ recurse into array containers

Expand All @@ -593,7 +592,6 @@ def __call__(
else:
return deserialize_container(ary, [
(key, self(subary,
_force_use_loopy=_force_use_loopy,
_force_no_merged_batches=_force_no_merged_batches))
for key, subary in iterable
])
Expand Down Expand Up @@ -706,6 +704,9 @@ def group_pick_knl(is_surjective: bool):
"idof": ConcurrentDOFInameTag()})

# }}}
if not actx.permits_advanced_indexing:
raise ValueError("Array context does not allow advanced indexing. "
"This is no longer supported.")

group_arrays = []
for i_tgrp, (cgrp, group_pick_info) in enumerate(
Expand All @@ -719,51 +720,33 @@ def group_pick_knl(is_surjective: bool):
if group_pick_info is not None:
group_array_contributions = []

if actx.permits_advanced_indexing and not _force_use_loopy:
for fgpd in group_pick_info:
from_element_indices = actx.thaw(fgpd.from_element_indices)

if ary[fgpd.from_group_index].size:
grp_ary_contrib = ary[fgpd.from_group_index][
_reshape_and_preserve_tags(
actx, from_element_indices, (-1, 1)),
actx.thaw(fgpd.dof_pick_lists)[
actx.thaw(fgpd.dof_pick_list_indices)]
]

if not fgpd.is_surjective:
from_el_present = actx.thaw(fgpd.from_el_present)
grp_ary_contrib = actx.np.where(
for fgpd in group_pick_info:
from_element_indices = actx.thaw(fgpd.from_element_indices)

if ary[fgpd.from_group_index].size:
grp_ary_contrib = ary[fgpd.from_group_index][
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
grp_ary_contrib,
0)

# attach metadata
grp_ary_contrib = tag_axes(
actx,
{0: DiscretizationElementAxisTag(),
1: DiscretizationDOFAxisTag()},
grp_ary_contrib)

group_array_contributions.append(grp_ary_contrib)
else:
for fgpd in group_pick_info:
group_knl_kwargs = {}
actx, from_element_indices, (-1, 1)),
actx.thaw(fgpd.dof_pick_lists)[
actx.thaw(fgpd.dof_pick_list_indices)]
]

if not fgpd.is_surjective:
group_knl_kwargs["from_el_present"] = \
fgpd.from_el_present

group_array_contributions.append(
actx.call_loopy(
group_pick_knl(fgpd.is_surjective),
dof_pick_lists=fgpd.dof_pick_lists,
dof_pick_list_indices=fgpd.dof_pick_list_indices,
ary=ary[fgpd.from_group_index],
from_element_indices=fgpd.from_element_indices,
nunit_dofs_tgt=(
self.to_discr.groups[i_tgrp].nunit_dofs),
**group_knl_kwargs)["result"])
from_el_present = actx.thaw(fgpd.from_el_present)
grp_ary_contrib = actx.np.where(
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
grp_ary_contrib,
0)

# attach metadata
grp_ary_contrib = tag_axes(
actx,
{0: DiscretizationElementAxisTag(),
1: DiscretizationDOFAxisTag()},
grp_ary_contrib)

group_array_contributions.append(grp_ary_contrib)

group_array = sum(group_array_contributions)
elif cgrp.batches:
Expand All @@ -783,47 +766,25 @@ def group_pick_knl(is_surjective: bool):
if point_pick_indices is None:
grp_ary = ary[batch.from_group_index]
mat = self._resample_matrix(actx, i_tgrp, i_batch)
if actx.permits_advanced_indexing and not _force_use_loopy:
batch_result = actx.np.where(
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
actx.einsum("ij,ej->ei",
mat, grp_ary[from_element_indices]),
0)
else:
batch_result = actx.call_loopy(
batch_mat_knl(),
resample_mat=mat,
ary=grp_ary,
from_el_present=from_el_present,
from_element_indices=from_element_indices,
nunit_dofs_tgt=(
self.to_discr.groups[i_tgrp].nunit_dofs)
)["result"]
batch_result = actx.np.where(
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
actx.einsum("ij,ej->ei",
mat, grp_ary[from_element_indices]),
0)

else:
from_vec = ary[batch.from_group_index]
pick_list = actx.thaw(point_pick_indices)

if actx.permits_advanced_indexing and not _force_use_loopy:
batch_result = actx.np.where(
batch_result = actx.np.where(
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
from_vec[
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
from_vec[
_reshape_and_preserve_tags(
actx, from_element_indices, (-1, 1)),
pick_list],
0)
else:
batch_result = actx.call_loopy(
batch_pick_knl(),
pick_list=pick_list,
ary=from_vec,
from_el_present=from_el_present,
from_element_indices=from_element_indices,
nunit_dofs_tgt=(
self.to_discr.groups[i_tgrp].nunit_dofs)
)["result"]
actx, from_element_indices, (-1, 1)),
pick_list],
0)

# attach metadata
batch_result = tag_axes(actx,
Expand Down
14 changes: 2 additions & 12 deletions test/test_meshmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,8 @@ def f(x):
bdry_f_2 = opp_face(bdry_f)

# Ensure test coverage for alternate modes in DirectConnection
for force_loopy, force_no_merged_batches in [
(False, True),
(True, False),
(True, True),
]:
for force_no_merged_batches in [False, True]:
bdry_f_2_alt = opp_face(bdry_f,
_force_use_loopy=force_loopy,
_force_no_merged_batches=force_no_merged_batches)
assert actx.to_numpy(flat_norm(bdry_f_2 - bdry_f_2_alt, np.inf)) < 1e-14

Expand Down Expand Up @@ -994,13 +989,8 @@ def grp_factory(mesh_el_group: MeshElementGroup):
op_bdry_f = opposite(bdry_f)

# Ensure test coverage for alternate modes in DirectConnection
for force_loopy, force_no_merged_batches in [
(False, True),
(True, False),
(True, True),
]:
for force_no_merged_batches in [False, True]:
op_bdry_f_2 = opposite(bdry_f,
_force_use_loopy=force_loopy,
_force_no_merged_batches=force_no_merged_batches)
error = flat_norm(op_bdry_f - op_bdry_f_2, np.inf)
assert actx.to_numpy(error) < 1e-15
Expand Down