[PATCH 3/3] refactor sympy.solvers.solvers.guess_solve_strategy

1 view
Skip to first unread message

Fabian Pedregosa

unread,
May 7, 2009, 12:37:08 PM5/7/09
to sympy-...@googlegroups.com, Fabian Pedregosa
Some code was redundant and was erased. Also corrected a bug
( it wrongly treated 3**x-10 ) and created a test for that.

In the tests, I changed some ocurrences of tsolve by solve. Solve
anyway calls tsolve, and this way we ensure that solve correctly
recognizes the equation as trascendental.

Thanks to smichr for the report.
---
sympy/solvers/solvers.py | 35 ++++--------
sympy/solvers/tests/test_solvers.py | 102 ++++++++++++++++++++---------------
2 files changed, 70 insertions(+), 67 deletions(-)

diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index 061874f..2ad256c 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -38,7 +38,7 @@
# Codes for guess solve strategy
GS_POLY = 0
GS_RATIONAL = 1
-GS_POLY_CV_1 = 2 # can be converted to a polynomial equation via the change of variable y -> x**n
+GS_POLY_CV_1 = 2 # can be converted to a polynomial equation via the change of variable y -> x**a, a real
GS_POLY_CV_2 = 3 # can be converted to a polynomial equation multiplying on both sides by x**m
# for example, x + 1/x == 0. Multiplying by x yields x**2 + x == 0
GS_RATIONAL_CV_1 = 4 # can be converted to a rational equation via the change of variable y -> x**n
@@ -65,23 +65,7 @@ def guess_solve_strategy(expr, symbol):
"""
eq_type = -1
if expr.is_Add:
- items = expr.args
- for item in items:
- if item.is_Number or item.is_Symbol:
- eq_type = max(eq_type, GS_POLY)
- elif item.is_Mul:
- for arg in item.args:
- eq_type = max(guess_solve_strategy(arg, symbol), eq_type)
- elif item.is_Pow and item.base.has(symbol):
- if item.exp.is_Integer:
- if item.exp > 0:
- eq_type = max(eq_type, GS_POLY)
- else:
- eq_type = max(eq_type, GS_POLY_CV_2)
- elif item.exp.is_Rational:
- eq_type = max(eq_type, GS_POLY_CV_1)
- elif item.is_Function:
- return GS_TRASCENDENTAL
+ return max([guess_solve_strategy(i, symbol) for i in expr.args])

elif expr.is_Mul:
# check for rational functions
@@ -96,7 +80,7 @@ def guess_solve_strategy(expr, symbol):
else:
raise NotImplementedError
else:
- return max(map(guess_solve_strategy, expr.args, [symbol]*len(expr.args)))
+ return max([guess_solve_strategy(i, symbol) for i in expr.args])

elif expr.is_Symbol:
return GS_POLY
@@ -104,11 +88,15 @@ def guess_solve_strategy(expr, symbol):
elif expr.is_Pow:
if expr.exp.has(symbol):
return GS_TRASCENDENTAL
- elif expr.exp.is_Number and expr.base.has(symbol):
- if expr.exp.is_Integer:
+ elif not expr.exp.has(symbol) and expr.base.has(symbol):
+ if expr.exp.is_Integer and expr.exp > 0:
eq_type = max(eq_type, GS_POLY)
- else:
+ elif expr.exp.is_Integer and expr.exp < 0:
+ eq_type = max(eq_type, GS_POLY_CV_2)
+ elif expr.exp.is_Rational:
eq_type = max(eq_type, GS_POLY_CV_1)
+ else:
+ return GS_TRASCENDENTAL

elif expr.is_Function and expr.has(symbol):
return GS_TRASCENDENTAL
@@ -158,6 +146,7 @@ def solve(f, *symbols, **flags):
symbols = symbols[0]

symbols = map(sympify, symbols)
+ result = list()

if any(not s.is_Symbol for s in symbols):
raise TypeError('not a Symbol')
@@ -211,12 +200,10 @@ def solve(f, *symbols, **flags):
if guess_solve_strategy(f_, t) != GS_POLY:
raise TypeError("Could not convert to a polynomial equation: %s" % f_)
cv_sols = solve(f_, t)
- result = list()
for sol in cv_sols:
result.append(sol**m)

elif isinstance(f, Mul):
- result = []
for mul_arg in args:
result.extend(solve(mul_arg, symbol))

diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py
index 4ac368e..6e40b96 100644
--- a/sympy/solvers/tests/test_solvers.py
+++ b/sympy/solvers/tests/test_solvers.py
@@ -4,20 +4,21 @@
from sympy.solvers import solve_linear_system, solve_linear_system_LU,dsolve,\
tsolve, deriv_degree

-from sympy.solvers.solvers import guess_solve_strategy, GS_POLY, GS_POLY_CV_1, GS_TRASCENDENTAL, \
- GS_RATIONAL, GS_RATIONAL_CV_1
+from sympy.solvers.solvers import guess_solve_strategy, GS_POLY, GS_POLY_CV_1, GS_POLY_CV_2,\
+ GS_TRASCENDENTAL, GS_RATIONAL, GS_RATIONAL_CV_1

from sympy.utilities.pytest import XFAIL

-def test_guess_strategy():
+def test_guess_poly():
"""
See solvers._guess_solve_strategy
"""
- x, y = symbols('xy')
+ x, y, a = symbols('xya')

# polynomial equations
assert guess_solve_strategy( S(4), x ) == GS_POLY
assert guess_solve_strategy( x, x ) == GS_POLY
+ assert guess_solve_strategy( x + a, x ) == GS_POLY
assert guess_solve_strategy( 2*x, x ) == GS_POLY
assert guess_solve_strategy( x + sqrt(2), x) == GS_POLY
assert guess_solve_strategy( x + 2**Rational(1,4), x) == GS_POLY
@@ -27,15 +28,19 @@ def test_guess_strategy():
assert guess_solve_strategy( x*exp(y) + y, x) == GS_POLY
assert guess_solve_strategy( (x - y**3)/(y**2*(1 - y**2)**(S(1)/2)), x) == GS_POLY

+def test_guess_poly_cv():
+ x, y = symbols('xy')
# polynomial equations via a change of variable
assert guess_solve_strategy( x**Rational(1,2) + 1, x ) == GS_POLY_CV_1
assert guess_solve_strategy( x**Rational(1,3) + x**Rational(1,2) + 1, x ) == GS_POLY_CV_1
assert guess_solve_strategy( 4*x*(1 - sqrt(x)), x ) == GS_POLY_CV_1

# polynomial equation multiplying both sides by x**n
- assert guess_solve_strategy( x + 1/x + y, x )
+ assert guess_solve_strategy( x + 1/x + y, x ) == GS_POLY_CV_2

+def test_guess_rational_cv():
# rational functions
+ x, y = symbols('xy')
assert guess_solve_strategy( (x+1)/(x**2 + 2), x) == GS_RATIONAL
assert guess_solve_strategy( (x - y**3)/(y**2*(1 - y**2)**(S(1)/2)), y) == GS_RATIONAL_CV_1

@@ -43,10 +48,16 @@ def test_guess_strategy():
assert guess_solve_strategy( (x**Rational(1,2) + 1)/(x**Rational(1,3) + x**Rational(1,2) + 1), x ) \
== GS_RATIONAL_CV_1

+def test_guess_trascendental():
+ x, y, a, b = symbols('xyab')
#trascendental functions
assert guess_solve_strategy( exp(x) + 1, x ) == GS_TRASCENDENTAL
assert guess_solve_strategy( 2*cos(x)-y, x ) == GS_TRASCENDENTAL
assert guess_solve_strategy( exp(x) + exp(-x) - y, x ) == GS_TRASCENDENTAL
+ assert guess_solve_strategy(3**x-10, x) == GS_TRASCENDENTAL
+ assert guess_solve_strategy(-3**x+10, x) == GS_TRASCENDENTAL
+
+ assert guess_solve_strategy(a*x**b-y, x) == GS_TRASCENDENTAL

def test_solve_polynomial1():
x, y = map(Symbol, 'xy')
@@ -94,8 +105,6 @@ def test_solve_polynomial_cv_1a():
"""

