Skip to content

Commit 95b980f

Browse files
committed
wip
1 parent 013e079 commit 95b980f

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

galgebra/dop.py

+42-10
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,6 @@ def _eval_derivative_n_times_terms(terms, x, n):
7171

7272
################ Scalar Partial Differential Operator Class ############
7373

74-
def _diff_op_apply(d, x):
75-
if not isinstance(d, DiffOpExpr):
76-
d = d * DiffOpPartial({})
77-
return d._diff_op_apply(x)
78-
79-
8074
class _BaseDop(object):
8175
""" Base class for differential operators - used to avoid accidental promotion """
8276
pass
@@ -116,6 +110,20 @@ def __rsub__(self, x):
116110
def __call__(self, x):
117111
return self._diff_op_apply(x)
118112

113+
def _diff_op_ify(x):
114+
if isinstance(x, DiffOpExpr):
115+
return x
116+
elif isinstance(x, sympy.Add):
117+
return DiffOpAdd(*(a for a in x.args))
118+
elif isinstance(x, sympy.Mul):
119+
return DiffOpMul(*(a for a in x.args))
120+
else:
121+
return x * DiffOpPartial({})
122+
123+
def _diff_op_apply(d, x):
124+
if not isinstance(d, DiffOpExpr):
125+
d = d * DiffOpPartial({})
126+
return _diff_op_ify(d)._diff_op_apply(x)
119127

120128

121129
def Sdop(*args):
@@ -176,11 +184,11 @@ def sort_key(self, order=None):
176184
# lower order derivatives first
177185
self.order,
178186
# sorted by symbol after that, after expansion
179-
sorted([
187+
tuple(sorted([
180188
x.sort_key(order)
181189
for x, k in self.pdiffs.items()
182190
for i in range(k)
183-
])
191+
]))
184192
)
185193

186194
def __new__(cls, __arg):
@@ -210,7 +218,7 @@ def __new__(cls, __arg):
210218
return self
211219

212220
def _eval_derivative(self, x):
213-
self._eval_derivative_n_times(x, 1)
221+
return self._eval_derivative_n_times(x, 1)
214222

215223
def _eval_derivative_n_times(self, x, n) -> 'Pdop': # pdiff(self)
216224
# d is partial derivative
@@ -272,6 +280,9 @@ def __str__(self):
272280
def __repr__(self):
273281
return str(self)
274282

283+
def __srepr__(self):
284+
return '{}({})'.format(type(self).__name__, self.pdiffs)
285+
275286
def _hashable_content(self):
276287
from sympy.utilities import default_sort_key
277288
sorted_items = sorted(
@@ -349,15 +360,36 @@ def _diff_op_apply(self, x):
349360
x = r._diff_op_apply(x)
350361
return coeff * x
351362

363+
def diff(self, *args, **kwargs):
364+
return super().diff(*args, simplify=False, **kwargs)
365+
366+
def _eval_derivative(self, x):
367+
coeff, diff = self.args
368+
return sympy.diff(coeff, x) * diff + coeff * sympy.diff(diff, x)
369+
352370

353371
class DiffOpAdd(DiffOpExpr, sympy.Add):
354372
identity = DiffOpZero()
355373

374+
@classmethod
375+
def _from_args(self, args, is_commutative):
376+
args = [_diff_op_ify(arg) for arg in args]
377+
return super()._from_args(args, is_commutative)
378+
356379
def _diff_op_apply(self, x):
357380
args = self.args
358381
assert args
359382
# avoid `sympy.Add` so that this works on multivectors
360383
return functools.reduce(operator.add, (_diff_op_apply(a, x) for a in args))
361384

362385
def _eval_derivative_n_times(self, x, n):
363-
return sympy.Add._eval_derivative_n_times(self, x, n)
386+
return DiffOpAdd(*(a.diff(x, n) for a in self.args))
387+
388+
def _eval_derivative(self, x):
389+
return DiffOpAdd(*(a.diff(x) for a in self.args))
390+
391+
def _eval_simplify(self, *args, **kwargs):
392+
return self
393+
394+
def diff(self, *args, **kwargs):
395+
return super().diff(*args, simplify=False, **kwargs)

galgebra/mv.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1676,10 +1676,7 @@ def is_scalar(self):
16761676

16771677
def components(self):
16781678
return tuple(
1679-
Dop(dop._consolidate_terms(
1680-
(Mv(coef * base, ga=self.Ga), pdiff)
1681-
for (coef, pdiff) in sdop.terms
1682-
), ga=self.Ga)
1679+
Dop([(base, sdop)], ga=self.Ga)
16831680
for (sdop, base) in self.Dop_mv_expand()
16841681
)
16851682

@@ -1710,7 +1707,7 @@ def Dop_mv_expand(self, modes=None):
17101707
coefs.append(dop.Sdop([(mv_coef, pdiff)]))
17111708
if modes is not None:
17121709
for i in range(len(coefs)):
1713-
coefs[i] = coefs[i].simplify(modes)
1710+
coefs[i] = coefs[i].simplify(modes=modes)
17141711
terms = list(zip(coefs, bases))
17151712
return sorted(terms, key=lambda x: self.Ga._all_blades_lst.index(x[1]))
17161713

0 commit comments

Comments
 (0)