[PATCH 1/3] guess_solve_strategy() improved to handle numbers as well

3 views
Skip to first unread message

Ondrej Certik

unread,
Mar 29, 2009, 3:50:28 PM3/29/09
to sympy-...@googlegroups.com, Ondrej Certik
Previously guess_solve_strategy( 4, x ) return -1 (=can't guess), but it should
return GS_POLY, as this is a very simple polynomial equation. This was fixed
and a test written.

Signed-off-by: Ondrej Certik <ond...@certik.cz>
---
sympy/solvers/solvers.py | 3 +++
sympy/solvers/tests/test_solvers.py | 1 +
2 files changed, 4 insertions(+), 0 deletions(-)

diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index 82fe034..c00d372 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -110,6 +110,9 @@ def guess_solve_strategy(expr, symbol):
elif expr.is_Function and expr.has(symbol):
return GS_TRASCENDENTAL

+ elif not expr.has(symbol):
+ return GS_POLY
+
return eq_type

def solve(f, *symbols, **flags):
diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py
index 026b647..6fab8b3 100644
--- a/sympy/solvers/tests/test_solvers.py
+++ b/sympy/solvers/tests/test_solvers.py
@@ -16,6 +16,7 @@ def test_guess_strategy():
x, y = symbols('xy')

# polynomial equations
+ assert guess_solve_strategy( S(4), x ) == GS_POLY
assert guess_solve_strategy( x, x ) == GS_POLY
assert guess_solve_strategy( 2*x, x ) == GS_POLY
assert guess_solve_strategy( x + sqrt(2), x) == GS_POLY
--
1.6.2

Ondrej Certik

unread,
Mar 29, 2009, 3:50:30 PM3/29/09
to sympy-...@googlegroups.com, Ondrej Certik
it used to return [sqrt(2)], but the answer is [4]. There was a bug in the
solver, that noone has noticed so far: it did everything correctly, only at the
very end it should power the result to "m", and it used to power it to "1/m",
which is wrong.

More tests were written for this.

Signed-off-by: Ondrej Certik <ond...@certik.cz>
---

sympy/solvers/solvers.py | 2 +-
sympy/solvers/tests/test_solvers.py | 10 +++++++---
2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index 57bc775..1cf6747 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -210,7 +210,7 @@ def solve(f, *symbols, **flags):
cv_sols = solve(f_, t)
result = list()
for sol in cv_sols:
- result.append(sol**(S.One/m))
+ result.append(sol**m)

elif isinstance(f, Mul):
result = []
diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py
index 5f78f8a..4ac368e 100644
--- a/sympy/solvers/tests/test_solvers.py
+++ b/sympy/solvers/tests/test_solvers.py
@@ -95,14 +95,18 @@ def test_solve_polynomial_cv_1a():

x = Symbol('x')

+
assert solve( x**Rational(1,2) - 1, x) == [1]
- assert solve( x**Rational(1,2) - 2, x) == [sqrt(2)]
+ assert solve( x**Rational(1,2) - 2, x) == [4]
+ assert solve( x**Rational(1,4) - 2, x) == [16]
+ assert solve( x**Rational(1,3) - 3, x) == [27]

def test_solve_polynomial_cv_1b():
x, a = symbols('x a')

- assert set(solve(4*x*(1 - a*x**(S(1)/2)), x)) == \
- set([S(0), (1/a)**(S(1)/2)])
+
+ assert set(solve(4*x*(1 - a*x**(S(1)/2)), x)) == set([S(0), 1/a**2])
+ assert set(solve(x * (x**(S(1)/3) - 3), x)) == set([S(0), S(27)])

def test_solve_polynomial_cv_2():
"""
--
1.6.2

Ondrej Certik

unread,
Mar 29, 2009, 3:50:29 PM3/29/09
to sympy-...@googlegroups.com, Ondrej Certik
Signed-off-by: Ondrej Certik <ond...@certik.cz>
---
sympy/solvers/solvers.py | 44 +++++++++++++++++-----------------
sympy/solvers/tests/test_solvers.py | 14 +++++++++-
2 files changed, 34 insertions(+), 24 deletions(-)

diff --git a/sympy/solvers/solvers.py b/sympy/solvers/solvers.py
index c00d372..57bc775 100644
--- a/sympy/solvers/solvers.py
+++ b/sympy/solvers/solvers.py
@@ -183,11 +183,11 @@ def solve(f, *symbols, **flags):
return solve(P, symbol, **flags)

elif strategy == GS_POLY_CV_1:
- # we must search for a suitable change of variable
- # collect exponents
- exponents_denom = list()
args = list(f.args)
if isinstance(f, Add):
+ # we must search for a suitable change of variable
+ # collect exponents
+ exponents_denom = list()
for arg in args:
if isinstance(arg, Pow):
exponents_denom.append(arg.exp.q)
@@ -195,27 +195,27 @@ def solve(f, *symbols, **flags):
for mul_arg in arg.args:
if isinstance(mul_arg, Pow):
exponents_denom.append(mul_arg.exp.q)
+ assert len(exponents_denom) > 0
+ if len(exponents_denom) == 1:
+ m = exponents_denom[0]
+ else:
+ # get the GCD of the denominators
+ m = ilcm(*exponents_denom)
+ # x -> y**m.
+ # we assume positive for simplification purposes
+ t = Symbol('t', positive=True, dummy=True)
+ f_ = f.subs(symbol, t**m)
+ if guess_solve_strategy(f_, t) != GS_POLY:
+ raise TypeError("Could not convert to a polynomial equation: %s" % f_)
+ cv_sols = solve(f_, t)
+ result = list()
+ for sol in cv_sols:
+ result.append(sol**(S.One/m))
+
elif isinstance(f, Mul):
+ result = []
for mul_arg in args:
- if isinstance(mul_arg, Pow):
- exponents_denom.append(mul_arg.exp.q)
-
- assert len(exponents_denom) > 0
- if len(exponents_denom) == 1:
- m = exponents_denom[0]
- else:
- # get the GCD of the denominators
- m = ilcm(*exponents_denom)
- # x -> y**m.
- # we assume positive for simplification purposes
- t = Symbol('t', positive=True, dummy=True)
- f_ = f.subs(symbol, t**m)
- if guess_solve_strategy(f_, t) != GS_POLY:
- raise TypeError("Could not convert to a polynomial equation: %s" % f_)
- cv_sols = solve(f_, t)
- result = list()
- for sol in cv_sols:
- result.append(sol**(S.One/m))
+ result.extend(solve(mul_arg, symbol))

elif strategy == GS_POLY_CV_2:
m = 0
diff --git a/sympy/solvers/tests/test_solvers.py b/sympy/solvers/tests/test_solvers.py
index 6fab8b3..5f78f8a 100644
--- a/sympy/solvers/tests/test_solvers.py
+++ b/sympy/solvers/tests/test_solvers.py
@@ -48,7 +48,7 @@ def test_guess_strategy():
assert guess_solve_strategy( 2*cos(x)-y, x ) == GS_TRASCENDENTAL
assert guess_solve_strategy( exp(x) + exp(-x) - y, x ) == GS_TRASCENDENTAL

-def test_solve_polynomial():
+def test_solve_polynomial1():
x, y = map(Symbol, 'xy')

assert solve(3*x-2, x) == [Rational(2,3)]
@@ -83,7 +83,11 @@ def test_solve_polynomial():
raises(TypeError, "solve(x**2-pi, pi)")
raises(ValueError, "solve(x**2-pi)")

-def test_solve_polynomial_cv_1():
+def test_solve_polynomial2():
+ x = Symbol('x')
+ assert solve(4, x) == []
+
+def test_solve_polynomial_cv_1a():
"""
Test for solving on equations that can be converted to a polynomial equation
using the change of variable y -> x**Rational(p, q)
@@ -94,6 +98,12 @@ def test_solve_polynomial_cv_1():


assert solve( x**Rational(1,2) - 1, x) == [1]

assert solve( x**Rational(1,2) - 2, x) == [sqrt(2)]

+def test_solve_polynomial_cv_1b():
+ x, a = symbols('x a')
+
+ assert set(solve(4*x*(1 - a*x**(S(1)/2)), x)) == \
+ set([S(0), (1/a)**(S(1)/2)])
+
def test_solve_polynomial_cv_2():
"""
Test for solving on equations that can be converted to a polynomial equation
--
1.6.2

Fabian Seoane

unread,
Mar 29, 2009, 8:03:00 PM3/29/09
to sympy-...@googlegroups.com
Ondrej Certik wrote:
> it used to return [sqrt(2)], but the answer is [4]. There was a bug in the
> solver, that noone has noticed so far: it did everything correctly, only at the
> very end it should power the result to "m", and it used to power it to "1/m",
> which is wrong.

all patches are +1. This bug was probably introduced when i refactored
the solver module ...

Ondrej Certik

unread,
Mar 29, 2009, 8:12:26 PM3/29/09
to sympy-...@googlegroups.com
On Sun, Mar 29, 2009 at 5:03 PM, Fabian Seoane <fab...@fseoane.net> wrote:
>
> Ondrej Certik wrote:
>> it used to return [sqrt(2)], but the answer is [4]. There was a bug in the
>> solver, that noone has noticed so far: it did everything correctly, only at the
>> very end it should power the result to "m", and it used to power it to "1/m",
>> which is wrong.
>
> all patches are +1. This bug was probably introduced when i refactored
> the solver module ...

Yes, I think so. Thanks for review, all patches in.

Ondrej

Reply all
Reply to author
Forward
0 new messages