x = Symbol('x')
-
-
assert solve( x**Rational(1,2) - 1, x) == [1]
assert solve( x**Rational(1,2) - 2, x) == [4]
assert solve( x**Rational(1,4) - 2, x) == [16]
@@ -121,10 +130,10 @@ def test_solve_polynomial_cv_2():
[ Rational(1,2) - I*sqrt(3)/2, Rational(1,2) + I*sqrt(3)/2]]

def test_solve_rational():
- x = Symbol('x')
- y = Symbol('y')
-
- solve( ( x - y**3 )/( (y**2)*sqrt(1 - y**2) ), x) == [x**Rational(1,3)]
+ """Test solve for rational functions"""
+ x, y, a, b = symbols('xyab')
+ assert solve( ( x - y**3 )/( (y**2)*sqrt(1 - y**2) ), x) == [y**3]
+ assert solve(y-b/(1+a*x), x) == [(b - y)/(a*y)]

def test_linear_system():
x, y, z, t, n = map(Symbol, 'xyztn')
@@ -190,7 +199,7 @@ def test_deriv_degree():
# Note: multiple solutions exist for some of these equations, so the tests
# should be expected to break if the implementation of the solver changes
# in such a way that a different branch is chosen
-def test_tsolve():
+def test_tsolve_1():
x = Symbol('x')
y = Symbol('y')
z = Symbol('z')
@@ -202,36 +211,43 @@ def test_tsolve():
# XXX in the following test, log(2*y + 2*...) should -> log(2) + log(y +...)
assert solve(exp(x)+exp(-x)-y,x) == [-log(4) + log(2*y + 2*(-4 + y**2)**Rational(1,2)),
-log(4) + log(2*y - 2*(-4 + y**2)**Rational(1,2))]
- assert tsolve(exp(x)-3, x) == [log(3)]
- assert tsolve(Eq(exp(x), 3), x) == [log(3)]
- assert tsolve(log(x)-3, x) == [exp(3)]
- assert tsolve(sqrt(3*x)-4, x) == [Rational(16,3)]
- assert tsolve(3**(x+2), x) == [-oo]
- assert tsolve(3**(2-x), x) == [oo]
- assert tsolve(4*3**(5*x+2)-7, x) == [(log(Rational(7,4))-2*log(3))/(5*log(3))]
- assert tsolve(x+2**x, x) == [-LambertW(log(2))/log(2)]
- assert tsolve(3*x+5+2**(-5*x+3), x) == \
- [-Rational(5,3) + LambertW(-10240*2**Rational(1,3)*log(2)/3)/(5*log(2))]
- assert tsolve(5*x-1+3*exp(2-7*x), x) == \
+ assert solve(exp(x)-3, x) == [log(3)]
+ assert solve(Eq(exp(x), 3), x) == [log(3)]
+ assert solve(log(x)-3, x) == [exp(3)]
+ assert solve(sqrt(3*x)-4, x) == [Rational(16,3)]
+ assert solve(3**(x+2), x) == [-oo]
+ assert solve(3**(2-x), x) == [oo]
+ assert solve(4*3**(5*x+2)-7, x) == [(-log(4) - 2*log(3) + log(7))/(5*log(3))]
+ assert solve(x+2**x, x) == [-LambertW(log(2))/log(2)]
+ assert solve(3*x+5+2**(-5*x+3), x) in \
+ [[-Rational(5,3) + LambertW(-10240*2**Rational(1,3)*log(2)/3)/(5*log(2))],\
+ [(-25*log(2) + 3*LambertW(-10240*2**(Rational(1, 3))*log(2)/3))/(15*log(2))]]
+ assert solve(5*x-1+3*exp(2-7*x), x) == \
[Rational(1,5) + LambertW(-21*exp(Rational(3,5))/5)/7]
- assert tsolve(2*x+5+log(3*x-2), x) == \
+ assert solve(2*x+5+log(3*x-2), x) == \
[Rational(2,3) + LambertW(2*exp(-Rational(19,3))/3)/2]
- assert tsolve(3*x+log(4*x), x) == [LambertW(Rational(3,4))/3]
- assert tsolve((2*x+8)*(8+exp(x)), x) == [-4]
- assert tsolve(2*exp(3*x+4)-3, x) == [-Rational(4,3)+log(Rational(3,2))/3]
- assert tsolve(2*log(3*x+4)-3, x) == [(exp(Rational(3,2))-4)/3]
- assert tsolve(exp(x)+1, x) == [pi*I]
- assert tsolve(x**2 - 2**x, x) == [2]
- assert tsolve(x**3 - 3**x, x) == [-3/log(3)*LambertW(-log(3)/3)]
- assert tsolve(2*(3*x+4)**5 - 6*7**(3*x+9), x) == \
- [Rational(-4,3) - 5/log(7)/3*LambertW(-7*2**Rational(4,5)*6**Rational(1,5)*log(7)/10)]
-
- assert tsolve(z*cos(x)-y, x) == [acos(y/z)]
- assert tsolve(z*cos(2*x)-y, x) == [acos(y/z)/2]
- assert tsolve(z*cos(sin(x))-y, x) == [asin(acos(y/z))]
-
- assert tsolve(z*cos(x), x) == [acos(0)]
-
- assert tsolve(exp(x)+exp(-x)-y, x)== [log(y/2 + Rational(1,2)*(-4 + y**2)**Rational(1,2)),
- log(y/2 - Rational(1,2)*(-4 + y**2)**Rational(1,2))]
-
+ assert solve(3*x+log(4*x), x) == [LambertW(Rational(3,4))/3]
+ assert solve((2*x+8)*(8+exp(x)), x) == [-4]
+ assert solve(2*exp(3*x+4)-3, x) in [ [-Rational(4,3)+log(Rational(3,2))/3],\
+ [Rational(-4, 3) - log(2)/3 + log(3)/3]]
+ assert solve(2*log(3*x+4)-3, x) == [(exp(Rational(3,2))-4)/3]
+ assert solve(exp(x)+1, x) == [pi*I]
+ assert solve(x**2 - 2**x, x) == [2]
+ assert solve(x**3 - 3**x, x) == [-3/log(3)*LambertW(-log(3)/3)]
+ assert solve(2*(3*x+4)**5 - 6*7**(3*x+9), x) in \
+ [[Rational(-4,3) - 5/log(7)/3*LambertW(-7*2**Rational(4,5)*6**Rational(1,5)*log(7)/10)],\
+ [(-5*LambertW(-7*2**(Rational(4, 5))*6**(Rational(1, 5))*log(7)/10) - 4*log(7))/(3*log(7))]]
+
+ assert solve(z*cos(x)-y, x) == [acos(y/z)]
+ assert solve(z*cos(2*x)-y, x) == [acos(y/z)/2]
+ assert solve(z*cos(sin(x))-y, x) == [asin(acos(y/z))]
+
+ assert solve(z*cos(x), x) == [acos(0)]
+
+ assert solve(exp(x)+exp(-x)-y, x)== [-log(4) + log(2*y + 2*(-4 + y**2)**(Rational(1, 2))),
+ -log(4) + log(2*y - 2*(-4 + y**2)**(Rational(1, 2)))]
+
+
+def test_tsolve_2():
+ x, y, a, b = symbols('xyab')
+ assert solve(y-a*x**b, x) == [y**(1/b)*(1/a)**(1/b)]
\ No newline at end of file
--
1.6.1.2

