|
13 | 13 | Integer, |
14 | 14 | Integer0, |
15 | 15 | Integer1, |
| 16 | + RationalOneHalf, |
16 | 17 | Number, |
17 | 18 | Symbol, |
18 | 19 | SymbolFalse, |
|
21 | 22 | SymbolTrue, |
22 | 23 | ) |
23 | 24 | from mathics.core.convert import from_sympy, sympy_symbol_prefix |
24 | | -from mathics.core.rules import Pattern |
| 25 | +from mathics.core.pattern import Pattern |
25 | 26 | from mathics.builtin.scoping import dynamic_scoping |
26 | 27 | from mathics.builtin.inference import evaluate_predicate |
27 | 28 |
|
@@ -213,10 +214,12 @@ def unconvert_subexprs(expr): |
213 | 214 | ) |
214 | 215 |
|
215 | 216 | sympy_expr = convert_sympy(expr) |
216 | | - |
217 | 217 | if deep: |
218 | 218 | # thread over everything |
219 | | - for (i, sub_expr,) in enumerate(sub_exprs): |
| 219 | + for ( |
| 220 | + i, |
| 221 | + sub_expr, |
| 222 | + ) in enumerate(sub_exprs): |
220 | 223 | if not sub_expr.is_atom(): |
221 | 224 | head = _expand(sub_expr.head) # also expand head |
222 | 225 | leaves = sub_expr.get_leaves() |
@@ -270,7 +273,6 @@ def unconvert_subexprs(expr): |
270 | 273 | sympy_expr = sympy_expr.expand(**hints) |
271 | 274 | result = from_sympy(sympy_expr) |
272 | 275 | result = unconvert_subexprs(result) |
273 | | - |
274 | 276 | return result |
275 | 277 |
|
276 | 278 |
|
@@ -1606,3 +1608,306 @@ def apply(self, expr, form, h, evaluation): |
1606 | 1608 | return Expression( |
1607 | 1609 | "List", *[Expression(h, *[i for i in s]) for s in exponents] |
1608 | 1610 | ) |
| 1611 | + |
| 1612 | + |
| 1613 | +class CoefficientArrays(Builtin): |
| 1614 | + """ |
| 1615 | + <dl> |
| 1616 | + <dt>'CoefficientArrays[$polys$, $vars$]' |
| 1617 | + <dd>returns a list of arrays of coefficients of the variables $vars$ in the polynomial $poly$. |
| 1618 | + </dl> |
| 1619 | + """ |
| 1620 | + |
| 1621 | + options = { |
| 1622 | + "Symmetric": "False", |
| 1623 | + } |
| 1624 | + messages = { |
| 1625 | + "poly": "`1` is not a polynomial", |
| 1626 | + } |
| 1627 | + |
| 1628 | + def apply_list(self, polys, varlist, expression, options): |
| 1629 | + "%(name)s[polys_list, varlist_, OptionsPattern[]]" |
| 1630 | + return |
| 1631 | + if polys.has_form("List", None): |
| 1632 | + polys = polys.leaves |
| 1633 | + else: |
| 1634 | + polys = [polys] |
| 1635 | + |
| 1636 | + # Expand all the polynomials before start |
| 1637 | + polys = [Expression("ExpandAll", poly).evaluate(evaluation) for poly in polys] |
| 1638 | + |
| 1639 | + if varlist.has_form("List", None): |
| 1640 | + varpat = varlist.leaves |
| 1641 | + else: |
| 1642 | + varpat = [varlist] |
| 1643 | + |
| 1644 | + degree = 0 |
| 1645 | + |
| 1646 | + def isvar(var): |
| 1647 | + # TODO: check also patterns |
| 1648 | + if term.is_atom(): |
| 1649 | + return var in varpat |
| 1650 | + # if the expression do not match, and |
| 1651 | + # is not atomic, do not decide. |
| 1652 | + return |
| 1653 | + |
| 1654 | + def term_degree(term): |
| 1655 | + degree = 0 |
| 1656 | + linear = isvar(term) |
| 1657 | + if not (linear is None): |
| 1658 | + return linear |
| 1659 | + |
| 1660 | + if term.get_head_name() == "System`Times": |
| 1661 | + for factor in term.leaves: |
| 1662 | + q = factor_degree(factor) |
| 1663 | + if factor is None: |
| 1664 | + return None |
| 1665 | + degree += q |
| 1666 | + elif term.get_head_name() == "System`Power": |
| 1667 | + return power_degree(term) |
| 1668 | + |
| 1669 | + def factor_degree(factor): |
| 1670 | + linear = isvar(factor) |
| 1671 | + if not (islinear is None): |
| 1672 | + return linear |
| 1673 | + if factor.get_head_name() == "System`Power": |
| 1674 | + return power_degree(factor) |
| 1675 | + return None |
| 1676 | + |
| 1677 | + def power_degree(factor): |
| 1678 | + if not isvar(factor.leaves[0]): |
| 1679 | + return 0 |
| 1680 | + if not isinstance(factor.leaves[1], Integer): |
| 1681 | + return None |
| 1682 | + return factor.leaves[1].get_int_value() |
| 1683 | + |
| 1684 | + for poly in polys: |
| 1685 | + if poly.is_atom(): |
| 1686 | + if degree == 0 and poly in varpat: |
| 1687 | + degree = 1 |
| 1688 | + # TODO: handle patterns |
| 1689 | + continue |
| 1690 | + elif poly.get_head_name() == "System`Plus": |
| 1691 | + for term in poly.leaves: |
| 1692 | + curr_degree = 0 |
| 1693 | + if term.get_head_name() == "System`Times": |
| 1694 | + for factor in term.leaves: |
| 1695 | + if factor.get_head_name() == "System`Power": |
| 1696 | + if isinstance(factor.leaves[1], Integer): |
| 1697 | + curr_degree = factor.leaves[1].get_int_value() |
| 1698 | + else: |
| 1699 | + evaluation.message( |
| 1700 | + "CoefficientArrays", "poly", poly |
| 1701 | + ) |
| 1702 | + |
| 1703 | + elif term.get_head_name() == "System`Power": |
| 1704 | + if term.leaves[0] in varpat: |
| 1705 | + if isinstance(term.leaves[1], Integer): |
| 1706 | + curr_degree = term.leaves[1].get_int_value() |
| 1707 | + else: |
| 1708 | + evaluation.message("CoefficientArrays", "poly", poly) |
| 1709 | + elif term in vars: |
| 1710 | + curr_degree = 1 |
| 1711 | + elif poly.get_head_name() not in ( |
| 1712 | + "System`Plus", |
| 1713 | + "System`Times", |
| 1714 | + "System`Power", |
| 1715 | + ): |
| 1716 | + evaluation.message("CoefficientArrays", "poly", poly) |
| 1717 | + return |
| 1718 | + |
| 1719 | + |
| 1720 | +class Collect(Builtin): |
| 1721 | + """ |
| 1722 | + <dl> |
| 1723 | + <dt>'Collect[$expr$, $x$]' |
| 1724 | + <dd> Expands $expr$ and collect together terms having the same power of $x$. |
| 1725 | + <dt>'Collect[$expr$, {$x_1$, $x_2$, ...}]' |
| 1726 | + <dd> Expands $expr$ and collect together terms having the same powers of |
| 1727 | + $x_1$, $x_2$, .... |
| 1728 | + <dt>'Collect[$expr$, {$x_1$, $x_2$, ...}, $filter$]' |
| 1729 | + <dd> After collect the terms, applies $filter$ to each coefficient. |
| 1730 | + </dl> |
| 1731 | +
|
| 1732 | + >> Collect[(x+y)^3, y] |
| 1733 | + = x ^ 3 + 3 x ^ 2 y + 3 x y ^ 2 + y ^ 3 |
| 1734 | + >> Collect[2 Sin[x z] (x+2 y^2 + Sin[y] x), y] |
| 1735 | + = 2 x Sin[x z] + 2 x Sin[x z] Sin[y] + 4 y ^ 2 Sin[x z] |
| 1736 | + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, y] |
| 1737 | + = 4 x Sin[x z] + x ^ 3 + y (3 x + 3 x ^ 2) + y ^ 2 (3 x + 4 Sin[x z]) + y ^ 3 |
| 1738 | + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}] |
| 1739 | + = 4 x Sin[x z] + x ^ 3 + 3 x y + 3 x ^ 2 y + 4 y ^ 2 Sin[x z] + 3 x y ^ 2 + y ^ 3 |
| 1740 | + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}, h] |
| 1741 | + = x h[4 Sin[x z]] + x ^ 3 h[1] + x y h[3] + x ^ 2 y h[3] + y ^ 2 h[4 Sin[x z]] + x y ^ 2 h[3] + y ^ 3 h[1] |
| 1742 | + """ |
| 1743 | + |
| 1744 | + rules = { |
| 1745 | + "Collect[expr_, varlst_]": "Collect[expr, varlst, Identity]", |
| 1746 | + } |
| 1747 | + |
| 1748 | + def apply_var_filter(self, expr, varlst, filt, evaluation): |
| 1749 | + """Collect[expr_, varlst_, filt_]""" |
| 1750 | + from mathics.builtin.patterns import match |
| 1751 | + |
| 1752 | + if varlst.is_symbol(): |
| 1753 | + var_exprs = [varlst] |
| 1754 | + elif varlst.has_form("List", None): |
| 1755 | + var_exprs = varlst.get_leaves() |
| 1756 | + else: |
| 1757 | + var_exprs = [varlst] |
| 1758 | + |
| 1759 | + if len(var_exprs) > 1: |
| 1760 | + target_pat = Pattern.create(Expression("Alternatives", *var_exprs)) |
| 1761 | + var_pats = [Pattern.create(var) for var in var_exprs] |
| 1762 | + else: |
| 1763 | + target_pat = Pattern.create(varlst) |
| 1764 | + var_pats = [target_pat] |
| 1765 | + |
| 1766 | + expr = expand( |
| 1767 | + expr, |
| 1768 | + numer=True, |
| 1769 | + denom=False, |
| 1770 | + deep=False, |
| 1771 | + trig=False, |
| 1772 | + modulus=None, |
| 1773 | + target_pat=target_pat, |
| 1774 | + ) |
| 1775 | + if filt == Symbol("Identity"): |
| 1776 | + filt = None |
| 1777 | + |
| 1778 | + def key_powers(lst): |
| 1779 | + key = Expression("Plus", *lst) |
| 1780 | + key = key.evaluate(evaluation) |
| 1781 | + if key.is_numeric(): |
| 1782 | + return key.to_python() |
| 1783 | + return 0 |
| 1784 | + |
| 1785 | + def powers_list(pf): |
| 1786 | + powers = [Integer0 for i, p in enumerate(var_pats)] |
| 1787 | + if pf is None: |
| 1788 | + return powers |
| 1789 | + if pf.is_symbol(): |
| 1790 | + for i, pat in enumerate(var_pats): |
| 1791 | + if match(pf, pat, evaluation): |
| 1792 | + powers[i] = Integer(1) |
| 1793 | + return powers |
| 1794 | + if pf.has_form("Sqrt", 1): |
| 1795 | + for i, pat in enumerate(var_pats): |
| 1796 | + if match(pf._leaves[0], pat, evaluation): |
| 1797 | + powers[i] = RationalOneHalf |
| 1798 | + return powers |
| 1799 | + if pf.has_form("Power", 2): |
| 1800 | + for i, pat in enumerate(var_pats): |
| 1801 | + matchval = match(pf._leaves[0], pat, evaluation) |
| 1802 | + if matchval: |
| 1803 | + powers[i] = pf._leaves[1] |
| 1804 | + return powers |
| 1805 | + if pf.has_form("Times", None): |
| 1806 | + contrib = [powers_list(factor) for factor in pf._leaves] |
| 1807 | + for i in range(len(var_pats)): |
| 1808 | + powers[i] = Expression("Plus", *[c[i] for c in contrib]).evaluate( |
| 1809 | + evaluation |
| 1810 | + ) |
| 1811 | + return powers |
| 1812 | + return powers |
| 1813 | + |
| 1814 | + def split_coeff_pow(term: Expression): |
| 1815 | + """ |
| 1816 | + This function factorizes term in a coefficent free |
| 1817 | + of powers of the target variables, and a factor with |
| 1818 | + that powers. |
| 1819 | + """ |
| 1820 | + coeffs = [] |
| 1821 | + powers = [] |
| 1822 | + # First, split factors on those which are powers of the variables |
| 1823 | + # and the rest. |
| 1824 | + if term.is_free(target_pat, evaluation): |
| 1825 | + coeffs.append(term) |
| 1826 | + elif ( |
| 1827 | + term.is_symbol() |
| 1828 | + or term.has_form("Power", 2) |
| 1829 | + or term.has_form("Sqrt", 1) |
| 1830 | + ): |
| 1831 | + powers.append(term) |
| 1832 | + elif term.has_form("Times", None): |
| 1833 | + for factor in term.leaves: |
| 1834 | + if factor.is_free(target_pat, evaluation): |
| 1835 | + coeffs.append(factor) |
| 1836 | + elif match(factor, target_pat, evaluation): |
| 1837 | + powers.append(factor) |
| 1838 | + elif ( |
| 1839 | + factor.has_form("Power", 2) or factor.has_form("Sqrt", 1) |
| 1840 | + ) and match(factor._leaves[0], target_pat, evaluation): |
| 1841 | + powers.append(factor) |
| 1842 | + else: |
| 1843 | + coeffs.append(factor) |
| 1844 | + else: |
| 1845 | + coeffs.append(term) |
| 1846 | + # Now, rebuild both factors |
| 1847 | + if len(coeffs) == 0: |
| 1848 | + coeffs = None |
| 1849 | + elif len(coeffs) == 1: |
| 1850 | + coeffs = coeffs[0] |
| 1851 | + else: |
| 1852 | + coeffs = Expression("Times", *coeffs) |
| 1853 | + if len(powers) == 0: |
| 1854 | + powers = None |
| 1855 | + elif len(powers) == 1: |
| 1856 | + powers = powers[0] |
| 1857 | + else: |
| 1858 | + powers = Expression("Times", *sorted(powers)) |
| 1859 | + return coeffs, powers |
| 1860 | + |
| 1861 | + if expr.is_free(target_pat, evaluation): |
| 1862 | + if filt: |
| 1863 | + return Expression(filt, expr).evaluate(evaluation) |
| 1864 | + else: |
| 1865 | + return expr |
| 1866 | + elif expr.is_symbol() or expr.has_form("Power", 2) or expr.has_form("Sqrt", 1): |
| 1867 | + if filt: |
| 1868 | + return Expression( |
| 1869 | + "Times", Expression(filt, Integer1).evaluate(evaluation), expr |
| 1870 | + ) |
| 1871 | + else: |
| 1872 | + return expr |
| 1873 | + elif expr.has_form("Plus", None): |
| 1874 | + coeff_dict = {} |
| 1875 | + powers_dict = {} |
| 1876 | + powers_order = {} |
| 1877 | + for term in expr._leaves: |
| 1878 | + coeff, powers = split_coeff_pow(term) |
| 1879 | + pl = powers_list(powers) |
| 1880 | + key = str(pl) |
| 1881 | + if not key in powers_dict: |
| 1882 | + powers_dict[key] = powers |
| 1883 | + coeff_dict[key] = [] |
| 1884 | + powers_order[key] = key_powers(pl) |
| 1885 | + |
| 1886 | + coeff_dict[key].append(Integer1 if coeff is None else coeff) |
| 1887 | + |
| 1888 | + terms = [] |
| 1889 | + for key in sorted( |
| 1890 | + coeff_dict, key=lambda kv: powers_order[kv], reverse=False |
| 1891 | + ): |
| 1892 | + val = coeff_dict[key] |
| 1893 | + if len(val) == 0: |
| 1894 | + continue |
| 1895 | + elif len(val) == 1: |
| 1896 | + coeff = val[0] |
| 1897 | + else: |
| 1898 | + coeff = Expression("Plus", *val) |
| 1899 | + if filt: |
| 1900 | + coeff = Expression(filt, coeff).evaluate(evaluation) |
| 1901 | + |
| 1902 | + powerfactor = powers_dict[key] |
| 1903 | + if powerfactor: |
| 1904 | + terms.append(Expression("Times", coeff, powerfactor)) |
| 1905 | + else: |
| 1906 | + terms.append(coeff) |
| 1907 | + |
| 1908 | + return Expression("Plus", *terms) |
| 1909 | + else: |
| 1910 | + if filt: |
| 1911 | + return Expression(filt, expr).evaluate(evaluation) |
| 1912 | + else: |
| 1913 | + return expr |
0 commit comments