Add return to Update (issue 20920043)

19 views
Skip to first unread message

Cédric Krier

unread,
Nov 1, 2013, 3:18:45 PM11/1/13
to python-sql
Please review this at http://codereview.appspot.com/20920043/

Affected files (+49, -11 lines):
M CHANGELOG
M sql/__init__.py
M sql/tests/test_delete.py
M sql/tests/test_insert.py
M sql/tests/test_update.py


Index: CHANGELOG
===================================================================

--- a/CHANGELOG
+++ b/CHANGELOG
@@ -1,3 +1,5 @@
+* Add returning to Update
+* Add missing params for returning
* Add columns definitions to Function
* Fix AtTimeZone mapping


Index: sql/__init__.py
===================================================================

--- a/sql/__init__.py
+++ b/sql/__init__.py
@@ -470,26 +470,29 @@

@property
def params(self):
+ p = []
if isinstance(self.values, list):
- p = []
for values in self.values:
for value in values:
if isinstance(value, (Expression, Select)):
p.extend(value.params)
else:
p.append(value)
- return tuple(p)
elif isinstance(self.values, Select):
- return self.values.params
- else:
- return ()
+ p.extend(self.values.params)
+ if self.returning:
+ for exp in self.returning:
+ p.extend(exp.params)
+ return tuple(p)


class Update(Insert):
__slots__ = ('__where', '__values', 'from_')

- def __init__(self, table, columns, values, from_=None, where=None):
- super(Update, self).__init__(table, columns=columns, values=values)
+ def __init__(self, table, columns, values, from_=None, where=None,
+ returning=None):
+ super(Update, self).__init__(table, columns=columns, values=values,
+ returning=returning)
self.__where = None
self.from_ = From(from_) if from_ else None
self.where = where
@@ -534,7 +537,11 @@
where = ''
if self.where:
where = ' WHERE ' + str(self.where)
- return 'UPDATE %s SET ' % table + values + from_ + where
+ returning = ''
+ if self.returning:
+ returning = ' RETURNING ' + ', '.join(map(str,
self.returning))
+ return ('UPDATE %s SET ' % table + values + from_ + where
+ + returning)

@property
def params(self):
@@ -548,6 +555,9 @@
p.extend(self.from_.params)
if self.where:
p.extend(self.where.params)
+ if self.returning:
+ for exp in self.returning:
+ p.extend(exp.params)
return tuple(p)


@@ -607,7 +617,13 @@

@property
def params(self):
- return self.where.params if self.where else ()
+ p = []
+ if self.where:
+ p.extend(self.where.params)
+ if self.returning:
+ for exp in self.returning:
+ p.extend(exp.params)
+ return tuple(p)


class CombiningQuery(Query, FromItem, _SelectQueryMixin):
@@ -670,9 +686,9 @@
return Insert(self, columns=columns, values=values,
returning=returning)

- def update(self, columns, values, from_=None, where=None):
+ def update(self, columns, values, from_=None, where=None,
returning=None):
return Update(self, columns=columns, values=values, from_=from_,
- where=where)
+ where=where, returning=returning)

def delete(self, only=False, using=None, where=None, returning=None):
return Delete(self, only=only, using=using, where=where,

Index: sql/tests/test_delete.py
===================================================================

--- a/sql/tests/test_delete.py
+++ b/sql/tests/test_delete.py
@@ -52,3 +52,8 @@
'DELETE FROM "t1" WHERE ("c" IN ('
'SELECT "a"."c" FROM "t2" AS "a"))')
self.assertEqual(query.params, ())
+
+ def test_delete_returning(self):
+ query = self.table.delete(returning=[self.table.c])
+ self.assertEqual(str(query), 'DELETE FROM "t" RETURNING "c"')
+ self.assertEqual(query.params, ())

Index: sql/tests/test_insert.py
===================================================================

--- a/sql/tests/test_insert.py
+++ b/sql/tests/test_insert.py
@@ -70,3 +70,11 @@
self.assertEqual(str(query),
'INSERT INTO "t" ("c") VALUES (ABS(%s))')
self.assertEqual(query.params, (-1,))
+
+ def test_insert_returning(self):
+ query = self.table.insert([self.table.c1, self.table.c2],
+ [['foo', 'bar']], returning=[self.table.c1, self.table.c2])
+ self.assertEqual(str(query),
+ 'INSERT INTO "t" ("c1", "c2") VALUES (%s, %s) '
+ 'RETURNING "c1", "c2"')
+ self.assertEqual(query.params, ('foo', 'bar'))

Index: sql/tests/test_update.py
===================================================================

--- a/sql/tests/test_update.py
+++ b/sql/tests/test_update.py
@@ -61,3 +61,10 @@
'UPDATE "t1" SET "c" = ('
'SELECT "b"."c" FROM "t2" AS "b" WHERE ("b"."i" = "t1"."i"))')
self.assertEqual(query.params, ())
+
+ def test_update_returning(self):
+ query = self.table.update([self.table.c], ['foo'],
+ returning=[self.table.c])
+ self.assertEqual(str(query),
+ 'UPDATE "t" SET "c" = %s RETURNING "t"."c"')
+ self.assertEqual(query.params, ('foo',))




--
Cédric Krier - B2CK SPRL
Email/Jabber: cedric...@b2ck.com
Tel: +32 472 54 46 59
Website: http://www.b2ck.com/
Reply all
Reply to author
Forward
0 new messages