Vinzent Steinberg

unread,
May 7, 2009, 12:59:41 PM5/7/09
to sympy-...@googlegroups.com
Shouldn't it be TRANSCENDENTAL instead of TRASCENDENTAL?

Otherwise it looks fine to me, thank you!

Vinzent

2009/5/7 Fabian Pedregosa <fab...@fseoane.net>

Ondrej Certik

unread,
May 8, 2009, 9:43:39 AM5/8/09
to sympy-...@googlegroups.com
On Thu, May 7, 2009 at 9:59 AM, Vinzent Steinberg
<vinzent....@googlemail.com> wrote:
> Shouldn't it be TRANSCENDENTAL instead of TRASCENDENTAL?
>
> Otherwise it looks fine to me, thank you!

Looks good to me too, thanks! Fix the TRASCENDENTAL thing, it's
Spanish, isn't it? :)

Ondrej

Fabian Pedregosa

unread,
May 9, 2009, 7:15:46 AM5/9/09
to sympy-...@googlegroups.com
It is :-), and I've always wrote trascendental without noticing ... so
some others (http://plaes.org/blog/2009/4/11/some-sympy-hacking) had to
fix my spelling errors ...

Thanks for the review, I pushed in the changes

>
> Ondrej
>
> >
>

Ondrej Certik

unread,
May 9, 2009, 1:51:29 PM5/9/09
to sympy-...@googlegroups.com


Thanks.

Ondrej

Reply all
Reply to author
Forward
0 new messages