Skip to content

Commit 4fe9ec0

Browse files
mmaterarocky
authored andcommitted
Implement CoefficientArrays
1 parent 3ac1338 commit 4fe9ec0

6 files changed

Lines changed: 317 additions & 9 deletions

File tree

CHANGES.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ New variables and builtins
2424
* ``StringReverse``
2525
* ``$SystemMemory``
2626
* Add all of the named colors, e.g. ``Brown`` or ``LighterMagenta``.
27-
27+
* ``Collect``
2828

2929

3030
Enhancements

mathics/builtin/numbers/algebra.py

Lines changed: 309 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Integer,
1414
Integer0,
1515
Integer1,
16+
RationalOneHalf,
1617
Number,
1718
Symbol,
1819
SymbolFalse,
@@ -21,7 +22,7 @@
2122
SymbolTrue,
2223
)
2324
from mathics.core.convert import from_sympy, sympy_symbol_prefix
24-
from mathics.core.rules import Pattern
25+
from mathics.core.pattern import Pattern
2526
from mathics.builtin.scoping import dynamic_scoping
2627
from mathics.builtin.inference import evaluate_predicate
2728

@@ -213,10 +214,12 @@ def unconvert_subexprs(expr):
213214
)
214215

215216
sympy_expr = convert_sympy(expr)
216-
217217
if deep:
218218
# thread over everything
219-
for (i, sub_expr,) in enumerate(sub_exprs):
219+
for (
220+
i,
221+
sub_expr,
222+
) in enumerate(sub_exprs):
220223
if not sub_expr.is_atom():
221224
head = _expand(sub_expr.head) # also expand head
222225
leaves = sub_expr.get_leaves()
@@ -270,7 +273,6 @@ def unconvert_subexprs(expr):
270273
sympy_expr = sympy_expr.expand(**hints)
271274
result = from_sympy(sympy_expr)
272275
result = unconvert_subexprs(result)
273-
274276
return result
275277

276278

