Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 41 additions & 18 deletions pytential/symbolic/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, cast

Expand All @@ -33,7 +33,12 @@
from pymbolic import ArithmeticExpression
from pytools.obj_array import ObjectArray, ObjectArray1D, ShapeT, from_numpy

from pytential.symbolic.mappers import CachedIdentityMapper, DependencyMapper
from pytential.symbolic.mappers import (
CachedIdentityMapper,
DependencyMapper,
PrettyStringifyMapper,
StringifyMapper,
)
from pytential.symbolic.primitives import (
DOFDescriptor,
IntG,
Expand All @@ -44,6 +49,7 @@

if TYPE_CHECKING:
from collections.abc import (
Callable,
Collection,
Hashable,
Iterator,
Expand Down Expand Up @@ -76,7 +82,7 @@
# {{{ statements

@dataclass(frozen=True, eq=False)
class Statement:
class Statement(ABC):
"""
.. autoattribute:: names
.. autoattribute:: exprs
Expand All @@ -93,23 +99,25 @@ class Statement:
priority: int
"""The priority of the statement."""

@abstractmethod
def get_assignees(self) -> set[str]:
"""
:returns: names of variables that are assigned to in this statement.
"""
raise NotImplementedError(
f"get_assignees for '{self.__class__.__name__}'")

@abstractmethod
def get_dependencies(self, dep_mapper: DependencyMapper) -> set[prim.Variable]:
"""
:returns: variables that are dependencies of the assignees.
"""
raise NotImplementedError(
f"get_dependencies for '{self.__class__.__name__}'")

@abstractmethod
def stringify(self, expr_mapper: Callable[[Expression | Operand], str]) -> str:
...

@override
def __str__(self) -> str:
raise NotImplementedError
def __str__(self):
return self.stringify(StringifyMapper())


@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -152,14 +160,17 @@ def get_dependencies(self, dep_mapper: DependencyMapper) -> set[prim.Variable]:
return deps

@override
def __str__(self) -> str:
def stringify(self, expr_mapper: Callable[[Expression | Operand], str]) -> str:
comment = self.comment

if len(self.names) == 1:
if comment:
comment = f"/* {comment} */ "

return "{} <- {}{}".format(self.names[0], comment, self.exprs[0])
return "{} <- {}{}".format(
self.names[0],
comment,
expr_mapper(self.exprs[0]))
else:
do_not_return = self.do_not_return
if do_not_return is None:
Expand All @@ -176,7 +187,7 @@ def __str__(self) -> str:
else:
dnr_indicator = ""

lines.append(f" {n} <{dnr_indicator}- {e}")
lines.append(f" {n} <{dnr_indicator}- {expr_mapper(e)}")
lines.append("}")

return "\n".join(lines)
Expand Down Expand Up @@ -266,14 +277,12 @@ def get_dependencies(self, dep_mapper: DependencyMapper) -> set[prim.Variable]:
return result

@override
def __str__(self) -> str:
def stringify(self, expr_mapper: Callable[[Expression | Operand], str]) -> str:
args = [f"source={self.source}"]
for i, density in enumerate(self.densities):
args.append(f"density{i}={density}")

from pytential.symbolic.mappers import StringifyMapper, stringify_where
strify = StringifyMapper()

from pytential.symbolic.mappers import stringify_where
lines: list[str] = []
for o in self.outputs:
if o.target_name != self.source:
Expand Down Expand Up @@ -308,7 +317,7 @@ def __str__(self) -> str:
lines.append(line)

for arg_name, arg_expr in self.kernel_arguments.items():
arg_expr_lines = strify(arg_expr).split("\n")
arg_expr_lines = expr_mapper(arg_expr).split("\n")
lines.append(" {} = {}".format(arg_name, arg_expr_lines[0]))
lines.extend(" " + s for s in arg_expr_lines[1:])

Expand Down Expand Up @@ -417,9 +426,23 @@ def statements(self) -> list[Statement]:

@override
def __str__(self) -> str:
strify_mapper = PrettyStringifyMapper()
lines: list[str] = []
for insn in self.statements:
lines.extend(str(insn).split("\n"))
lines.extend(insn.stringify(strify_mapper).split("\n"))

if strify_mapper.cse_name_list:
# FIXME: There's potential here for name clashes between the 'code'
# and 'discretization CSE' parts. It's just presentation, so if it's
# bothersome, near here is the place to fix it.
lines = [
"DISCRETIZATION-LEVEL COMMON SUBEXPRESSIONS:",
*[
f"{name} <- {cse_expr_str}"
for name, cse_expr_str in strify_mapper.cse_name_list],
"-"*75,
*lines]

lines.append(f"RESULT: {self.result}")

return "\n".join(lines)
Expand Down
76 changes: 64 additions & 12 deletions pytential/symbolic/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,9 @@
:attr:`~pytential.symbolic.dof_desc.DOFDescriptor.discr_stage`.
"""

def __init__(self, discr_stage):
discr_stage: DiscretizationStage

def __init__(self, discr_stage: DiscretizationStage):
if discr_stage not in {
pp.QBX_SOURCE_STAGE1,
pp.QBX_SOURCE_STAGE2,
Expand All @@ -581,7 +583,8 @@

self.discr_stage = discr_stage

def map_node_coordinate_component(self, expr):
@override
def map_node_coordinate_component(self, expr: pp.NodeCoordinateComponent):
dofdesc = expr.dofdesc
if dofdesc.discr_stage == self.discr_stage:
return expr
Expand All @@ -590,7 +593,8 @@
expr.ambient_axis,
dofdesc.copy(discr_stage=self.discr_stage))

def map_num_reference_derivative(self, expr):
@override
def map_num_reference_derivative(self, expr: pp.NumReferenceDerivative):
dofdesc = expr.dofdesc
if dofdesc.discr_stage == self.discr_stage:
return expr
Expand Down Expand Up @@ -740,6 +744,55 @@

# {{{ InterpolationPreprocessor

@dataclass
class EarlyInterpolationAdder(
# This is deliberately inheriting from the pymbolic mapper,
# based on the assumption that all the pymbolic-defined operations
# will apply elementwise. Pytential nodes will end up in
# handle_unsupported_expression below.
IdentityMapperBase[[]],
CSECachingMapperMixin[Expression, []]):
"""Used from within :class:`InterpolationPreprocessor`. Rather than
interpolate the result of a computation, push interpolation as far
'upstream' as possible, to minimize aliasing error.
"""
from_dd: DOFDescriptor
to_dd: DOFDescriptor

@override
def map_variable(self, expr: p.Variable):
return pp.interpolate(expr, self.from_dd, self.to_dd)

@override
def map_call(self,
expr: p.Call,
) -> Expression:
parameters = tuple(self.rec(child) for child in expr.parameters)
if all(child is orig_child for child, orig_child in
zip(expr.parameters, parameters, strict=True)):
return expr

return type(expr)(expr.function, parameters)

@override
def handle_unsupported_expression(self, expr: p.ExpressionNode) -> Expression:
return pp.interpolate(expr, self.from_dd, self.to_dd)

@override
def map_common_subexpression_uncached(self,
expr: p.CommonSubexpression, /,
) -> Expression:
result = self.rec(expr.child)
if result is expr.child:
return expr

return type(expr)(
result,
expr.prefix,
expr.scope,
**expr.get_extra_properties())


class InterpolationPreprocessor(IdentityMapper):
"""Handle expressions that require upsampling or downsampling by inserting
a :class:`~pytential.symbolic.primitives.Interpolation`. This is used to
Expand Down Expand Up @@ -801,16 +854,18 @@

from_dd = expr.source.to_stage1()
to_dd = from_dd.to_quad_stage2()
interp_adder = EarlyInterpolationAdder(from_dd, to_dd)
densities = tuple(
pp.interpolate(self.rec_arith(density), from_dd, to_dd)
interp_adder.rec_arith(self.rec_arith(density))
for density in expr.densities)

from_dd = from_dd.copy(discr_stage=self.from_discr_stage)
interp_adder = EarlyInterpolationAdder(from_dd, to_dd)
kernel_arguments = constantdict({
name: componentwise(
lambda aexpr: pp.interpolate(
lambda aexpr: interp_adder.rec_arith(
self.rec_arith(
self.tagger.rec_arith(aexpr)), from_dd, to_dd),
self.tagger.rec_arith(aexpr))),
arg_expr)
for name, arg_expr in expr.kernel_arguments.items()})

Expand Down Expand Up @@ -889,7 +944,7 @@
return str(pp.as_dofdesc(where))


class StringifyMapper(BaseStringifyMapper):
class StringifyMapper(BaseStringifyMapper[[]]):

def map_ones(self, expr: pp.Ones, enclosing_prec: int):
return "Ones[%s]" % stringify_where(expr.dofdesc)
Expand Down Expand Up @@ -1025,9 +1080,6 @@
# {{{ graphviz

class GraphvizMapper(GraphvizMapperBase):
def __init__(self):
super().__init__()

def map_pytential_leaf(self, expr):
self.lines.append(
'{} [label="{}", shape=box];'.format(
Expand All @@ -1039,7 +1091,7 @@

map_ones = map_pytential_leaf

def map_map_node_sum(self, expr):
def map_map_node_sum(self, expr: pp.NodeSum):
self.lines.append(
'{} [label="{}",shape=circle];'.format(
self.get_id(expr), type(expr).__name__))
Expand All @@ -1055,7 +1107,7 @@

map_q_weight = map_pytential_leaf

def map_int_g(self, expr):
def map_int_g(self, expr: pp.IntG):
descr = "Int[%s->%s]@(%d) (%s)" % (
stringify_where(expr.source),
stringify_where(expr.target),
Expand All @@ -1069,7 +1121,7 @@

self.rec(expr.densities)
for arg_expr in expr.kernel_arguments.values():
self.rec(arg_expr)

Check failure on line 1124 in pytential/symbolic/mappers.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument of type "Operand" cannot be assigned to parameter "expr" of type "Expression" in function "__call__"   Type "Operand" is not assignable to type "Expression"     Type "MultiVector[ArithmeticExpression]" is not assignable to type "Expression"       "MultiVector[ArithmeticExpression]" is not assignable to "int"       "MultiVector[ArithmeticExpression]" is not assignable to "integer[Any]"       "MultiVector[ArithmeticExpression]" is not assignable to "float"       "MultiVector[ArithmeticExpression]" is not assignable to "complex"       "MultiVector[ArithmeticExpression]" is not assignable to "inexact[Any, float | complex]"       "MultiVector[ArithmeticExpression]" is not assignable to "bool" ... (reportArgumentType)

self.post_visit(expr)

Expand Down
Loading
Loading