Sophie

Sophie

distrib > Mandriva > 2010.0 > i586 > media > contrib-release > by-pkgid > 429d9064b451fbf785f6785aa01bfcf8 > files > 61

python-storm-0.13-3mdv2010.0.noarch.rpm

#
# Copyright (c) 2006, 2007 Canonical
#
# Written by Gustavo Niemeyer <gustavo@niemeyer.net>
#
# This file is part of Storm Object Relational Mapper.
#
# Storm is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation; either version 2.1 of
# the License, or (at your option) any later version.
#
# Storm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import os
import gc

from storm.database import create_database
from storm.properties import Int, List

from tests.store.base import StoreTest, EmptyResultSetTest, Foo
from tests.helper import TestHelper

from tests.helper import run_this


class Lst1(object):
    __storm_table__ = "lst1"
    id = Int(primary=True)
    ints = List(type=Int())

class Lst2(object):
    __storm_table__ = "lst2"
    id = Int(primary=True)
    ints = List(type=List(type=Int()))

class FooWithSchema(Foo):
    __storm_table__ = "public.foo"


class PostgresStoreTest(TestHelper, StoreTest):

    def setUp(self):
        TestHelper.setUp(self)
        StoreTest.setUp(self)

    def tearDown(self):
        TestHelper.tearDown(self)
        StoreTest.tearDown(self)

    def is_supported(self):
        return bool(os.environ.get("STORM_POSTGRES_URI"))

    def create_database(self):
        self.database = create_database(os.environ["STORM_POSTGRES_URI"])

    def create_tables(self):
        connection = self.connection
        connection.execute("CREATE TABLE foo "
                           "(id SERIAL PRIMARY KEY,"
                           " title VARCHAR DEFAULT 'Default Title')")
        # Prevent dynamically created Foos from having conflicting ids.
        connection.execute("SELECT setval('foo_id_seq', 1000)")
        connection.execute("CREATE TABLE bar "
                           "(id SERIAL PRIMARY KEY,"
                           " foo_id INTEGER, title VARCHAR)")
        connection.execute("CREATE TABLE bin "
                           "(id SERIAL PRIMARY KEY, bin BYTEA)")
        connection.execute("CREATE TABLE link "
                           "(foo_id INTEGER, bar_id INTEGER,"
                           " PRIMARY KEY (foo_id, bar_id))")
        connection.execute("CREATE TABLE money "
                           "(id SERIAL PRIMARY KEY, value NUMERIC(6,4))")
        connection.execute("CREATE TABLE selfref "
                           "(id SERIAL PRIMARY KEY, title VARCHAR,"
                           " selfref_id INTEGER REFERENCES selfref(id))")
        connection.execute("CREATE TABLE lst1 "
                           "(id SERIAL PRIMARY KEY, ints INTEGER[])")
        connection.execute("CREATE TABLE lst2 "
                           "(id SERIAL PRIMARY KEY, ints INTEGER[][])")
        connection.execute("CREATE TABLE foovalue "
                           "(id SERIAL PRIMARY KEY, foo_id INTEGER,"
                           " value1 INTEGER, value2 INTEGER)")
        connection.commit()

    def drop_tables(self):
        StoreTest.drop_tables(self)
        for table in ["lst1", "lst2"]:
            try:
                self.connection.execute("DROP TABLE %s" % table)
                self.connection.commit()
            except:
                self.connection.rollback()

    def test_list_variable(self):

        lst = Lst1()
        lst.id = 1
        lst.ints = [1,2,3,4]

        self.store.add(lst)

        result = self.store.execute("SELECT ints FROM lst1 WHERE id=1")
        self.assertEquals(result.get_one(), ([1,2,3,4],))

        del lst
        gc.collect()

        lst = self.store.find(Lst1, Lst1.ints == [1,2,3,4]).one()
        self.assertTrue(lst)

        lst.ints.append(5)

        result = self.store.execute("SELECT ints FROM lst1 WHERE id=1")
        self.assertEquals(result.get_one(), ([1,2,3,4,5],))

    def test_list_variable_nested(self):

        lst = Lst2()
        lst.id = 1
        lst.ints = [[1, 2], [3, 4]]

        self.store.add(lst)

        result = self.store.execute("SELECT ints FROM lst2 WHERE id=1")
        self.assertEquals(result.get_one(), ([[1,2],[3,4]],))

        del lst
        gc.collect()

        lst = self.store.find(Lst2, Lst2.ints == [[1,2],[3,4]]).one()
        self.assertTrue(lst)

        lst.ints.append([5, 6])

        result = self.store.execute("SELECT ints FROM lst2 WHERE id=1")
        self.assertEquals(result.get_one(), ([[1,2],[3,4],[5,6]],))

    def test_add_find_with_schema(self):
        foo = FooWithSchema()
        foo.title = u"Title"
        self.store.add(foo)
        self.store.flush()
        # We use find() here to actually exercise the backend code.
        # get() would just pick the object from the cache.
        self.assertEquals(self.store.find(FooWithSchema, id=foo.id).one(), foo)

    def test_wb_currval_based_identity(self):
        """
        Ensure that the currval()-based identity retrieval continues
        to work, even if we're currently running on a 8.2+ database.
        """
        self.database._version = (8, 1, 9)
        foo1 = self.store.add(Foo())
        self.store.flush()
        foo2 = self.store.add(Foo())
        self.store.flush()
        self.assertEquals(foo2.id-foo1.id, 1)


class PostgresEmptyResultSetTest(TestHelper, EmptyResultSetTest):

    def setUp(self):
        TestHelper.setUp(self)
        EmptyResultSetTest.setUp(self)

    def tearDown(self):
        TestHelper.tearDown(self)
        EmptyResultSetTest.tearDown(self)

    def is_supported(self):
        return bool(os.environ.get("STORM_POSTGRES_URI"))

    def create_database(self):
        self.database = create_database(os.environ["STORM_POSTGRES_URI"])

    def create_tables(self):
        self.connection.execute("CREATE TABLE foo "
                                "(id SERIAL PRIMARY KEY,"
                                " title VARCHAR DEFAULT 'Default Title')")
        self.connection.commit()