@@ -1606,3 +1608,306 @@ def apply(self, expr, form, h, evaluation):
16061608
return Expression(
16071609
"List", *[Expression(h, *[i for i in s]) for s in exponents]
16081610
)
1611+
1612+
1613+
class CoefficientArrays(Builtin):
1614+
"""
1615+
<dl>
1616+
<dt>'CoefficientArrays[$polys$, $vars$]'
1617+
<dd>returns a list of arrays of coefficients of the variables $vars$ in the polynomial $poly$.
1618+
</dl>
1619+
"""
1620+
1621+
options = {
1622+
"Symmetric": "False",
1623+
}
1624+
messages = {
1625+
"poly": "`1` is not a polynomial",
1626+
}
1627+
1628+
def apply_list(self, polys, varlist, expression, options):
1629+
"%(name)s[polys_list, varlist_, OptionsPattern[]]"
1630+
return
1631+
if polys.has_form("List", None):
1632+
polys = polys.leaves
1633+
else:
1634+
polys = [polys]
1635+
1636+
# Expand all the polynomials before start
1637+
polys = [Expression("ExpandAll", poly).evaluate(evaluation) for poly in polys]
1638+
1639+
if varlist.has_form("List", None):
1640+
varpat = varlist.leaves
1641+
else:
1642+
varpat = [varlist]
1643+
1644+
degree = 0
1645+
1646+
def isvar(var):
1647+
# TODO: check also patterns
1648+
if term.is_atom():
1649+
return var in varpat
1650+
# if the expression do not match, and
1651+
# is not atomic, do not decide.
1652+
return
1653+
1654+
def term_degree(term):
1655+
degree = 0
1656+
linear = isvar(term)
1657+
if not (linear is None):
1658+
return linear
1659+
1660+
if term.get_head_name() == "System`Times":
1661+
for factor in term.leaves:
1662+
q = factor_degree(factor)
1663+
if factor is None:
1664+
return None
1665+
degree += q
1666+
elif term.get_head_name() == "System`Power":
1667+
return power_degree(term)
1668+
1669+
def factor_degree(factor):
1670+
linear = isvar(factor)
1671+
if not (islinear is None):
1672+
return linear
1673+
if factor.get_head_name() == "System`Power":
1674+
return power_degree(factor)
1675+
return None
1676+
1677+
def power_degree(factor):
1678+
if not isvar(factor.leaves[0]):
1679+
return 0
1680+
if not isinstance(factor.leaves[1], Integer):
1681+
return None
1682+
return factor.leaves[1].get_int_value()
1683+
1684+
for poly in polys:
1685+
if poly.is_atom():
1686+
if degree == 0 and poly in varpat:
1687+
degree = 1
1688+
# TODO: handle patterns
1689+
continue
1690+
elif poly.get_head_name() == "System`Plus":
1691+
for term in poly.leaves:
1692+
curr_degree = 0
1693+
if term.get_head_name() == "System`Times":
1694+
for factor in term.leaves:
1695+
if factor.get_head_name() == "System`Power":
1696+
if isinstance(factor.leaves[1], Integer):
1697+
curr_degree = factor.leaves[1].get_int_value()
1698+
else:
1699+
evaluation.message(
1700+
"CoefficientArrays", "poly", poly
1701+
)
1702+
1703+
elif term.get_head_name() == "System`Power":
1704+
if term.leaves[0] in varpat:
1705+
if isinstance(term.leaves[1], Integer):
1706+
curr_degree = term.leaves[1].get_int_value()
1707+
else:
1708+
evaluation.message("CoefficientArrays", "poly", poly)
1709+
elif term in vars:
1710+
curr_degree = 1
1711+
elif poly.get_head_name() not in (
1712+
"System`Plus",
1713+
"System`Times",
1714+
"System`Power",
1715+
):
1716+
evaluation.message("CoefficientArrays", "poly", poly)
1717+
return
1718+
1719+
1720+
class Collect(Builtin):
1721+
"""
1722+
<dl>
1723+
<dt>'Collect[$expr$, $x$]'
1724+
<dd> Expands $expr$ and collect together terms having the same power of $x$.
1725+
<dt>'Collect[$expr$, {$x_1$, $x_2$, ...}]'
1726+
<dd> Expands $expr$ and collect together terms having the same powers of
1727+
$x_1$, $x_2$, ....
1728+
<dt>'Collect[$expr$, {$x_1$, $x_2$, ...}, $filter$]'
1729+
<dd> After collect the terms, applies $filter$ to each coefficient.
1730+
</dl>
1731+
1732+
>> Collect[(x+y)^3, y]
1733+
= x ^ 3 + 3 x ^ 2 y + 3 x y ^ 2 + y ^ 3
1734+
>> Collect[2 Sin[x z] (x+2 y^2 + Sin[y] x), y]
1735+
= 2 x Sin[x z] + 2 x Sin[x z] Sin[y] + 4 y ^ 2 Sin[x z]
1736+
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, y]
1737+
= 4 x Sin[x z] + x ^ 3 + y (3 x + 3 x ^ 2) + y ^ 2 (3 x + 4 Sin[x z]) + y ^ 3
1738+
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}]
1739+
= 4 x Sin[x z] + x ^ 3 + 3 x y + 3 x ^ 2 y + 4 y ^ 2 Sin[x z] + 3 x y ^ 2 + y ^ 3
1740+
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}, h]
1741+
= x h[4 Sin[x z]] + x ^ 3 h[1] + x y h[3] + x ^ 2 y h[3] + y ^ 2 h[4 Sin[x z]] + x y ^ 2 h[3] + y ^ 3 h[1]
1742+
"""
1743+
1744+
rules = {
1745+
"Collect[expr_, varlst_]": "Collect[expr, varlst, Identity]",
1746+
}
1747+
1748+
def apply_var_filter(self, expr, varlst, filt, evaluation):
1749+
"""Collect[expr_, varlst_, filt_]"""
1750+
from mathics.builtin.patterns import match
1751+
1752+
if varlst.is_symbol():
1753+
var_exprs = [varlst]
1754+
elif varlst.has_form("List", None):
1755+
var_exprs = varlst.get_leaves()
1756+
else:
1757+
var_exprs = [varlst]
1758+
1759+
if len(var_exprs) > 1:
1760+
target_pat = Pattern.create(Expression("Alternatives", *var_exprs))
1761+
var_pats = [Pattern.create(var) for var in var_exprs]
1762+
else:
1763+
target_pat = Pattern.create(varlst)
1764+
var_pats = [target_pat]
1765+
1766+
expr = expand(
1767+
expr,
1768+
numer=True,
1769+
denom=False,
1770+
deep=False,
1771+
trig=False,
1772+
modulus=None,
1773+
target_pat=target_pat,
1774+
)
1775+
if filt == Symbol("Identity"):
1776+
filt = None
1777+
1778+
def key_powers(lst):
1779+
key = Expression("Plus", *lst)
1780+
key = key.evaluate(evaluation)
1781+
if key.is_numeric():
1782+
return key.to_python()
1783+
return 0
1784+
1785+
def powers_list(pf):
1786+
powers = [Integer0 for i, p in enumerate(var_pats)]
1787+
if pf is None:
1788+
return powers
1789+
if pf.is_symbol():
1790+
for i, pat in enumerate(var_pats):
1791+
if match(pf, pat, evaluation):
1792+
powers[i] = Integer(1)
1793+
return powers
1794+
if pf.has_form("Sqrt", 1):
1795+
for i, pat in enumerate(var_pats):
1796+
if match(pf._leaves[0], pat, evaluation):
1797+
powers[i] = RationalOneHalf
1798+
return powers
1799+
if pf.has_form("Power", 2):
1800+
for i, pat in enumerate(var_pats):
1801+
matchval = match(pf._leaves[0], pat, evaluation)
1802+
if matchval:
1803+
powers[i] = pf._leaves[1]
1804+
return powers
1805+
if pf.has_form("Times", None):
1806+
contrib = [powers_list(factor) for factor in pf._leaves]
1807+
for i in range(len(var_pats)):
1808+
powers[i] = Expression("Plus", *[c[i] for c in contrib]).evaluate(
1809+
evaluation
1810+
)
1811+
return powers
1812+
return powers
1813+
1814+
def split_coeff_pow(term: Expression):
1815+
"""
1816+
This function factorizes term in a coefficent free
1817+
of powers of the target variables, and a factor with
1818+
that powers.
1819+
"""
1820+
coeffs = []
1821+
powers = []
1822+
# First, split factors on those which are powers of the variables
1823+
# and the rest.
1824+
if term.is_free(target_pat, evaluation):
1825+
coeffs.append(term)
1826+
elif (
1827+
term.is_symbol()
1828+
or term.has_form("Power", 2)
1829+
or term.has_form("Sqrt", 1)
1830+
):
1831+
powers.append(term)
1832+
elif term.has_form("Times", None):
1833+
for factor in term.leaves:
1834+
if factor.is_free(target_pat, evaluation):
1835+
coeffs.append(factor)
1836+
elif match(factor, target_pat, evaluation):
1837+
powers.append(factor)
1838+
elif (
1839+
factor.has_form("Power", 2) or factor.has_form("Sqrt", 1)
1840+
) and match(factor._leaves[0], target_pat, evaluation):
1841+
powers.append(factor)
1842+
else:
1843+
coeffs.append(factor)
1844+
else:
1845+
coeffs.append(term)
1846+
# Now, rebuild both factors
1847+
if len(coeffs) == 0:
1848+
coeffs = None
1849+
elif len(coeffs) == 1:
1850+
coeffs = coeffs[0]
1851+
else:
1852+
coeffs = Expression("Times", *coeffs)
1853+
if len(powers) == 0:
1854+
powers = None
1855+
elif len(powers) == 1:
1856+
powers = powers[0]
1857+
else:
1858+
powers = Expression("Times", *sorted(powers))
1859+
return coeffs, powers
1860+
1861+
if expr.is_free(target_pat, evaluation):
1862+
if filt:
1863+
return Expression(filt, expr).evaluate(evaluation)
1864+
else:
1865+
return expr
1866+
elif expr.is_symbol() or expr.has_form("Power", 2) or expr.has_form("Sqrt", 1):
1867+
if filt:
1868+
return Expression(
1869+
"Times", Expression(filt, Integer1).evaluate(evaluation), expr
1870+
)
1871+
else:
1872+
return expr
1873+
elif expr.has_form("Plus", None):
1874+
coeff_dict = {}
1875+
powers_dict = {}
1876+
powers_order = {}
1877+
for term in expr._leaves:
1878+
coeff, powers = split_coeff_pow(term)
1879+
pl = powers_list(powers)
1880+
key = str(pl)
1881+
if not key in powers_dict:
1882+
powers_dict[key] = powers
1883+
coeff_dict[key] = []
1884+
powers_order[key] = key_powers(pl)
1885+
1886+
coeff_dict[key].append(Integer1 if coeff is None else coeff)
1887+
1888+
terms = []
1889+
for key in sorted(
1890+
coeff_dict, key=lambda kv: powers_order[kv], reverse=False
1891+
):
1892+
val = coeff_dict[key]
1893+
if len(val) == 0:
1894+
continue
1895+
elif len(val) == 1:
1896+
coeff = val[0]
1897+
else:
1898+
coeff = Expression("Plus", *val)
1899+
if filt:
1900+
coeff = Expression(filt, coeff).evaluate(evaluation)
1901+
1902+
powerfactor = powers_dict[key]
1903+
if powerfactor:
1904+
terms.append(Expression("Times", coeff, powerfactor))
1905+
else:
1906+
terms.append(coeff)
1907+
1908+
return Expression("Plus", *terms)
1909+
else:
1910+
if filt:
1911+
return Expression(filt, expr).evaluate(evaluation)
1912+
else:
1913+
return expr

