@@ -71,12 +71,6 @@ def _eval_derivative_n_times_terms(terms, x, n):
71
71
72
72
################ Scalar Partial Differential Operator Class ############
73
73
74
- def _diff_op_apply (d , x ):
75
- if not isinstance (d , DiffOpExpr ):
76
- d = d * DiffOpPartial ({})
77
- return d ._diff_op_apply (x )
78
-
79
-
80
74
class _BaseDop (object ):
81
75
""" Base class for differential operators - used to avoid accidental promotion """
82
76
pass
@@ -116,6 +110,20 @@ def __rsub__(self, x):
116
110
def __call__ (self , x ):
117
111
return self ._diff_op_apply (x )
118
112
113
+ def _diff_op_ify (x ):
114
+ if isinstance (x , DiffOpExpr ):
115
+ return x
116
+ elif isinstance (x , sympy .Add ):
117
+ return DiffOpAdd (* (a for a in x .args ))
118
+ elif isinstance (x , sympy .Mul ):
119
+ return DiffOpMul (* (a for a in x .args ))
120
+ else :
121
+ return x * DiffOpPartial ({})
122
+
123
+ def _diff_op_apply (d , x ):
124
+ if not isinstance (d , DiffOpExpr ):
125
+ d = d * DiffOpPartial ({})
126
+ return _diff_op_ify (d )._diff_op_apply (x )
119
127
120
128
121
129
def Sdop (* args ):
@@ -176,11 +184,11 @@ def sort_key(self, order=None):
176
184
# lower order derivatives first
177
185
self .order ,
178
186
# sorted by symbol after that, after expansion
179
- sorted ([
187
+ tuple ( sorted ([
180
188
x .sort_key (order )
181
189
for x , k in self .pdiffs .items ()
182
190
for i in range (k )
183
- ])
191
+ ]))
184
192
)
185
193
186
194
def __new__ (cls , __arg ):
@@ -210,7 +218,7 @@ def __new__(cls, __arg):
210
218
return self
211
219
212
220
def _eval_derivative (self , x ):
213
- self ._eval_derivative_n_times (x , 1 )
221
+ return self ._eval_derivative_n_times (x , 1 )
214
222
215
223
def _eval_derivative_n_times (self , x , n ) -> 'Pdop' : # pdiff(self)
216
224
# d is partial derivative
@@ -272,6 +280,9 @@ def __str__(self):
272
280
def __repr__ (self ):
273
281
return str (self )
274
282
283
+ def __srepr__ (self ):
284
+ return '{}({})' .format (type (self ).__name__ , self .pdiffs )
285
+
275
286
def _hashable_content (self ):
276
287
from sympy .utilities import default_sort_key
277
288
sorted_items = sorted (
@@ -349,15 +360,36 @@ def _diff_op_apply(self, x):
349
360
x = r ._diff_op_apply (x )
350
361
return coeff * x
351
362
363
+ def diff (self , * args , ** kwargs ):
364
+ return super ().diff (* args , simplify = False , ** kwargs )
365
+
366
+ def _eval_derivative (self , x ):
367
+ coeff , diff = self .args
368
+ return sympy .diff (coeff , x ) * diff + coeff * sympy .diff (diff , x )
369
+
352
370
353
371
class DiffOpAdd (DiffOpExpr , sympy .Add ):
354
372
identity = DiffOpZero ()
355
373
374
+ @classmethod
375
+ def _from_args (self , args , is_commutative ):
376
+ args = [_diff_op_ify (arg ) for arg in args ]
377
+ return super ()._from_args (args , is_commutative )
378
+
356
379
def _diff_op_apply (self , x ):
357
380
args = self .args
358
381
assert args
359
382
# avoid `sympy.Add` so that this works on multivectors
360
383
return functools .reduce (operator .add , (_diff_op_apply (a , x ) for a in args ))
361
384
362
385
def _eval_derivative_n_times (self , x , n ):
363
- return sympy .Add ._eval_derivative_n_times (self , x , n )
386
+ return DiffOpAdd (* (a .diff (x , n ) for a in self .args ))
387
+
388
+ def _eval_derivative (self , x ):
389
+ return DiffOpAdd (* (a .diff (x ) for a in self .args ))
390
+
391
+ def _eval_simplify (self , * args , ** kwargs ):
392
+ return self
393
+
394
+ def diff (self , * args , ** kwargs ):
395
+ return super ().diff (* args , simplify = False , ** kwargs )
0 commit comments