mathics/builtin/numbers/calculus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ class Integrate(SympyFunction):
515515
= f[b] - f[a]
516516
>> Integrate[x/Exp[x^2/t], {x, 0, Infinity}]
517517
= ConditionalExpression[t / 2, Abs[Arg[t]] < Pi / 2]
518-
# This should work after merging the more sophisticated predicate_evaluation routine
518+
# This should work after merging the more sophisticated predicate_evaluation routine
519519
# be merged...
520520
# >> Assuming[Abs[Arg[t]] < Pi / 2, Integrate[x/Exp[x^2/t], {x, 0, Infinity}]]
521521
# = t / 2

mathics/builtin/patterns.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,10 @@ class _StopGeneratorMatchQ(StopGenerator):
630630

631631
class Matcher(object):
632632
def __init__(self, form):
633-
self.form = Pattern.create(form)
633+
if isinstance(form, Pattern):
634+
self.form = form
635+
else:
636+
self.form = Pattern.create(form)
634637

635638
def match(self, expr, evaluation):
636639
def yield_func(vars, rest):

mathics/builtin/system.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ class VersionNumber(Predefined):
451451
"""
452452

453453
name = "$VersionNumber"
454-
value = 6.0
454+
value = 10.0
455455

456456
def evaluate(self, evaluation) -> Real:
457457
# Make this be whatever the latest Mathematica release is,

0 commit comments

Comments
 (0)